1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
|
# Owner(s): ["oncall: jit"]
import os
import sys
from textwrap import dedent
import torch
from torch.testing._internal import jit_utils
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests various JIT-related utility functions.
class TestJitUtils(JitTestCase):
# Tests that POSITIONAL_OR_KEYWORD arguments are captured.
def test_get_callable_argument_names_positional_or_keyword(self):
def fn_positional_or_keyword_args_only(x, y):
return x + y
self.assertEqual(
["x", "y"],
torch._jit_internal.get_callable_argument_names(
fn_positional_or_keyword_args_only
),
)
# Tests that POSITIONAL_ONLY arguments are ignored.
def test_get_callable_argument_names_positional_only(self):
code = dedent(
"""
def fn_positional_only_arg(x, /, y):
return x + y
"""
)
fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
self.assertEqual(
["y"],
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
)
# Tests that VAR_POSITIONAL arguments are ignored.
def test_get_callable_argument_names_var_positional(self):
# Tests that VAR_POSITIONAL arguments are ignored.
def fn_var_positional_arg(x, *arg):
return x + arg[0]
self.assertEqual(
["x"],
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
)
# Tests that KEYWORD_ONLY arguments are ignored.
def test_get_callable_argument_names_keyword_only(self):
def fn_keyword_only_arg(x, *, y):
return x + y
self.assertEqual(
["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
)
# Tests that VAR_KEYWORD arguments are ignored.
def test_get_callable_argument_names_var_keyword(self):
def fn_var_keyword_arg(**args):
return args["x"] + args["y"]
self.assertEqual(
[], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
)
# Tests that a function signature containing various different types of
# arguments are ignored.
def test_get_callable_argument_names_hybrid(self):
code = dedent(
"""
def fn_hybrid_args(x, /, y, *args, **kwargs):
return x + y + args[0] + kwargs['z']
"""
)
fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
self.assertEqual(
["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
)
def test_checkscriptassertraisesregex(self):
def fn():
tup = (1, 2)
return tup[2]
self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
s = dedent(
"""
def fn():
tup = (1, 2)
return tup[2]
"""
)
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
def test_no_tracer_warn_context_manager(self):
torch._C._jit_set_tracer_state_warn(True)
with jit_utils.NoTracerWarnContextManager() as no_warn:
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
|