1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
"""Diagnostic components for PyTorch ONNX export."""
import contextlib
from typing import Any, Optional, Tuple, TypeVar
import torch
from torch.onnx._internal.diagnostics import _rules, infra
# This is a workaround for mypy not supporting Self from typing_extensions.
_ExportDiagnostic = TypeVar("_ExportDiagnostic", bound="ExportDiagnostic")
class ExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
infra.Diagnostic, and adds additional methods to add more information to the
diagnostic.
"""
def __init__(
self,
rule: infra.Rule,
level: infra.Level,
message_args: Optional[Tuple[Any, ...]],
**kwargs,
) -> None:
super().__init__(rule, level, message_args, **kwargs)
def with_cpp_stack(self: _ExportDiagnostic) -> _ExportDiagnostic:
# TODO: Implement this.
# self.stacks.append(...)
raise NotImplementedError()
return self
def with_python_stack(self: _ExportDiagnostic) -> _ExportDiagnostic:
# TODO: Implement this.
# self.stacks.append(...)
raise NotImplementedError()
return self
def with_model_source_location(
self: _ExportDiagnostic,
) -> _ExportDiagnostic:
# TODO: Implement this.
# self.locations.append(...)
raise NotImplementedError()
return self
def with_export_source_location(
self: _ExportDiagnostic,
) -> _ExportDiagnostic:
# TODO: Implement this.
# self.locations.append(...)
raise NotImplementedError()
return self
class ExportDiagnosticTool(infra.DiagnosticTool):
"""Base class for all export diagnostic tools.
This class is used to represent all export diagnostic tools. It is a subclass
of infra.DiagnosticTool.
"""
def __init__(self) -> None:
super().__init__(
name="torch.onnx.export",
version=torch.__version__,
rules=_rules.rules,
diagnostic_type=ExportDiagnostic,
)
class ExportDiagnosticEngine(infra.DiagnosticEngine):
"""PyTorch ONNX Export diagnostic engine.
The only purpose of creating this class instead of using the base class directly
is to provide a background context for `diagnose` calls inside exporter.
By design, one `torch.onnx.export` call should initialize one diagnostic context.
All `diagnose` calls inside exporter should be made in the context of that export.
However, since diagnostic context is currently being accessed via a global variable,
there is no guarantee that the context is properly initialized. Therefore, we need
to provide a default background context to fallback to, otherwise any invocation of
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
This can be removed once the pipeline for context to flow through the exporter is
established.
"""
_background_context: infra.DiagnosticContext
def __init__(self) -> None:
super().__init__()
self._background_context = infra.DiagnosticContext(
ExportDiagnosticTool(), options=None
)
@property
def background_context(self) -> infra.DiagnosticContext:
return self._background_context
def clear(self):
super().clear()
self._background_context._diagnostics.clear()
def sarif_log(self):
log = super().sarif_log()
log.runs.append(self._background_context.sarif())
return log
engine = ExportDiagnosticEngine()
context = engine.background_context
@contextlib.contextmanager
def create_export_diagnostic_context():
"""Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by
export internals via global variable. See `ExportDiagnosticEngine` for more details.
"""
global context
context = engine.create_diagnostic_context(ExportDiagnosticTool())
try:
yield context
finally:
context = engine.background_context
|