1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
"""ONNX exporter exceptions."""
from __future__ import annotations
import textwrap
from typing import Optional
from torch import _C
from torch.onnx import _constants
from torch.onnx._internal import diagnostics
__all__ = [
"OnnxExporterError",
"OnnxExporterWarning",
"CallHintViolationWarning",
"CheckerError",
"UnsupportedOperatorError",
"SymbolicValueError",
]
class OnnxExporterWarning(UserWarning):
"""Base class for all warnings in the ONNX exporter."""
pass
class CallHintViolationWarning(OnnxExporterWarning):
"""Warning raised when a type hint is violated during a function call."""
pass
class OnnxExporterError(RuntimeError):
"""Errors raised by the ONNX exporter."""
pass
class CheckerError(OnnxExporterError):
"""Raised when ONNX checker detects an invalid model."""
pass
class UnsupportedOperatorError(OnnxExporterError):
"""Raised when an operator is unsupported by the exporter."""
def __init__(
self,
domain: str,
op_name: str,
version: int,
supported_version: Optional[int],
):
if domain in {"", "aten", "prim", "quantized"}:
msg = f"Exporting the operator '{domain}::{op_name}' to ONNX opset version {version} is not supported. "
if supported_version is not None:
msg += (
f"Support for this operator was added in version {supported_version}, "
"try exporting with this version."
)
diagnostics.context.diagnose(
diagnostics.rules.operator_supported_in_newer_opset_version,
diagnostics.levels.ERROR,
message_args=(
f"{domain}::{op_name}",
version,
supported_version,
),
)
else:
msg += "Please feel free to request support or submit a pull request on PyTorch GitHub: "
msg += _constants.PYTORCH_GITHUB_ISSUES_URL
diagnostics.context.diagnose(
diagnostics.rules.missing_standard_symbolic_function,
diagnostics.levels.ERROR,
message_args=(
f"{domain}::{op_name}",
version,
_constants.PYTORCH_GITHUB_ISSUES_URL,
),
)
else:
msg = (
f"ONNX export failed on an operator with unrecognized namespace '{domain}::{op_name}'. "
"If you are trying to export a custom operator, make sure you registered "
"it with the right domain and version."
)
diagnostics.context.diagnose(
diagnostics.rules.missing_custom_symbolic_function,
diagnostics.levels.ERROR,
message_args=(f"{domain}::{op_name}",),
)
super().__init__(msg)
class SymbolicValueError(OnnxExporterError):
"""Errors around TorchScript values and nodes."""
def __init__(self, msg: str, value: _C.Value):
message = (
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
)
code_location = value.node().sourceRange()
if code_location:
message += f"\n (node defined in {code_location})"
try:
# Add its input and output to the message.
message += "\n\n"
message += textwrap.indent(
(
"Inputs:\n"
+ (
"\n".join(
f" #{i}: {input_} (type '{input_.type()}')"
for i, input_ in enumerate(value.node().inputs())
)
or " Empty"
)
+ "\n"
+ "Outputs:\n"
+ (
"\n".join(
f" #{i}: {output} (type '{output.type()}')"
for i, output in enumerate(value.node().outputs())
)
or " Empty"
)
),
" ",
)
except AttributeError:
message += (
" Failed to obtain its input and output for debugging. "
"Please refer to the TorchScript graph for debugging information."
)
super().__init__(message)
|