1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import enum
import threading
import weakref
import awkward as ak
from awkward import highlevel
from awkward._nplikes.numpy import Numpy
from awkward._typing import TypeVar
numpy = Numpy.instance()
def assert_never(arg) -> None:
raise AssertionError(f"this should never be run: {arg}")
class _RegistrationState(enum.Enum):
INIT = enum.auto()
SUCCESS = enum.auto()
FAILED = enum.auto()
_registration_lock = threading.RLock()
_registration_state = _RegistrationState.INIT
def register_and_check():
"""
Register Awkward Array node types with JAX's tree mechanism.
"""
try:
import jax # noqa: TID251
# ak.from_buffers needs this
jax.config.update("jax_enable_x64", True)
except ModuleNotFoundError:
raise ModuleNotFoundError(
"""install the 'jax' package with:
python3 -m pip install jax jaxlib
or
conda install -c conda-forge jax jaxlib
"""
) from None
_register()
HighLevelType = TypeVar(
"HighLevelType", bound="type[highlevel.Array | highlevel.Record]"
)
_known_highlevel_classes = weakref.WeakSet([highlevel.Array, highlevel.Record])
def register_behavior_class(cls: HighLevelType):
"""
Args:
cls: behavior class to register with JAX
Register the behavior class with JAX, if JAX integration is enabled. Otherwise,
queue the type for subsequent registration when/if JAX is registered.
"""
# Acquire lock so that we know registration has completed
with _registration_lock:
if _registration_state == _RegistrationState.SUCCESS:
# Safe to invoke JAX code here
import awkward._connect.jax as jax_connect
jax_connect.register_pytree_class(cls)
else:
_known_highlevel_classes.add(cls)
def _register():
"""
Register Awkward Array node types with JAX's tree mechanism.
"""
global _registration_state
# Require that no threads are trying to register before checking this flag
with _registration_lock:
if _registration_state != _RegistrationState.INIT:
return
try:
import awkward._connect.jax as jax_connect
for cls in [
ak.contents.BitMaskedArray,
ak.contents.ByteMaskedArray,
ak.contents.EmptyArray,
ak.contents.IndexedArray,
ak.contents.IndexedOptionArray,
ak.contents.NumpyArray,
ak.contents.ListArray,
ak.contents.ListOffsetArray,
ak.contents.RecordArray,
ak.contents.UnionArray,
ak.contents.UnmaskedArray,
ak.record.Record,
]:
jax_connect.register_pytree_class(cls)
for cls in _known_highlevel_classes:
jax_connect.register_pytree_class(cls)
except Exception:
_registration_state = _RegistrationState.FAILED
raise
else:
_registration_state = _RegistrationState.SUCCESS
def assert_registered():
"""Ensure that JAX integration is registered. Raise a RuntimeError if not."""
with _registration_lock:
if _registration_state == _RegistrationState.INIT:
raise RuntimeError("JAX features require `ak.jax.register_and_check()`")
elif _registration_state == _RegistrationState.FAILED:
raise RuntimeError(
"JAX features require `ak.jax.register_and_check()`, "
"but the last call to `ak.jax.register_and_check()` did not succeed. "
"Please look for a traceback to identify the error."
)
elif _registration_state == _RegistrationState.SUCCESS:
return
assert_never(_registration_state)
def import_jax():
"""Ensure that JAX integration is registered, and return the JAX module. Raise a RuntimeError if not."""
assert_registered()
import jax # noqa: TID251
return jax
|