1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
"""Required functions for optimized contractions of numpy arrays using jax."""
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = ["build_expression", "evaluate_constants"]
_JAX = None
def _get_jax_and_to_jax():
global _JAX
if _JAX is None:
import jax # type: ignore
@to_backend_cache_wrap
@jax.jit
def to_jax(x):
return x
_JAX = jax, to_jax
return _JAX
def build_expression(_, expr): # pragma: no cover
"""Build a jax function based on ``arrays`` and ``expr``."""
jax, _ = _get_jax_and_to_jax()
jax_expr = jax.jit(expr._contract)
def jax_contract(*arrays):
import numpy as np # type: ignore
return np.asarray(jax_expr(arrays))
return jax_contract
def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to jax arrays, and perform any possible
constant contractions.
"""
jax, to_jax = _get_jax_and_to_jax()
return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True)
|