File: _diagnostic.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (129 lines) | stat: -rw-r--r-- 4,167 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Diagnostic components for PyTorch ONNX export."""

import contextlib
from typing import Any, Optional, Tuple, TypeVar

import torch
from torch.onnx._internal.diagnostics import _rules, infra

# This is a workaround for mypy not supporting Self from typing_extensions.
_ExportDiagnostic = TypeVar("_ExportDiagnostic", bound="ExportDiagnostic")


class ExportDiagnostic(infra.Diagnostic):
    """Base class for all export diagnostics.

    This class is used to represent all export diagnostics. It is a subclass of
    infra.Diagnostic, and adds additional methods to add more information to the
    diagnostic.
    """

    def __init__(
        self,
        rule: infra.Rule,
        level: infra.Level,
        message_args: Optional[Tuple[Any, ...]],
        **kwargs,
    ) -> None:
        super().__init__(rule, level, message_args, **kwargs)

    def with_cpp_stack(self: _ExportDiagnostic) -> _ExportDiagnostic:
        # TODO: Implement this.
        # self.stacks.append(...)
        raise NotImplementedError()
        return self

    def with_python_stack(self: _ExportDiagnostic) -> _ExportDiagnostic:
        # TODO: Implement this.
        # self.stacks.append(...)
        raise NotImplementedError()
        return self

    def with_model_source_location(
        self: _ExportDiagnostic,
    ) -> _ExportDiagnostic:
        # TODO: Implement this.
        # self.locations.append(...)
        raise NotImplementedError()
        return self

    def with_export_source_location(
        self: _ExportDiagnostic,
    ) -> _ExportDiagnostic:
        # TODO: Implement this.
        # self.locations.append(...)
        raise NotImplementedError()
        return self


class ExportDiagnosticTool(infra.DiagnosticTool):
    """Base class for all export diagnostic tools.

    This class is used to represent all export diagnostic tools. It is a subclass
    of infra.DiagnosticTool.
    """

    def __init__(self) -> None:
        super().__init__(
            name="torch.onnx.export",
            version=torch.__version__,
            rules=_rules.rules,
            diagnostic_type=ExportDiagnostic,
        )


class ExportDiagnosticEngine(infra.DiagnosticEngine):
    """PyTorch ONNX Export diagnostic engine.

    The only purpose of creating this class instead of using the base class directly
    is to provide a background context for `diagnose` calls inside exporter.

    By design, one `torch.onnx.export` call should initialize one diagnostic context.
    All `diagnose` calls inside exporter should be made in the context of that export.
    However, since diagnostic context is currently being accessed via a global variable,
    there is no guarantee that the context is properly initialized. Therefore, we need
    to provide a default background context to fallback to, otherwise any invocation of
    exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
    This can be removed once the pipeline for context to flow through the exporter is
    established.
    """

    _background_context: infra.DiagnosticContext

    def __init__(self) -> None:
        super().__init__()
        self._background_context = infra.DiagnosticContext(
            ExportDiagnosticTool(), options=None
        )

    @property
    def background_context(self) -> infra.DiagnosticContext:
        return self._background_context

    def clear(self):
        super().clear()
        self._background_context._diagnostics.clear()

    def sarif_log(self):
        log = super().sarif_log()
        log.runs.append(self._background_context.sarif())
        return log


engine = ExportDiagnosticEngine()
context = engine.background_context


@contextlib.contextmanager
def create_export_diagnostic_context():
    """Create a diagnostic context for export.

    This is a workaround for code robustness since diagnostic context is accessed by
    export internals via global variable. See `ExportDiagnosticEngine` for more details.
    """
    global context
    context = engine.create_diagnostic_context(ExportDiagnosticTool())
    try:
        yield context
    finally:
        context = engine.background_context