1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
|
"""
pint.facets.numpy.unit
~~~~~~~~~~~~~~~~~~~~~~
:copyright: 2022 by Pint Authors, see AUTHORS for more details.
:license: BSD, see LICENSE for more details.
"""
from __future__ import annotations
from ...compat import is_upcast_type
from ..plain import PlainUnit
class NumpyUnit(PlainUnit):
__array_priority__ = 17
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != "__call__":
# Only handle ufuncs as callables
return NotImplemented
# Check types and return NotImplemented when upcast type encountered
types = {
type(arg)
for arg in list(inputs) + list(kwargs.values())
if hasattr(arg, "__array_ufunc__")
}
if any(is_upcast_type(other) for other in types):
return NotImplemented
# Act on limited implementations by conversion to multiplicative identity
# Quantity
if ufunc.__name__ in ("true_divide", "divide", "floor_divide", "multiply"):
return ufunc(
*tuple(
self._REGISTRY.Quantity(1, self._units) if arg is self else arg
for arg in inputs
),
**kwargs,
)
return NotImplemented
|