1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import numpy as np # noqa: F401
import pytest
import awkward as ak
numba = pytest.importorskip("numba")
ak.numba.register_and_check()
def test_string():
array = ak.highlevel.Array(["one", "two", "three", "four", "five"])
def f1(x, i):
return x[i]
assert f1(array, 0) == "one"
assert f1(array, 1) == "two"
assert f1(array, 2) == "three"
f1 = numba.njit(f1)
assert f1(array, 0) == "one"
assert f1(array, 1) == "two"
assert f1(array, 2) == "three"
def f2(x, i, j):
return x[i] + x[j]
assert f2(array, 1, 3) == "twofour"
assert numba.njit(f2)(array, 1, 3) == "twofour"
def test_bytestring():
array = ak.highlevel.Array([b"one", b"two", b"three", b"four", b"five"])
def f1(x, i):
return x[i]
assert f1(array, 0) == b"one"
assert f1(array, 1) == b"two"
assert f1(array, 2) == b"three"
f1 = numba.njit(f1)
assert f1(array, 0) == b"one"
assert f1(array, 1) == b"two"
assert f1(array, 2) == b"three"
|