1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
|
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
from . import duck_array_ops
from .computation import dot
from .options import _get_keep_attrs
from .pycompat import is_duck_dask_array
if TYPE_CHECKING:
from .dataarray import DataArray, Dataset
_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to apply the weighted ``{fcn}``.
skipna : bool, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
keep_attrs : bool, optional
If True, the attributes (``attrs``) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
Returns
-------
reduced : {cls}
New {cls} object with weighted ``{fcn}`` applied to its data and
the indicated dimension(s) removed.
Notes
-----
Returns {on_zero} if the ``weights`` sum to 0.0 along the reduced
dimension(s).
"""
_SUM_OF_WEIGHTS_DOCSTRING = """
Calculate the sum of weights, accounting for missing values in the data
Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to sum the weights.
keep_attrs : bool, optional
If True, the attributes (``attrs``) will be copied from the original
object to the new one. If False (default), the new object will be
returned without attributes.
Returns
-------
reduced : {cls}
New {cls} object with the sum of the weights over the given dimension.
"""
class Weighted:
"""An object that implements weighted operations.
You should create a Weighted object by using the ``DataArray.weighted`` or
``Dataset.weighted`` methods.
See Also
--------
Dataset.weighted
DataArray.weighted
"""
__slots__ = ("obj", "weights")
@overload
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
...
@overload
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
...
def __init__(self, obj, weights):
"""
Create a Weighted object
Parameters
----------
obj : DataArray or Dataset
Object over which the weighted reduction operation is applied.
weights : DataArray
An array of weights associated with the values in the obj.
Each value in the obj contributes to the reduction operation
according to its associated weight.
Notes
-----
``weights`` must be a ``DataArray`` and cannot contain missing values.
Missing values can be replaced by ``weights.fillna(0)``.
"""
from .dataarray import DataArray
if not isinstance(weights, DataArray):
raise ValueError("`weights` must be a DataArray")
def _weight_check(w):
# Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
if duck_array_ops.isnull(w).any():
raise ValueError(
"`weights` cannot contain missing values. "
"Missing values can be replaced by `weights.fillna(0)`."
)
return w
if is_duck_dask_array(weights.data):
# assign to copy - else the check is not triggered
weights = weights.copy(
data=weights.data.map_blocks(_weight_check, dtype=weights.dtype)
)
else:
_weight_check(weights.data)
self.obj = obj
self.weights = weights
@staticmethod
def _reduce(
da: "DataArray",
weights: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
for internal use only
"""
# need to infer dims as we use `dot`
if dim is None:
dim = ...
# need to mask invalid values in da, as `dot` does not implement skipna
if skipna or (skipna is None and da.dtype.kind in "cfO"):
da = da.fillna(0.0)
# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
# maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
return dot(da, weights, dims=dim)
def _sum_of_weights(
self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None
) -> "DataArray":
""" Calculate the sum of weights, accounting for missing values """
# we need to mask data values that are nan; else the weights are wrong
mask = da.notnull()
# bool -> int, because ``xr.dot([True, True], [True, True])`` -> True
# (and not 2); GH4074
if self.weights.dtype == bool:
sum_of_weights = self._reduce(
mask, self.weights.astype(int), dim=dim, skipna=False
)
else:
sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False)
# 0-weights are not valid
valid_weights = sum_of_weights != 0.0
return sum_of_weights.where(valid_weights)
def _weighted_sum(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a by a weighted ``sum`` along some dimension(s)."""
return self._reduce(da, self.weights, dim=dim, skipna=skipna)
def _weighted_mean(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna)
sum_of_weights = self._sum_of_weights(da, dim=dim)
return weighted_sum / sum_of_weights
def _implementation(self, func, dim, **kwargs):
raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
def sum_of_weights(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
return self._implementation(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
)
def sum(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
return self._implementation(
self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
def mean(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
return self._implementation(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
def __repr__(self):
"""provide a nice str repr of our Weighted object"""
klass = self.__class__.__name__
weight_dims = ", ".join(self.weights.dims)
return f"{klass} with weights along dimensions: {weight_dims}"
class DataArrayWeighted(Weighted):
def _implementation(self, func, dim, **kwargs):
keep_attrs = kwargs.pop("keep_attrs")
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
weighted = func(self.obj, dim=dim, **kwargs)
if keep_attrs:
weighted.attrs = self.obj.attrs
return weighted
class DatasetWeighted(Weighted):
def _implementation(self, func, dim, **kwargs) -> "Dataset":
return self.obj.map(func, dim=dim, **kwargs)
def _inject_docstring(cls, cls_name):
cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name)
cls.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="sum", on_zero="0"
)
cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="mean", on_zero="NaN"
)
_inject_docstring(DataArrayWeighted, "DataArray")
_inject_docstring(DatasetWeighted, "Dataset")
|