1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
import subprocess
import sys
import warnings
import numpy as np
import pytest
import torch
import array_api_compat
from array_api_compat import array_namespace
from ._helpers import import_, all_libraries, wrapped_libraries
@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
@pytest.mark.parametrize("library", all_libraries)
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)
array = xp.asarray([1.0, 2.0, 3.0])
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
pytest.skip("Unsupported API version")
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
if use_compat is False or use_compat is None and library not in wrapped_libraries:
if library == "jax.numpy" and use_compat is None:
import jax.numpy
if hasattr(jax.numpy, "__array_api_version__"):
# JAX v0.4.32 or later uses jax.numpy directly
assert namespace == jax.numpy
else:
# JAX v0.4.31 or earlier uses jax.experimental.array_api
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
else:
assert namespace == xp
else:
if library == "dask.array":
assert namespace == array_api_compat.dask.array
else:
assert namespace == getattr(array_api_compat, library)
if library == "numpy":
# check that the same namespace is returned for NumPy scalars
scalar_namespace = array_namespace(
xp.float64(0.0), api_version=api_version, use_compat=use_compat
)
assert scalar_namespace == namespace
# Check that array_namespace works even if jax.experimental.array_api
# hasn't been imported yet (it monkeypatches __array_namespace__
# onto JAX arrays, but we should support them regardless). The only way to
# do this is to use a subprocess, since we cannot un-import it and another
# test probably already imported it.
if library == "jax.numpy" and sys.version_info >= (3, 9):
code = f"""\
import sys
import jax.numpy
import array_api_compat
array = jax.numpy.asarray([1.0, 2.0, 3.0])
assert 'jax.experimental.array_api' not in sys.modules
namespace = array_api_compat.array_namespace(array, api_version={api_version!r})
if hasattr(jax.numpy, '__array_api_version__'):
assert namespace == jax.numpy
else:
import jax.experimental.array_api
assert namespace == jax.experimental.array_api
"""
subprocess.run([sys.executable, "-c", code], check=True)
def test_jax_zero_gradient():
return
jx = jax.numpy.arange(4)
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
assert array_namespace(jax_zero) is array_namespace(jx)
def test_array_namespace_errors():
pytest.raises(TypeError, lambda: array_namespace([1]))
pytest.raises(TypeError, lambda: array_namespace())
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace((x, x)))
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
def test_array_namespace_errors_torch():
y = torch.asarray([1, 2])
x = np.asarray([1, 2])
pytest.raises(TypeError, lambda: array_namespace(x, y))
def test_api_version_torch():
x = torch.asarray([1, 2])
torch_ = import_("torch", wrapper=True)
assert array_namespace(x, api_version="2023.12") == torch_
assert array_namespace(x, api_version=None) == torch_
assert array_namespace(x) == torch_
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2021.12") == torch_
assert len(w) == 1
assert "2021.12" in str(w[0].message)
# Should issue a warning
with warnings.catch_warnings(record=True) as w:
assert array_namespace(x, api_version="2022.12") == torch_
assert len(w) == 1
assert "2022.12" in str(w[0].message)
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
def test_get_namespace():
# Backwards compatible wrapper
assert array_api_compat.get_namespace is array_namespace
def test_python_scalars():
a = torch.asarray([1, 2])
xp = import_("torch", wrapper=True)
pytest.raises(TypeError, lambda: array_namespace(1))
pytest.raises(TypeError, lambda: array_namespace(1.0))
pytest.raises(TypeError, lambda: array_namespace(1j))
pytest.raises(TypeError, lambda: array_namespace(True))
pytest.raises(TypeError, lambda: array_namespace(None))
assert array_namespace(a, 1) == xp
assert array_namespace(a, 1.0) == xp
assert array_namespace(a, 1j) == xp
assert array_namespace(a, True) == xp
assert array_namespace(a, None) == xp
|