1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
# Owner(s): ["oncall: jit"]
from torch.testing._internal.common_utils import TestCase
import torch
from torch import nn
r"""
Test TorchScript exception handling.
"""
class TestException(TestCase):
def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=5)
@torch.jit.script_method
def forward(self, x):
return self.conv(x)
foo = Foo()
# testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
foo(torch.ones([123])) # wrong size
def test_builtin_error_messsage(self):
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
"supported in TorchScript"):
@torch.jit.script
def unknown_op(x):
torch.set_anomaly_enabled(True)
return x
def test_exceptions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
if bool(cond):
raise ValueError(3)
return 1
''')
cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "3"):
cu.foo(torch.tensor(1))
def foo(cond):
a = 3
if bool(cond):
raise ArbitraryError(a, "hi")
if 1 == 2:
raise ArbitraryError
return a
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
torch.jit.script(foo)
def exception_as_value():
a = Exception()
print(a)
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
torch.jit.script(exception_as_value)
@torch.jit.script
def foo_no_decl_always_throws():
raise RuntimeError("Hi")
# function that has no declared type but always throws set to None
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "NoneType")
@torch.jit.script
def foo_decl_always_throws():
# type: () -> Tensor
raise Exception("Hi")
output_type = next(foo_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "Tensor")
def foo():
raise 3 + 4
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
torch.jit.script(foo)
# a escapes scope
@torch.jit.script
def foo():
if 1 == 1:
a = 1
else:
if 1 == 1:
raise Exception("Hi")
else:
raise Exception("Hi")
return a
self.assertEqual(foo(), 1)
@torch.jit.script
def tuple_fn():
raise RuntimeError("hello", "goodbye")
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
tuple_fn()
@torch.jit.script
def no_message():
raise RuntimeError
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
no_message()
def test_assertions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
assert bool(cond), "hi"
return 0
''')
cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
cu.foo(torch.tensor(0))
@torch.jit.script
def foo(cond):
assert bool(cond), "hi"
foo(torch.tensor(1))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
foo(torch.tensor(0))
def test_python_op_exception(self):
@torch.jit.ignore
def python_op(x):
raise Exception("bad!")
@torch.jit.script
def fn(x):
return python_op(x)
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
fn(torch.tensor(4))
def test_dict_expansion_raises_error(self):
def fn(self):
d = {"foo": 1, "bar": 2, "baz": 3}
return {**d}
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
"Dict expansion "):
torch.jit.script(fn)
def test_custom_python_exception(self):
class MyValueError(ValueError):
def __init__(self, msg):
super(MyValueError, self).__init__(msg)
@torch.jit.script
def fn():
raise MyValueError("test custom exception")
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
fn()
def test_custom_python_exception_defined_elsewhere(self):
from jit.myexception import MyKeyError
@torch.jit.script
def fn():
raise MyKeyError("This is a user defined key error")
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
fn()
|