"""
Functions for detrending xarray data.
"""
import numpy as np
import xarray as xr
import scipy.signal as sps
import scipy.linalg as spl
[docs]def detrend(da, dim, detrend_type="constant"):
"""
Detrend a DataArray
Parameters
----------
da : xarray.DataArray
The data to detrend
dim : str or list
Dimensions along which to apply detrend.
Can be either one dimension or a list with two dimensions.
Higher-dimensional detrending is not supported.
If dask data are passed, the data must be chunked along dim.
detrend_type : {'constant', 'linear'}
If ``constant``, a constant offset will be removed from each dim.
If ``linear``, a linear least-squares fit will be estimated and removed
from the data.
Returns
-------
da : xarray.DataArray
The detrended data.
Notes
-----
This function will act lazily in the presence of dask arrays on the
input.
"""
if dim is None:
dim = list(da.dims)
else:
if isinstance(dim, str):
dim = [dim]
if detrend_type not in ["constant", "linear", None]:
raise NotImplementedError(
"%s is not a valid detrending option. Valid "
"options are: 'constant','linear', or None." % detrend_type
)
if detrend_type is None:
return da
elif detrend_type == "constant":
return da - da.mean(dim=dim)
elif detrend_type == "linear":
data = da.data
axis_num = [da.get_axis_num(d) for d in dim]
chunks = getattr(data, "chunks", None)
if chunks:
axis_chunks = [data.chunks[a] for a in axis_num]
if not all([len(ac) == 1 for ac in axis_chunks]):
raise ValueError("Contiguous chunks required for detrending.")
if len(dim) == 1:
dt = xr.apply_ufunc(
sps.detrend,
da,
axis_num[0],
output_dtypes=[da.dtype],
dask="parallelized",
)
elif len(dim) == 2:
dt = xr.apply_ufunc(
_detrend_2d_ufunc,
da,
input_core_dims=[dim],
output_core_dims=[dim],
output_dtypes=[da.dtype],
vectorize=True,
dask="parallelized",
)
else: # pragma: no cover
raise NotImplementedError(
"Only 1D and 2D detrending are implemented so far."
)
return dt
def _detrend_2d_ufunc(arr):
assert arr.ndim == 2
N = arr.shape
col0 = np.ones(N[0] * N[1])
col1 = np.repeat(np.arange(N[0]), N[1]) + 1
col2 = np.tile(np.arange(N[1]), N[0]) + 1
G = np.stack([col0, col1, col2]).transpose()
d_obs = np.reshape(arr, (N[0] * N[1], 1))
m_est = np.dot(np.dot(spl.inv(np.dot(G.T, G)), G.T), d_obs)
d_est = np.dot(G, m_est)
linear_fit = np.reshape(d_est, N)
return arr - linear_fit