File: test_1504_typetracer_like.py

package info (click to toggle)
python-awkward 2.6.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,088 kB
  • sloc: python: 148,689; cpp: 33,562; sh: 432; makefile: 21; javascript: 8
file content (46 lines) | stat: -rw-r--r-- 1,853 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest

import awkward as ak
from awkward._nplikes.typetracer import TypeTracer

typetracer = TypeTracer.instance()


@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None])
def test_ones_like(dtype, like_dtype):
    array = ak.contents.numpyarray.NumpyArray(
        np.array([99, 88, 77, 66, 66], dtype=dtype)
    )
    ones = ak.ones_like(array.to_typetracer(), dtype=like_dtype, highlevel=False)
    assert ones.to_typetracer().shape == array.shape
    assert ones.to_typetracer().dtype == like_dtype or array.dtype


@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None])
def test_zeros_like(dtype, like_dtype):
    array = ak.contents.numpyarray.NumpyArray(
        np.array([99, 88, 77, 66, 66], dtype=dtype)
    )

    full = ak.zeros_like(array.to_typetracer(), dtype=like_dtype, highlevel=False)
    assert full.to_typetracer().shape == array.shape
    assert full.to_typetracer().dtype == like_dtype or array.dtype


@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("value", [1.0, -20, np.iinfo(np.int64).max])
def test_full_like(dtype, like_dtype, value):
    array = ak.contents.numpyarray.NumpyArray(
        np.array([99, 88, 77, 66, 66], dtype=dtype)
    )
    full = ak.full_like(array.to_typetracer(), value, dtype=like_dtype, highlevel=False)
    assert full.to_typetracer().shape == array.shape
    assert full.to_typetracer().dtype == like_dtype or array.dtype