File: test_binary_folding.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (350 lines) | stat: -rw-r--r-- 11,929 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# Owner(s): ["module: inductor"]
import functools
import importlib
import itertools
import os
import sys

import torch
from torch import nn
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch.testing._internal.common_cuda import TEST_CUDNN


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)

from inductor.test_inductor_freezing import (  # @manual=fbcode//caffe2/test/inductor:inductor_freezing-library
    TestCase,
)
from inductor.test_torchinductor import (  # @manual=fbcode//caffe2/test/inductor:test_inductor-library
    check_model,
    check_model_gpu,
    copy_tests,
)
from torch.testing._internal.common_utils import TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import skipCUDAIf


importlib.import_module("functorch")
importlib.import_module("filelock")

from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU


aten = torch.ops.aten


class BinaryFoldingTemplate(TestCase):
    @skipCUDAIf(TEST_CUDNN, "CUDNN has accuracy issues for this test")
    def test_conv_binary_folding(self):
        @torch.no_grad()
        def test_conv_fusion(use_bias, module, op, scalar, add_tensor, expect_success):
            class ConvOp(nn.Module):
                __constants__ = ["use_scalar"]

                def __init__(self, in_channels, out_channels, device, **kwargs):
                    super().__init__()
                    self.conv = module(
                        in_channels, out_channels, bias=use_bias, **kwargs
                    ).to(device)
                    self.use_scalar = scalar
                    tensor_size = [1 for _ in range(self.conv.weight.ndim)]
                    tensor_size[1] = self.conv.weight.size(0)
                    self.tensor = torch.nn.Parameter(
                        add_tensor
                        if add_tensor is not None
                        else torch.rand(tensor_size).to(device)
                    )
                    self.op = op

                def forward(self, x):
                    x = self.conv(x)
                    if self.use_scalar:
                        return self.op(x, 2.0)
                    else:
                        return self.op(x, self.tensor)

            torch._dynamo.reset()
            counters.clear()
            mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval()
            out_optimized = torch.compile(mod_eager)

            inps = [4, 3, 4]
            if module == nn.Conv2d:
                inps.append(inps[-1])
            if module == nn.Conv3d:
                inps.append(inps[-1])
                inps.append(inps[-1])

            torch.manual_seed(1234)
            inp = torch.rand(inps).to(self.device)
            out_eager = mod_eager(inp)
            out_optimized = out_optimized(inp)
            self.assertEqual(out_optimized, out_eager)
            if expect_success:
                self.assertEqual(counters["inductor"]["binary_folding"], 1)
            else:
                self.assertEqual(counters["inductor"]["binary_folding"], 0)

        conv_bias = [True, False]
        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
        use_scalar = [True, False]
        ops = [torch.add, torch.sub, torch.mul, torch.div]
        for use_bias, module, pytorch_op, scalar in itertools.product(
            conv_bias, modules, ops, use_scalar
        ):
            test_conv_fusion(
                use_bias,
                module,
                pytorch_op,
                scalar,
                add_tensor=None,
                expect_success=True,
            )

        for use_bias, pytorch_op in itertools.product(conv_bias, ops):
            # broadcasting add
            test_conv_fusion(
                use_bias,
                nn.Conv2d,
                pytorch_op,
                False,
                add_tensor=torch.rand(
                    32,
                    1,
                    32,
                ).to(self.device),
                expect_success=False,
            )

            # broadcasting add
            test_conv_fusion(
                use_bias,
                nn.Conv2d,
                pytorch_op,
                False,
                add_tensor=torch.rand(1, 1).to(self.device),
                expect_success=True,
            )

            # add with different dtype
            test_conv_fusion(
                use_bias,
                nn.Conv2d,
                pytorch_op,
                False,
                add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
                expect_success=False,
            )

    @inductor_config.patch({"freezing": True})
    def test_conv_bn_folding(self):
        @torch.no_grad()
        def test_conv_fusion(use_bias, module, expect_success):
            class ConvOp(nn.Module):
                def __init__(self, in_channels, out_channels, device, **kwargs):
                    super().__init__()
                    self.conv = module[0](
                        in_channels, out_channels, bias=use_bias, **kwargs
                    ).to(device)
                    self.bn = module[1](out_channels).to(device)

                def forward(self, x):
                    x = self.conv(x)
                    return self.bn(x)

            from torch._inductor.compile_fx import compile_fx, compile_fx_inner

            aten_binary = [
                aten.add.Tensor,
                aten.sub.Tensor,
                aten.mul.Tensor,
                aten.div.Tensor,
            ]
            n_binary_ops = 0

            def my_inner_compile(gm, example_inputs, *args, **kwargs):
                out = compile_fx_inner(gm, example_inputs, *args, **kwargs)
                nonlocal n_binary_ops
                binarry_ops = [n for n in gm.graph.nodes if n.target in aten_binary]
                n_binary_ops += len(binarry_ops)
                return out

            torch._dynamo.reset()
            mod_eager = ConvOp(3, 32, self.device, kernel_size=3, stride=2).eval()
            out_optimized = torch.compile(
                mod_eager,
                backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
            )

            inps = [4, 3, 4]
            if module[0] == nn.Conv2d:
                inps.append(inps[-1])
            if module[0] == nn.Conv3d:
                inps.append(inps[-1])
                inps.append(inps[-1])

            inp = torch.rand(inps).to(self.device)
            out_eager = mod_eager(inp)
            out_optimized = out_optimized(inp)
            self.assertEqual(out_optimized, out_eager, atol=2e-04, rtol=1e-5)
            if expect_success:
                self.assertTrue(n_binary_ops == 0)
            else:
                self.assertTrue(n_binary_ops > 1)

        conv_bias = [True, False]
        modules = [
            (nn.Conv1d, nn.BatchNorm1d),
            (nn.Conv2d, nn.BatchNorm2d),
            (nn.Conv3d, nn.BatchNorm3d),
        ]
        for use_bias, module in itertools.product(conv_bias, modules):
            test_conv_fusion(
                use_bias,
                module,
                expect_success=True,
            )

    @inductor_config.patch({"enable_linear_binary_folding": True})
    def test_linear_binary_folding(self):
        @torch.no_grad()
        def test_linear_fusion(
            use_bias, op, scalar, add_tensor, expect_success, input_3d=False
        ):
            class LinearOp(nn.Module):
                __constants__ = ["use_scalar"]

                def __init__(self, in_channels, out_channels, device, **kwargs):
                    super().__init__()
                    self.linear = nn.Linear(
                        in_channels, out_channels, bias=use_bias, **kwargs
                    ).to(device)
                    self.use_scalar = scalar
                    tensor_size = [
                        self.linear.weight.size(0),
                    ]
                    self.tensor = torch.nn.Parameter(
                        add_tensor
                        if add_tensor is not None
                        else torch.rand(tensor_size).to(device)
                    )
                    self.op = op

                def forward(self, x):
                    x = self.linear(x)
                    if self.use_scalar:
                        return self.op(x, 2.0)
                    else:
                        return self.op(x, self.tensor)

            torch._dynamo.reset()
            counters.clear()
            mod_eager = LinearOp(3, 32, self.device).eval()
            out_optimized = torch.compile(mod_eager)

            torch.manual_seed(1234)
            if input_3d:
                inp = torch.rand([2, 4, 3]).to(self.device)
            else:
                inp = torch.rand([4, 3]).to(self.device)
            out_eager = mod_eager(inp)
            out_optimized = out_optimized(inp)
            self.assertEqual(out_optimized, out_eager, atol=5e-05, rtol=5e-06)
            if expect_success:
                self.assertEqual(counters["inductor"]["binary_folding"], 1)
            else:
                self.assertEqual(counters["inductor"]["binary_folding"], 0)

        linear_bias = [True, False]
        use_scalar = [True, False]
        ops = [torch.add, torch.sub, torch.mul, torch.div]
        add_tensor_size = [
            [
                32,
            ],
            [1, 32],
            [
                1,
            ],
            [1, 1],
        ]
        for use_bias, pytorch_op, scalar, tensor_size in itertools.product(
            linear_bias, ops, use_scalar, add_tensor_size
        ):
            test_linear_fusion(
                use_bias,
                pytorch_op,
                scalar,
                add_tensor=torch.rand(tensor_size).to(self.device),
                expect_success=True,
            )

        add_tensor_size.extend([[1, 1, 32], [1, 1, 1]])
        for use_bias, pytorch_op, scalar, tensor_size in itertools.product(
            linear_bias, ops, use_scalar, add_tensor_size
        ):
            test_linear_fusion(
                use_bias,
                pytorch_op,
                scalar,
                add_tensor=torch.rand(tensor_size).to(self.device),
                expect_success=True,
                input_3d=True,
            )

        # In the following test, the shape of 'add_tensor' does not satisfy
        # the requirements of binary folding, so it will not be folded.
        for use_bias, pytorch_op in itertools.product(linear_bias, ops):
            test_linear_fusion(
                use_bias,
                pytorch_op,
                False,
                add_tensor=torch.rand(
                    4,
                    32,
                ).to(self.device),
                expect_success=False,
            )

            test_linear_fusion(
                use_bias,
                pytorch_op,
                False,
                add_tensor=torch.rand(
                    4,
                    1,
                ).to(self.device),
                expect_success=False,
            )


if HAS_CPU and not torch.backends.mps.is_available():

    class FreezingCpuTests(TestCase):
        common = check_model
        device = "cpu"
        autocast = torch.cpu.amp.autocast

    copy_tests(BinaryFoldingTemplate, FreezingCpuTests, "cpu")

if HAS_GPU and not TEST_WITH_ASAN:

    class FreezingGpuTests(TestCase):
        common = check_model_gpu
        device = GPU_TYPE
        autocast = torch.amp.autocast(device_type=GPU_TYPE)

    copy_tests(BinaryFoldingTemplate, FreezingGpuTests, GPU_TYPE)


del BinaryFoldingTemplate

if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    if HAS_CPU or HAS_GPU:
        run_tests(needs="filelock")