1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
|
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import numpy as np
import pytest
import awkward as ak
numba = pytest.importorskip("numba")
def test_field_name():
builder = ak.ArrayBuilder()
builder.begin_record("x")
builder.field("time").real(0.0)
builder.end_record()
@numba.njit
def func(builder):
builder.begin_record("x")
builder.field("time").real(2.0)
builder.end_record()
return builder
func(builder)
assert builder.type == ak.types.ArrayType(
ak.types.RecordType(
[ak.types.NumpyType("float64")], ["time"], parameters={"__record__": "x"}
),
2,
)
def test_no_field_name():
builder = ak.ArrayBuilder()
builder.begin_record()
builder.field("time").real(0.0)
builder.end_record()
@numba.njit
def func(builder):
builder.begin_record()
builder.field("time").real(2.0)
builder.end_record()
return builder
func(builder)
result = builder.snapshot()
assert ak.almost_equal(
result,
ak.contents.RecordArray(
fields=["time"],
contents=[ak.contents.NumpyArray(np.array([0, 2], dtype=np.float64))],
),
)
|