File: test_2637_jax_tracer_error.py

package info (click to toggle)
python-awkward 2.6.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,088 kB
  • sloc: python: 148,689; cpp: 33,562; sh: 432; makefile: 21; javascript: 8
file content (37 lines) | stat: -rw-r--r-- 961 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations

import pytest

import awkward as ak

jax = pytest.importorskip("jax")


def test():
    ak.jax.register_and_check()

    jets = ak.Array(
        [
            [
                {"pt": 1.0, "eta": 1.1, "phi": 0.1, "mass": 0.01},
                {"pt": 2, "eta": 2.2, "phi": 0.2, "mass": 0.02},
            ],
            [
                {"pt": 4.0, "eta": 4.4, "phi": 0.4, "mass": 0.04},
                {"pt": 5.0, "eta": 5.5, "phi": 0.5, "mass": 0.05},
                {"pt": 6.0, "eta": 6.6, "phi": 0.6, "mass": 0.06},
            ],
        ],
        backend="jax",
    )

    def correct_jets(jets, alpha):
        new_pt = jets["pt"] + 25.0 * alpha
        jets["pt"] = new_pt
        return ak.sum(jets["pt"])

    val, grad = jax.value_and_grad(correct_jets, argnums=1)(jets, 0.1)

    assert val == 30.5
    assert grad == 125.0