File: arithmetic.py

package info (click to toggle)
python-xarray 0.16.2-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 6,568 kB
  • sloc: python: 60,570; makefile: 236; sh: 38
file content (105 lines) | stat: -rw-r--r-- 3,454 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Base classes implementing arithmetic for xarray objects."""
import numbers

import numpy as np

from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_array_type
from .utils import not_implemented


class SupportsArithmetic:
    """Base class for xarray types that support arithmetic.

    Used by Dataset, DataArray, Variable and GroupBy.
    """

    __slots__ = ()

    # TODO: implement special methods for arithmetic here rather than injecting
    # them in xarray/core/ops.py. Ideally, do so by inheriting from
    # numpy.lib.mixins.NDArrayOperatorsMixin.

    # TODO: allow extending this with some sort of registration system
    _HANDLED_TYPES = (
        np.ndarray,
        np.generic,
        numbers.Number,
        bytes,
        str,
    ) + dask_array_type

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        from .computation import apply_ufunc

        # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
        out = kwargs.get("out", ())
        for x in inputs + out:
            if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
                return NotImplemented

        if ufunc.signature is not None:
            raise NotImplementedError(
                "{} not supported: xarray objects do not directly implement "
                "generalized ufuncs. Instead, use xarray.apply_ufunc or "
                "explicitly convert to xarray objects to NumPy arrays "
                "(e.g., with `.values`).".format(ufunc)
            )

        if method != "__call__":
            # TODO: support other methods, e.g., reduce and accumulate.
            raise NotImplementedError(
                "{} method for ufunc {} is not implemented on xarray objects, "
                "which currently only support the __call__ method. As an "
                "alternative, consider explicitly converting xarray objects "
                "to NumPy arrays (e.g., with `.values`).".format(method, ufunc)
            )

        if any(isinstance(o, SupportsArithmetic) for o in out):
            # TODO: implement this with logic like _inplace_binary_op. This
            # will be necessary to use NDArrayOperatorsMixin.
            raise NotImplementedError(
                "xarray objects are not yet supported in the `out` argument "
                "for ufuncs. As an alternative, consider explicitly "
                "converting xarray objects to NumPy arrays (e.g., with "
                "`.values`)."
            )

        join = dataset_join = OPTIONS["arithmetic_join"]

        return apply_ufunc(
            ufunc,
            *inputs,
            input_core_dims=((),) * ufunc.nin,
            output_core_dims=((),) * ufunc.nout,
            join=join,
            dataset_join=dataset_join,
            dataset_fill_value=np.nan,
            kwargs=kwargs,
            dask="allowed",
            keep_attrs=_get_keep_attrs(default=True),
        )

    # this has no runtime function - these are listed so IDEs know these
    # methods are defined and don't warn on these operations
    __lt__ = (
        __le__
    ) = (
        __ge__
    ) = (
        __gt__
    ) = (
        __add__
    ) = (
        __sub__
    ) = (
        __mul__
    ) = (
        __truediv__
    ) = (
        __floordiv__
    ) = (
        __mod__
    ) = (
        __pow__
    ) = __and__ = __xor__ = __or__ = __div__ = __eq__ = __ne__ = not_implemented