Source code for xrft.detrend

"""
Functions for detrending xarray data.
"""

import numpy as np
import xarray as xr
import scipy.signal as sps
import scipy.linalg as spl


[docs]def detrend(da, dim, detrend_type="constant"): """ Detrend a DataArray Parameters ---------- da : xarray.DataArray The data to detrend dim : str or list Dimensions along which to apply detrend. Can be either one dimension or a list with two dimensions. Higher-dimensional detrending is not supported. If dask data are passed, the data must be chunked along dim. detrend_type : {'constant', 'linear'} If ``constant``, a constant offset will be removed from each dim. If ``linear``, a linear least-squares fit will be estimated and removed from the data. Returns ------- da : xarray.DataArray The detrended data. Notes ----- This function will act lazily in the presence of dask arrays on the input. """ if dim is None: dim = list(da.dims) else: if isinstance(dim, str): dim = [dim] if detrend_type not in ["constant", "linear", None]: raise NotImplementedError( "%s is not a valid detrending option. Valid " "options are: 'constant','linear', or None." % detrend_type ) if detrend_type is None: return da elif detrend_type == "constant": return da - da.mean(dim=dim) elif detrend_type == "linear": data = da.data axis_num = [da.get_axis_num(d) for d in dim] chunks = getattr(data, "chunks", None) if chunks: axis_chunks = [data.chunks[a] for a in axis_num] if not all([len(ac) == 1 for ac in axis_chunks]): raise ValueError("Contiguous chunks required for detrending.") if len(dim) == 1: dt = xr.apply_ufunc( sps.detrend, da, axis_num[0], output_dtypes=[da.dtype], dask="parallelized", ) elif len(dim) == 2: dt = xr.apply_ufunc( _detrend_2d_ufunc, da, input_core_dims=[dim], output_core_dims=[dim], output_dtypes=[da.dtype], vectorize=True, dask="parallelized", ) else: # pragma: no cover raise NotImplementedError( "Only 1D and 2D detrending are implemented so far." ) return dt
def _detrend_2d_ufunc(arr): assert arr.ndim == 2 N = arr.shape col0 = np.ones(N[0] * N[1]) col1 = np.repeat(np.arange(N[0]), N[1]) + 1 col2 = np.tile(np.arange(N[1]), N[0]) + 1 G = np.stack([col0, col1, col2]).transpose() d_obs = np.reshape(arr, (N[0] * N[1], 1)) m_est = np.dot(np.dot(spl.inv(np.dot(G.T, G)), G.T), d_obs) d_est = np.dot(G, m_est) linear_fit = np.reshape(d_est, N) return arr - linear_fit