File: test_custom_post_grad_passes.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (272 lines) | stat: -rw-r--r-- 9,370 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# Owner(s): ["module: inductor"]
import contextlib
import operator
from collections import defaultdict

import torch
import torch._inductor.pattern_matcher as pattern_matcher
import torch.fx as fx
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
from torch._inductor.lowering import lowerings as L
from torch._inductor.pattern_matcher import Arg, CallFunction, PatternMatcherPass
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CPU


@config.patch({"freezing": True})
class TestCustomPassBase(TestCase):
    def _clone_inputs(self, inputs):
        def clone(x):
            if not isinstance(x, torch.Tensor):
                return x
            return x.clone()

        return tuple(clone(x) for x in inputs)

    def _test_common(
        self,
        mod,
        inputs,
        matcher_count,
        matcher_nodes,
        atol=1e-5,
        rtol=1.3e-6,
    ):
        counters.clear()
        maybe_autocast = contextlib.nullcontext()
        with torch.no_grad(), maybe_autocast:
            clone_inputs = self._clone_inputs(inputs)
            expected = mod(*inputs)
            actual = torch.compile(mod)(*clone_inputs)
            torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
            self.assertEqual(
                counters["inductor"]["pattern_matcher_count"], matcher_count
            )
            self.assertEqual(
                counters["inductor"]["pattern_matcher_nodes"],
                matcher_nodes,
            )


aten = torch.ops.aten
mkldnn = torch.ops.mkldnn


def change_cos_pass(graph):
    for node in graph.nodes:
        if node.op == "call_function" and node.target == aten.cos.default:
            node.target = aten.sin.default


