File: test_prim_hop_base.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (188 lines) | stat: -rw-r--r-- 6,948 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Owner(s): ["module: dynamo"]
import unittest

import torch
import torch._dynamo.test_case
import torch._functorch.config
import torch.utils.checkpoint
from torch._dynamo.testing import (
    AotEagerAndRecordGraphs,
    EagerAndRecordGraphs,
    normalize_gm,
)
from torch.testing._internal.inductor_utils import HAS_CUDA


requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")


def normalize_graph(gm):
    return normalize_gm(gm.print_readable(print_output=False))


class InvokeQuantTest(torch._higher_order_ops.PrimHOPBase):
    def __init__(self):
        super().__init__("invoke_quant_test")

    def __call__(self, subgraph, operands, *, scheme):
        return super().__call__(subgraph, operands, scheme=scheme)


invoke_quant_test = InvokeQuantTest()


class PrimHOPBaseTest(torch._dynamo.test_case.TestCase):
    # TODO: flip to False later, we're landing a refactor PR and don't want to merge conflict
    @torch._dynamo.config.patch(assume_static_by_default=True)
    def test_dynamo(self):
        def inner(x, y):
            return (x @ y).sin().cos()

        x = torch.randn(3, 3, requires_grad=True)
        y = torch.randn(3, 3, requires_grad=True)

        backend = EagerAndRecordGraphs()

        @torch.compile(backend=backend)
        def f(x, y):
            return invoke_quant_test(inner, (x, y), scheme="nf4")

        out = f(x, y)
        self.assertEqual(out, inner(x, y))

        assert len(backend.graphs) == 1
        self.assertExpectedInline(
            normalize_graph(backend.graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"):
        l_x_ = L_x_
        l_y_ = L_y_

        subgraph_0 = self.subgraph_0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, (l_x_, l_y_), scheme = 'nf4');  subgraph_0 = l_x_ = l_y_ = None
        getitem: "f32[3, 3]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem,)

    class subgraph_0(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"):
            matmul: "f32[3, 3]" = l_x_ @ l_y_;  l_x_ = l_y_ = None
            sin: "f32[3, 3]" = matmul.sin();  matmul = None
            cos: "f32[3, 3]" = sin.cos();  sin = None
            return (cos,)
""",  # NOQA: B950
        )

    @torch._dynamo.config.patch(assume_static_by_default=True)
    def test_aot_eager(self):
        def inner(x, y):
            return (x @ y).sin_().cos()

        x = torch.randn(3, 3, requires_grad=True)
        y = torch.randn(3, 3, requires_grad=True)

        backend = AotEagerAndRecordGraphs()

        @torch.compile(backend=backend)
        def f(x, y):
            return invoke_quant_test(inner, (x, y), scheme="nf4")

        out = f(x, y)
        result = torch.autograd.grad(out, x, y)
        out = inner(x, y)
        expected = torch.autograd.grad(out, x, y)
        self.assertEqual(result, expected)

        assert len(backend.fw_graphs) == 1
        self.assertExpectedInline(
            normalize_graph(backend.fw_graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]"):
        subgraph0 = self.subgraph0
        invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph0, (primals_1, primals_2), scheme = 'nf4');  subgraph0 = None
        getitem: "f32[3, 3]" = invoke_quant_test[0];  invoke_quant_test = None
        return (getitem, primals_1, primals_2)

    class subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"):
            mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm);  mm = None
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin);  sin = None
            return (cos,)
""",  # NOQA: B950
        )

        assert len(backend.bw_graphs) == 1
        self.assertExpectedInline(
            normalize_graph(backend.bw_graphs[0]),
            """\
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[3, 3]", primals_2: "f32[3, 3]", tangents_1: "f32[3, 3]"):
        subgraph1 = self.subgraph1
        invoke_quant_test_1 = torch.ops.higher_order.invoke_quant_test(subgraph1, (primals_1, primals_2, tangents_1), scheme = 'nf4');  subgraph1 = primals_1 = primals_2 = tangents_1 = None
        getitem_1: "f32[3, 3]" = invoke_quant_test_1[0]
        getitem_2: "f32[3, 3]" = invoke_quant_test_1[1];  invoke_quant_test_1 = None
        return (getitem_1, getitem_2)

    class subgraph1(torch.nn.Module):
        def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]"):
            mm: "f32[3, 3]" = torch.ops.aten.mm.default(arg0_1, arg1_1)
            clone: "f32[3, 3]" = torch.ops.aten.clone.default(mm)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(mm);  mm = None
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(sin);  cos = None
            sin_1: "f32[3, 3]" = torch.ops.aten.sin.default(sin);  sin = None
            neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin_1);  sin_1 = None
            mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, neg);  arg2_1 = neg = None
            cos_1: "f32[3, 3]" = torch.ops.aten.cos.default(clone);  clone = None
            mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
            t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1);  arg0_1 = None
            mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(t, mul_1);  t = None
            t_1: "f32[3, 3]" = torch.ops.aten.t.default(arg1_1);  arg1_1 = None
            mm_2: "f32[3, 3]" = torch.ops.aten.mm.default(mul_1, t_1);  mul_1 = t_1 = None
            return [mm_2, mm_1]
""",  # NOQA: B950
        )

    def test_aliasing_mutation_error(self):
        def inner(x, y):
            return x

        def inner2(x, y):
            x.sin_()
            return x + y

        x = torch.randn(3, 3)
        y = torch.randn(3, 3)

        @torch.compile(backend="eager", fullgraph=True)
        def f(inner, x, y):
            return invoke_quant_test(inner, (x, y), scheme="nf4")

        with self.assertRaisesRegex(RuntimeError, "aliases of the inputs"):
            out = f(inner, x, y)

        with self.assertRaisesRegex(RuntimeError, "inputs are mutated"):
            out = f(inner2, x, y)

    def test_eager_call(self):
        def inner(x, y):
            return x + y

        x = torch.randn(3, 3)
        y = torch.randn(3, 3)

        with self.assertRaisesRegex(RuntimeError, "torch.fx.GraphModule"):
            invoke_quant_test(inner, (x, y), scheme="nf4")

        from functorch import make_fx

        result = make_fx(inner)(x, y)
        # smoke test
        invoke_quant_test(result, (x, y), scheme="nf4")


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()