1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
import pytest
import autoray as ar
from .test_autoray import BACKENDS
@pytest.mark.parametrize(
"backend",
[
b
for b in BACKENDS
if any(
b.values[0] == other
for other in (
"numpy",
"jax",
"torch",
"cupy",
"dask",
)
)
],
)
@pytest.mark.parametrize(
"dist,args,kwargs",
[
("binomial", (7, 0.424), {"size": (3, 4)}),
("choice", ([11.1 * i for i in range(100)],), {"size": (3, 4)}),
("choice", ([11.1 * i for i in range(1000)],), {}),
("exponential", (), {"size": (3, 4)}),
("exponential", (), {}),
("gumbel", (), {"size": (3, 4)}),
("gumbel", (), {}),
("integers", (100, 1000), {"size": (3, 4)}),
("integers", (100, 1000), {}),
("normal", (), {"size": (3, 4)}),
("normal", (), {}),
("permutation", ([11.1 * i for i in range(100)],), {}),
("poisson", (100,), {"size": (3, 4, 5)}),
("random", (), {"size": (3, 4)}),
("random", (), {}),
("uniform", (), {"size": (3, 4)}),
("uniform", (), {}),
],
)
def test_random_default_rng(backend, dist, args, kwargs):
if dist in ("choice", "permutation"):
args = (ar.do("array", args[0], like=backend), *args[1:])
if dist == "permutation" and backend == "dask":
pytest.xfail("bug: https://github.com/dask/dask/issues/12029")
if backend == "torch" and dist in (
"binomial",
"choice",
"exponential",
"gumbel",
"poisson",
):
pytest.xfail(f"torch: no {dist} interface yet.")
if backend == "cupy" and dist in (
"choice",
"gumbel",
"normal",
"permutation",
):
pytest.xfail(f"torch: no {dist} interface yet.")
seed = 42
seed2 = 43
rng = ar.do("random.default_rng", seed, like=backend)
x = ar.do("to_numpy", getattr(rng, dist)(*args, **kwargs))
if "size" in kwargs:
assert ar.do("shape", x) == kwargs["size"]
y = ar.do("to_numpy", getattr(rng, dist)(*args, **kwargs))
assert not ar.do("allclose", x, y)
rng = ar.do("random.default_rng", seed2, like=backend)
z = ar.do("to_numpy", getattr(rng, dist)(*args, **kwargs))
assert not ar.do("allclose", x, z)
rng = ar.do("random.default_rng", seed, like=backend)
x2 = ar.do("to_numpy", getattr(rng, dist)(*args, **kwargs))
assert ar.do("allclose", x, x2)
def test_jax_jit_random():
pytest.importorskip("jax")
@ar.autojit(backend="jax")
def f(seed):
rng = ar.do("random.default_rng", seed)
return rng.normal(size=(3, 4))
x1 = ar.do("to_numpy", f(ar.do("array", 42)))
x2 = ar.do("to_numpy", f(ar.do("array", 42)))
assert ar.do("allclose", x1, x2)
x3 = ar.do("to_numpy", f(ar.do("array", 43)))
assert not ar.do("allclose", x1, x3)
|