1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
|
# Owner(s): ["oncall: jit"]
import os
import sys
import inspect
import unittest
from typing import Dict, List
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestBuiltins(JitTestCase):
"""
Tests for TorchScript support of Python builtin functions.
"""
def test_has_attr(self):
class HasA(torch.nn.Module):
def __init__(self):
super(HasA, self).__init__()
self.a = 0
class HasB(torch.nn.Module):
def __init__(self):
super(HasB, self).__init__()
self.b = 1
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.mods = torch.nn.ModuleList([HasA(), HasB()])
def forward(self):
# use a list to encode hasattr results
l = torch.jit.annotate(List[int], [])
for mod in self.mods:
l.append(int(hasattr(mod, "a")))
l.append(int(hasattr(mod, "b")))
# actually retrieve the attr to test static refinement
if hasattr(mod, "a"):
l.append(mod.a)
if hasattr(mod, "b"):
l.append(mod.b)
return l
self.checkModule(Mod(), ())
def test_has_attr_invalid_args(self):
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.mod = torch.nn.Linear(1, 1)
def forward(self, name):
# not allowed, `name` must be static.
return hasattr(self.mod, name)
with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"):
torch.jit.script(Mod())
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
def forward(self, name):
# not allowed, `torch.rand` is not a class type
return hasattr(torch.rand(2, 3), name)
with self.assertRaisesRegexWithHighlight(RuntimeError, "hasattr", "name"):
torch.jit.script(Mod())
def test_del(self):
def fn(x: List[int]) -> List[int]:
a = x * 2
del a
return x
self.checkScript(fn, ([1, 2, 3],))
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script
def fn(x):
a = x ** 2
del a
return a
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script
def fn(x):
a = x ** 2
if a:
del a
return a
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"):
@torch.jit.script
def fn(x):
a = x ** 2
del b
return a
def test_del_multiple_operands(self):
def fn(x: List[int]) -> List[int]:
a, b, c = x[0], x[1], x[2]
del a, b, c
return x
self.checkScript(fn, ([1, 2, 3],))
def del_list_multiple_operands(x: List[int]) -> List[int]:
del x[0], x[1]
return x
py_out = del_list_multiple_operands([0, 1, 2])
jit_out = torch.jit.script(del_list_multiple_operands)([0, 1, 2])
self.assertEqual(py_out, jit_out)
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
del x['hi'], x['there']
return x
py_out = del_dict_multiple_operands({"hi": 5, "there": 6})
jit_out = torch.jit.script(del_dict_multiple_operands)({"hi": 5, "there": 6})
self.assertEqual(py_out, jit_out)
class TestTensorBuiltins(JitTestCase):
def test_tensor_properties(self):
def should_keep(tensor, name):
if inspect.isroutine(getattr(tensor, name)):
return False
if name.startswith('_'):
return False
return True
tensor = torch.arange(4, dtype=torch.float).view(2, 2)
keys = dir(tensor)
# real and imag are only implemented for complex tensors.
self.assertRaises(RuntimeError, lambda: should_keep(tensor, 'imag'))
keys.remove('imag')
properties = [p for p in keys if should_keep(tensor, p)]
code_template = """
def fn(x):
return x.{}
"""
EQUALITY_MISMATCH = set([
# TorchScript doesn't have real enums so they return an int instead
# of the actual value
'dtype',
'layout',
])
MISSING_PROPERTIES = set([
'grad_fn',
# This is an undocumented property so it's not included
"output_nr",
# This has a longer implementation, maybe not worth copying to
# TorchScript if named tensors don't work there anyways
'names',
])
for p in properties:
if p in MISSING_PROPERTIES:
continue
code = code_template.format(p)
cu = torch.jit.CompilationUnit()
cu.define(code)
if p in EQUALITY_MISMATCH:
continue
self.assertEqual(getattr(tensor, p), cu.fn(tensor))
def test_tensor_subscript_assign(self):
def fn1(x):
a = torch.zeros_like(x, dtype=torch.uint8)
a[torch.tensor(0)] = torch.tensor(2, dtype=torch.uint8)
return a
def fn2(x):
a = torch.zeros_like(x, dtype=torch.uint8)
a[0] = 2
return a
def fn3(x):
a = torch.zeros_like(x, dtype=torch.uint8)
a[torch.tensor(0)] = 2
return a
def fn4(x):
a = torch.zeros_like(x, dtype=torch.uint8)
a[0] = torch.tensor(2, dtype=torch.uint8)
return a
def fn5(x):
a = torch.zeros_like(x, dtype=torch.float32)
a[torch.tensor(0)] = 2
return a
for fn in (fn1, fn2, fn3, fn4, fn5):
self.checkScript(fn, (torch.zeros(2, dtype=torch.uint8),))
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_tensor_subscript_assign_device(self):
def fn6(x):
a = torch.zeros_like(x, dtype=torch.float32, device="cuda")
a[torch.tensor(0)] = 2
return a
self.checkScript(fn6, (torch.zeros(2, dtype=torch.float32, device="cuda"),))
def test_tensor_item(self):
def test_scalar_cast(x):
scalar = x.item()
return int(scalar), float(scalar)
graph = torch.jit.script(test_scalar_cast).graph
FileCheck().check("(int, float) = prim::TupleConstruct").run(graph)
self.checkScript(test_scalar_cast, (torch.tensor(1.0),))
self.checkScript(test_scalar_cast, (torch.tensor(1),))
def test_method_on_number(self):
def func():
c = 1
return c.add(1)
with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
torch.jit.script(func)
# testing implicit conversion of tensors to scalars to match function arguments
def test_scalar_to_num_conversions(self):
@torch.jit.script
def multiple_defs(x):
c = 1
x = x + c
return x
self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
@torch.jit.script
def tensor_to_int_script(x, tensor):
return x.unsqueeze(tensor)
# location present in error message
with self.assertRaisesRegex(RuntimeError, "x.unsqueeze"):
tensor_to_int_script(torch.tensor([2]), torch.tensor([2, 2]))
def tensor_to_int(x, tensor):
return x.unsqueeze(tensor)
@torch.jit.script
def tensor_to_float_script(x, tensor):
return x.addcmul(tensor, tensor, value=tensor)
def tensor_to_float(x, tensor):
return x.addcmul(tensor, tensor, value=tensor)
x = torch.zeros(10)
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
tensors = [torch.tensor(1.1),
torch.tensor(1.1, requires_grad=True),
torch.tensor(0),
torch.tensor([2])]
script_funs = [tensor_to_int_script, tensor_to_float_script]
funs = [tensor_to_int, tensor_to_float]
# return the result, or whether exception was thrown
def test_func(func, x, tensor):
try:
result = func(x, tensor)
except RuntimeError as e:
result = True
except TypeError as e:
result = True
return result
# assert result or exception equal for each (function, inputs)
for tensor in tensors:
for i in range(len(script_funs)):
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
|