File: test_decompose_mem_bound_mm.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (426 lines) | stat: -rw-r--r-- 14,020 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
# Owner(s): ["module: inductor"]

import logging

import torch
import torch._inductor
from torch._dynamo.utils import counters
from torch._inductor.fx_passes.decompose_mem_bound_mm import check_device
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    skipIfXpu,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
from torch.testing._internal.triton_utils import requires_gpu


class MyModule(torch.nn.Module):
    def __init__(
        self, n_input: int, n_output: int, has_bias: bool, device=GPU_TYPE
    ) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(n_input, n_output, bias=has_bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


class MyModule2(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input1, input2):
        output = torch.bmm(input1, input2)
        return output


class MyModule3(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input1, input2):
        output = torch.mm(input1, input2)
        return output


@requires_gpu
@skipIfXpu(
    msg="Intel GPU has not enabled decompose_mem_bound_mm PASS in "
    "torch/_inductor/fx_passes/decompose_mem_bound_mm.py"
)
@torch._inductor.config.patch(
    post_grad_fusion_options={
        "decompose_mm_pass": {},
    }
)
@instantiate_parametrized_tests
class TestDecomposeMemMM(TestCase):
    def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
        if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
            return False
        for key1 in ref_dict.keys():
            key2 = "_orig_mod." + key1
            assert key2 in res_dict, f"{key1} does not exist in traced module"
            if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
                return False
        return True

    def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
        ref = module(*input)
        res = traced(*input)
        self.assertEqual(ref, res, rtol=rtol, atol=atol)

    def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
        ref_params = dict(module.named_parameters())
        res_params = dict(traced.named_parameters())
        self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))

    def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
        ref_grad = {key: param.grad for key, param in module.named_parameters()}
        res_grad = {key: param.grad for key, param in traced.named_parameters()}
        self.assertTrue(
            self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
        )

    @parametrize(
        "b,m,k,n,should_decompose",
        [(10240, 2, 2, 2, True), (10240, 2, 32, 32, False), (2000, 2, 2, 2, False)],
    )
    def test_decompose_bmm(self, b, m, n, k, should_decompose):
        torch._logging.set_logs(inductor=logging.DEBUG)
        mat1 = torch.randn(b, m, k, device=GPU_TYPE).requires_grad_(True)
        mat2 = torch.randn(b, k, n, device=GPU_TYPE).requires_grad_(True)

        counters.clear()

        module = MyModule2().to(GPU_TYPE)
        traced = torch.compile(module)
        input = [mat1, mat2]
        ref = module(*input)
        res = traced(*input)

        self.compare_pred(module, traced, input)

        expected_val = 1 if should_decompose and HAS_CUDA else 0
        self.assertEqual(
            counters["inductor"]["decompose_bmm"],
            expected_val,
        )

        ref.sum().backward()
        res.sum().backward()
        self.compare_parameters(module, traced)
        self.compare_gradients(module, traced)

        expected_val = 3 if should_decompose and HAS_CUDA else 0
        self.assertEqual(
            counters["inductor"]["decompose_bmm"],
            expected_val,
        )
        counters.clear()

    @parametrize(
        "b,m,k,n,should_decompose",
        [(1, 2, 2, 2, True), (2, 2, 2, 2, False)],
    )
    def test_decompose_bmm_cpu(self, b, m, n, k, should_decompose):
        torch._logging.set_logs(inductor=logging.DEBUG)
        mat1 = torch.randn(b, m, k)
        mat2 = torch.randn(b, k, n)

        counters.clear()

        module = MyModule2()
        traced = torch.compile(module)
        input = [mat1, mat2]
        self.compare_pred(module, traced, input)

        expected_val = 1 if should_decompose else 0
        self.assertEqual(
            counters["inductor"]["decompose_bmm"],
            expected_val,
        )
        counters.clear()

    @parametrize(
        "m,k,n, should_decompose",
        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
    )
    @parametrize("has_bias", [True, False])
    def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
        torch._logging.set_logs(inductor=logging.DEBUG)
        input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)

        counters.clear()

        module = MyModule(k, n, has_bias).to(GPU_TYPE)
        traced = torch.compile(module)
        input = [input]
        ref = module(*input)
        res = traced(*input)

        self.compare_pred(module, traced, input)

        expected_val = 1 if should_decompose and HAS_CUDA else 0
        if has_bias:
            self.assertEqual(
                counters["inductor"]["decompose_addmm"],
                expected_val,
            )
        else:
            self.assertEqual(
                counters["inductor"]["decompose_mm"],
                expected_val,
            )
        decompose_mm_fwd = counters["inductor"]["decompose_mm"]

        ref.sum().backward()
        res.sum().backward()

        self.compare_parameters(module, traced)
        self.compare_gradients(module, traced)

        self.assertEqual(
            counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
            expected_val,
        )
        counters.clear()

    @parametrize(
        "m,k,n, should_decompose",
        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
    )
    @parametrize("has_bias", [True, False])
    def test_decompose_linear_mixed_precision(
        self, m, n, k, has_bias, should_decompose
    ):
        with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16):
            torch._logging.set_logs(inductor=logging.DEBUG)
            input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)

            counters.clear()

            module = MyModule(k, n, has_bias).to(GPU_TYPE)
            traced = torch.compile(module)
            input = [input]
            ref = module(*input)
            res = traced(*input)

            self.compare_pred(module, traced, input)

            expected_val = 1 if should_decompose and HAS_CUDA else 0
            if has_bias:
                self.assertEqual(
                    counters["inductor"]["decompose_addmm"],
                    expected_val,
                )
            else:
                self.assertEqual(
                    counters["inductor"]["decompose_mm"],
                    expected_val,
                )
            decompose_mm_fwd = counters["inductor"]["decompose_mm"]

            ref.sum().backward()
            res.sum().backward()

            self.compare_parameters(module, traced)
            self.compare_gradients(module, traced)

            self.assertEqual(
                counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
                expected_val,
            )
            counters.clear()

    @parametrize(
        "m,k,n, should_decompose",
        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
    )
    @parametrize("has_bias", [True, False])
    def test_decompose_mm(self, m, n, k, has_bias, should_decompose):
        torch._logging.set_logs(inductor=logging.DEBUG)
        mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
        mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True)

        counters.clear()

        module = MyModule3().to(GPU_TYPE)
        traced = torch.compile(module)
        input = [mat1, mat2]
        ref = module(*input)
        res = traced(*input)

        self.compare_pred(module, traced, input)

        expected_val = 1 if should_decompose and HAS_CUDA else 0
        self.assertEqual(
            counters["inductor"]["decompose_mm"],
            expected_val,
        )
        decompose_mm_fwd = counters["inductor"]["decompose_mm"]

        ref.sum().backward()
        res.sum().backward()
        self.compare_parameters(module, traced)
        self.compare_gradients(module, traced)

        expected_val = 1 if should_decompose and HAS_CUDA else 0
        self.assertEqual(
            counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
            expected_val,
        )
        counters.clear()

    @parametrize(
        "m,k,n, should_decompose",
        [(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, False)],
    )
    def test_decompose_mm_cpu(self, m, n, k, should_decompose):
        torch._logging.set_logs(inductor=logging.DEBUG)
        mat1 = torch.randn(m, k)
        mat2 = torch.randn(k, n)
        counters.clear()

        module = MyModule3()
        traced = torch.compile(module)
        input = [mat1, mat2]
        self.compare_pred(module, traced, input)

        expected_val = 1 if should_decompose else 0
        self.assertEqual(
            counters["inductor"]["decompose_mm"],
            expected_val,
        )
        counters.clear()

    @parametrize(
        "m,k,n, should_decompose",
        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
    )
    @parametrize("has_bias", [True, False])
    def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose):
        with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16):
            torch._logging.set_logs(inductor=logging.DEBUG)
            mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
            mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True)

            counters.clear()

            module = MyModule3().to(GPU_TYPE)
            traced = torch.compile(module)
            input = [mat1, mat2]
            ref = module(*input)
            res = traced(*input)

            self.compare_pred(module, traced, input)

            expected_val = 1 if should_decompose and HAS_CUDA else 0
            self.assertEqual(
                counters["inductor"]["decompose_mm"],
                expected_val,
            )
            decompose_mm_fwd = counters["inductor"]["decompose_mm"]

            ref.sum().backward()
            res.sum().backward()
            self.compare_parameters(module, traced)
            self.compare_gradients(module, traced)

            expected_val = 1 if should_decompose and HAS_CUDA else 0
            self.assertEqual(
                counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
                expected_val,
            )
            counters.clear()

    @parametrize("m,k,n, should_decompose", [(20480, 5, 2, True)])
    @parametrize("has_bias", [True, False])
    def test_dynamic_shape(self, m, n, k, has_bias, should_decompose):
        torch._logging.set_logs(inductor=logging.DEBUG)
        input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)

        counters.clear()

        module = MyModule(k, n, has_bias).to(GPU_TYPE)
        traced = torch.compile(module, dynamic=True)
        input = [input]
        ref = module(*input)
        res = traced(*input)

        self.compare_pred(module, traced, input)

        expected_val = 1 if should_decompose and HAS_CUDA else 0
        if has_bias:
            self.assertEqual(
                counters["inductor"]["decompose_addmm"],
                expected_val,
            )

        ref.sum().backward()
        res.sum().backward()

        self.compare_parameters(module, traced)
        self.compare_gradients(module, traced)

        expected_val = 0
        if HAS_CUDA:
            expected_val = 1 if has_bias else 2

        self.assertEqual(
            counters["inductor"]["decompose_mm"],
            expected_val,
        )
        counters.clear()

    def test_realize_input(self):
        m = 20480
        k = 5
        n = 2
        torch._logging.set_logs(inductor=logging.DEBUG)
        input1 = torch.randn(m, k, device=GPU_TYPE).T.contiguous()
        input2 = torch.randn(k, n, device=GPU_TYPE)

        @torch.compile()
        def foo(x, y):
            return x.T.contiguous() @ y

        out, code = run_and_get_code(foo, input1, input2)

        if GPU_TYPE == "xpu":
            # only 1 kernel generated on the XPU stack
            FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
        else:
            # two kernels generated
            FileCheck().check_count(".run(", 2, exactly=True).run(code[0])

    def test_check_device(self):
        m = 5
        k = 5
        n = 2
        torch._logging.set_logs(inductor=logging.DEBUG)

        input1 = torch.randn(m, k, device=GPU_TYPE)
        input2 = torch.randn(k, n, device=GPU_TYPE)
        self.assertTrue(check_device(input1, input2))
        self.assertFalse(check_device(input1, input2, device="cpu"))

        input1 = torch.randn(m, k)
        input2 = torch.randn(k, n)
        self.assertTrue(check_device(input1, input2, device="cpu"))
        self.assertFalse(check_device(input1, input2))

        input1 = torch.randn(m, k, device=GPU_TYPE)
        input2 = torch.randn(k, n)
        self.assertFalse(check_device(input1, input2, device="gpu"))
        self.assertFalse(check_device(input1, input2, device="cpu"))

        self.assertFalse(check_device(input1, input2, device="mtia"))


if __name__ == "__main__":
    run_tests()