1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
|
import functools
import inspect
import itertools
from inspect import Parameter
import pint
import pint.testing
import xarray as xr
from pint_xarray.accessors import get_registry
from pint_xarray.conversion import extract_units
from pint_xarray.itertools import zip_mappings
variable_parameters = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
def _number_of_results(result):
if isinstance(result, tuple):
return len(result)
elif result is None:
return 0
else:
return 1
def expects(*args_units, return_value=None, **kwargs_units):
"""
Decorator which ensures the inputs and outputs of the decorated
function are expressed in the expected units.
Arguments to the decorated function are checked for the specified
units, converting to those units if necessary, and then stripped
of their units before being passed into the undecorated
function. Therefore the undecorated function should expect
unquantified DataArrays, Datasets, or numpy-like arrays, but with
the values expressed in specific units.
Parameters
----------
func : callable
Function to decorate, which accepts zero or more
xarray.DataArrays or numpy-like arrays as inputs, and may
optionally return one or more xarray.DataArrays or numpy-like
arrays.
*args_units : unit-like or mapping of hashable to unit-like, optional
Units to expect for each positional argument given to func.
The decorator will first check that arguments passed to the
decorated function possess these specific units (or will
attempt to convert the argument to these units), then will
strip the units before passing the magnitude to the wrapped
function.
A value of None indicates not to check that argument for units
(suitable for flags and other non-data arguments).
return_value : unit-like or list of unit-like or mapping of hashable to unit-like \
or list of mapping of hashable to unit-like, optional
The expected units of the returned value(s), either as a
single unit or as a list of units. The decorator will attach
these units to the variables returned from the function.
A value of None indicates not to attach any units to that
return value (suitable for flags and other non-data results).
**kwargs_units : mapping of hashable to unit-like, optional
Unit to expect for each keyword argument given to func.
The decorator will first check that arguments passed to the decorated
function possess these specific units (or will attempt to convert the
argument to these units), then will strip the units before passing the
magnitude to the wrapped function.
A value of None indicates not to check that argument for units (suitable
for flags and other non-data arguments).
Returns
-------
return_values : Any
Return values of the wrapped function, either a single value or a tuple
of values. These will be given units according to ``return_value``.
Raises
------
TypeError
If any of the units are not a valid type.
ValueError
If the number of arguments or return values does not match the number of
units specified. Also thrown if any parameter does not have a unit
specified.
See Also
--------
pint.wraps
Examples
--------
Decorating a function which takes one quantified input, but
returns a non-data value (in this case a boolean).
>>> @expects("deg C")
... def above_freezing(temp):
... return temp > 0
...
Decorating a function which allows any dimensions for the array, but also
accepts an optional `weights` keyword argument, which must be dimensionless.
>>> @expects(None, weights="dimensionless")
... def mean(da, weights=None):
... if weights:
... return da.weighted(weights=weights).mean()
... else:
... return da.mean()
...
"""
def outer(func):
signature = inspect.signature(func)
params_units = signature.bind(*args_units, **kwargs_units)
missing_params = [
name
for name, p in signature.parameters.items()
if p.kind not in variable_parameters and name not in params_units.arguments
]
if missing_params:
raise ValueError(
"Missing units for the following parameters: "
+ ", ".join(map(repr, missing_params))
)
n_expected_results = _number_of_results(return_value)
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal return_value
params = signature.bind(*args, **kwargs)
# don't apply defaults, as those can't be quantities and thus must
# already be in the correct units
spec_units = dict(
enumerate(
itertools.chain.from_iterable(
spec.values() if isinstance(spec, dict) else (spec,)
for spec in params_units.arguments.values()
if spec is not None
)
)
)
params_units_ = dict(
enumerate(
itertools.chain.from_iterable(
(
extract_units(param)
if isinstance(param, (xr.DataArray, xr.Dataset))
else (param.units,)
)
for name, param in params.arguments.items()
if isinstance(param, (xr.DataArray, xr.Dataset, pint.Quantity))
)
)
)
ureg = get_registry(
None,
dict(spec_units) if spec_units else {},
dict(params_units_) if params_units else {},
)
errors = []
for name, (value, units) in zip_mappings(
params.arguments, params_units.arguments
):
try:
if units is None:
if isinstance(value, pint.Quantity) or (
isinstance(value, (xr.DataArray, xr.Dataset))
and value.pint.units
):
raise TypeError(
"Passed in a quantity where none was expected"
)
continue
if isinstance(value, pint.Quantity):
params.arguments[name] = value.m_as(units)
elif isinstance(value, (xr.DataArray, xr.Dataset)):
params.arguments[name] = value.pint.to(units).pint.dequantify()
else:
raise TypeError(
f"Attempting to convert non-quantity {value} to {units}."
)
except (
TypeError,
pint.errors.UndefinedUnitError,
pint.errors.DimensionalityError,
) as e:
e.add_note(
f"expects: raised while trying to convert parameter {name}"
)
errors.append(e)
if errors:
raise ExceptionGroup("Errors while converting parameters", errors)
result = func(*params.args, **params.kwargs)
n_results = _number_of_results(result)
if return_value is not None and (
(isinstance(result, tuple) ^ isinstance(return_value, tuple))
or (n_results != n_expected_results)
):
message = "mismatched number of return values:"
if n_results != n_expected_results:
message += f" expected {n_expected_results} but got {n_results}."
elif isinstance(result, tuple) and not isinstance(return_value, tuple):
message += (
" expected a single return value but got a 1-sized tuple."
)
else:
message += (
" expected a 1-sized tuple but got a single return value."
)
raise ValueError(message)
if result is None:
return
if not isinstance(result, tuple):
result = (result,)
if not isinstance(return_value, tuple):
return_value = (return_value,)
final_result = []
errors = []
for index, (value, units) in enumerate(zip(result, return_value)):
if units is not None:
try:
if isinstance(value, (xr.Dataset, xr.DataArray)):
value = value.pint.quantify(units)
else:
value = ureg.Quantity(value, units)
except Exception as e:
e.add_note(
f"expects: raised while trying to convert return value {index}"
)
errors.append(e)
final_result.append(value)
if errors:
raise ExceptionGroup("Errors while converting return values", errors)
if n_results == 1:
return final_result[0]
return tuple(final_result)
return wrapper
return outer
|