File: test_jax.py

package info (click to toggle)
nanobind 2.11.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,300 kB
  • sloc: cpp: 12,232; python: 6,315; ansic: 4,813; makefile: 22; sh: 15
file content (97 lines) | stat: -rw-r--r-- 2,547 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import test_ndarray_ext as t
import test_jax_ext as tj
import pytest
import warnings
import importlib
from common import collect

try:
    import jax.numpy as jnp
    def needs_jax(x):
        return x
except:
    needs_jax = pytest.mark.skip(reason="JAX is required")


@needs_jax
def test01_constrain_order():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        try:
            c = jnp.zeros((3, 5))
        except:
            pytest.skip('jax is missing')

    z = jnp.zeros((3, 5, 4, 6))
    assert t.check_order(z) == 'C'


@needs_jax
def test02_implicit_conversion():
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        try:
            c = jnp.zeros((3, 5))
        except:
            pytest.skip('jax is missing')

    t.implicit(jnp.zeros((2, 2), dtype=jnp.int32))
    t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.float32)[:, :, 4])
    t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.int32)[:, :, 4])
    t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.bool_)[:, :, 4])

    with pytest.raises(TypeError) as excinfo:
        t.noimplicit(jnp.zeros((2, 2), dtype=jnp.int32))

    with pytest.raises(TypeError) as excinfo:
        t.noimplicit(jnp.zeros((2, 2), dtype=jnp.uint8))


@needs_jax
def test03_return_jax():
    collect()
    dc = tj.destruct_count()
    x = tj.ret_jax()
    assert x.shape == (2, 4)
    assert jnp.all(x == jnp.array([[1,2,3,4], [5,6,7,8]], dtype=jnp.float32))
    del x
    collect()
    assert tj.destruct_count() - dc == 1


@needs_jax
def test04_check():
    assert t.check(jnp.zeros((1)))


@needs_jax
def test05_passthrough():
    a = tj.ret_jax()
    b = t.passthrough(a)
    assert a is b

    a = jnp.array([1, 2, 3])
    b = t.passthrough(a)
    assert a is b

    a = None
    with pytest.raises(TypeError) as excinfo:
        b = t.passthrough(a)
    assert 'incompatible function arguments' in str(excinfo.value)
    b = t.passthrough_arg_none(a)
    assert a is b


@needs_jax
def test06_ro_array():
    if (not hasattr(jnp, '__array_api_version__')
        or jnp.__array_api_version__ < '2024'):
        pytest.skip('jax version is too old')
    a = jnp.array([1, 2], dtype=jnp.float32)  # JAX arrays are immutable.
    assert t.accept_ro(a) == 1
    # If the next line fails, delete it, update the array_api_version above,
    # and uncomment the three lines below.
    assert t.accept_rw(a) == 1
    # with pytest.raises(TypeError) as excinfo:
    #     t.accept_rw(a)
    # assert 'incompatible function arguments' in str(excinfo.value)