File: test_custom_ops.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (159 lines) | stat: -rw-r--r-- 5,609 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Owner(s): ["module: unknown"]

import os.path
import sys
import tempfile
import unittest

from model import get_custom_op_library_path, Model

import torch
import torch._library.utils as utils
from torch import ops
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase


torch.ops.import_module("pointwise")


class TestCustomOperators(TestCase):
    def setUp(self):
        self.library_path = get_custom_op_library_path()
        ops.load_library(self.library_path)

    def test_custom_library_is_loaded(self):
        self.assertIn(self.library_path, ops.loaded_libraries)

    def test_op_with_no_abstract_impl_pystub(self):
        x = torch.randn(3, device="meta")
        if utils.requires_set_python_module():
            with self.assertRaisesRegex(RuntimeError, "pointwise"):
                torch.ops.custom.tan(x)
        else:
            # Smoketest
            torch.ops.custom.tan(x)

    def test_op_with_incorrect_abstract_impl_pystub(self):
        x = torch.randn(3, device="meta")
        with self.assertRaisesRegex(RuntimeError, "pointwise"):
            torch.ops.custom.cos(x)

    @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
    def test_dynamo_pystub_suggestion(self):
        x = torch.randn(3)

        @torch.compile(backend="eager", fullgraph=True)
        def f(x):
            return torch.ops.custom.asin(x)

        with self.assertRaisesRegex(
            RuntimeError,
            r"unsupported operator: .* you may need to `import nonexistent`",
        ):
            f(x)

    def test_abstract_impl_pystub_faketensor(self):
        from functorch import make_fx

        x = torch.randn(3, device="cpu")
        self.assertNotIn("my_custom_ops", sys.modules.keys())

        with self.assertRaises(
            torch._subclasses.fake_tensor.UnsupportedOperatorException
        ):
            gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)

        torch.ops.import_module("my_custom_ops")
        gm = make_fx(torch.ops.custom.nonzero.default, tracing_mode="symbolic")(x)
        self.assertExpectedInline(
            """\
def forward(self, arg0_1):
    nonzero = torch.ops.custom.nonzero.default(arg0_1);  arg0_1 = None
    return nonzero
""".strip(),
            gm.code.strip(),
        )

    def test_abstract_impl_pystub_meta(self):
        x = torch.randn(3, device="meta")
        self.assertNotIn("my_custom_ops2", sys.modules.keys())
        with self.assertRaisesRegex(NotImplementedError, r"'my_custom_ops2'"):
            y = torch.ops.custom.sin.default(x)
        torch.ops.import_module("my_custom_ops2")
        y = torch.ops.custom.sin.default(x)

    def test_calling_custom_op_string(self):
        output = ops.custom.op2("abc", "def")
        self.assertLess(output, 0)
        output = ops.custom.op2("abc", "abc")
        self.assertEqual(output, 0)

    def test_calling_custom_op(self):
        output = ops.custom.op(torch.ones(5), 2.0, 3)
        self.assertEqual(type(output), list)
        self.assertEqual(len(output), 3)
        for tensor in output:
            self.assertTrue(tensor.allclose(torch.ones(5) * 2))

        output = ops.custom.op_with_defaults(torch.ones(5))
        self.assertEqual(type(output), list)
        self.assertEqual(len(output), 1)
        self.assertTrue(output[0].allclose(torch.ones(5)))

    def test_calling_custom_op_with_autograd(self):
        x = torch.randn((5, 5), requires_grad=True)
        y = torch.randn((5, 5), requires_grad=True)
        output = ops.custom.op_with_autograd(x, 2, y)
        self.assertTrue(output.allclose(x + 2 * y + x * y))

        go = torch.ones((), requires_grad=True)
        output.sum().backward(go, False, True)
        grad = torch.ones(5, 5)

        self.assertEqual(x.grad, y + grad)
        self.assertEqual(y.grad, x + grad * 2)

        # Test with optional arg.
        x.grad.zero_()
        y.grad.zero_()
        z = torch.randn((5, 5), requires_grad=True)
        output = ops.custom.op_with_autograd(x, 2, y, z)
        self.assertTrue(output.allclose(x + 2 * y + x * y + z))

        go = torch.ones((), requires_grad=True)
        output.sum().backward(go, False, True)
        self.assertEqual(x.grad, y + grad)
        self.assertEqual(y.grad, x + grad * 2)
        self.assertEqual(z.grad, grad)

    def test_calling_custom_op_with_autograd_in_nograd_mode(self):
        with torch.no_grad():
            x = torch.randn((5, 5), requires_grad=True)
            y = torch.randn((5, 5), requires_grad=True)
            output = ops.custom.op_with_autograd(x, 2, y)
            self.assertTrue(output.allclose(x + 2 * y + x * y))

    def test_calling_custom_op_inside_script_module(self):
        model = Model()
        output = model.forward(torch.ones(5))
        self.assertTrue(output.allclose(torch.ones(5) + 1))

    def test_saving_and_loading_script_module_with_custom_op(self):
        model = Model()
        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
        # close the file after creation and try to remove it manually.
        file = tempfile.NamedTemporaryFile(delete=False)
        try:
            file.close()
            model.save(file.name)
            loaded = torch.jit.load(file.name)
        finally:
            os.unlink(file.name)

        output = loaded.forward(torch.ones(5))
        self.assertTrue(output.allclose(torch.ones(5) + 1))


if __name__ == "__main__":
    run_tests()