File: index.py

package info (click to toggle)
python-awkward 2.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 35,360 kB
  • sloc: python: 187,941; cpp: 33,672; sh: 432; ansic: 256; makefile: 21; javascript: 8
file content (326 lines) | stat: -rw-r--r-- 10,816 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import copy

import awkward as ak
from awkward._nplikes import to_nplike
from awkward._nplikes.array_like import ArrayLike, maybe_materialize
from awkward._nplikes.cupy import Cupy
from awkward._nplikes.dispatch import nplike_of_obj
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyLike, NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._nplikes.typetracer import TypeTracer
from awkward._nplikes.virtual import VirtualNDArray
from awkward._slicing import normalize_slice
from awkward._typing import Any, DType, Final, Self, cast

np: Final = NumpyMetadata.instance()
numpy: Final = Numpy.instance()


_dtype_to_form: Final[dict[DType, str]] = {
    np.dtype(np.int8): "i8",
    np.dtype(np.uint8): "u8",
    np.dtype(np.int32): "i32",
    np.dtype(np.uint32): "u32",
    np.dtype(np.int64): "i64",
}

_form_to_dtype: Final[dict[str, DType]] = {v: k for k, v in _dtype_to_form.items()}


def _form_to_zero_length(form: str) -> Index:
    try:
        dtype = _form_to_dtype[form]
    except KeyError:
        raise AssertionError(f"unrecognized Index form: {form!r}") from None
    return Index(numpy.zeros(0, dtype=dtype))


