File: test_jit_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (108 lines) | stat: -rw-r--r-- 3,959 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Owner(s): ["oncall: jit"]

import os
import sys
from textwrap import dedent
import unittest

import torch

from torch.testing._internal import jit_utils

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_jit.py TESTNAME\n\n"
                       "instead.")

# Tests various JIT-related utility functions.
class TestJitUtils(JitTestCase):
    # Tests that POSITIONAL_OR_KEYWORD arguments are captured.
    def test_get_callable_argument_names_positional_or_keyword(self):
        def fn_positional_or_keyword_args_only(x, y):
            return x + y
        self.assertEqual(
            ["x", "y"],
            torch._jit_internal.get_callable_argument_names(fn_positional_or_keyword_args_only))

    # Tests that POSITIONAL_ONLY arguments are ignored.
    @unittest.skipIf(sys.version_info < (3, 8), 'POSITIONAL_ONLY arguments are not supported before 3.8')
    def test_get_callable_argument_names_positional_only(self):
        code = dedent('''
            def fn_positional_only_arg(x, /, y):
                return x + y
        ''')

        fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg')
        self.assertEqual(
            [],
            torch._jit_internal.get_callable_argument_names(fn_positional_only_arg))

    # Tests that VAR_POSITIONAL arguments are ignored.
    def test_get_callable_argument_names_var_positional(self):
        # Tests that VAR_POSITIONAL arguments are ignored.
        def fn_var_positional_arg(x, *arg):
            return x + arg[0]
        self.assertEqual(
            [],
            torch._jit_internal.get_callable_argument_names(fn_var_positional_arg))

    # Tests that KEYWORD_ONLY arguments are ignored.
    def test_get_callable_argument_names_keyword_only(self):
        def fn_keyword_only_arg(x, *, y):
            return x + y
        self.assertEqual(
            [],
            torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg))

    # Tests that VAR_KEYWORD arguments are ignored.
    def test_get_callable_argument_names_var_keyword(self):
        def fn_var_keyword_arg(**args):
            return args['x'] + args['y']
        self.assertEqual(
            [],
            torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg))

    # Tests that a function signature containing various different types of
    # arguments are ignored.
    @unittest.skipIf(sys.version_info < (3, 8), 'POSITIONAL_ONLY arguments are not supported before 3.8')
    def test_get_callable_argument_names_hybrid(self):
        code = dedent('''
            def fn_hybrid_args(x, /, y, *args, **kwargs):
                return x + y + args[0] + kwargs['z']
        ''')
        fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args')
        self.assertEqual(
            [],
            torch._jit_internal.get_callable_argument_names(fn_hybrid_args))

    def test_checkscriptassertraisesregex(self):
        def fn():
            tup = (1, 2)
            return tup[2]

        self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")

        s = dedent("""
        def fn():
            tup = (1, 2)
            return tup[2]
        """)

        self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")

    def test_no_tracer_warn_context_manager(self):
        torch._C._jit_set_tracer_state_warn(True)
        with jit_utils.NoTracerWarnContextManager() as no_warn:
            self.assertEqual(
                False,
                torch._C._jit_get_tracer_state_warn()
            )
        self.assertEqual(
            True,
            torch._C._jit_get_tracer_state_warn()
        )