File: test_minifier.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (337 lines) | stat: -rw-r--r-- 11,775 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import patch

import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
from torch._dynamo.test_minifier_common import MinifierTestBase
from torch._inductor import config
from torch.export import load as export_load
from torch.testing._internal.common_utils import (
    IS_JETSON,
    IS_MACOS,
    skipIfXpu,
    TEST_WITH_ASAN,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu


class MinifierTests(MinifierTestBase):
    # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA)
    def _test_after_aot(self, device, expected_error):
        # NB: The program is intentionally quite simple, just enough to
        # trigger one minification step, no more (dedicated minifier tests
        # should exercise minifier only)
        run_code = f"""\
@torch.compile()
def inner(x):
    x = torch.relu(x)
    x = torch.cos(x)
    return x

inner(torch.randn(20, 20).to("{device}"))
"""
        self._run_full_test(run_code, "aot", expected_error, isolate=False)

    @unittest.skipIf(IS_JETSON, "Fails on Jetson")
    @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "compile_error")
    def test_after_aot_cpu_compile_error(self):
        self._test_after_aot("cpu", "CppCompileError")

    @unittest.skipIf(IS_JETSON, "Fails on Jetson")
    @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
    def test_after_aot_cpu_accuracy_error(self):
        self._test_after_aot("cpu", "AccuracyError")

    @requires_gpu
    @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "compile_error")
    def test_after_aot_gpu_compile_error(self):
        self._test_after_aot(GPU_TYPE, "SyntaxError")

    @requires_gpu
    @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy")
    def test_after_aot_gpu_accuracy_error(self):
        self._test_after_aot(GPU_TYPE, "AccuracyError")

    @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
    def test_constant_in_graph(self):
        run_code = """\
@torch.compile()
def inner(x):
    return torch.tensor(2) + torch.relu(x)

inner(torch.randn(2))
"""
        self._run_full_test(run_code, "aot", "AccuracyError", isolate=False)

    @requires_gpu
    @patch.object(config, "joint_graph_constant_folding", False)
    def test_rmse_improves_over_atol(self):
        # From https://twitter.com/itsclivetime/status/1651135821045719041?s=20
        run_code = """
@torch.compile()
def inner(x):
    return x - torch.tensor(655, dtype=torch.half, device='GPU_TYPE') * 100

inner(torch.tensor(655 * 100, dtype=torch.half, device='GPU_TYPE'))
""".replace(
            "GPU_TYPE", GPU_TYPE
        )

        # If we disable RMSE against fp64, this triggers accuracy error,
        # as the increased precision from torch.compile changes the result
        # of 655 * 100
        with dynamo_config.patch("same_two_models_use_fp64", False):
            self._run_full_test(
                run_code,
                "aot",
                "AccuracyError",
                isolate=False,
                # NB: need this to avoid refusing to minify when fp64 doesn't work
                # (which it doesn't, due to the config patch above)
                minifier_args=["--strict-accuracy"],
            )

        # But using fp64, we see that the intended semantics is the increased
        # 655 * 100 precision, and so we report no problem
        self._run_full_test(run_code, "aot", None, isolate=False)

    @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
    @inductor_config.patch("cpp.inject_log1p_bug_TESTING_ONLY", "accuracy")
    def test_accuracy_vs_strict_accuracy(self):
        run_code = """
@torch.compile()
def inner(x):
    y = torch.log1p(x)
    b = y > 0
    # Need to ensure suffix removal hits a boolean output
    b = torch.logical_not(b)
    b = torch.logical_not(b)
    x = torch.relu(x)
    return torch.where(b, x, x)

inner(torch.randn(20))
"""

        # Strict accuracy gets hung up on the boolean mask difference, which
        # will localize the error to sigmoid, even though it doesn't actually
        # matter to the end result
        res = self._run_full_test(
            run_code,
            "aot",
            "AccuracyError",
            isolate=False,
            minifier_args=["--strict-accuracy"],
        )
        self.assertExpectedInline(
            res.repro_module(),
            """\
class Repro(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, arg0_1):
        log1p = torch.ops.aten.log1p.default(arg0_1);  arg0_1 = None
        return (log1p,)""",
        )

        # FP accuracy will refuse to promote the logical_not on the outputs,
        # and so you'll get to the relu (unless the minifier somehow tries
        # removing entire suffix except the log1p first!)
        res = self._run_full_test(run_code, "aot", "AccuracyError", isolate=False)
        self.assertExpectedInline(
            res.repro_module(),
            """\
class Repro(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, arg0_1):
        relu = torch.ops.aten.relu.default(arg0_1);  arg0_1 = None
        return (relu,)""",
        )

    @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy")
    def test_offload_to_disk(self):
        # Just a smoketest, this doesn't actually test that memory
        # usage went down.  Test case is carefully constructed to hit
        # delta debugging.
        run_code = """\
@torch.compile()
def inner(x):
    x = torch.sin(x)
    x = torch.sin(x)
    x = torch.cos(x)
    x = torch.relu(x)
    return x

inner(torch.randn(20, 20))
"""
        self._run_full_test(
            run_code,
            "aot",
            "AccuracyError",
            isolate=False,
            minifier_args=["--offload-to-disk"],
        )

    # Test that compile errors in AOTInductor can be repro'd (both CPU and CUDA)
    def _test_aoti(self, device, expected_error):
        # NB: The program is intentionally quite simple, just enough to
        # trigger one minification step, no more (dedicated minifier tests
        # should exercise minifier only)
        run_code = f"""\
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 16)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.sigmoid(x)
        return x
with torch.no_grad():
    model = Model().to("{device}")
    example_inputs = (torch.randn(8, 10).to("{device}"),)
    ep = torch.export.export(
        model, example_inputs
    )
    torch._inductor.aoti_compile_and_package(
        ep
    )
"""
        return self._run_full_test(run_code, None, expected_error, isolate=True)

    # Test that compile errors in AOTInductor can be repro'd (both CPU and CUDA)
    def _test_aoti_unflattened_inputs(self, device, expected_error):
        # NB: The program is intentionally quite simple, just enough to
        # trigger one minification step, no more (dedicated minifier tests
        # should exercise minifier only)

        # It tests that the minifier can handle unflattened inputs and kwargs
        run_code = f"""\
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 16)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, inp, *, k):
        x = inp["x"]
        y = inp["y"]
        x = self.fc1(x)
        y = self.fc1(y)
        k = self.fc1(k)
        x = self.relu(x)
        x = self.sigmoid(x)
        return x + y + k

with torch.no_grad():
    model = Model().to("{device}")
    val = torch.randn(8, 10).to("{device}")
    example_inputs = ({{"x": val.clone(), "y": val.clone()}},)
    kwargs = {{"k": val.clone()}}
    ep = torch.export.export(
        model, example_inputs, kwargs
    )
    torch._inductor.aoti_compile_and_package(
        ep, example_inputs, kwargs
    )
"""
        return self._run_full_test(run_code, None, expected_error, isolate=True)

    @unittest.skipIf(IS_JETSON, "Fails on Jetson")
    @inductor_config.patch(
        {
            "cpp.inject_relu_bug_TESTING_ONLY": "compile_error",
            "aot_inductor.dump_aoti_minifier": True,
        }
    )
    def test_aoti_cpu_compile_error(self):
        res = self._test_aoti("cpu", "CppCompileError")
        ep_file_path = res.get_exported_program_path()
        gm = export_load(ep_file_path).module()
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, linear):
    linear, = fx_pytree.tree_flatten_spec(([linear], {}), self._in_spec)
    relu = torch.ops.aten.relu.default(linear);  linear = None
    return pytree.tree_unflatten((relu,), self._out_spec)""",
        )

    @unittest.skipIf(IS_JETSON, "Fails on Jetson")
    @inductor_config.patch(
        {
            "cpp.inject_relu_bug_TESTING_ONLY": "compile_error",
            "aot_inductor.dump_aoti_minifier": True,
        }
    )
    def test_aoti_cpu_compile_error_unflatten(self):
        res = self._test_aoti_unflattened_inputs("cpu", "CppCompileError")
        ep_file_path = res.get_exported_program_path()
        gm = export_load(ep_file_path).module()
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, linear):
    linear, = fx_pytree.tree_flatten_spec(([linear], {}), self._in_spec)
    relu = torch.ops.aten.relu.default(linear);  linear = None
    return pytree.tree_unflatten((relu,), self._out_spec)""",
        )

    @requires_gpu
    @skipIfXpu(msg="AOTI for XPU not enabled yet")
    @inductor_config.patch(
        {
            "triton.inject_relu_bug_TESTING_ONLY": "compile_error",
            "aot_inductor.dump_aoti_minifier": True,
        }
    )
    def test_aoti_gpu_compile_error(self):
        res = self._test_aoti(GPU_TYPE, "SyntaxError")
        ep_file_path = res.get_exported_program_path()
        gm = export_load(ep_file_path).module()
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, linear):
    linear, = fx_pytree.tree_flatten_spec(([linear], {}), self._in_spec)
    relu = torch.ops.aten.relu.default(linear);  linear = None
    return pytree.tree_unflatten((relu,), self._out_spec)""",
        )

    @requires_gpu
    @skipIfXpu(msg="AOTI for XPU not enabled yet")
    @inductor_config.patch(
        {
            "triton.inject_relu_bug_TESTING_ONLY": "compile_error",
            "aot_inductor.dump_aoti_minifier": True,
        }
    )
    def test_aoti_gpu_compile_error_unflatten(self):
        res = self._test_aoti_unflattened_inputs(GPU_TYPE, "SyntaxError")
        ep_file_path = res.get_exported_program_path()
        gm = export_load(ep_file_path).module()
        self.assertExpectedInline(
            str(gm.code).strip(),
            """\
def forward(self, linear):
    linear, = fx_pytree.tree_flatten_spec(([linear], {}), self._in_spec)
    relu = torch.ops.aten.relu.default(linear);  linear = None
    return pytree.tree_unflatten((relu,), self._out_spec)""",
        )


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    # Skip CI tests on mac since CPU inductor does not seem to work due to C++ compile errors,
    # also skip on ASAN due to https://github.com/pytorch/pytorch/issues/98262
    if not IS_MACOS and not TEST_WITH_ASAN:
        run_tests()