File: jax.py

package info (click to toggle)
python-awkward 2.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 35,360 kB
  • sloc: python: 187,941; cpp: 33,672; sh: 432; ansic: 256; makefile: 21; javascript: 8
file content (143 lines) | stat: -rw-r--r-- 4,155 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import enum
import threading
import weakref

import awkward as ak
from awkward import highlevel
from awkward._nplikes.numpy import Numpy
from awkward._typing import TypeVar

numpy = Numpy.instance()


def assert_never(arg) -> None:
    raise AssertionError(f"this should never be run: {arg}")


class _RegistrationState(enum.Enum):
    INIT = enum.auto()
    SUCCESS = enum.auto()
    FAILED = enum.auto()


_registration_lock = threading.RLock()
_registration_state = _RegistrationState.INIT


def register_and_check():
    """
    Register Awkward Array node types with JAX's tree mechanism.
    """
    try:
        import jax  # noqa: TID251

        # ak.from_buffers needs this
        jax.config.update("jax_enable_x64", True)

    except ModuleNotFoundError:
        raise ModuleNotFoundError(
            """install the 'jax' package with:

        python3 -m pip install jax jaxlib

    or

        conda install -c conda-forge jax jaxlib
    """
        ) from None

    _register()


HighLevelType = TypeVar(
    "HighLevelType", bound="type[highlevel.Array | highlevel.Record]"
)


_known_highlevel_classes = weakref.WeakSet([highlevel.Array, highlevel.Record])


def register_behavior_class(cls: HighLevelType):
    """
    Args:
        cls: behavior class to register with JAX

    Register the behavior class with JAX, if JAX integration is enabled. Otherwise,
    queue the type for subsequent registration when/if JAX is registered.
    """
    # Acquire lock so that we know registration has completed
    with _registration_lock:
        if _registration_state == _RegistrationState.SUCCESS:
            # Safe to invoke JAX code here
            import awkward._connect.jax as jax_connect

            jax_connect.register_pytree_class(cls)
        else:
            _known_highlevel_classes.add(cls)


def _register():
    """
    Register Awkward Array node types with JAX's tree mechanism.
    """
    global _registration_state
    # Require that no threads are trying to register before checking this flag
    with _registration_lock:
        if _registration_state != _RegistrationState.INIT:
            return

        try:
            import awkward._connect.jax as jax_connect

            for cls in [
                ak.contents.BitMaskedArray,
                ak.contents.ByteMaskedArray,
                ak.contents.EmptyArray,
                ak.contents.IndexedArray,
                ak.contents.IndexedOptionArray,
                ak.contents.NumpyArray,
                ak.contents.ListArray,
                ak.contents.ListOffsetArray,
                ak.contents.RecordArray,
                ak.contents.UnionArray,
                ak.contents.UnmaskedArray,
                ak.record.Record,
            ]:
                jax_connect.register_pytree_class(cls)

            for cls in _known_highlevel_classes:
                jax_connect.register_pytree_class(cls)
        except Exception:
            _registration_state = _RegistrationState.FAILED
            raise
        else:
            _registration_state = _RegistrationState.SUCCESS


def assert_registered():
    """Ensure that JAX integration is registered. Raise a RuntimeError if not."""
    with _registration_lock:
        if _registration_state == _RegistrationState.INIT:
            raise RuntimeError("JAX features require `ak.jax.register_and_check()`")
        elif _registration_state == _RegistrationState.FAILED:
            raise RuntimeError(
                "JAX features require `ak.jax.register_and_check()`, "
                "but the last call to `ak.jax.register_and_check()` did not succeed. "
                "Please look for a traceback to identify the error."
            )
        elif _registration_state == _RegistrationState.SUCCESS:
            return

        assert_never(_registration_state)


def import_jax():
    """Ensure that JAX integration is registered, and return the JAX module. Raise a RuntimeError if not."""
    assert_registered()
    import jax  # noqa: TID251

    return jax