File: test_exc.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (354 lines) | stat: -rw-r--r-- 10,034 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# Owner(s): ["module: dynamo"]

import logging
import unittest

import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
from torch._dynamo.comptime import comptime
from torch._dynamo.exc import Unsupported
from torch.testing._internal.common_device_type import skipIf
from torch.testing._internal.common_utils import (
    IS_FBCODE,
    munge_exc,
    skipIfWindows,
    TEST_Z3,
)
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test


class ExcTests(LoggingTestCase):
    maxDiff = None

    def test_unsupported_real_stack(self):
        # exercise Unsupported constructor and augment_exc_message
        def fn002(x):
            torch._dynamo.graph_break()

        def fn001(x):
            x = x + 1
            fn002(x)

        self.assertExpectedInlineMunged(
            Unsupported,
            lambda: torch.compile(fn001, backend="eager", fullgraph=True)(
                torch.randn(1)
            ),
            """\
'skip function graph_break in file _dynamo/decorators.py'

from user code:
   File "test_exc.py", line N, in fn001
    fn002(x)
  File "test_exc.py", line N, in fn002
    torch._dynamo.graph_break()""",
        )

    @torch._dynamo.config.patch(verbose=True, suppress_errors=True)
    @make_logging_test()
    @unittest.skipIf(IS_FBCODE, "stack trace slightly different in fbcode")
    def test_internal_error_suppress_errors(self, records):
        def fn001(x):
            def f(ctx):
                raise AssertionError

            comptime(f)

        torch.compile(fn001, backend="eager")(torch.randn(1))

        record = self.getRecord(records, "WON'T CONVERT")

        self.assertExpectedInline(
            munge_exc(record.getMessage()),
            """\
WON'T CONVERT fn001 test_exc.py line N
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
  File "test_exc.py", line N, in f
    raise AssertionError
AssertionError:

from user code:
   File "test_exc.py", line N, in fn001
    comptime(f)


========== The above exception occurred while processing the following code ==========

  File "test_exc.py", line N, in test_internal_error_suppress_errors
    torch.compile(fn001, backend="eager")(torch.randn(1))
  File "test_exc.py", line N, in fn001
    comptime(f)

==========""",
        )

    @make_logging_test()
    def test_not_implemented_error(self, records):
        def fn001(x):
            def f(ctx):
                raise NotImplementedError

            # Ensure graph break is not possible
            for i in range(3):
                comptime(f)

        torch.compile(fn001, backend="eager")(torch.randn(1))

        record = self.getRecord(records, "WON'T CONVERT")

        self.assertExpectedInline(
            munge_exc(record.getMessage()),
            """\
WON'T CONVERT fn001 test_exc.py line N
due to:
Traceback (most recent call last):
  File "test_exc.py", line N, in f
    raise NotImplementedError
torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError:

from user code:
   File "test_exc.py", line N, in fn001
    comptime(f)""",
        )

    @torch._dynamo.config.patch(inject_BUILD_SET_unimplemented_TESTING_ONLY=True)
    @make_logging_test(dynamo=logging.DEBUG)
    def test_unsupported_error(self, records):
        def fn001(x):
            return {1, 2}

        torch.compile(fn001, backend="eager")(torch.randn(1))

        # TODO: There is no graph break log!  This is because the graph break
        # logging is not in a centralized location; unsupported
        # instruction bypasses it
        self.getRecord(records, "Graph break:")

    @torch._dynamo.config.patch(suppress_errors=False)
    def test_internal_error_no_suppress(self):
        def fn001(x):
            # NB: avoid decorator, as 3.11 changed the line number attributed
            # in this situation
            def f(ctx):
                raise AssertionError

            comptime(f)

        # NB: OK for user code to be truncated here, because the regular
        # exception backtrace has the rest of the crumbs
        self.assertExpectedInlineMunged(
            AssertionError,
            lambda: torch.compile(fn001, backend="eager")(torch.randn(1)),
            """\


from user code:
   File "test_exc.py", line N, in fn001
    comptime(f)""",
        )

    @make_logging_test(graph_breaks=True)
    def test_graph_break_log(self, records):
        def fn002(x):
            x = x + 1
            torch._dynamo.graph_break()
            x = x + 1
            return x

        def fn001(x):
            return fn002(x)

        torch.compile(fn001, backend="eager")(torch.randn(1))

        record = self.getRecord(records, "Graph break in user code")

        # TODO: This should also report the enclosing frames; need to plumb
        # frame object to it
        self.assertExpectedInline(
            munge_exc(record.getMessage()),
            """\
Graph break in user code at test_exc.py:N
Reason: Unsupported: 'skip function graph_break in file _dynamo/decorators.py'
User code traceback:
  File "test_exc.py", line N, in fn001
    return fn002(x)
  File "test_exc.py", line N, in fn002
    torch._dynamo.graph_break()
""",  # noqa: B950
        )

    @make_logging_test(graph_breaks=True)
    def test_graph_break_log_generic_jump(self, records):
        def fn(x):
            if x.sum() > 0:
                return x + 1
            else:
                return x - 1

        torch.compile(fn, backend="eager")(torch.ones(3, 3))

        # check for record existence
        self.getRecord(records, "Graph break in user code")

    @torch._dynamo.config.patch(suppress_errors=False)
    def test_backend_suppress_line(self):
        def fn001(x):
            x = torch.relu(x)
            return x + 1

        # Do NOT let this get attributed to x + 1
        self.assertExpectedInlineMunged(
            torch._dynamo.exc.BackendCompilerFailed,
            lambda: torch.compile(fn001, backend="relu_compile_error_TESTING_ONLY")(
                torch.randn(1)
            ),
            """\
backend='relu_compile_error_TESTING_ONLY' raised:
ReluCompileError:""",
        )

    @skipIf(not TEST_Z3, "z3 not installed")
    @torch._dynamo.config.patch(
        assume_static_by_default=False,
        suppress_errors=False,
    )
    @torch.fx.experimental._config.patch(
        inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
        translation_validation=True,
        translation_validation_no_bisect=True,
    )
    @skipIfWindows(
        msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n  ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])'  # noqa: PLR0133
        != 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n  ==> (<= (+ s1 s2) [483 chars][0])"'
    )
    def test_trigger_on_error(self):
        from torch.fx.experimental.validator import ValidationException

        @torch.compile
        def fn(x, shape):
            return x.split(shape)

        self.assertExpectedInlineMunged(
            ValidationException,
            lambda: fn(torch.randn(20), (5, 10, 5)),
            """\
translation validation failed.

Model:
  ==> L['shape'][0]: 0
  ==> L['shape'][1]: 1
  ==> L['shape'][2]: 1
  ==> L['x'].size()[0]: 3
  ==> L['x'].storage_offset(): 0
  ==> L['x'].stride()[0]: 1
  ==> s0: 3
  ==> s1: 0
  ==> s2: 1
  ==> s3: 1

Assertions:
  ==> (== 0 L['x'].storage_offset())
  ==> (== 1 L['x'].stride()[0])
  ==> (== L['shape'][0] s1)
  ==> (== L['shape'][1] s2)
  ==> (== L['shape'][2] s3)
  ==> (== L['x'].size()[0] s0)
  ==> (> s0 1)
  ==> (True)

Target Expressions:
  ==> (!= (+ s1 s2 s3) s0)
  ==> (<= (+ s1 s2 s3) s0)
  ==> (<= (+ s1 s2) (+ s0 (* -1 s3)))
  ==> (<= (+ s1 s2) s0)
  ==> (<= 0 s1)
  ==> (<= 0 s2)
  ==> (<= 0 s3)
  ==> (<= 2 s0)
  ==> (<= s1 (+ s0 (* -1 s2)))
  ==> (== 0 L['x'].storage_offset())
  ==> (== 1 L['x'].stride()[0])
  ==> (== L['shape'][0] s1)
  ==> (== L['shape'][1] s2)
  ==> (== L['shape'][2] s3)
  ==> (== L['x'].size()[0] s0)
  ==> (> s0 0)
  ==> (>= 0 s1)
  ==> (And (<= (+ s1 s2) s0) (<= (* -1 s0) (+ s1 s2)))

Failed Source Expressions:
  ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
        )

    @skipIf(not TEST_Z3, "z3 not installed")
    @torch._dynamo.config.patch(
        assume_static_by_default=False,
        suppress_errors=False,
    )
    @torch.fx.experimental._config.patch(
        inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
        translation_validation=True,
    )
    def test_trigger_bisect_on_error(self):
        from torch.fx.experimental.validator import BisectValidationException

        @torch.compile
        def fn(x, shape):
            return x.split(shape)

        self.assertExpectedInlineMunged(
            BisectValidationException,
            lambda: fn(torch.randn(20), (5, 10, 5)),
            """\
translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)

