File: test_nep50_examples.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (220 lines) | stat: -rw-r--r-- 6,755 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Owner(s): ["module: dynamo"]

"""Test examples for NEP 50."""

import itertools
from unittest import skipIf as skipif, SkipTest


try:
    import numpy as _np

    v = _np.__version__.split(".")
    HAVE_NUMPY = int(v[0]) >= 1 and int(v[1]) >= 24
except ImportError:
    HAVE_NUMPY = False

import torch._numpy as tnp
from torch._numpy import (  # noqa: F401
    array,
    bool_,
    complex128,
    complex64,
    float32,
    float64,
    inf,
    int16,
    int32,
    int64,
    uint8,
)
from torch._numpy.testing import assert_allclose
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TestCase,
)


uint16 = uint8  # can be anything here, see below


# from numpy import array, uint8, uint16, int64, float32, float64, inf
# from numpy.testing import assert_allclose
# import numpy as np
# np._set_promotion_state('weak')

from pytest import raises as assert_raises


unchanged = None

# expression    old result   new_result
examples = {
    "uint8(1) + 2": (int64(3), uint8(3)),
    "array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)),
    "array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)),
    "array([1.], float32) + float64(1.)": (
        array([2.0], float32),
        array([2.0], float64),
    ),
    "array([1.], float32) + array(1., float64)": (
        array([2.0], float32),
        array([2.0], float64),
    ),
    "array([1], uint8) + 1": (array([2], uint8), unchanged),
    "array([1], uint8) + 200": (array([201], uint8), unchanged),
    "array([100], uint8) + 200": (array([44], uint8), unchanged),
    "array([1], uint8) + 300": (array([301], uint16), Exception),
    "uint8(1) + 300": (int64(301), Exception),
    "uint8(100) + 200": (int64(301), uint8(44)),  # and RuntimeWarning
    "float32(1) + 3e100": (float64(3e100), float32(inf)),  # and RuntimeWarning [T7]
    "array([1.0], float32) + 1e-14 == 1.0": (array([True]), unchanged),
    "array([0.1], float32) == float64(0.1)": (array([True]), array([False])),
    "array(1.0, float32) + 1e-14 == 1.0": (array(False), array(True)),
    "array([1.], float32) + 3": (array([4.0], float32), unchanged),
    "array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)),
    "3j + array(3, complex64)": (array(3 + 3j, complex128), array(3 + 3j, complex64)),
    "float32(1) + 1j": (array(1 + 1j, complex128), array(1 + 1j, complex64)),
    "int32(1) + 5j": (array(1 + 5j, complex128), unchanged),
    # additional examples from the NEP text
    "int16(2) + 2": (int64(4), int16(4)),
    "int16(4) + 4j": (complex128(4 + 4j), unchanged),
    "float32(5) + 5j": (complex128(5 + 5j), complex64(5 + 5j)),
    "bool_(True) + 1": (int64(2), unchanged),
    "True + uint8(2)": (uint8(3), unchanged),
}


@skipif(not HAVE_NUMPY, reason="NumPy not found")
@instantiate_parametrized_tests
class TestNEP50Table(TestCase):
    @parametrize("example", examples)
    def test_nep50_exceptions(self, example):
        old, new = examples[example]

        if new == Exception:
            with assert_raises(OverflowError):
                eval(example)

        else:
            result = eval(example)

            if new is unchanged:
                new = old

            assert_allclose(result, new, atol=1e-16)
            assert result.dtype == new.dtype


# ### Directly compare to numpy ###

weaks = (True, 1, 2.0, 3j)
non_weaks = (
    tnp.asarray(True),
    tnp.uint8(1),
    tnp.int8(1),
    tnp.int32(1),
    tnp.int64(1),
    tnp.float32(1),
    tnp.float64(1),
    tnp.complex64(1),
    tnp.complex128(1),
)
if HAVE_NUMPY:
    dtypes = (
        None,
        _np.bool_,
        _np.uint8,
        _np.int8,
        _np.int32,
        _np.int64,
        _np.float32,
        _np.float64,
        _np.complex64,
        _np.complex128,
    )
else:
    dtypes = (None,)


# ufunc name: [array.dtype]
corners = {
    "true_divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
    "divide": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
    "arctan2": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
    "copysign": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
    "heaviside": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
    "ldexp": ["bool_", "uint8", "int8", "int16", "int32", "int64"],
    "power": ["uint8"],
    "nextafter": ["float32"],
}


@skipif(not HAVE_NUMPY, reason="NumPy not found")
@instantiate_parametrized_tests
class TestCompareToNumpy(TestCase):
    @parametrize("scalar, array, dtype", itertools.product(weaks, non_weaks, dtypes))
    def test_direct_compare(self, scalar, array, dtype):
        # compare to NumPy w/ NEP 50.
        try:
            state = _np._get_promotion_state()
            _np._set_promotion_state("weak")

            if dtype is not None:
                kwargs = {"dtype": dtype}
            try:
                result_numpy = _np.add(scalar, array.tensor.numpy(), **kwargs)
            except Exception:
                return

            kwargs = {}
            if dtype is not None:
                kwargs = {"dtype": getattr(tnp, dtype.__name__)}
            result = tnp.add(scalar, array, **kwargs).tensor.numpy()
            assert result.dtype == result_numpy.dtype
            assert result == result_numpy

        finally:
            _np._set_promotion_state(state)

    @parametrize("name", tnp._ufuncs._binary)
    @parametrize("scalar, array", itertools.product(weaks, non_weaks))
    def test_compare_ufuncs(self, name, scalar, array):
        if name in corners and (
            array.dtype.name in corners[name]
            or tnp.asarray(scalar).dtype.name in corners[name]
        ):
            raise SkipTest(f"{name}(..., dtype=array.dtype)")

        try:
            state = _np._get_promotion_state()
            _np._set_promotion_state("weak")

            if name in ["matmul", "modf", "divmod", "ldexp"]:
                return
            ufunc = getattr(tnp, name)
            ufunc_numpy = getattr(_np, name)

            try:
                result = ufunc(scalar, array)
            except RuntimeError:
                # RuntimeError: "bitwise_xor_cpu" not implemented for 'ComplexDouble' etc
                result = None

            try:
                result_numpy = ufunc_numpy(scalar, array.tensor.numpy())
            except TypeError:
                # TypeError: ufunc 'hypot' not supported for the input types
                result_numpy = None

            if result is not None and result_numpy is not None:
                assert result.tensor.numpy().dtype == result_numpy.dtype

        finally:
            _np._set_promotion_state(state)


if __name__ == "__main__":
    run_tests()