File: test_2650_arrow_array_highlevel.py

package info (click to toggle)
python-awkward 2.6.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,088 kB
  • sloc: python: 148,689; cpp: 33,562; sh: 432; makefile: 21; javascript: 8
file content (48 lines) | stat: -rw-r--r-- 1,424 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest
from packaging.version import parse as parse_version

import awkward as ak

pyarrow = pytest.importorskip("pyarrow")
pyarrow_types = pytest.importorskip("pyarrow.types")


def test_simple_conversion():
    array = ak.Array(
        ak.contents.RecordArray(
            [
                ak.to_layout(np.arange(6, dtype=np.uint8)),
                ak.to_layout(np.arange(6, dtype=np.float64)),
            ],
            ["x", "y"],
        )
    )
    arrow_array = pyarrow.array(array)

    assert arrow_array.tolist() == array.to_list()


@pytest.mark.skipif(
    parse_version(pyarrow.__version__) < parse_version("12.0.0"),
    reason="pyarrow >= 12.0.0 required for casting test",
)
def test_type_cast():
    array = ak.mask(np.array([1, 2, 3], dtype=np.uint8), [True, False, False])
    arrow_array = pyarrow.array(array, type=pyarrow.float64())
    assert pyarrow_types.is_float64(arrow_array.type)

    arrow_array = pyarrow.array(array)
    assert pyarrow_types.is_uint8(arrow_array.type)


def test_zero_copy():
    np_array = np.array([1, 2, 3], dtype=np.uint32)
    ak_array = ak.from_numpy(np_array)
    arrow_array = pyarrow.array(ak_array)
    assert pyarrow_types.is_uint32(arrow_array.type)
    assert np.shares_memory(np_array, arrow_array.to_numpy())