1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
|
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
import warnings
from typing import List, Any, Dict, Tuple, Optional
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests for torch.jit.isinstance
class TestIsinstance(JitTestCase):
def test_int(self):
def int_test(x: Any):
assert torch.jit.isinstance(x, int)
assert not torch.jit.isinstance(x, float)
x = 1
self.checkScript(int_test, (x,))
def test_float(self):
def float_test(x: Any):
assert torch.jit.isinstance(x, float)
assert not torch.jit.isinstance(x, int)
x = 1.0
self.checkScript(float_test, (x,))
def test_bool(self):
def bool_test(x: Any):
assert torch.jit.isinstance(x, bool)
assert not torch.jit.isinstance(x, float)
x = False
self.checkScript(bool_test, (x,))
def test_list(self):
def list_str_test(x: Any):
assert torch.jit.isinstance(x, List[str])
assert not torch.jit.isinstance(x, List[int])
assert not torch.jit.isinstance(x, Tuple[int])
x = ["1", "2", "3"]
self.checkScript(list_str_test, (x,))
def test_list_tensor(self):
def list_tensor_test(x: Any):
assert torch.jit.isinstance(x, List[torch.Tensor])
assert not torch.jit.isinstance(x, Tuple[int])
x = [torch.tensor([1]), torch.tensor([2]), torch.tensor([3])]
self.checkScript(list_tensor_test, (x,))
def test_dict(self):
def dict_str_int_test(x: Any):
assert torch.jit.isinstance(x, Dict[str, int])
assert not torch.jit.isinstance(x, Dict[int, str])
assert not torch.jit.isinstance(x, Dict[str, str])
x = {"a": 1, "b": 2}
self.checkScript(dict_str_int_test, (x,))
def test_dict_tensor(self):
def dict_int_tensor_test(x: Any):
assert torch.jit.isinstance(x, Dict[int, torch.Tensor])
x = {2: torch.tensor([2])}
self.checkScript(dict_int_tensor_test, (x,))
def test_tuple(self):
def tuple_test(x: Any):
assert torch.jit.isinstance(x, Tuple[str, int, str])
assert not torch.jit.isinstance(x, Tuple[int, str, str])
assert not torch.jit.isinstance(x, Tuple[str])
x = ("a", 1, "b")
self.checkScript(tuple_test, (x,))
def test_tuple_tensor(self):
def tuple_tensor_test(x: Any):
assert torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor])
x = (torch.tensor([1]), torch.tensor([[2], [3]]))
self.checkScript(tuple_tensor_test, (x,))
def test_optional(self):
def optional_test(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
assert not torch.jit.isinstance(x, Optional[str])
x = torch.ones(3, 3)
self.checkScript(optional_test, (x,))
def test_optional_none(self):
def optional_test_none(x: Any):
assert torch.jit.isinstance(x, Optional[torch.Tensor])
# assert torch.jit.isinstance(x, Optional[str])
# TODO: above line in eager will evaluate to True while in
# the TS interpreter will evaluate to False as the
# first torch.jit.isinstance refines the 'None' type
x = None
self.checkScript(optional_test_none, (x,))
def test_list_nested(self):
def list_nested(x: Any):
assert torch.jit.isinstance(x, List[Dict[str, int]])
assert not torch.jit.isinstance(x, List[List[str]])
x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
self.checkScript(list_nested, (x,))
def test_dict_nested(self):
def dict_nested(x: Any):
assert torch.jit.isinstance(x, Dict[str, Tuple[str, str, str]])
assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
x = {"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")}
self.checkScript(dict_nested, (x,))
def test_tuple_nested(self):
def tuple_nested(x: Any):
assert torch.jit.isinstance(
x, Tuple[Dict[str, Tuple[str, str, str]], List[bool], Optional[str]]
)
assert not torch.jit.isinstance(x, Dict[str, Tuple[int, int, int]])
assert not torch.jit.isinstance(x, Tuple[str])
assert not torch.jit.isinstance(x, Tuple[List[bool], List[str], List[int]])
x = (
{"a": ("aa", "aa", "aa"), "b": ("bb", "bb", "bb")},
[True, False, True],
None,
)
self.checkScript(tuple_nested, (x,))
def test_optional_nested(self):
def optional_nested(x: Any):
assert torch.jit.isinstance(x, Optional[List[str]])
x = ["a", "b", "c"]
self.checkScript(optional_nested, (x,))
def test_list_tensor_type_true(self):
def list_tensor_type_true(x: Any):
assert torch.jit.isinstance(x, List[torch.Tensor])
x = [torch.rand(3, 3), torch.rand(4, 3)]
self.checkScript(list_tensor_type_true, (x,))
def test_tensor_type_false(self):
def list_tensor_type_false(x: Any):
assert not torch.jit.isinstance(x, List[torch.Tensor])
x = [1, 2, 3]
self.checkScript(list_tensor_type_false, (x,))
def test_in_if(self):
def list_in_if(x: Any):
if torch.jit.isinstance(x, List[int]):
assert True
if torch.jit.isinstance(x, List[str]):
assert not True
x = [1, 2, 3]
self.checkScript(list_in_if, (x,))
def test_if_else(self):
def list_in_if_else(x: Any):
if torch.jit.isinstance(x, Tuple[str, str, str]):
assert True
else:
assert not True
x = ("a", "b", "c")
self.checkScript(list_in_if_else, (x,))
def test_in_while_loop(self):
def list_in_while_loop(x: Any):
count = 0
while torch.jit.isinstance(x, List[Dict[str, int]]) and count <= 0:
count = count + 1
assert count == 1
x = [{"a": 1, "b": 2}, {"aa": 11, "bb": 22}]
self.checkScript(list_in_while_loop, (x,))
def test_type_refinement(self):
def type_refinement(obj: Any):
hit = False
if torch.jit.isinstance(obj, List[torch.Tensor]):
hit = not hit
for el in obj:
# perform some tensor operation
y = el.clamp(0, 0.5)
if torch.jit.isinstance(obj, Dict[str, str]):
hit = not hit
str_cat = ""
for val in obj.values():
str_cat = str_cat + val
assert "111222" == str_cat
assert hit
x = [torch.rand(3, 3), torch.rand(4, 3)]
self.checkScript(type_refinement, (x,))
x = {"1": "111", "2": "222"}
self.checkScript(type_refinement, (x,))
def test_list_no_contained_type(self):
def list_no_contained_type(x: Any):
assert torch.jit.isinstance(x, List)
x = ["1", "2", "3"]
err_msg = "Attempted to use List without a contained type. " \
r"Please add a contained type, e.g. List\[int\]"
with self.assertRaisesRegex(RuntimeError, err_msg,):
torch.jit.script(list_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
list_no_contained_type(x)
def test_tuple_no_contained_type(self):
def tuple_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Tuple)
x = ("1", "2", "3")
err_msg = "Attempted to use Tuple without a contained type. " \
r"Please add a contained type, e.g. Tuple\[int\]"
with self.assertRaisesRegex(RuntimeError, err_msg,):
torch.jit.script(tuple_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
tuple_no_contained_type(x)
def test_optional_no_contained_type(self):
def optional_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Optional)
x = ("1", "2", "3")
err_msg = "Attempted to use Optional without a contained type. " \
r"Please add a contained type, e.g. Optional\[int\]"
with self.assertRaisesRegex(RuntimeError, err_msg,):
torch.jit.script(optional_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
optional_no_contained_type(x)
def test_dict_no_contained_type(self):
def dict_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Dict)
x = {"a": "aa"}
err_msg = "Attempted to use Dict without contained types. " \
r"Please add contained type, e.g. Dict\[int, int\]"
with self.assertRaisesRegex(RuntimeError, err_msg,):
torch.jit.script(dict_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
dict_no_contained_type(x)
def test_tuple_rhs(self):
def fn(x: Any):
assert torch.jit.isinstance(x, (int, List[str]))
assert not torch.jit.isinstance(x, (List[float], Tuple[int, str]))
assert not torch.jit.isinstance(x, (List[float], str))
self.checkScript(fn, (2,))
self.checkScript(fn, (["foo", "bar", "baz"],))
def test_nontuple_container_rhs_throws_in_eager(self):
def fn1(x: Any):
assert torch.jit.isinstance(x, [int, List[str]])
def fn2(x: Any):
assert not torch.jit.isinstance(x, {List[str], Tuple[int, str]})
err_highlight = "must be a type or a tuple of types"
with self.assertRaisesRegex(RuntimeError, err_highlight):
fn1(2)
with self.assertRaisesRegex(RuntimeError, err_highlight):
fn2(2)
def test_empty_container_throws_warning_in_eager(self):
def fn(x: Any):
torch.jit.isinstance(x, List[int])
with warnings.catch_warnings(record=True) as w:
x: List[int] = []
fn(x)
self.assertEqual(len(w), 1)
with warnings.catch_warnings(record=True) as w:
x: int = 2
fn(x)
self.assertEqual(len(w), 0)
def test_empty_container_special_cases(self):
# Should not throw "Boolean value of Tensor with no values is
# ambiguous" error
torch._jit_internal.check_empty_containers(torch.Tensor([]))
# Should not throw "Boolean value of Tensor with more than
# one value is ambiguous" error
torch._jit_internal.check_empty_containers(torch.rand(2, 3))
|