1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
import pytest
from autoray import do, autojit, infer_backend, to_numpy, shape
from .test_autoray import BACKENDS, gen_rand
from numpy.testing import assert_allclose
BACKENDS = [
p for p in BACKENDS if p.values[0] in ("jax", "torch", "tensorflow")
]
def modified_gram_schmidt(X):
Q = []
for j in range(0, shape(X)[0]):
q = X[j, :]
for i in range(0, j):
rij = do("tensordot", do("conj", Q[i]), q, axes=1)
q = q - rij * Q[i]
rjj = do("linalg.norm", q, 2)
Q.append(q / rjj)
return do("stack", tuple(Q), axis=0)
@pytest.fixture
def mgs_case():
x = gen_rand((10, 10), "numpy")
y = modified_gram_schmidt(x)
return x, y
@pytest.mark.parametrize("share_intermediates", [False, True])
@pytest.mark.parametrize("nested", [False, True])
def test_compile_python(mgs_case, share_intermediates, nested):
x, y = mgs_case
compiler_opts = {"python": {"share_intermediates": share_intermediates}}
mgs = autojit(modified_gram_schmidt, compiler_opts=compiler_opts)
if nested:
mgs = autojit(mgs, compiler_opts=compiler_opts)
y2 = mgs(x)
assert_allclose(y, y2)
@pytest.mark.parametrize("backend", BACKENDS)
def test_others_numpy(backend, mgs_case):
x, y = mgs_case
mgs = autojit(modified_gram_schmidt)
y2 = mgs(x, backend=backend)
assert infer_backend(y2) == "numpy"
assert_allclose(y, y2)
@pytest.mark.parametrize("backend", BACKENDS)
def test_autodispatch(backend, mgs_case):
x, y = mgs_case
x = do("array", x, like=backend)
mgs = autojit(modified_gram_schmidt)
y2 = mgs(x, backend=backend)
assert infer_backend(y2) == backend
assert_allclose(y, to_numpy(y2))
def test_complicated_signature():
@autojit
def foo(a, b, c):
a1, a2 = a
b1 = b["1"]
c1, c2 = c["sub"]
return do("sum", do("stack", (a1, a2, b1, c1, c2)), axis=0)
x = do("random.uniform", size=(5, 7), like="numpy")
y = foo((x[0, :], x[1, :]), {"1": x[2, :]}, c={"sub": (x[3, :], x[4, :])})
assert_allclose(y, x.sum(0))
def test_multi_output():
@autojit
def foo(a, b, c):
a = a - do("sum", b)
b = b - do("sum", a)
return a + c, b - c
a = gen_rand((2, 3), "numpy")
b = gen_rand((4, 5), "numpy")
x, y = foo(a, b, 1)
assert_allclose(x, a - b.sum() + 1)
assert_allclose(y, b - (a - b.sum()).sum() - 1)
|