File: arrays.py

package info (click to toggle)
python-xarray 2025.08.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 11,796 kB
  • sloc: python: 115,416; makefile: 258; sh: 47
file content (218 lines) | stat: -rw-r--r-- 7,143 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access.
"""

from collections.abc import Callable, Iterable
from typing import Any, Self

import numpy as np

from xarray.core import utils
from xarray.core.indexing import ExplicitlyIndexed


class UnexpectedDataAccess(Exception):
    pass


class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed):
    """Disallows any loading."""

    def __init__(self, array):
        self.array = array

    def get_duck_array(self):
        raise UnexpectedDataAccess("Tried accessing data")

    def __array__(
        self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
    ) -> np.ndarray:
        raise UnexpectedDataAccess("Tried accessing data")

    def __getitem__(self, key):
        raise UnexpectedDataAccess("Tried accessing data.")


class FirstElementAccessibleArray(InaccessibleArray):
    def __getitem__(self, key):
        tuple_idxr = key.tuple
        if len(tuple_idxr) > 1:
            raise UnexpectedDataAccess("Tried accessing more than one element.")
        return self.array[tuple_idxr]


class DuckArrayWrapper(utils.NDArrayMixin):
    """Array-like that prevents casting to array.
    Modeled after cupy."""

    def __init__(self, array: np.ndarray):
        self.array = array

    def __getitem__(self, key):
        return type(self)(self.array[key])

    def to_numpy(self) -> np.ndarray:
        """Allow explicit conversions to numpy in `to_numpy`, but disallow np.asarray etc."""
        return self.array

    def __array__(
        self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
    ) -> np.ndarray:
        raise UnexpectedDataAccess("Tried accessing data")

    def __array_namespace__(self):
        """Present to satisfy is_duck_array test."""
        from xarray.tests import namespace

        return namespace


CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {}


def implements(numpy_function):
    """Register an __array_function__ implementation for ConcatenatableArray objects."""

    def decorator(func):
        CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func
        return func

    return decorator


@implements(np.concatenate)
def concatenate(
    arrays: Iterable["ConcatenatableArray"], /, *, axis=0
) -> "ConcatenatableArray":
    if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
        raise TypeError

    result = np.concatenate([arr._array for arr in arrays], axis=axis)
    return ConcatenatableArray(result)


@implements(np.stack)
def stack(
    arrays: Iterable["ConcatenatableArray"], /, *, axis=0
) -> "ConcatenatableArray":
    if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
        raise TypeError

    result = np.stack([arr._array for arr in arrays], axis=axis)
    return ConcatenatableArray(result)


@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
    """Called by xarray to ensure all arguments to concat have the same dtype."""
    first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes)
    for other_dtype in other_dtypes:
        if other_dtype != first_dtype:
            raise ValueError("dtypes not all consistent")
    return first_dtype


@implements(np.broadcast_to)
def broadcast_to(
    x: "ConcatenatableArray", /, shape: tuple[int, ...]
) -> "ConcatenatableArray":
    """
    Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
    """
    if not isinstance(x, ConcatenatableArray):
        raise TypeError

    result = np.broadcast_to(x._array, shape=shape)
    return ConcatenatableArray(result)


@implements(np.full_like)
def full_like(
    x: "ConcatenatableArray", /, fill_value, **kwargs
) -> "ConcatenatableArray":
    """
    Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
    """
    if not isinstance(x, ConcatenatableArray):
        raise TypeError
    return ConcatenatableArray(np.full(x.shape, fill_value=fill_value, **kwargs))


@implements(np.all)
def numpy_all(x: "ConcatenatableArray", **kwargs) -> "ConcatenatableArray":
    return type(x)(np.all(x._array, **kwargs))


class ConcatenatableArray:
    """Disallows loading or coercing to an index but does support concatenation / stacking."""

    def __init__(self, array):
        # use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit)
        self._array = array

    @property
    def dtype(self: Any) -> np.dtype:
        return self._array.dtype

    @property
    def shape(self: Any) -> tuple[int, ...]:
        return self._array.shape

    @property
    def ndim(self: Any) -> int:
        return self._array.ndim

    def __repr__(self: Any) -> str:
        return f"{type(self).__name__}(array={self._array!r})"

    def get_duck_array(self):
        raise UnexpectedDataAccess("Tried accessing data")

    def __array__(
        self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
    ) -> np.ndarray:
        raise UnexpectedDataAccess("Tried accessing data")

    def __getitem__(self, key) -> Self:
        """Some cases of concat require supporting expanding dims by dimensions of size 1"""
        # see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing
        arr = self._array
        for axis, indexer_1d in enumerate(key):
            if indexer_1d is None:
                arr = np.expand_dims(arr, axis)
            elif indexer_1d is Ellipsis:
                pass
            else:
                raise UnexpectedDataAccess("Tried accessing data.")
        return type(self)(arr)

    def __eq__(self, other: Self) -> Self:  # type: ignore[override]
        return type(self)(self._array == other._array)

    def __array_function__(self, func, types, args, kwargs) -> Any:
        if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS:
            return NotImplemented

        # Note: this allows subclasses that don't override
        # __array_function__ to handle ManifestArray objects
        if not all(issubclass(t, ConcatenatableArray) for t in types):
            return NotImplemented

        return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs)

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
        """We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs."""
        return NotImplemented

    def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> Self:
        """Needed because xarray will call this even when it's a no-op"""
        if dtype != self.dtype:
            raise NotImplementedError()
        else:
            return self

    def __and__(self, other: Self) -> Self:
        return type(self)(self._array & other._array)

    def __or__(self, other: Self) -> Self:
        return type(self)(self._array | other._array)