File: test_3457_jax_setitems.py

package info (click to toggle)
python-awkward 2.8.9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 24,932 kB
  • sloc: python: 178,875; cpp: 33,828; sh: 432; makefile: 21; javascript: 8
file content (51 lines) | stat: -rw-r--r-- 1,948 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest

import awkward as ak

jax = pytest.importorskip("jax")
ak.jax.register_and_check()


def test_jax_ak_firsts():
    jax_array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]], backend="jax")
    jax_firsts = ak.firsts(jax_array)
    cpu_array = ak.to_backend(jax_array, "cpu")
    cpu_firsts = ak.firsts(cpu_array)
    assert jax_firsts.to_list() == cpu_firsts.to_list()


def test_jax_ak_unflatten():
    original = ak.Array([[0, 1, 2], [], [3, 4], [5], [6, 7, 8, 9]], backend="jax")
    jax_counts = ak.num(original)
    jax_array = ak.flatten(original)
    jax_unflatten = ak.unflatten(jax_array, jax_counts)
    cpu_counts = ak.to_backend(jax_counts, "cpu")
    cpu_array = ak.to_backend(jax_array, "cpu")
    cpu_unflatten = ak.unflatten(cpu_array, cpu_counts)
    assert jax_unflatten.to_list() == cpu_unflatten.to_list()


def test_jax_run_lengths():
    jax_array = ak.Array([1.1, 1.1, 1.1, 2.2, 3.3, 3.3, 4.4, 4.4, 5.5], backend="jax")
    jax_run_lengths = ak.run_lengths(jax_array)
    cpu_array = ak.to_backend(jax_array, "cpu")
    cpu_run_lengths = ak.run_lengths(cpu_array)
    assert jax_run_lengths.to_list() == cpu_run_lengths.to_list()


def test_jax_listarray_to_listoffsetarray64():
    content = ak.contents.NumpyArray(
        np.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
    )
    starts = ak.index.Index64(np.array([0, 3, 3, 5, 6]))
    stops = ak.index.Index64(np.array([3, 3, 5, 6, 9]))
    cpu_listarray = ak.contents.ListArray(starts, stops, content)
    jax_listarray = ak.to_backend(cpu_listarray, "jax", highlevel=False)
    cpu_listoffsetarray = ak.Array(cpu_listarray.to_ListOffsetArray64())
    jax_listoffsetarray = ak.Array(jax_listarray.to_ListOffsetArray64())
    assert cpu_listoffsetarray.to_list() == jax_listoffsetarray.to_list()