1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
# mypy: allow-untyped-defs
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
from __future__ import annotations
import contextlib
import gzip
from typing import TYPE_CHECKING
import torch
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace
if TYPE_CHECKING:
from collections.abc import Generator
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
"""Returns the current C++ call stack.
This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
The returned C++ call stack is a concatenated string of the C++ call stack frames.
Each frame is separated by a newline character, in the same format of
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
"""
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
frame_messages = []
for frame in frames:
segments = frame.split(":", 1)
if len(segments) == 2:
frame_messages.append(segments[1].strip())
else:
frame_messages.append("<unknown frame>")
return infra.Stack(
frames=[
infra.StackFrame(location=infra.Location(message=message))
for message in frame_messages
]
)
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
infra.Diagnostic, and adds additional methods to add more information to the
diagnostic.
"""
python_call_stack: infra.Stack | None = None
cpp_call_stack: infra.Stack | None = None
def __init__(
self,
*args,
frames_to_skip: int = 1,
cpp_stack: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.python_call_stack = self.record_python_call_stack(
frames_to_skip=frames_to_skip
)
if cpp_stack:
self.cpp_call_stack = self.record_cpp_call_stack(
frames_to_skip=frames_to_skip
)
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current C++ call stack in the diagnostic."""
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
stack.message = "C++ call stack"
self.with_stack(stack)
return stack
class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
The only purpose of creating this class instead of using `DiagnosticContext` directly
is to provide a background context for `diagnose` calls inside exporter.
By design, one `torch.onnx.export` call should initialize one diagnostic context.
All `diagnose` calls inside exporter should be made in the context of that export.
However, since diagnostic context is currently being accessed via a global variable,
there is no guarantee that the context is properly initialized. Therefore, we need
to provide a default background context to fallback to, otherwise any invocation of
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
This can be removed once the pipeline for context to flow through the exporter is
established.
"""
contexts: list[infra.DiagnosticContext]
_background_context: infra.DiagnosticContext
def __init__(self) -> None:
self.contexts = []
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
)
@property
def background_context(self) -> infra.DiagnosticContext:
return self._background_context
def create_diagnostic_context(
self,
name: str,
version: str,
options: infra.DiagnosticOptions | None = None,
) -> infra.DiagnosticContext:
"""Creates a new diagnostic context.
Args:
name: The subject name for the diagnostic context.
version: The subject version for the diagnostic context.
options: The options for the diagnostic context.
Returns:
A new diagnostic context.
"""
if options is None:
options = infra.DiagnosticOptions()
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
name, version, options
)
self.contexts.append(context)
return context
def clear(self):
"""Clears all diagnostic contexts."""
self.contexts.clear()
self._background_context.diagnostics.clear()
def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())
def dump(self, file_path: str, compress: bool = False) -> None:
"""Dumps the SARIF log to a file."""
if compress:
with gzip.open(file_path, "wt") as f:
f.write(self.to_json())
else:
with open(file_path, "w") as f:
f.write(self.to_json())
def sarif_log(self):
log = sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[context.sarif() for context in self.contexts],
)
log.runs.append(self._background_context.sarif())
return log
engine = ExportDiagnosticEngine()
_context = engine.background_context
@contextlib.contextmanager
def create_export_diagnostic_context() -> (
Generator[infra.DiagnosticContext, None, None]
):
"""Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by
export internals via global variable. See `ExportDiagnosticEngine` for more details.
"""
global _context
assert (
_context == engine.background_context
), "Export context is already set. Nested export is not supported."
_context = engine.create_diagnostic_context(
"torch.onnx.export",
torch.__version__,
)
try:
yield _context
finally:
_context = engine.background_context
def diagnose(
rule: infra.Rule,
level: infra.Level,
message: str | None = None,
frames_to_skip: int = 2,
**kwargs,
) -> TorchScriptOnnxExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().log(diagnostic)
return diagnostic
def export_context() -> infra.DiagnosticContext:
global _context
return _context
|