File: _draft_export.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (308 lines) | stat: -rw-r--r-- 11,998 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import inspect
import logging
import os
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch._logging._internal
import torch._logging.structured
from torch._export.passes.insert_custom_op_guards import insert_custom_op_guards
from torch.export import ExportedProgram
from torch.export._trace import _export
from torch.export.dynamic_shapes import refine_dynamic_shapes_from_suggested_fixes


log = logging.getLogger(__name__)


class FailureType(IntEnum):
    MISSING_FAKE_KERNEL = 1
    DATA_DEPENDENT_ERROR = 2
    CONSTRAINT_VIOLATION_ERROR = 3
    MISMATCHED_FAKE_KERNEL = 4

    def __str__(self) -> str:
        return self.name


def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str]) -> str:
    res = ""
    for frame in stack:
        if frame["filename"] not in str_to_filename:
            continue

        res += f"""
        File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}"""
    return res


def filter_stack(
    stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
) -> List[Dict[str, str]]:
    for i, s in enumerate(reversed(stack)):
        s["filename"] = str(s["filename"])
        if s["filename"] not in str_to_filename:
            continue
        torch_filepath = os.path.dirname(inspect.getfile(torch)) + os.path.sep
        if torch_filepath not in str_to_filename[s["filename"]]:
            return stack[len(stack) - i - 3 : len(stack) - i]
    return stack[-3:]


def hash_stack(stack: List[Dict[str, str]]) -> str:
    return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)


class FailureReport:
    def __init__(
        self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
    ) -> None:
        self.failure_type: FailureType = failure_type
        self.data: Dict[str, Any] = data
        self.xfail: bool = xfail

    def __repr__(self) -> str:
        return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"

    def print(self, str_to_filename: Dict[str, str]) -> str:
        if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
            op = self.data["op"]

            return f"""Missing fake kernel.
    torch.ops.{op} is missing a fake kernel implementation.

    Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation.
"""  # noqa: B950

        elif self.failure_type == FailureType.CONSTRAINT_VIOLATION_ERROR:
            return f"""Constraint violation error.
    The specified input dynamic_shapes spec was found to be incorrect during tracing.
    Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}.
    This occured at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}.
    Because of this, we have modified the dynamic shapes structure to be the
    following. You can also use torch.export.Dim.AUTO instead to specify your
    dynamic shapes, and we will automatically infer the dynamism for you.
    ```
    dynamic_shapes = {self.data["new_dynamic_shapes"]}
    ```
"""

        elif self.failure_type == FailureType.DATA_DEPENDENT_ERROR:
            return f"""Data dependent error.
    When exporting, we were unable to figure out if the expression `{self.data["expr"]}` always holds.
    This occurred at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}.
    As a result, it was specialized to evaluate to `{self.data["result"]}`, and asserts were inserted into the graph.

    Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
    Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
"""  # noqa: B950

        elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL:
            op = self.data["op"]
            reason = self.data["reason"]
            return f"""Mismatched fake kernel.
    torch.ops.{op} has a fake kernel implementation, but it has incorrect behavior, based on the real kernel.
    The reason for the mismatch is: {reason}.

    Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation.
"""  # noqa: B950

        else:
            raise ValueError(f"Unknown failure type: {self.failure_type}")


class DraftExportReport:
    def __init__(self, failures: List[FailureReport], str_to_filename: Dict[str, str]):
        self.failures: List[FailureReport] = failures
        self.str_to_filename = str_to_filename

    def successful(self) -> bool:
        return len(self.failures) == 0 or all(
            failure.xfail for failure in self.failures
        )

    def __repr__(self) -> str:
        return f"DraftExportReport({self.failures})"

    def __str__(self) -> str:
        WARNING_COLOR = "\033[93m"
        GREEN_COLOR = "\033[92m"
        END_COLOR = "\033[0m"

        if self.successful():
            return f"""{GREEN_COLOR}
##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################
{END_COLOR}"""

        error = f"""{WARNING_COLOR}
###################################################################################################
WARNING: {len(self.failures)} issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################

"""

        for i, failure in enumerate(self.failures):
            error += f"{i + 1}. {failure.print(self.str_to_filename)}\n"
        error += END_COLOR
        return error

    def apply_suggested_fixes(self) -> None:
        raise NotImplementedError("Not implemented yet")


