File: diagnostics.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (261 lines) | stat: -rw-r--r-- 8,935 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# mypy: allow-untyped-defs
from __future__ import annotations

import dataclasses
import functools
from typing import Any, TYPE_CHECKING

import onnxscript  # type: ignore[import]
from onnxscript.function_libs.torch_lib import graph_building  # type: ignore[import]

import torch
import torch.fx
from torch.onnx._internal import diagnostics
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import decorator, formatter
from torch.onnx._internal.fx import registration, type_utils as fx_type_utils


if TYPE_CHECKING:
    import logging

# NOTE: The following limits are for the number of items to display in diagnostics for
# a list, tuple or dict. The limit is picked such that common useful scenarios such as
# operator arguments are covered, while preventing excessive processing loads on considerably
# large containers such as the dictionary mapping from fx to onnx nodes.
_CONTAINER_ITEM_LIMIT: int = 10

# NOTE(bowbao): This is a shim over `torch.onnx._internal.diagnostics`, which is
# used in `torch.onnx`, and loaded with `torch`. Hence anything related to `onnxscript`
# cannot be put there.

# [NOTE: `dynamo_export` diagnostics logging]
# The 'dynamo_export' diagnostics leverages the PT2 artifact logger to handle the verbosity
# level of logs that are recorded in each SARIF log diagnostic. In addition to SARIF log,
# terminal logging is by default disabled. Terminal logging can be activated by setting
# the environment variable `TORCH_LOGS="onnx_diagnostics"`. When the environment variable
# is set, it also fixes logging level to `logging.DEBUG`, overriding the verbosity level
# specified in the diagnostic options.
# See `torch/_logging/__init__.py` for more on PT2 logging.
_ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME = "onnx_diagnostics"
diagnostic_logger = torch._logging.getArtifactLogger(
    "torch.onnx", _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME
)


def is_onnx_diagnostics_log_artifact_enabled() -> bool:
    return torch._logging._internal.log_state.is_artifact_enabled(
        _ONNX_DIAGNOSTICS_ARTIFACT_LOGGER_NAME
    )


@functools.singledispatch
def _format_argument(obj: Any) -> str:
    return formatter.format_argument(obj)


def format_argument(obj: Any) -> str:
    formatter = _format_argument.dispatch(type(obj))
    return formatter(obj)


# NOTE: EDITING BELOW? READ THIS FIRST!
#
# The below functions register the `format_argument` function for different types via
# `functools.singledispatch` registry. These are invoked by the diagnostics system
# when recording function arguments and return values as part of a diagnostic.
# Hence, code with heavy workload should be avoided. Things to avoid for example:
# `torch.fx.GraphModule.print_readable()`.


@_format_argument.register
def _torch_nn_module(obj: torch.nn.Module) -> str:
    return f"torch.nn.Module({obj.__class__.__name__})"


@_format_argument.register
def _torch_fx_graph_module(obj: torch.fx.GraphModule) -> str:
    return f"torch.fx.GraphModule({obj.__class__.__name__})"


@_format_argument.register
def _torch_fx_node(obj: torch.fx.Node) -> str:
    node_string = f"fx.Node({obj.target})[{obj.op}]:"
    if "val" not in obj.meta:
        return node_string + "None"
    return node_string + format_argument(obj.meta["val"])


@_format_argument.register
def _torch_fx_symbolic_bool(obj: torch.SymBool) -> str:
    return f"SymBool({obj})"


@_format_argument.register
def _torch_fx_symbolic_int(obj: torch.SymInt) -> str:
    return f"SymInt({obj})"


@_format_argument.register
def _torch_fx_symbolic_float(obj: torch.SymFloat) -> str:
    return f"SymFloat({obj})"


@_format_argument.register
def _torch_tensor(obj: torch.Tensor) -> str:
    return f"Tensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})"


@_format_argument.register
def _int(obj: int) -> str:
    return str(obj)


@_format_argument.register
def _float(obj: float) -> str:
    return str(obj)


@_format_argument.register
def _bool(obj: bool) -> str:
    return str(obj)


@_format_argument.register
def _str(obj: str) -> str:
    return obj


@_format_argument.register
def _registration_onnx_function(obj: registration.ONNXFunction) -> str:
    # TODO: Compact display of `param_schema`.
    return f"registration.ONNXFunction({obj.op_full_name}, is_custom={obj.is_custom}, is_complex={obj.is_complex})"


