1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
|
# Owner(s): ["oncall: jit"]
import io
import os
import sys
import copy
import unittest
import torch
from typing import Optional
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
find_library_location,
)
from torch.testing import FileCheck
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTorchbind(JitTestCase):
def setUp(self):
if IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
lib_file_path = find_library_location('libtorchbind_test.so')
torch.ops.load_library(str(lib_file_path))
def test_torchbind(self):
def test_equality(f, cmp_key):
obj1 = f()
obj2 = torch.jit.script(f)()
return (cmp_key(obj1), cmp_key(obj2))
def f():
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo')
def f():
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop()
test_equality(f, lambda x: x)
def f():
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop())
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
# test nn module with prepare_scriptable function
class NonJitableClass(object):
def __init__(self, int1, int2):
self.int1 = int1
self.int2 = int2
def return_vals(self):
return self.int1, self.int2
class CustomWrapper(torch.nn.Module):
def __init__(self, foo):
super(CustomWrapper, self).__init__()
self.foo = foo
def forward(self) -> None:
self.foo.increment(1)
return
def __prepare_scriptable__(self):
int1, int2 = self.foo.return_vals()
foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
return CustomWrapper(foo)
foo = CustomWrapper(NonJitableClass(1, 2))
jit_foo = torch.jit.script(foo)
def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
def foo(stackstring):
# type: (StackString)
stackstring.push("lel")
return stackstring
script_input = torch.classes._TorchScriptTesting._StackString([])
scripted = torch.jit.script(foo)
script_output = scripted(script_input)
self.assertEqual(script_output.pop(), "lel")
def test_torchbind_return_instance(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
return ss
scripted = torch.jit.script(foo)
# Ensure we are creating the object and calling __init__
# rather than calling the __init__wrapper nonsense
fc = FileCheck().check('prim::CreateObject()')\
.check('prim::CallMethod[name="__init__"]')
fc.run(str(scripted.graph))
out = scripted()
self.assertEqual(out.pop(), "mom")
self.assertEqual(out.pop(), "hi")
def test_torchbind_return_instance_from_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
clone = ss.clone()
ss.pop()
return ss, clone
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out[0].pop(), "hi")
self.assertEqual(out[1].pop(), "mom")
self.assertEqual(out[1].pop(), "hi")
def test_torchbind_def_property_getter_setter(self):
def foo_getter_setter_full():
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
# getX method intentionally adds 2 to x
old = fooGetterSetter.x
# setX method intentionally adds 2 to x
fooGetterSetter.x = old + 4
new = fooGetterSetter.x
return old, new
self.checkScript(foo_getter_setter_full, ())
def foo_getter_setter_lambda():
foo = torch.classes._TorchScriptTesting._FooGetterSetterLambda(5)
old = foo.x
foo.x = old + 4
new = foo.x
return old, new
self.checkScript(foo_getter_setter_lambda, ())
def test_torchbind_def_property_just_getter(self):
def foo_just_getter():
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
# getY method intentionally adds 4 to x
return fooGetterSetter, fooGetterSetter.y
scripted = torch.jit.script(foo_just_getter)
out, result = scripted()
self.assertEqual(result, 10)
with self.assertRaisesRegex(RuntimeError, 'can\'t set attribute'):
out.y = 5
def foo_not_setter():
fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
old = fooGetterSetter.y
fooGetterSetter.y = old + 4
# getY method intentionally adds 4 to x
return fooGetterSetter.y
with self.assertRaisesRegexWithHighlight(RuntimeError,
'Tried to set read-only attribute: y',
'fooGetterSetter.y = old + 4'):
scripted = torch.jit.script(foo_not_setter)
def test_torchbind_def_property_readwrite(self):
def foo_readwrite():
fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
old = fooReadWrite.x
fooReadWrite.x = old + 4
return fooReadWrite.x, fooReadWrite.y
self.checkScript(foo_readwrite, ())
def foo_readwrite_error():
fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
fooReadWrite.y = 5
return fooReadWrite
with self.assertRaisesRegexWithHighlight(RuntimeError,
'Tried to set read-only attribute: y',
'fooReadWrite.y = 5'):
scripted = torch.jit.script(foo_readwrite_error)
def test_torchbind_take_instance_as_method_arg(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out.pop(), "hi")
self.assertEqual(out.pop(), "mom")
def test_torchbind_return_tuple(self):
def f():
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
return val.return_a_tuple()
scripted = torch.jit.script(f)
tup = scripted()
self.assertEqual(tup, (1337.0, 123))
def test_torchbind_save_load(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
self.getExportImportCopy(scripted)
def test_torchbind_lambda_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
return ss.top()
scripted = torch.jit.script(foo)
self.assertEqual(scripted(), "mom")
def test_torchbind_class_attr_recursive(self):
class FooBar(torch.nn.Module):
def __init__(self, foo_model):
super(FooBar, self).__init__()
self.foo_mod = foo_model
def forward(self) -> int:
return self.foo_mod.info()
def to_ivalue(self):
torchbind_model = torch.classes._TorchScriptTesting._Foo(self.foo_mod.info(), 1)
return FooBar(torchbind_model)
inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
scripted = torch.jit.script(inst.to_ivalue())
self.assertEqual(scripted(), 6)
def test_torchbind_class_attribute(self):
class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
def forward(self):
return self.f.top()
inst = FooBar1234()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
assert eic() == "deserialized"
for expected in ["deserialized", "was", "i"]:
assert eic.f.pop() == expected
def test_torchbind_getstate(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
# return {1, 3, 3, 7}. I tried to make this actually depend on the
# values at instantiation in the test with some transformation, but
# because it seems we serialize/deserialize multiple times, that
# transformation isn't as you would it expect it to be.
assert eic() == 7
for expected in [7, 3, 3, 1]:
assert eic.f.pop() == expected
def test_torchbind_deepcopy(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
copied = copy.deepcopy(scripted)
assert copied.forward() == 7
for expected in [7, 3, 3, 1]:
assert copied.f.pop() == expected
def test_torchbind_python_deepcopy(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
copied = copy.deepcopy(inst)
assert copied() == 7
for expected in [7, 3, 3, 1]:
assert copied.f.pop() == expected
def test_torchbind_tracing(self):
class TryTracing(torch.nn.Module):
def __init__(self):
super(TryTracing, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
traced = torch.jit.trace(TryTracing(), ())
self.assertEqual(torch.zeros(4, 4), traced())
def test_torchbind_pass_wrong_type(self):
with self.assertRaisesRegex(RuntimeError, 'but instead found type \'Tensor\''):
torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
def test_torchbind_tracing_nested(self):
class TryTracingNest(torch.nn.Module):
def __init__(self):
super(TryTracingNest, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
class TryTracing123(torch.nn.Module):
def __init__(self):
super(TryTracing123, self).__init__()
self.nest = TryTracingNest()
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
traced = torch.jit.trace(TryTracing123(), ())
self.assertEqual(torch.zeros(4, 4), traced())
def test_torchbind_pickle_serialization(self):
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
b = io.BytesIO()
torch.save(nt, b)
b.seek(0)
nt_loaded = torch.load(b)
for exp in [7, 3, 3, 1]:
self.assertEqual(nt_loaded.pop(), exp)
def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
torch.classes.foo.IDontExist(3, 4, 5)
def test_torchbind_optional_explicit_attr(self):
class TorchBindOptionalExplicitAttr(torch.nn.Module):
foo : Optional[torch.classes._TorchScriptTesting._StackString]
def __init__(self):
super().__init__()
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
def forward(self) -> str:
foo_obj = self.foo
if foo_obj is not None:
return foo_obj.pop()
else:
return '<None>'
mod = TorchBindOptionalExplicitAttr()
scripted = torch.jit.script(mod)
def test_torchbind_no_init(self):
with self.assertRaisesRegex(RuntimeError, 'torch::init'):
x = torch.classes._TorchScriptTesting._NoInit()
def test_profiler_custom_op(self):
inst = torch.classes._TorchScriptTesting._PickleTester([3, 4])
with torch.autograd.profiler.profile() as prof:
torch.ops._TorchScriptTesting.take_an_instance(inst)
found_event = False
for e in prof.function_events:
if e.name == '_TorchScriptTesting::take_an_instance':
found_event = True
self.assertTrue(found_event)
def test_torchbind_getattr(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
self.assertEqual(None, getattr(foo, 'bar', None))
def test_torchbind_attr_exception(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
foo.bar
def test_lambda_as_constructor(self):
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
self.assertEqual(obj_no_swap.diff(), 1)
obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
self.assertEqual(obj_swap.diff(), -1)
def test_staticmethod(self):
def fn(inp: int) -> int:
return torch.classes._TorchScriptTesting._StaticMethod.staticMethod(inp)
self.checkScript(fn, (1,))
def test_default_args(self):
def fn() -> int:
obj = torch.classes._TorchScriptTesting._DefaultArgs()
obj.increment(5)
obj.decrement()
obj.decrement(2)
obj.divide()
obj.scale_add(5)
obj.scale_add(3, 2)
obj.divide(3)
return obj.increment()
self.checkScript(fn, ())
def gn() -> int:
obj = torch.classes._TorchScriptTesting._DefaultArgs(5)
obj.increment(3)
obj.increment()
obj.decrement(2)
obj.divide()
obj.scale_add(3)
obj.scale_add(3, 2)
obj.divide(2)
return obj.decrement()
self.checkScript(gn, ())
|