File: test_misc.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (505 lines) | stat: -rw-r--r-- 16,160 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
# Owner(s): ["oncall: jit"]

import os
import sys
import unittest
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.testing._internal.jit_utils
from jit.test_module_interface import TestModuleInterface  # noqa: F401
from torch import jit
from torch.testing import FileCheck
from torch.testing._internal.common_utils import freeze_rng_state
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestMisc(JitTestCase):
    def test_joined_str(self):
        def func(x):
            hello, test = "Hello", "test"
            print(f"{hello + ' ' + test}, I'm a {test}")
            print("format blank")
            hi = "hi"
            print(f"stuff before {hi}")
            print(f"{hi} stuff after")
            return x + 1

        x = torch.arange(4.0, requires_grad=True)
        # TODO: Add support for f-strings in string parser frontend
        # self.checkScript(func, [x], optimize=True, capture_output=True)

        with self.capture_stdout() as captured:
            out = func(x)

        scripted = torch.jit.script(func)
        with self.capture_stdout() as captured_script:
            out_script = func(x)

        self.assertEqual(out, out_script)
        self.assertEqual(captured, captured_script)

    def test_kwarg_support(self):
        with self.assertRaisesRegex(
            torch.jit.frontend.NotSupportedError, "variable number of arguments"
        ):

            class M(torch.nn.Module):
                def forward(self, *, n_tokens: int, device_name: str = 2):
                    pass

            torch.jit.script(M())

        class M(torch.nn.Module):
            def forward(self, *, n_tokens: int, device_name: str):
                return n_tokens, device_name

        sm = torch.jit.script(M())

        with self.assertRaisesRegex(
            RuntimeError, "missing value for argument 'n_tokens'"
        ):
            sm()

        with self.assertRaisesRegex(RuntimeError, "positional arg"):
            sm(3, "hello")

        self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))

    def test_tuple_subscripted_assign(self):
        with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):

            @torch.jit.script
            def foo(a: Tuple[int, int]) -> None:
                a[0] = a[1]

        with self.assertRaisesRegex(RuntimeError, "augmented assignment"):

            @torch.jit.script
            def bar(a: Tuple[int, int]) -> None:
                a[0] += a[1]

    def test_subexpression_List_Future(self):
        @torch.jit.script
        def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
            return x[0]

        FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)

    def test_subexpression_Future_annotate(self):
        @torch.jit.script
        def fn() -> torch.jit.Future[int]:
            x: List[torch.jit.Future[int]] = []
            return x[0]

        FileCheck().check("Future[int][]").run(fn.graph)

    def test_future_isinstance(self):
        @torch.jit.script
        def fn(x: Any) -> torch.jit.Future[int]:
            assert isinstance(x, jit.Future[int])
            return x

        FileCheck().check("Future[int]").run(fn.graph)

    def test_str_refine_any(self):
        def forward(x: Any) -> str:
            if isinstance(x, str):
                return x
            return "foo"

        forward = torch.jit.script(forward)
        self.assertEqual(forward(1), "foo")
        self.assertEqual(forward("bar"), "bar")

    def test_subexpression_Tuple_int_int_Future(self):
        @torch.jit.script
        def fn(
            x: Tuple[int, int, torch.jit.Future[int]]
        ) -> Tuple[int, torch.jit.Future[int]]:
            return x[0], x[2]

        FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
            fn.graph
        )

    def test_subexpression_Dict_int_Future(self):
        @torch.jit.script
        def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
            return x[y]

        FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)

    def test_subexpression_Optional(self):
        @torch.jit.script
        def fn(
            x: Optional[Dict[int, torch.jit.Future[int]]]
        ) -> Optional[torch.jit.Future[int]]:
            if x is not None:
                return x[0]
            else:
                return None

        FileCheck().check("Dict(int, Future(int))?").run(fn.graph)

    def test_if_returning_any(self):
        """
        Check that an if statement can return different
        types early from each branch when the return
        type of the function is Any.
        """

        def if_function(inp: torch.Tensor) -> Any:
            if inp.shape[0] == 1:
                return inp * inp
            else:
                return "str"

        self.checkScript(if_function, (torch.randn(5),))

    def test_hacked_twin(self):
        def gen_data():
            with freeze_rng_state():
                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)

        (
            input,
            index,
            value,
        ) = gen_data()
        (
            input1,
            index1,
            value1,
        ) = gen_data()
        out1 = torch.ops.aten.index_put.hacked_twin(
            input, [index], value, accumulate=False
        )
        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
        self.assertEqual(out1, out2)

        torch.ops.aten.index_put_.hacked_twin(input, [index], value, accumulate=False)
        torch.index_put_(input1, [index1], value1, accumulate=False)
        self.assertEqual(input, input1)

    def test_unsafe_hacked_twin(self):
        def gen_data():
            with freeze_rng_state():
                return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)

        (
            input,
            index,
            value,
        ) = gen_data()
        (
            input1,
            index1,
            value1,
        ) = gen_data()
        out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
            input, [index], value, accumulate=False
        )
        out2 = torch.index_put(input1, [index1], value1, accumulate=False)
        self.assertEqual(out1, out2)

        torch.ops.aten._unsafe_index.Tensor_hacked_twin(input, [index])
        torch.index_put(input1, [index1], value1, accumulate=False)
        self.assertEqual(input, input1)

        def index_put_fn(input, index, value):
            return torch.ops.aten._unsafe_index_put(
                input, [index], value, accumulate=False
            )

        input2, index2, value2 = gen_data()
        script_index_put_fn = torch.jit.script(index_put_fn)
        expect = index_put_fn(input2.clone(), index2, value2)
        actual = script_index_put_fn(input2.clone(), index2, value2)
        self.assertEqual(expect, actual)

        def index_fn(input, index, value):
            return torch.ops.aten._unsafe_index_put(
                input, [index], value, accumulate=False
            )

        script_index_fn = torch.jit.script(index_fn)
        expect = index_fn(input2.clone(), index2, value2)
        actual = script_index_fn(input2.clone(), index2, value2)
        self.assertEqual(expect, actual)

    def test_export_opnames_interface(self):
        @torch.jit.interface
        class OneTwoModule(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                pass

            def two(self, x: torch.Tensor) -> torch.Tensor:
                pass

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                pass

        class FooMod(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y

            def two(self, x: torch.Tensor) -> torch.Tensor:
                return 2 * x

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.one(self.two(x), x)

        class BarMod(nn.Module):
            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x * y

            def two(self, x: torch.Tensor) -> torch.Tensor:
                return 2 / x

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.two(self.one(x, x))

        make_global(OneTwoModule)

        class M(nn.Module):
            sub: OneTwoModule

            def __init__(self) -> None:
                super().__init__()
                self.sub = BarMod()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.sub.forward(x)

        def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
            return mod_list[0].forward(x) + mod_list[1].forward(x)

        torch._C._enable_mobile_interface_call_export()
        scripted_M_mod = torch.jit.script(M())
        self.assertTrue(
            {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
                set(torch.jit.export_opnames(scripted_M_mod))
            )
        )

        scripted_M_mod.sub = torch.jit.script(FooMod())
        self.assertTrue(
            {"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
                set(torch.jit.export_opnames(scripted_M_mod))
            )
        )

    def test_math_inf(self):
        from math import inf

        def foo():
            return inf

        self.checkScript(foo, ())

    def test_list_literal_infer(self):
        def expects_intlist(x: List[int]):
            x.append(3)
            return x

        def foo():
            return expects_intlist([])

        self.checkScript(foo, ())

        def annotated_list_fail():
            return expects_intlist(torch.jit.annotate([], List[Tensor]))  # noqa: F821

        with self.assertRaises(RuntimeError):
            torch.jit.script(annotated_list_fail)

        def non_temporary_fail():
            a = []
            return expects_intlist(a)

        with self.assertRaises(RuntimeError):
            torch.jit.script(non_temporary_fail)

        @torch.jit.script
        def test_return():
            return []

        FileCheck().check("Tensor[] = prim::ListConstruct").run(test_return.graph)

    def test_legacy_tensor_constructor(self):
        # testing PyObject overload
        def test_all_dtypes():
            return (
                torch.BoolTensor([2]),
                torch.LongTensor([3]),
                torch.ByteTensor([4]),
                torch.CharTensor([5]),
                torch.DoubleTensor([6]),
                torch.FloatTensor([7]),
                torch.IntTensor([8]),
                torch.ShortTensor([1]),
                torch.HalfTensor([1]),
            )

        self.checkScript(test_all_dtypes, ())

        # now test empty overload
        def empty_overload():
            return torch.LongTensor(2, 3, 4)

        eager = empty_overload()
        jit = torch.jit.script(empty_overload)()
        eager[:] = 1
        jit[:] = 1
        self.assertEqual(eager, jit)

        def no_inputs():
            return torch.DoubleTensor()

        self.checkScript(no_inputs, ())

        # bad schema
        def multiple_args():
            return torch.LongTensor(1, [2])

        with self.assertRaisesRegex(
            RuntimeError, "multiple positional arguments that were not all integers"
        ):
            torch.jit.script(multiple_args)

        # kwarg bad schema
        def bad_kwarg():
            return torch.LongTensor(hello="1")

        with self.assertRaisesRegex(RuntimeError, "hello"):
            torch.jit.script(bad_kwarg)

    def test_broadcasting_list(self):
        """
        Test BroadcastingList and torch.nn._size_N_t alias
        """
        from torch._jit_internal import BroadcastingList2
        from torch.nn.common_types import _size_2_t

        def sum_i(x: _size_2_t) -> int:
            return x[0] + x[1]

        def sum_f(x: BroadcastingList2[float]) -> float:
            return x[0] + x[1]

        self.assertTrue(torch.jit.script(sum_i)(4) == 8)
        self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)

    def test_parse_ir_annotate(self):
        ir = """
        graph():
          %3 : int[] = prim::Constant[value=annotate(List[int], [])]()
          return (%3)
        """
        graph = torch._C.parse_ir(ir, True)
        func = torch._C._create_function_from_graph("forward", graph)
        ret = func()
        self.assertTrue(ret == [])

    def test_parse_ir_single_element_tensor_positive(self):
        ir = """
        graph():
          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={0}]()
          return (%7)
        """
        graph = torch._C.parse_ir(ir, True)
        func = torch._C._create_function_from_graph("forward", graph)
        ret = func()
        self.assertTrue(ret.numel() == 1)
        self.assertTrue(len(ret.size()) == 1)

    def test_parse_ir_single_element_tensor_negative(self):
        ir = """
        graph():
          %7 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={-17}]()
          return (%7)
        """
        graph = torch._C.parse_ir(ir, True)
        func = torch._C._create_function_from_graph("forward", graph)
        ret = func()
        self.assertTrue(ret.numel() == 1)
        self.assertTrue(len(ret.size()) == 1)

    def test_script_many_decorators(self):
        def no_op_decorator(f):
            return f

        @no_op_decorator
        @no_op_decorator
        @no_op_decorator
        @no_op_decorator
        @no_op_decorator
        def foo(x, dim: int):
            return x.unsqueeze(dim)

        x = torch.randn(
            1,
        )
        expected = foo(x, 0)
        scripted = torch.jit.script(foo)
        actual = scripted(x, 0)
        torch.testing.assert_close(expected, actual)

    @unittest.skipIf(not RUN_CUDA_HALF, "need CUDA half support")
    def test_pow_multiple_dtype(self):
        # https://github.com/pytorch/pytorch/issues/75476
        def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
            p = torch.sigmoid(p)
            result = p**gamma
            return result

        x = torch.rand((2, 2), dtype=torch.half, device="cuda")

        ref = fn(x)

        script_fn = torch.jit.script(fn)
        for i in range(4):
            res = script_fn(x)

        self.assertEqual(ref, res)

    def test_jit_get_operation_order(self):
        # See https://github.com/pytorch/pytorch/pull/107138.
        # Depending on order of operator registration, you can get different
        # order of overloads in the JIT operator registry.
        # This is to verify that the order of operators returned by
        # _jit_get_operation always puts aten ops first (i.e. by sorting
        # to put them first)

        # Make sure that this chooses a "scalar" overload not a "complex" overload
        ret = torch.ops.aten.add(4, 3.3)
        self.assertFalse("complex" in str(ret.dtype))

        # "Scalar" overload is a normal aten op; "complex" is added by torchscript.
        # We want "Scalar" to come before "complex".
        op, override_names = torch._C._jit_get_operation("aten::add")
        print(override_names)
        complex_indices = [
            i for i, name in enumerate(override_names) if name == "complex"
        ]
        Scalar_indices = [
            i for i, name in enumerate(override_names) if name == "Scalar"
        ]

        self.assertTrue(len(complex_indices) > 0)
        self.assertTrue(len(Scalar_indices) > 0)
        self.assertTrue(complex_indices[0] > Scalar_indices[0])