File: errors.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (141 lines) | stat: -rw-r--r-- 4,695 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""ONNX exporter exceptions."""
from __future__ import annotations

import textwrap
from typing import Optional

from torch import _C
from torch.onnx import _constants
from torch.onnx._internal import diagnostics

__all__ = [
    "OnnxExporterError",
    "OnnxExporterWarning",
    "CallHintViolationWarning",
    "CheckerError",
    "UnsupportedOperatorError",
    "SymbolicValueError",
]


class OnnxExporterWarning(UserWarning):
    """Base class for all warnings in the ONNX exporter."""

    pass


class CallHintViolationWarning(OnnxExporterWarning):
    """Warning raised when a type hint is violated during a function call."""

    pass


class OnnxExporterError(RuntimeError):
    """Errors raised by the ONNX exporter."""

    pass


class CheckerError(OnnxExporterError):
    """Raised when ONNX checker detects an invalid model."""

    pass


class UnsupportedOperatorError(OnnxExporterError):
    """Raised when an operator is unsupported by the exporter."""

    def __init__(
        self,
        domain: str,
        op_name: str,
        version: int,
        supported_version: Optional[int],
    ):
        if domain in {"", "aten", "prim", "quantized"}:
            msg = f"Exporting the operator '{domain}::{op_name}' to ONNX opset version {version} is not supported. "
            if supported_version is not None:
                msg += (
                    f"Support for this operator was added in version {supported_version}, "
                    "try exporting with this version."
                )
                diagnostics.context.diagnose(
                    diagnostics.rules.operator_supported_in_newer_opset_version,
                    diagnostics.levels.ERROR,
                    message_args=(
                        f"{domain}::{op_name}",
                        version,
                        supported_version,
                    ),
                )
            else:
                msg += "Please feel free to request support or submit a pull request on PyTorch GitHub: "
                msg += _constants.PYTORCH_GITHUB_ISSUES_URL
                diagnostics.context.diagnose(
                    diagnostics.rules.missing_standard_symbolic_function,
                    diagnostics.levels.ERROR,
                    message_args=(
                        f"{domain}::{op_name}",
                        version,
                        _constants.PYTORCH_GITHUB_ISSUES_URL,
                    ),
                )
        else:
            msg = (
                f"ONNX export failed on an operator with unrecognized namespace '{domain}::{op_name}'. "
                "If you are trying to export a custom operator, make sure you registered "
                "it with the right domain and version."
            )
            diagnostics.context.diagnose(
                diagnostics.rules.missing_custom_symbolic_function,
                diagnostics.levels.ERROR,
                message_args=(f"{domain}::{op_name}",),
            )
        super().__init__(msg)


class SymbolicValueError(OnnxExporterError):
    """Errors around TorchScript values and nodes."""

    def __init__(self, msg: str, value: _C.Value):
        message = (
            f"{msg}  [Caused by the value '{value}' (type '{value.type()}') in the "
            f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
        )

        code_location = value.node().sourceRange()
        if code_location:
            message += f"\n    (node defined in {code_location})"

        try:
            # Add its input and output to the message.
            message += "\n\n"
            message += textwrap.indent(
                (
                    "Inputs:\n"
                    + (
                        "\n".join(
                            f"    #{i}: {input_}  (type '{input_.type()}')"
                            for i, input_ in enumerate(value.node().inputs())
                        )
                        or "    Empty"
                    )
                    + "\n"
                    + "Outputs:\n"
                    + (
                        "\n".join(
                            f"    #{i}: {output}  (type '{output.type()}')"
                            for i, output in enumerate(value.node().outputs())
                        )
                        or "    Empty"
                    )
                ),
                "    ",
            )
        except AttributeError:
            message += (
                " Failed to obtain its input and output for debugging. "
                "Please refer to the TorchScript graph for debugging information."
            )

        super().__init__(message)