File: extension_array.py

package info (click to toggle)
python-xarray 2025.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 12,160 kB
  • sloc: python: 118,690; makefile: 269
file content (324 lines) | stat: -rw-r--r-- 12,057 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from __future__ import annotations

import copy
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, cast

import numpy as np
import pandas as pd
from packaging.version import Version
from pandas.api.extensions import ExtensionArray, ExtensionDtype
from pandas.api.types import is_scalar as pd_is_scalar

from xarray.core.types import DTypeLikeSave, T_ExtensionArray
from xarray.core.utils import (
    NDArrayMixin,
    is_allowed_extension_array,
    is_allowed_extension_array_dtype,
)

HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}


if TYPE_CHECKING:
    from typing import Any

    from pandas._typing import DtypeObj, Scalar


def is_scalar(value: object) -> bool:
    """Workaround: pandas is_scalar doesn't recognize Categorical nulls for some reason."""
    return value is pd.CategoricalDtype.na_value or pd_is_scalar(value)


def implements(numpy_function_or_name: Callable | str) -> Callable:
    """Register an __array_function__ implementation.

    Pass a function directly if it's guaranteed to exist in all supported numpy versions, or a
    string to first check for its existence.
    """

    def decorator(func):
        if isinstance(numpy_function_or_name, str):
            numpy_function = getattr(np, numpy_function_or_name, None)
        else:
            numpy_function = numpy_function_or_name

        if numpy_function:
            HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func
        return func

    return decorator


@implements(np.issubdtype)
def __extension_duck_array__issubdtype(
    extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave
) -> bool:
    return False  # never want a function to think a pandas extension dtype is a subtype of numpy


@implements("astype")  # np.astype was added in 2.1.0, but we only require >=1.24
def __extension_duck_array__astype(
    array_or_scalar: T_ExtensionArray,
    dtype: DTypeLikeSave,
    order: str = "K",
    casting: str = "unsafe",
    subok: bool = True,
    copy: bool = True,
    device: str | None = None,
) -> ExtensionArray:
    if (
        not (
            is_allowed_extension_array(array_or_scalar)
            or is_allowed_extension_array_dtype(dtype)
        )
        or casting != "unsafe"
        or not subok
        or order != "K"
    ):
        return NotImplemented

    return as_extension_array(array_or_scalar, dtype, copy=copy)


@implements(np.asarray)
def __extension_duck_array__asarray(
    array_or_scalar: np.typing.ArrayLike | T_ExtensionArray,
    dtype: DTypeLikeSave | None = None,
) -> ExtensionArray:
    if not is_allowed_extension_array(dtype):
        return NotImplemented

    return as_extension_array(array_or_scalar, dtype)


def as_extension_array(
    array_or_scalar: np.typing.ArrayLike | T_ExtensionArray,
    dtype: ExtensionDtype | DTypeLikeSave | None,
    copy: bool = False,
) -> ExtensionArray:
    if is_scalar(array_or_scalar):
        return dtype.construct_array_type()._from_sequence(  # type: ignore[union-attr]
            [array_or_scalar], dtype=dtype
        )
    else:
        return array_or_scalar.astype(dtype, copy=copy)  # type: ignore[union-attr]


@implements(np.result_type)
def __extension_duck_array__result_type(
    *arrays_and_dtypes: list[
        np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype | ExtensionArray
    ],
) -> DtypeObj:
    extension_arrays_and_dtypes: list[ExtensionDtype | ExtensionArray] = [
        cast(ExtensionDtype | ExtensionArray, x)
        for x in arrays_and_dtypes
        if is_allowed_extension_array(x) or is_allowed_extension_array_dtype(x)
    ]
    if not extension_arrays_and_dtypes:
        return NotImplemented

    ea_dtypes: list[ExtensionDtype] = [
        getattr(x, "dtype", cast(ExtensionDtype, x))
        for x in extension_arrays_and_dtypes
    ]
    scalars = [
        x for x in arrays_and_dtypes if is_scalar(x) and x not in {pd.NA, np.nan}
    ]
    # other_stuff could include:
    # - arrays such as pd.ABCSeries, np.ndarray, or other array-api duck arrays
    # - dtypes such as pd.DtypeObj, np.dtype, or other array-api duck dtypes
    other_stuff = [
        x
        for x in arrays_and_dtypes
        if not is_allowed_extension_array_dtype(x) and not is_scalar(x)
    ]
    # We implement one special case: when possible, preserve Categoricals (avoid promoting
    # to object) by merging the categories of all given Categoricals + scalars + NA.
    # Ideally this could be upstreamed into pandas find_result_type / find_common_type.
    if not other_stuff and all(
        isinstance(x, pd.CategoricalDtype) and not x.ordered for x in ea_dtypes
    ):
        return union_unordered_categorical_and_scalar(
            cast(list[pd.CategoricalDtype], ea_dtypes),
            scalars,  # type: ignore[arg-type]
        )
    if not other_stuff and all(
        isinstance(x, type(ea_type := ea_dtypes[0])) for x in ea_dtypes
    ):
        return ea_type
    raise ValueError(
        f"Cannot cast values to shared type, found values: {arrays_and_dtypes}"
    )


def union_unordered_categorical_and_scalar(
    categorical_dtypes: list[pd.CategoricalDtype], scalars: list[Scalar]
) -> pd.CategoricalDtype:
    scalars = [x for x in scalars if x is not pd.CategoricalDtype.na_value]
    all_categories = set().union(*(x.categories for x in categorical_dtypes))
    all_categories = all_categories.union(scalars)
    return pd.CategoricalDtype(categories=list(all_categories))


