1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
# mypy: allow-untyped-defs
import contextlib
from typing import Optional
import torch
_TORCHBIND_IMPLS_INITIALIZED = False
_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
def init_torchbind_implementations():
global _TORCHBIND_IMPLS_INITIALIZED
global _TENSOR_QUEUE_GLOBAL_TEST
if _TORCHBIND_IMPLS_INITIALIZED:
return
load_torchbind_test_lib()
register_fake_operators()
register_fake_classes()
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
_TORCHBIND_IMPLS_INITIALIZED = True
def _empty_tensor_queue() -> torch.ScriptObject:
return torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
# put these under a function because the corresponding library might not be loaded yet.
def register_fake_operators():
@torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
def fake_takes_foo(foo, z):
return foo.add_tensor(z)
@torch.library.register_fake("_TorchScriptTesting::queue_pop")
def fake_queue_pop(tq):
return tq.pop()
@torch.library.register_fake("_TorchScriptTesting::queue_push")
def fake_queue_push(tq, x):
return tq.push(x)
@torch.library.register_fake("_TorchScriptTesting::queue_size")
def fake_queue_size(tq):
return tq.size()
def meta_takes_foo_list_return(foo, x):
a = foo.add_tensor(x)
b = foo.add_tensor(a)
c = foo.add_tensor(b)
return [a, b, c]
def meta_takes_foo_tuple_return(foo, x):
a = foo.add_tensor(x)
b = foo.add_tensor(a)
return (a, b)
torch.ops._TorchScriptTesting.takes_foo_list_return.default.py_impl(
torch._C.DispatchKey.Meta
)(meta_takes_foo_list_return)
torch.ops._TorchScriptTesting.takes_foo_tuple_return.default.py_impl(
torch._C.DispatchKey.Meta
)(meta_takes_foo_tuple_return)
torch.ops._TorchScriptTesting.takes_foo.default.py_impl(torch._C.DispatchKey.Meta)(
# make signature match original cpp implementation to support kwargs
lambda foo, x: foo.add_tensor(x)
)
def register_fake_classes():
# noqa: F841
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
def add_tensor(self, z):
return (self.x + self.y) * z
@torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
class FakeContainsTensor:
def __init__(self, t: torch.Tensor):
self.t = t
@classmethod
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
def get(self):
return self.t
def load_torchbind_test_lib():
import unittest
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
)
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
else:
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
@contextlib.contextmanager
def _register_py_impl_temporarily(op_overload, key, fn):
try:
op_overload.py_impl(key)(fn)
yield
finally:
del op_overload.py_kernels[key]
op_overload._dispatch_cache.clear()
|