@_format_argument.register
def _list(obj: list) -> str:
    list_string = f"List[length={len(obj)}](\n"
    if not obj:
        return list_string + "None)"
    for i, item in enumerate(obj):
        if i >= _CONTAINER_ITEM_LIMIT:
            # NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
            list_string += "...,\n"
            break
        list_string += f"{format_argument(item)},\n"
    return list_string + ")"


@_format_argument.register
def _tuple(obj: tuple) -> str:
    tuple_string = f"Tuple[length={len(obj)}](\n"
    if not obj:
        return tuple_string + "None)"
    for i, item in enumerate(obj):
        if i >= _CONTAINER_ITEM_LIMIT:
            # NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
            tuple_string += "...,\n"
            break
        tuple_string += f"{format_argument(item)},\n"
    return tuple_string + ")"


@_format_argument.register
def _dict(obj: dict) -> str:
    dict_string = f"Dict[length={len(obj)}](\n"
    if not obj:
        return dict_string + "None)"
    for i, (key, value) in enumerate(obj.items()):
        if i >= _CONTAINER_ITEM_LIMIT:
            # NOTE: Print only first _CONTAINER_ITEM_LIMIT items.
            dict_string += "...\n"
            break
        dict_string += f"{key}: {format_argument(value)},\n"
    return dict_string + ")"


@_format_argument.register
def _torch_nn_parameter(obj: torch.nn.Parameter) -> str:
    return f"Parameter({format_argument(obj.data)})"


@_format_argument.register
def _onnxscript_torch_script_tensor(obj: graph_building.TorchScriptTensor) -> str:
    return f"`TorchScriptTensor({fx_type_utils.from_torch_dtype_to_abbr(obj.dtype)}{_stringify_shape(obj.shape)})`"  # type: ignore[arg-type]  # noqa: B950


@_format_argument.register
def _onnxscript_onnx_function(obj: onnxscript.OnnxFunction) -> str:
    return f"`OnnxFunction({obj.name})`"


@_format_argument.register
def _onnxscript_traced_onnx_function(obj: onnxscript.TracedOnnxFunction) -> str:
    return f"`TracedOnnxFunction({obj.name})`"


# from torch/fx/graph.py to follow torch format
def _stringify_shape(shape: torch.Size | None) -> str:
    if shape is None:
        return ""
    return f"[{', '.join(str(x) for x in shape)}]"


rules = diagnostics.rules
levels = diagnostics.levels
RuntimeErrorWithDiagnostic = infra.RuntimeErrorWithDiagnostic
LazyString = formatter.LazyString
DiagnosticOptions = infra.DiagnosticOptions


@dataclasses.dataclass
class Diagnostic(infra.Diagnostic):
    logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)

    def log(self, level: int, message: str, *args, **kwargs) -> None:
        if self.logger.isEnabledFor(level):
            formatted_message = message % args
            if is_onnx_diagnostics_log_artifact_enabled():
                # Only log to terminal if artifact is enabled.
                # See [NOTE: `dynamo_export` diagnostics logging] for details.
                self.logger.log(level, formatted_message, **kwargs)

            self.additional_messages.append(formatted_message)


@dataclasses.dataclass
class DiagnosticContext(infra.DiagnosticContext[Diagnostic]):
    logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
    _bound_diagnostic_type: type[Diagnostic] = dataclasses.field(
        init=False, default=Diagnostic
    )

    def __enter__(self):
        self._previous_log_level = self.logger.level
        # Adjust the logger level based on `options.verbosity_level` and the environment
        # variable `TORCH_LOGS`. See [NOTE: `dynamo_export` diagnostics logging] for details.
        if not is_onnx_diagnostics_log_artifact_enabled():
            return super().__enter__()
        else:
            return self


diagnose_call = functools.partial(
    decorator.diagnose_call,
    diagnostic_type=Diagnostic,
    format_argument=format_argument,
)


@dataclasses.dataclass
class UnsupportedFxNodeDiagnostic(Diagnostic):
    unsupported_fx_node: torch.fx.Node | None = None

    def __post_init__(self) -> None:
        super().__post_init__()
        # NOTE: This is a hack to make sure that the additional fields must be set and
        # not None. Ideally they should not be set as optional. But this is a known
        # limitation with `dataclasses`. Resolvable in Python 3.10 with `kw_only=True`.
        # https://stackoverflow.com/questions/69711886/python-dataclasses-inheritance-and-default-values
        if self.unsupported_fx_node is None:
            raise ValueError("unsupported_fx_node must be specified.")