@implements(np.broadcast_to)
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
    if shape[0] == len(arr) and len(shape) == 1:
        return arr
    raise NotImplementedError("Cannot broadcast 1d-only pandas extension array.")


@implements(np.stack)
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
    raise NotImplementedError("Cannot stack 1d-only pandas extension array.")


@implements(np.concatenate)
def __extension_duck_array__concatenate(
    arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
) -> T_ExtensionArray:
    return type(arrays[0])._concat_same_type(arrays)  # type: ignore[attr-defined]


@implements(np.where)
def __extension_duck_array__where(
    condition: T_ExtensionArray | np.typing.ArrayLike,
    x: T_ExtensionArray,
    y: T_ExtensionArray | np.typing.ArrayLike,
) -> T_ExtensionArray:
    # pd.where won't broadcast 0-dim arrays across a scalar-like series; scalar y's must be preserved
    if hasattr(y, "shape") and len(y.shape) == 1 and y.shape[0] == 1:
        y = y[0]  # type: ignore[index]
    return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array)  # type: ignore[arg-type]


def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list:
    args_as_list = list(args)
    for index, value in enumerate(args_as_list):
        if isinstance(value, PandasExtensionArray):
            args_as_list[index] = replacer(value)
        elif isinstance(value, tuple):  # should handle more than just tuple? iterable?
            args_as_list[index] = tuple(_replace_duck(value, replacer))
        elif isinstance(value, list):
            args_as_list[index] = _replace_duck(value, replacer)
    return args_as_list


def replace_duck_with_extension_array(args) -> tuple:
    return tuple(_replace_duck(args, lambda duck: duck.array))


def replace_duck_with_series(args) -> tuple:
    return tuple(_replace_duck(args, lambda duck: pd.Series(duck.array)))


@implements(np.ndim)
def __extension_duck_array__ndim(x: PandasExtensionArray) -> int:
    return x.ndim


@implements(np.reshape)
def __extension_duck_array__reshape(
    arr: T_ExtensionArray, shape: tuple
) -> T_ExtensionArray:
    if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,):
        return arr
    raise NotImplementedError(
        f"Cannot reshape 1d-only pandas extension array to: {shape}"
    )


@dataclass(frozen=True)
class PandasExtensionArray(NDArrayMixin, Generic[T_ExtensionArray]):
    """NEP-18 compliant wrapper for pandas extension arrays.

    Parameters
    ----------
    array : T_ExtensionArray
        The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation.
    ```
    """

    array: T_ExtensionArray

    def __post_init__(self):
        if not isinstance(self.array, pd.api.extensions.ExtensionArray):
            raise TypeError(f"{self.array} is not a pandas ExtensionArray.")
        # This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
        # we do support extension arrays from datetime, for example, that need
        # duck array support internally via this class.  These can appear from `DatetimeIndex`
        # wrapped by `PandasIndex` internally, for example.
        if not is_allowed_extension_array(self.array):
            raise TypeError(
                f"{self.array.dtype!r} should be converted to a numpy array in `xarray` internally."
            )

    def __array_function__(self, func, types, args, kwargs):
        if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
            raise KeyError("Function not registered for pandas extension arrays.")
        args = replace_duck_with_extension_array(args)
        res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
        if isinstance(res, ExtensionArray):
            return PandasExtensionArray(res)
        return res

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        return ufunc(*inputs, **kwargs)

    def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
        if (
            isinstance(key, tuple) and len(key) == 1
        ):  # pyarrow type arrays can't handle single-length tuples
            (key,) = key
        item = self.array[key]
        if is_allowed_extension_array(item):
            return PandasExtensionArray(item)
        if is_scalar(item) or isinstance(key, int):
            return PandasExtensionArray(type(self.array)._from_sequence([item]))  # type: ignore[call-arg,attr-defined,unused-ignore]
        return PandasExtensionArray(item)

    def __setitem__(self, key, val):
        self.array[key] = val

    def __len__(self):
        return len(self.array)

    def __eq__(self, other):
        if isinstance(other, PandasExtensionArray):
            return self.array == other.array
        return self.array == other

    def __ne__(self, other):
        return ~(self == other)

    @property
    def ndim(self) -> int:
        return 1

    def __array__(
        self, dtype: np.typing.DTypeLike | None = None, /, *, copy: bool | None = None
    ) -> np.ndarray:
        if Version(np.__version__) >= Version("2.0.0"):
            return np.asarray(self.array, dtype=dtype, copy=copy)
        else:
            return np.asarray(self.array, dtype=dtype)

    def __getattr__(self, attr: str) -> Any:
        #  with __deepcopy__ or __copy__, the object is first constructed and then the sub-objects are attached (see https://docs.python.org/3/library/copy.html)
        # Thus, if we didn't have `super().__getattribute__("array")` this method would call `self.array` (i.e., `getattr(self, "array")`) again while looking for `__setstate__`
        # (which is apparently the first thing sought in copy.copy from the under-construction copied object),
        # which would cause a recursion error since `array` is not present on the object when it is being constructed during `__{deep}copy__`.
        # Even though we have defined these two methods now below due to `test_extension_array_copy_arrow_type` (cause unknown)
        # we leave this here as it more robust than self.array
        return getattr(super().__getattribute__("array"), attr)

    def __copy__(self) -> PandasExtensionArray[T_ExtensionArray]:
        return PandasExtensionArray(copy.copy(self.array))

    def __deepcopy__(
        self, memo: dict[int, Any] | None = None
    ) -> PandasExtensionArray[T_ExtensionArray]:
        return PandasExtensionArray(copy.deepcopy(self.array, memo=memo))