class CaptureStructuredTrace(logging.Handler):
    def __init__(self, specific_log_keys: List[str]):
        super().__init__()
        self.specific_log_keys = specific_log_keys
        self.logs: List[Tuple[str, Dict[str, Any]]] = []
        self.logger = logging.getLogger("torch.__trace")
        self.prev_get_dtrace = False

    def __enter__(self) -> "CaptureStructuredTrace":
        self.logs = []
        self.logger.addHandler(self)
        self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
        torch._logging._internal.GET_DTRACE_STRUCTURED = True
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:  # type: ignore[no-untyped-def]
        self.logs = []
        self.logger.removeHandler(self)
        torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
        self.prev_get_dtrace = False

    def emit(self, record: Any) -> None:
        metadata = record.metadata
        for key in self.specific_log_keys:
            if key in metadata:
                self.logs.append((key, metadata[key]))


def draft_export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
    preserve_module_call_signature: Tuple[str, ...] = (),
    strict: bool = False,
    pre_dispatch: bool = False,
) -> Tuple[ExportedProgram, DraftExportReport]:
    kwargs = kwargs or {}
    dynamic_shapes = dynamic_shapes or {}

    capture_structured_log = CaptureStructuredTrace(
        [
            "propagate_real_tensors",
            "guard_added",
            "missing_fake_kernel",
            "mismatched_fake_kernel",
        ]
    )

    with torch._functorch.config.patch(
        fake_tensor_propagate_real_tensors=True,
        generate_fake_kernels_from_real_mismatches=True,
    ), capture_structured_log:
        try:
            new_shapes = None
            ep = _export(
                mod,
                args,
                kwargs,
                dynamic_shapes=dynamic_shapes,
                strict=strict,
                pre_dispatch=pre_dispatch,
                preserve_module_call_signature=preserve_module_call_signature,
            )
        except torch._dynamo.exc.UserError as exc:
            new_shapes = refine_dynamic_shapes_from_suggested_fixes(
                exc.msg, dynamic_shapes
            )
            ep = _export(
                mod,
                args,
                kwargs,
                dynamic_shapes=new_shapes,
                strict=strict,
                pre_dispatch=pre_dispatch,
                preserve_module_call_signature=preserve_module_call_signature,
            )

        str_to_filename: Dict[str, str] = {
            str(v): k for (k, v) in torch._logging.structured.INTERN_TABLE.items()
        }
        failures: List[FailureReport] = []
        custom_ops_logs: Dict[
            Any, Tuple[Dict[str, Any], FailureType]
        ] = {}  # Dedup custom ops
        data_dependent_logs: Dict[
            str, Dict[str, Any]
        ] = {}  # Dedup data dependent errors based on stacktrace

        for log_name, log_contents in capture_structured_log.logs:
            failure_type = None

            if log_name == "propagate_real_tensors":
                log_contents["stack"] = filter_stack(
                    log_contents["stack"], str_to_filename
                )
                if hash_stack(log_contents["stack"]) in data_dependent_logs:
                    continue

                data_dependent_logs[hash_stack(log_contents["stack"])] = log_contents
                failure_type = FailureType.DATA_DEPENDENT_ERROR

            elif log_name == "guard_added":
                if new_shapes is None:
                    continue

                failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
                if len(log_contents["symbol_to_sources"]) == 0:
                    # We only want to include guards added that are relevant to
                    # the symbolic shapes corresponding to the inputs which were
                    # specified in the dynamic_shapes arg. These have a source.
                    continue

                log_contents["stack"] = filter_stack(
                    log_contents["stack"], str_to_filename
                )
                log_contents["new_dynamic_shapes"] = new_shapes
            elif log_name == "missing_fake_kernel":
                if log_contents["op"] in custom_ops_logs:
                    continue
                failure_type = FailureType.MISSING_FAKE_KERNEL
                custom_ops_logs[log_contents["op"]] = (log_contents, failure_type)
            elif log_name == "mismatched_fake_kernel":
                if (log_contents["op"], log_contents["reason"]) in custom_ops_logs:
                    continue
                failure_type = FailureType.MISMATCHED_FAKE_KERNEL
                custom_ops_logs[(log_contents["op"], log_contents["reason"])] = (
                    log_contents,
                    failure_type,
                )
            else:
                raise RuntimeError(f"Unknown log name: {log_name}")

            assert failure_type is not None
            failures.append(
                FailureReport(
                    failure_type,
                    log_contents,
                )
            )

        report = DraftExportReport(failures, str_to_filename)

        # Add asserts around custom ops
        insert_custom_op_guards(ep.graph_module, list(custom_ops_logs.keys()))

    ep._report = report
    if not report.successful():
        log.warning(report)
    return ep, report