Failure occurred while running node:
    %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})

Model:
  ==> L['shape'][0]: 1
  ==> L['shape'][1]: 1
  ==> L['shape'][2]: 0
  ==> L['x'].size()[0]: 3
  ==> L['x'].storage_offset(): 0
  ==> L['x'].stride()[0]: 1
  ==> s0: 3
  ==> s1: 1
  ==> s2: 1
  ==> s3: 0

Assertions:
  ==> (== 0 L['x'].storage_offset())
  ==> (== 1 L['x'].stride()[0])
  ==> (== L['shape'][0] s1)
  ==> (== L['shape'][1] s2)
  ==> (== L['shape'][2] s3)
  ==> (== L['x'].size()[0] s0)
  ==> (> s0 1)

Target Expressions:
  ==> (!= (+ s1 s2 s3) s0)
  ==> (<= 0 s1)
  ==> (<= 0 s2)
  ==> (<= 0 s3)
  ==> (<= 2 s0)
  ==> (== 0 L['x'].storage_offset())
  ==> (== 1 L['x'].stride()[0])
  ==> (== L['shape'][0] s1)
  ==> (== L['shape'][1] s2)
  ==> (== L['shape'][2] s3)
  ==> (== L['x'].size()[0] s0)
  ==> (> s0 0)

Failed Source Expressions:
  ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
        )


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()