File: test_torch.py

package info (click to toggle)
python-array-api-compat 1.11.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 708 kB
  • sloc: python: 3,954; sh: 16; makefile: 15
file content (98 lines) | stat: -rw-r--r-- 3,314 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
"""
import itertools

import pytest
import torch

from array_api_compat import torch as xp


class TestResultType:
    def test_empty(self):
        with pytest.raises(ValueError):
            xp.result_type()

    def test_one_arg(self):
        for x in [1, 1.0, 1j, '...', None]:
            with pytest.raises((ValueError, AttributeError)):
                xp.result_type(x)

        for x in [xp.float32, xp.int64, torch.complex64]:
            assert xp.result_type(x) == x

        for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
            assert xp.result_type(x) == x.dtype

    def test_two_args(self):
        # Only include here things "unspecified" in the spec

        # scalar, tensor or tensor,tensor
        for x, y in [
            (1., 1j),
            (1j, xp.arange(3)),
            (True, xp.asarray(3.)),
            (xp.ones(3) == 1, 1j*xp.ones(3)),
        ]:
            assert xp.result_type(x, y) == torch.result_type(x, y)

        # dtype, scalar
        for x, y in [
            (1j, xp.int64),
            (True, xp.float64),
        ]:
            assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))

        # dtype, dtype
        for x, y in [
            (xp.bool, xp.complex64)
        ]:
            xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
            assert xp.result_type(x, y) == torch.result_type(xt, yt)

    def test_multi_arg(self):
        torch.set_default_dtype(torch.float32)

        args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
        assert xp.result_type(*args) == torch.float16

        args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
        assert xp.result_type(*args) == xp.complex64

        args = [1, 2, 3j, xp.float64, 4, 5, 6]
        assert xp.result_type(*args) == xp.complex128

        args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
        assert xp.result_type(*args) == xp.complex128

        i64 = xp.ones(1, dtype=xp.int64)
        f16 = xp.ones(1, dtype=xp.float16)
        for i in itertools.permutations([i64, f16, 1.0, 1.0]):
            assert xp.result_type(*i) == xp.float16, f"{i}"

        with pytest.raises(ValueError):
            xp.result_type(1, 2, 3, 4)


    @pytest.mark.parametrize("default_dt", ['float32', 'float64'])
    @pytest.mark.parametrize("dtype_a",
        (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
    )
    @pytest.mark.parametrize("dtype_b", 
        (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
    )
    def test_gh_273(self, default_dt, dtype_a, dtype_b):
        # Regression test for https://github.com/data-apis/array-api-compat/issues/273

        try:
            prev_default = torch.get_default_dtype()
            default_dtype = getattr(torch, default_dt)
            torch.set_default_dtype(default_dtype)

            a = xp.asarray([2, 1], dtype=dtype_a)
            b = xp.asarray([1, -1], dtype=dtype_b)
            dtype_1 = xp.result_type(a, b, 1.0)
            dtype_2 = xp.result_type(b, a, 1.0)
            assert dtype_1 == dtype_2
        finally:
            torch.set_default_dtype(prev_default)