1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import numpy as np
import numpy.testing
import pytest
import awkward as ak
jax = pytest.importorskip("jax")
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
ak.jax.register_and_check()
# #### ak.contents.NumpyArray ####
test_regulararray = ak.Array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], backend="jax"
)
test_regulararray_tangent = ak.Array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], backend="jax"
)
test_regulararray_jax = jax.numpy.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64
)
test_regulararray_tangent_jax = jax.numpy.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64
)
@pytest.mark.parametrize("axis", [0, 1, None])
@pytest.mark.parametrize("func_ak", [ak.sum, ak.prod, ak.min, ak.max])
def test_reducer(func_ak, axis):
func_jax = getattr(jax.numpy, func_ak.__name__)
def func_ak_with_axis(x):
return func_ak(x, axis=axis)
def func_jax_with_axis(x):
return func_jax(x, axis=axis)
value_jvp, jvp_grad = jax.jvp(
func_ak_with_axis, (test_regulararray,), (test_regulararray_tangent,)
)
value_jvp_jax, jvp_grad_jax = jax.jvp(
func_jax_with_axis, (test_regulararray_jax,), (test_regulararray_tangent_jax,)
)
value_vjp, vjp_func = jax.vjp(func_ak_with_axis, test_regulararray)
value_vjp_jax, vjp_func_jax = jax.vjp(func_jax_with_axis, test_regulararray_jax)
numpy.testing.assert_allclose(
ak.to_list(value_jvp), value_jvp_jax.tolist(), rtol=1e-9, atol=np.inf
)
numpy.testing.assert_allclose(
ak.to_list(value_vjp), value_vjp_jax.tolist(), rtol=1e-9, atol=np.inf
)
numpy.testing.assert_allclose(
ak.to_list(jvp_grad), jvp_grad_jax.tolist(), rtol=1e-9, atol=np.inf
)
numpy.testing.assert_allclose(
ak.to_list(vjp_func(value_vjp)[0]),
(vjp_func_jax(value_vjp_jax)[0]).tolist(),
rtol=1e-9,
atol=np.inf,
)
@pytest.mark.parametrize("axis", [None])
@pytest.mark.parametrize("func_ak", [ak.any, ak.all])
def test_bool_returns(func_ak, axis):
func_jax = getattr(jax.numpy, func_ak.__name__)
def func_ak_with_axis(x):
return func_ak(x, axis=axis)
def func_jax_with_axis(x):
return func_jax(x, axis=axis)
value_jvp, jvp_grad = jax.jvp(
func_ak_with_axis, (test_regulararray,), (test_regulararray_tangent,)
)
value_jvp_jax, jvp_grad_jax = jax.jvp(
func_jax_with_axis, (test_regulararray_jax,), (test_regulararray_tangent_jax,)
)
value_vjp, vjp_func = jax.vjp(func_ak_with_axis, test_regulararray)
value_vjp_jax, vjp_func_jax = jax.vjp(func_jax_with_axis, test_regulararray_jax)
assert jvp_grad.dtype == jvp_grad_jax.dtype
assert value_jvp.tolist() == value_jvp_jax.tolist()
assert value_vjp.tolist() == value_vjp_jax.tolist()
numpy.testing.assert_allclose(
ak.to_list(vjp_func(value_vjp)[0]),
(vjp_func_jax(value_vjp_jax)[0]).tolist(),
rtol=1e-9,
atol=np.inf,
)
@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("func_ak", [ak.any, ak.all])
def test_bool_raises(func_ak, axis):
def func_with_axis(x):
return func_ak(x, axis=axis)
with pytest.raises(
TypeError, match=".*Make sure that you are not computing the derivative.*"
):
jax.jvp(func_with_axis, (test_regulararray,), (test_regulararray_tangent,))
|