File: test_2512_record_array_carry.py

package info (click to toggle)
python-awkward 2.6.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,088 kB
  • sloc: python: 148,689; cpp: 33,562; sh: 432; makefile: 21; javascript: 8
file content (162 lines) | stat: -rw-r--r-- 4,974 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np

import awkward as ak


def _reduce_max_masked(array, mask):
    assert mask
    j = ak.from_regular(
        ak.argmax(array["1"], axis=1, keepdims=True, mask_identity=True)
    )
    return ak.flatten(array[j], axis=1)


behavior = {}
behavior[ak.max, "pair"] = _reduce_max_masked


def test_axis_0():
    content = ak.contents.ListArray(
        ak.index.Index64([0, 2]),
        ak.index.Index64([2, 4]),
        ak.contents.ListOffsetArray(
            ak.index.Index64([0, 3, 6, 9, 11]),
            ak.contents.RecordArray(
                [
                    ak.contents.NumpyArray(np.arange(11, dtype=np.int64)),
                    ak.contents.NumpyArray(
                        np.array(
                            [
                                0.0,
                                2.0,
                                4.0,
                                6.0,
                                8.0,
                                10.0,
                                1.0,
                                14.0,
                                16.0,
                                18.0,
                                20.0,
                            ],
                            dtype=np.float64,
                        )
                    ),
                ],
                fields=None,
                parameters={"__record__": "pair"},
            ),
        ),
    )

    result = ak.max(
        content,
        axis=0,
        keepdims=True,
        mask_identity=True,
        behavior=behavior,
        highlevel=False,
    )

    expected_result = ak.contents.ListArray(
        ak.index.Index64([0]),
        ak.index.Index64([2]),
        ak.contents.ListArray(
            ak.index.Index64([0, 3]),
            ak.index.Index64([3, 6]),
            ak.contents.IndexedOptionArray(
                ak.index.Index64([1, 5, 9, 3, 7, 10]),
                ak.contents.RecordArray(
                    [
                        ak.contents.NumpyArray(
                            np.array([0, 6, 3, 9, 1, 7, 4, 10, 2, 8, 5], dtype=np.int64)
                        ),
                        ak.contents.NumpyArray(
                            np.array(
                                [0, 1, 6, 18, 2, 14, 8, 20, 4, 16, 10],
                                dtype=np.float64,
                            )
                        ),
                    ],
                    fields=None,
                    parameters={"__record__": "pair"},
                ),
            ),
        ),
    )
    assert result.is_equal_to(expected_result)


def test_axis_1():
    content = ak.contents.ListArray(
        ak.index.Index64([0, 2]),
        ak.index.Index64([2, 4]),
        ak.contents.ListOffsetArray(
            ak.index.Index64([0, 3, 6, 9, 11]),
            ak.contents.RecordArray(
                [
                    ak.contents.NumpyArray(np.arange(11, dtype=np.int64)),
                    ak.contents.NumpyArray(
                        np.array(
                            [
                                0.0,
                                2.0,
                                4.0,
                                6.0,
                                8.0,
                                10.0,
                                1.0,
                                14.0,
                                16.0,
                                18.0,
                                20.0,
                            ],
                            dtype=np.float64,
                        )
                    ),
                ],
                fields=None,
                parameters={"__record__": "pair"},
            ),
        ),
    )

    result = ak.max(
        content,
        axis=1,
        keepdims=True,
        mask_identity=True,
        behavior=behavior,
        highlevel=False,
    )

    expected_result = ak.contents.RegularArray(
        ak.contents.ListArray(
            ak.index.Index64([0, 3]),
            ak.index.Index64([3, 6]),
            ak.contents.IndexedOptionArray(
                ak.index.Index64([1, 5, 9, 3, 7, 10]),
                ak.contents.RecordArray(
                    [
                        ak.contents.NumpyArray(
                            np.array([0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8], dtype=np.int64)
                        ),
                        ak.contents.NumpyArray(
                            np.array(
                                [0, 6, 1, 18, 2, 8, 14, 20, 4, 10, 16],
                                dtype=np.float64,
                            )
                        ),
                    ],
                    fields=None,
                    parameters={"__record__": "pair"},
                ),
            ),
        ),
        size=1,
    )
    assert result.is_equal_to(expected_result)