File: pycompat.py

package info (click to toggle)
python-xarray 2025.08.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 11,796 kB
  • sloc: python: 115,416; makefile: 258; sh: 47
file content (161 lines) | stat: -rw-r--r-- 5,532 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from __future__ import annotations

from importlib import import_module
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
from packaging.version import Version

from xarray.core.utils import is_scalar
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array

integer_types = (int, np.integer)

if TYPE_CHECKING:
    ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
    DuckArrayTypes = tuple[type[Any], ...]  # TODO: improve this? maybe Generic
    from xarray.namedarray._typing import _DType, _ShapeType, duckarray


class DuckArrayModule:
    """
    Solely for internal isinstance and version checks.

    Motivated by having to only import pint when required (as pint currently imports xarray)
    https://github.com/pydata/xarray/pull/5561#discussion_r664815718
    """

    module: ModuleType | None
    version: Version
    type: DuckArrayTypes
    available: bool

    def __init__(self, mod: ModType) -> None:
        duck_array_module: ModuleType | None
        duck_array_version: Version
        duck_array_type: DuckArrayTypes
        try:
            duck_array_module = import_module(mod)
            duck_array_version = Version(duck_array_module.__version__)

            if mod == "dask":
                duck_array_type = (import_module("dask.array").Array,)
            elif mod == "pint":
                duck_array_type = (duck_array_module.Quantity,)
            elif mod == "cupy":
                duck_array_type = (duck_array_module.ndarray,)
            elif mod == "sparse":
                duck_array_type = (duck_array_module.SparseArray,)
            elif mod == "cubed":
                duck_array_type = (duck_array_module.Array,)
            # Not a duck array module, but using this system regardless, to get lazy imports
            elif mod == "numbagg":
                duck_array_type = ()
            else:
                raise NotImplementedError

        except (ImportError, AttributeError):  # pragma: no cover
            duck_array_module = None
            duck_array_version = Version("0.0.0")
            duck_array_type = ()

        self.module = duck_array_module
        self.version = duck_array_version
        self.type = duck_array_type
        self.available = duck_array_module is not None


_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {}


def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule:
    if mod not in _cached_duck_array_modules:
        duckmod = DuckArrayModule(mod)
        _cached_duck_array_modules[mod] = duckmod
        return duckmod
    else:
        return _cached_duck_array_modules[mod]


def array_type(mod: ModType) -> DuckArrayTypes:
    """Quick wrapper to get the array class of the module."""
    return _get_cached_duck_array_module(mod).type


def mod_version(mod: ModType) -> Version:
    """Quick wrapper to get the version of the module."""
    return _get_cached_duck_array_module(mod).version


def is_chunked_array(x: duckarray[Any, Any]) -> bool:
    return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))


def is_0d_dask_array(x: duckarray[Any, Any]) -> bool:
    return is_duck_dask_array(x) and is_scalar(x)


def to_numpy(
    data: duckarray[Any, Any], **kwargs: dict[str, Any]
) -> np.ndarray[Any, np.dtype[Any]]:
    from xarray.core.indexing import ExplicitlyIndexed
    from xarray.namedarray.parallelcompat import get_chunked_array_type

    try:
        # for tests only at the moment
        return data.to_numpy()  # type: ignore[no-any-return,union-attr]
    except AttributeError:
        pass

    if isinstance(data, ExplicitlyIndexed):
        data = data.get_duck_array()  # type: ignore[no-untyped-call]

    # TODO first attempt to call .to_numpy() once some libraries implement it
    if is_chunked_array(data):
        chunkmanager = get_chunked_array_type(data)
        data, *_ = chunkmanager.compute(data, **kwargs)
    if isinstance(data, array_type("cupy")):
        data = data.get()
    # pint has to be imported dynamically as pint imports xarray
    if isinstance(data, array_type("pint")):
        data = data.magnitude
    if isinstance(data, array_type("sparse")):
        data = data.todense()
    data = np.asarray(data)

    return data


def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, _DType]:
    from xarray.core.indexing import (
        ExplicitlyIndexed,
        ImplicitToExplicitIndexingAdapter,
    )
    from xarray.namedarray.parallelcompat import get_chunked_array_type

    if is_chunked_array(data):
        chunkmanager = get_chunked_array_type(data)
        loaded_data, *_ = chunkmanager.compute(data, **kwargs)  # type: ignore[var-annotated]
        return loaded_data

    if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter):
        return data.get_duck_array()  # type: ignore[no-untyped-call, no-any-return]
    elif is_duck_array(data):
        return data
    else:
        return np.asarray(data)  # type: ignore[return-value]


async def async_to_duck_array(
    data: Any, **kwargs: dict[str, Any]
) -> duckarray[_ShapeType, _DType]:
    from xarray.core.indexing import (
        ExplicitlyIndexed,
        ImplicitToExplicitIndexingAdapter,
    )

    if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter):
        return await data.async_get_duck_array()  # type: ignore[union-attr, no-any-return]
    else:
        return to_duck_array(data, **kwargs)