class Index:
    _expected_dtype: DType | None = None

    def __init__(
        self,
        data,
        *,
        metadata: dict | None = None,
        nplike: NumpyLike | None = None,
    ):
        assert not isinstance(data, Index)
        if nplike is None:
            self._nplike = cast(
                "NumpyLike[ArrayLike]", nplike_of_obj(data, default=Numpy.instance())
            )
        else:
            self._nplike = nplike

        if metadata is not None and not isinstance(metadata, dict):
            raise TypeError("Index metadata must be None or a dict")
        self._metadata = metadata
        # We don't care about F, C (it's one dimensional), but we do need
        # the array to be contiguous. This should _not_ return a copy if already
        self._data = self._nplike.ascontiguousarray(
            self._nplike.asarray(data, dtype=self._expected_dtype)
        )

        if len(ak._util.maybe_shape_of(self._data)) != 1:
            raise TypeError("Index data must be one-dimensional")

        if np.issubdtype(self._data.dtype, np.longlong):
            assert np.dtype(np.longlong).itemsize == 8, (
                "longlong is always 64-bit, right?"
            )

            self._data = self._data.view(np.int64)

        if self._expected_dtype is None:
            if self._data.dtype == np.dtype(np.int8):
                self.__class__ = Index8
            elif self._data.dtype == np.dtype(np.uint8):
                self.__class__ = IndexU8
            elif self._data.dtype == np.dtype(np.int32):
                self.__class__ = Index32
            elif self._data.dtype == np.dtype(np.uint32):
                self.__class__ = IndexU32
            elif self._data.dtype == np.dtype(np.int64):
                self.__class__ = Index64
            else:
                raise TypeError(
                    "Index data must be int8, uint8, int32, uint32, int64, not "
                    + repr(self._data.dtype)
                )
        else:
            if self._data.dtype != self._expected_dtype:
                # self._data = self._data.astype(self._expected_dtype)   # copy/convert
                raise NotImplementedError(
                    "while developing, we want to catch these errors"
                )

    @classmethod
    def zeros(
        cls, length: ShapeItem, nplike: NumpyLike, dtype: DType | None = None
    ) -> Index:
        if dtype is None:
            dtype = cls._expected_dtype
        return Index(nplike.zeros(length, dtype=dtype), nplike=nplike)

    @classmethod
    def empty(
        cls, length: ShapeItem, nplike: NumpyLike, dtype: DType | None = None
    ) -> Index:
        if dtype is None:
            dtype = cls._expected_dtype
        return Index(nplike.empty(length, dtype=dtype), nplike=nplike)

    @property
    def data(self) -> ArrayLike:
        return self._data

    @property
    def nplike(self) -> NumpyLike:
        return self._nplike

    @property
    def dtype(self) -> DType:
        return self._data.dtype

    @property
    def metadata(self) -> dict:
        if self._metadata is None:
            self._metadata = {}
        return self._metadata

    @property
    def ptr(self):
        return self._nplike.memory_ptr(self._data)

    @property
    def length(self) -> ShapeItem:
        return self._data.shape[0]

    def forget_length(self) -> Self:
        tt = TypeTracer.instance()
        if isinstance(self._nplike, type(tt)):
            data = self._data
        else:
            data = self.raw(tt)

        assert hasattr(data, "forget_length")
        return type(self)(data.forget_length(), metadata=self._metadata, nplike=tt)

    def raw(self, nplike: NumpyLike) -> ArrayLike:
        return to_nplike(self.data, nplike, from_nplike=self._nplike)

    def materialize(self, type_) -> Index:
        (out,) = maybe_materialize(self._data, type_=type_)
        return Index(out, metadata=self.metadata, nplike=self._nplike)

    @property
    def is_all_materialized(self) -> bool:
        buffer = self._data
        if isinstance(buffer, VirtualNDArray):
            return buffer.is_materialized
        return True

    @property
    def is_any_materialized(self) -> bool:
        buffer = self._data
        if isinstance(buffer, VirtualNDArray):
            return buffer.is_materialized
        return True

    def __len__(self) -> int:
        return int(self.length)

    @property
    def __cuda_array_interface__(self):
        return self._data.__cuda_array_interface__  # type: ignore[attr-defined]

    @property
    def __array_interface__(self):
        return self._data.__array_interface__  # type: ignore[attr-defined]

    def __dlpack_device__(self) -> tuple[int, int]:
        return self._data.__dlpack_device__()  # type: ignore[attr-defined]

    def __dlpack__(self, stream: Any = None) -> Any:
        if stream is None:
            return self._data.__dlpack__()  # type: ignore[attr-defined]
        else:
            return self._data.__dlpack__(stream=stream)  # type: ignore[attr-defined]

    def __repr__(self) -> str:
        return self._repr("", "", "")

    def _repr(self, indent: str, pre: str, post: str) -> str:
        out = [indent, pre, "<Index dtype="]
        out.append(repr(str(self.dtype)))
        out.append(" len=")
        out.append(repr(str(ak._util.maybe_length_of(self))))

        arraystr_lines = self._nplike.array_str(self._data, max_line_width=30).split(
            "\n"
        )

        if len(arraystr_lines) > 1 or self._metadata is not None:
            arraystr_lines = self._nplike.array_str(
                self._data, max_line_width=max(80 - len(indent) - 4, 40)
            ).split("\n")
            if len(arraystr_lines) > 5:
                arraystr_lines = [*arraystr_lines[:2], " ...", *arraystr_lines[-2:]]
            out.append(">\n" + indent + "    ")
            if self._metadata is not None:
                for k, v in self._metadata.items():
                    out.append(
                        f"<metadata key={k!r}>{v!r}</metadata>\n" + indent + "    "
                    )
            out.append(("\n" + indent + "    ").join(arraystr_lines))
            out.append("\n" + indent + "</Index>")
        else:
            if len(arraystr_lines) > 5:
                arraystr_lines = [*arraystr_lines[:2], " ...", *arraystr_lines[-2:]]
            out.append(">")
            out.append(arraystr_lines[0])
            out.append("</Index>")

        out.append(post)
        return "".join(out)

    @property
    def form(self) -> str:
        return _dtype_to_form[self._data.dtype]

    def __getitem__(self, where):
        if isinstance(where, slice):
            where = normalize_slice(where, nplike=self.nplike)

            # in non-typetracer mode (and if all lengths are known) we can check if the slice is a no-op
            # (i.e. slicing the full array) and shortcut to avoid noticeable python overhead
            if self._nplike.known_data and (
                where.step == 1 and where.start == 0 and where.stop == self.length
            ):
                return self

        out = self._data[where]

        if hasattr(out, "shape") and len(out.shape) != 0:
            return Index(out, metadata=self.metadata, nplike=self._nplike)
        elif (Jax.is_own_array(out) or Cupy.is_own_array(out)) and len(out.shape) == 0:
            return out.item()
        else:
            return out

    def __setitem__(self, where, what):
        (data, where, what) = maybe_materialize(self._data, where, what)
        if isinstance(self._nplike, Jax):
            new_data = data.at[where].set(what)
            if isinstance(self._data, VirtualNDArray):
                self._data._array = new_data
            else:
                self._data = new_data
        else:
            self._data[where] = what

    def to64(self) -> Index:
        return Index(self._nplike.astype(self._data, dtype=np.int64))

    def __copy__(self) -> Self:
        return type(self)(self._data, metadata=self._metadata, nplike=self._nplike)

    def __deepcopy__(self, memo: dict) -> Self:
        return type(self)(
            copy.deepcopy(self._data, memo),
            metadata=copy.deepcopy(self._metadata, memo),
            nplike=self._nplike,
        )

    def _nbytes_part(self) -> ShapeItem:
        return self.data.nbytes

    def to_nplike(self, nplike: NumpyLike) -> Self:
        return type(self)(self.raw(nplike), metadata=self.metadata, nplike=nplike)

    def is_equal_to(
        self, other: Any, index_dtype: bool = True, numpyarray: bool = True
    ) -> bool:
        if index_dtype:
            return (
                not self._nplike.known_data
                or self._nplike.array_equal(self.data, other.data)
            ) and self._data.dtype == other.data.dtype

        else:
            return self._nplike.array_equal(self.data, other.data)

    def _touch_data(self):
        if hasattr(self._data, "touch_data"):
            self._data.touch_data()

    def _touch_shape(self):
        if hasattr(self._data, "touch_shape"):
            self._data.touch_shape()


class Index8(Index):
    _expected_dtype = np.dtype(np.int8)


class IndexU8(Index):
    _expected_dtype = np.dtype(np.uint8)


class Index32(Index):
    _expected_dtype = np.dtype(np.int32)


class IndexU32(Index):
    _expected_dtype = np.dtype(np.uint32)


class Index64(Index):
    _expected_dtype = np.dtype(np.int64)