1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
"""
import itertools
import pytest
import torch
from array_api_compat import torch as xp
class TestResultType:
def test_empty(self):
with pytest.raises(ValueError):
xp.result_type()
def test_one_arg(self):
for x in [1, 1.0, 1j, '...', None]:
with pytest.raises((ValueError, AttributeError)):
xp.result_type(x)
for x in [xp.float32, xp.int64, torch.complex64]:
assert xp.result_type(x) == x
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
assert xp.result_type(x) == x.dtype
def test_two_args(self):
# Only include here things "unspecified" in the spec
# scalar, tensor or tensor,tensor
for x, y in [
(1., 1j),
(1j, xp.arange(3)),
(True, xp.asarray(3.)),
(xp.ones(3) == 1, 1j*xp.ones(3)),
]:
assert xp.result_type(x, y) == torch.result_type(x, y)
# dtype, scalar
for x, y in [
(1j, xp.int64),
(True, xp.float64),
]:
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
# dtype, dtype
for x, y in [
(xp.bool, xp.complex64)
]:
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
assert xp.result_type(x, y) == torch.result_type(xt, yt)
def test_multi_arg(self):
torch.set_default_dtype(torch.float32)
args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
assert xp.result_type(*args) == torch.float16
args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
assert xp.result_type(*args) == xp.complex64
args = [1, 2, 3j, xp.float64, 4, 5, 6]
assert xp.result_type(*args) == xp.complex128
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
assert xp.result_type(*args) == xp.complex128
i64 = xp.ones(1, dtype=xp.int64)
f16 = xp.ones(1, dtype=xp.float16)
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
assert xp.result_type(*i) == xp.float16, f"{i}"
with pytest.raises(ValueError):
xp.result_type(1, 2, 3, 4)
@pytest.mark.parametrize("default_dt", ['float32', 'float64'])
@pytest.mark.parametrize("dtype_a",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
@pytest.mark.parametrize("dtype_b",
(xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
)
def test_gh_273(self, default_dt, dtype_a, dtype_b):
# Regression test for https://github.com/data-apis/array-api-compat/issues/273
try:
prev_default = torch.get_default_dtype()
default_dtype = getattr(torch, default_dt)
torch.set_default_dtype(default_dtype)
a = xp.asarray([2, 1], dtype=dtype_a)
b = xp.asarray([1, -1], dtype=dtype_b)
dtype_1 = xp.result_type(a, b, 1.0)
dtype_2 = xp.result_type(b, a, 1.0)
assert dtype_1 == dtype_2
finally:
torch.set_default_dtype(prev_default)
|