1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
|
"""
This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access.
"""
from collections.abc import Callable, Iterable
from typing import Any, Self
import numpy as np
from xarray.core import utils
from xarray.core.indexing import ExplicitlyIndexed
class UnexpectedDataAccess(Exception):
pass
class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed):
"""Disallows any loading."""
def __init__(self, array):
self.array = array
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")
def __getitem__(self, key):
raise UnexpectedDataAccess("Tried accessing data.")
class FirstElementAccessibleArray(InaccessibleArray):
def __getitem__(self, key):
tuple_idxr = key.tuple
if len(tuple_idxr) > 1:
raise UnexpectedDataAccess("Tried accessing more than one element.")
return self.array[tuple_idxr]
class DuckArrayWrapper(utils.NDArrayMixin):
"""Array-like that prevents casting to array.
Modeled after cupy."""
def __init__(self, array: np.ndarray):
self.array = array
def __getitem__(self, key):
return type(self)(self.array[key])
def to_numpy(self) -> np.ndarray:
"""Allow explicit conversions to numpy in `to_numpy`, but disallow np.asarray etc."""
return self.array
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")
def __array_namespace__(self):
"""Present to satisfy is_duck_array test."""
from xarray.tests import namespace
return namespace
CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {}
def implements(numpy_function):
"""Register an __array_function__ implementation for ConcatenatableArray objects."""
def decorator(func):
CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func
return func
return decorator
@implements(np.concatenate)
def concatenate(
arrays: Iterable["ConcatenatableArray"], /, *, axis=0
) -> "ConcatenatableArray":
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
raise TypeError
result = np.concatenate([arr._array for arr in arrays], axis=axis)
return ConcatenatableArray(result)
@implements(np.stack)
def stack(
arrays: Iterable["ConcatenatableArray"], /, *, axis=0
) -> "ConcatenatableArray":
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
raise TypeError
result = np.stack([arr._array for arr in arrays], axis=axis)
return ConcatenatableArray(result)
@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes)
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype
@implements(np.broadcast_to)
def broadcast_to(
x: "ConcatenatableArray", /, shape: tuple[int, ...]
) -> "ConcatenatableArray":
"""
Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
"""
if not isinstance(x, ConcatenatableArray):
raise TypeError
result = np.broadcast_to(x._array, shape=shape)
return ConcatenatableArray(result)
@implements(np.full_like)
def full_like(
x: "ConcatenatableArray", /, fill_value, **kwargs
) -> "ConcatenatableArray":
"""
Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
"""
if not isinstance(x, ConcatenatableArray):
raise TypeError
return ConcatenatableArray(np.full(x.shape, fill_value=fill_value, **kwargs))
@implements(np.all)
def numpy_all(x: "ConcatenatableArray", **kwargs) -> "ConcatenatableArray":
return type(x)(np.all(x._array, **kwargs))
class ConcatenatableArray:
"""Disallows loading or coercing to an index but does support concatenation / stacking."""
def __init__(self, array):
# use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit)
self._array = array
@property
def dtype(self: Any) -> np.dtype:
return self._array.dtype
@property
def shape(self: Any) -> tuple[int, ...]:
return self._array.shape
@property
def ndim(self: Any) -> int:
return self._array.ndim
def __repr__(self: Any) -> str:
return f"{type(self).__name__}(array={self._array!r})"
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")
def __getitem__(self, key) -> Self:
"""Some cases of concat require supporting expanding dims by dimensions of size 1"""
# see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing
arr = self._array
for axis, indexer_1d in enumerate(key):
if indexer_1d is None:
arr = np.expand_dims(arr, axis)
elif indexer_1d is Ellipsis:
pass
else:
raise UnexpectedDataAccess("Tried accessing data.")
return type(self)(arr)
def __eq__(self, other: Self) -> Self: # type: ignore[override]
return type(self)(self._array == other._array)
def __array_function__(self, func, types, args, kwargs) -> Any:
if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle ManifestArray objects
if not all(issubclass(t, ConcatenatableArray) for t in types):
return NotImplemented
return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
"""We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs."""
return NotImplemented
def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> Self:
"""Needed because xarray will call this even when it's a no-op"""
if dtype != self.dtype:
raise NotImplementedError()
else:
return self
def __and__(self, other: Self) -> Self:
return type(self)(self._array & other._array)
def __or__(self, other: Self) -> Self:
return type(self)(self._array | other._array)
|