class TestPostGradCustomPrePostPass(TestCustomPassBase):
    #  mkldnn fusion's pattern_matcher
    # (torch/_inductor/fx_passes/mkldnn_fusion.py),
    # and apply it to custom post_grad_passes.
    def _register_mkldnn_conv_relu_fusion(self, custom_pass_dict):
        # pattern
        def _mkldnn_conv_relu_pattern():
            return CallFunction(
                aten.relu,
                CallFunction(
                    mkldnn._convolution_pointwise.default,
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    Arg(),
                    _users=1,
                ),
            )

        # utils of pattern matcher registration
        def _register_fusion_lowering(pattern, custom_pass_dict):
            def dummy_check(m):
                return True

            def register_custom_lowering_pattern(
                pattern, extra_check, custom_pass_dict
            ):
                return pattern_matcher.register_lowering_pattern(
                    pattern, extra_check, pass_dict=custom_pass_dict
                )

            @register_custom_lowering_pattern(pattern, dummy_check, custom_pass_dict)
            def fn(match, *args, **kwargs):
                computation_args = list(args)[:-3] + ["relu", [], ""]
                return L[mkldnn._convolution_pointwise.default](*computation_args)

            return fn

        _register_fusion_lowering(_mkldnn_conv_relu_pattern(), custom_pass_dict)

    # custom post grad pass
    class _CustomPass(PatternMatcherPass, CustomGraphPass):
        def __init__(self) -> None:
            super().__init__()

        def __call__(self, g: torch.fx.graph.Graph):
            self.apply(g)

        def uuid(self) -> bytes:
            return get_hash_for_files((__file__,))

    # case model
    class _ConvReLU(torch.nn.Module):
        def __init__(self, ic, oc):
            super().__init__()
            self.conv = torch.nn.Conv2d(ic, oc, kernel_size=3, stride=1, padding=1)

        def forward(self, x):
            x1 = self.conv(x)
            return x1.relu()

    def test_custom_joint_pass_pre(self):
        with config.patch(joint_custom_pre_pass=change_cos_pass):

            def g(x):
                return x.sin().sin().sin()

            def f(x):
                return x.cos().cos().cos()

            x = torch.randn(8, dtype=torch.float32)
            torch.testing.assert_close(torch.compile(f)(x), g(x))

    def test_custom_joint_pass_post(self):
        with config.patch(joint_custom_post_pass=change_cos_pass):

            def g(x):
                return x.sin().sin().sin()

            def f(x):
                return x.cos().cos().cos()

            x = torch.randn(8, dtype=torch.float32)
            torch.testing.assert_close(torch.compile(f)(x), g(x))

    def test_custom_pre_pass(self):
        with config.patch(
            # leave custom pass only in post_grad_passes()
            pattern_matcher=False,
            post_grad_custom_pre_pass=self._CustomPass(),
            # define pattern match as custom post grad opt pass
            post_grad_custom_post_pass=None,
        ):
            # init mkldnn fusion on custom_matcher
            self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_pre_pass)

            mod = self._ConvReLU(16, 16).eval()
            x = torch.randn((1, 16, 56, 56), dtype=torch.float32)

            match_count = 1
            match_nodes = 2
            other_match_count = 1  # conv prepack weight
            other_match_nodes = 1  # conv prepack weight
            self._test_common(
                mod,
                (x,),
                match_count + other_match_count,
                match_nodes + other_match_nodes,
            )

    def test_custom_post_pass(self):
        with config.patch(
            # leave custom pass only in post_grad_passes()
            pattern_matcher=False,
            # define pattern match as custom post grad opt pass
            post_grad_custom_pre_pass=None,
            post_grad_custom_post_pass=self._CustomPass(),
        ):
            # init mkldnn fusion on custom_matcher
            self._register_mkldnn_conv_relu_fusion(config.post_grad_custom_post_pass)

            mod = self._ConvReLU(16, 16).eval()
            x = torch.randn((1, 16, 56, 56), dtype=torch.float32)

            match_count = 1
            match_nodes = 2
            other_match_count = 1  # conv prepack weight
            other_match_nodes = 1  # conv prepack weight
            self._test_common(
                mod,
                (x,),
                match_count + other_match_count,
                match_nodes + other_match_nodes,
            )

    def test_custom_pre_grad_pass(self):
        saved_graph = [None]

        def merge_mm_shared_rhs(graph: fx.Graph):
            """
            Bad POC of merging mm with a shared RHS.
            i.e. [mm(x, W), mm(x2, W)] => mm(cat(x, x2), W).split()

            Isn't actually safe for a couple reasons. For example, it doesn't handle the
            case where the LHS inputs depend on each other
            """
            saved_graph[0] = graph
            matmuls = [n for n in graph.nodes if n.target == torch.mm]
            rhs_vals = defaultdict(set)
            for m in matmuls:
                rhs_vals[m.args[1]].add(m)

            order = {}
            for idx, n in enumerate(graph.nodes):
                order[n] = idx

            for rhs, matmuls in rhs_vals.items():
                if len(matmuls) == 1:
                    continue
                matmuls = sorted(matmuls, key=lambda x: order[x])
                with graph.inserting_before(matmuls[0]):
                    lhs_vals = [m.args[0] for m in matmuls]
                    new_cat = graph.create_node(
                        "call_function", torch.cat, args=(lhs_vals, 0)
                    )
                    new_mm = graph.create_node(
                        "call_function", torch.mm, args=(new_cat, rhs)
                    )
                    split_vals = graph.create_node(
                        "call_function",
                        torch.split,
                        args=(
                            new_mm,
                            [l.meta["example_value"].shape[0] for l in lhs_vals],
                        ),
                    )
                for idx, m in enumerate(matmuls):
                    m.target = operator.getitem
                    m.args = (split_vals, idx)

        @config.patch(pre_grad_custom_pass=merge_mm_shared_rhs)
        def inner_test():
            @torch.compile
            def f(W, nested_seqs):
                outs = [torch.mm(s, W) for s in nested_seqs]
                return outs

            W = torch.randn(16, 16, dtype=torch.bfloat16)
            nested_seqs = [
                torch.randn(l, 16, dtype=torch.bfloat16) for l in [4, 8, 5, 3]
            ]

            f(W, nested_seqs)
            assert saved_graph[0] is not None
            matmuls = [n for n in saved_graph[0].nodes if n.target == torch.mm]
            assert len(matmuls) == 1

        inner_test()


if __name__ == "__main__":
    if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
        run_tests()