1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
# Owner(s): ["module: unknown"]
import os.path
import sys
import tempfile
import unittest
from model import get_custom_op_library_path, Model
import torch
import torch._library.utils as utils
from torch import ops
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
torch.ops.import_module("pointwise")
class TestCustomOperators(TestCase):
def setUp(self):
self.library_path = get_custom_op_library_path()
ops.load_library(self.library_path)
def test_custom_library_is_loaded(self):
self.assertIn(self.library_path, ops.loaded_libraries)
def test_op_with_no_abstract_impl_pystub(self):
x = torch.randn(3, device="meta")
if utils.requires_set_python_module():
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.tan(x)
else:
# Smoketest
torch.ops.custom.tan(x)
def test_op_with_incorrect_abstract_impl_pystub(self):
x = torch.randn(3, device="meta")
with self.assertRaisesRegex(RuntimeError, "pointwise"):
torch.ops.custom.cos(x)
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_dynamo_pystub_suggestion(self):
x = torch.randn(3)
@torch.compile(backend="eager", fullgraph=True)
def f(x):
return torch.ops.custom.asin(x)
with self.assertRaisesRegex(
RuntimeError,
r"unsupported operator: .* you may need to `import nonexistent`",
):
f(x)
def test_abstract_impl_pystub_faketensor(self):
from functorch import make_fx
x = torch.randn(3, device="cpu")
self.assertNotIn("my_custom_ops", sys.modules.keys())
with self.assertRaises(
torch._subclasses.fake_tensor.UnsupportedOperatorException
):
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
torch.ops.import_module("my_custom_ops")
gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
self.assertExpectedInline(
"""\
def forward(self, arg0_1):
nonzero = torch.ops.custom.nonzero.default(arg0_1); arg0_1 = None
return nonzero
""".strip(),
gm.code.strip(),
)
def test_abstract_impl_pystub_meta(self):
x = torch.randn(3, device="meta")
self.assertNotIn("my_custom_ops2", sys.modules.keys())
with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
y = torch.ops.custom.sin.default(x)
torch.ops.import_module("my_custom_ops2")
y = torch.ops.custom.sin.default(x)
def test_calling_custom_op_string(self):
output = ops.custom.op2("abc", "def")
self.assertLess(output, 0)
output = ops.custom.op2("abc", "abc")
self.assertEqual(output, 0)
def test_calling_custom_op(self):
output = ops.custom.op(torch.ones(5), 2.0, 3)
self.assertEqual(type(output), list)
self.assertEqual(len(output), 3)
for tensor in output:
self.assertTrue(tensor.allclose(torch.ones(5) * 2))
output = ops.custom.op_with_defaults(torch.ones(5))
self.assertEqual(type(output), list)
self.assertEqual(len(output), 1)
self.assertTrue(output[0].allclose(torch.ones(5)))
def test_calling_custom_op_with_autograd(self):
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
grad = torch.ones(5, 5)
self.assertEqual(x.grad, y + grad)
self.assertEqual(y.grad, x + grad * 2)
# Test with optional arg.
x.grad.zero_()
y.grad.zero_()
z = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y, z)
self.assertTrue(output.allclose(x + 2 * y + x * y + z))
go = torch.ones((), requires_grad=True)
output.sum().backward(go, False, True)
self.assertEqual(x.grad, y + grad)
self.assertEqual(y.grad, x + grad * 2)
self.assertEqual(z.grad, grad)
def test_calling_custom_op_with_autograd_in_nograd_mode(self):
with torch.no_grad():
x = torch.randn((5, 5), requires_grad=True)
y = torch.randn((5, 5), requires_grad=True)
output = ops.custom.op_with_autograd(x, 2, y)
self.assertTrue(output.allclose(x + 2 * y + x * y))
def test_calling_custom_op_inside_script_module(self):
model = Model()
output = model.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
def test_saving_and_loading_script_module_with_custom_op(self):
model = Model()
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":
run_tests()
|