1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
|
"""Module for handling ATen to ONNX functions registration.
https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462/torch/onnx/_internal/fx/registration.py
"""
# NOTE: Why do we need a different registry than the one in torchlib?
# The registry in torchlib is used to register functions that are already implemented in
# torchlib, and is designed to be a static singleton. It does not take into account custom ops or different
# opsets etc. The registry implemented for the exporter is designed to be modifiable at
# export time by users, and is designed with dispatching in mind.
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
import importlib.util
import logging
import math
import operator
import types
from typing import Callable, Literal, Union
from typing_extensions import TypeAlias
import torch
import torch._ops
from torch.onnx._internal._lazy_import import onnxscript, onnxscript_apis
from torch.onnx._internal.exporter import _schemas
from torch.onnx._internal.exporter._torchlib import _torchlib_registry
TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable]
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class OnnxDecompMeta:
"""A wrapper of onnx-script function with additional metadata.
onnx_function: The onnx-script function from torchlib.
fx_target: The PyTorch node callable target.
is_custom: Whether the function is a custom function.
is_complex: Whether the function is a function that handles complex valued inputs.
device: The device the function is registered to. If None, it is registered to all devices.
"""
onnx_function: Callable
fx_target: TorchOp
is_custom: bool = False
is_complex: bool = False
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
"""Obtain the torch op from <namespace>::<op_name>[.<overload>]"""
# TODO(justinchuby): Handle arbitrary custom ops
namespace, opname_overload = qualified_name.split("::")
op_name, *maybe_overload = opname_overload.split(".", 1)
if namespace == "_operator":
# Builtin functions
return getattr(operator, op_name)
if namespace == "math":
return getattr(math, op_name)
if namespace == "torchvision":
if importlib.util.find_spec("torchvision") is None:
logger.warning("torchvision is not installed. Skipping %s", qualified_name)
return None
try:
op_packet = getattr(getattr(torch.ops, namespace), op_name)
if maybe_overload:
overload = maybe_overload[0]
elif "default" in op_packet._overload_names or "" in op_packet._overload_names:
# Has a default overload
overload = "default"
else:
logger.warning(
"'%s' does not have a 'default' overload. This could be an error in specifying the op name. Ignoring.",
qualified_name,
stacklevel=1,
)
return None
return getattr(op_packet, overload) # type: ignore[call-overload]
except AttributeError:
if qualified_name.endswith("getitem"):
# This is a special case where we registered the function incorrectly,
# but for BC reasons (pt<=2.4) we need to keep it.
return None
logger.info("'%s' is not found in this version of PyTorch.", qualified_name)
return None
except Exception:
logger.exception("Failed to find torch op '%s'", qualified_name)
return None
class ONNXRegistry:
"""Registry for ONNX functions.
The registry maintains a mapping from qualified names to symbolic functions under a
fixed opset version. It supports registering custom onnx-script functions and for
dispatcher to dispatch calls to the appropriate function.
"""
def __init__(self) -> None:
"""Initializes the registry"""
self._opset_version = onnxscript_apis.torchlib_opset_version()
self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {}
@property
def opset_version(self) -> int:
"""The ONNX opset version the exporter should target."""
return self._opset_version
@classmethod
def from_torchlib(cls) -> ONNXRegistry:
"""Populates the registry with ATen functions from torchlib.
Args:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
registry = cls()
torchlib_ops = onnxscript_apis.get_torchlib_ops()
for meta in torchlib_ops:
qualified_name = meta.qualified_name
overload_func = meta.function
domain = meta.domain
name = meta.name
try:
# NOTE: This is heavily guarded with try-except because we don't want
# to fail the entire registry population if one function fails.
target = _get_overload(qualified_name)
if target is None:
continue
if isinstance(overload_func, onnxscript.OnnxFunction):
opset_version = overload_func.opset.version
else:
opset_version = 1
overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
overload_func,
domain,
name,
opset_version=opset_version,
)
onnx_decomposition = OnnxDecompMeta(
onnx_function=overload_func,
fx_target=target,
is_custom=False,
is_complex=meta.is_complex,
)
registry._register(target, onnx_decomposition)
except Exception:
logger.exception("Failed to register '%s'. Skipped", qualified_name)
continue
# Gather ops from the internal torchlib registry
# TODO(justinchuby): Make this the main registry after torchlib is migrated to PyTorch
# Trigger registration
from torch.onnx._internal.exporter._torchlib import ops
del ops
for target, implementations in _torchlib_registry.registry.items(): # type: ignore[assignment]
for impl in implementations:
onnx_decomposition = OnnxDecompMeta(
onnx_function=impl,
fx_target=target, # type: ignore[arg-type]
)
registry._register(target, onnx_decomposition) # type: ignore[arg-type]
return registry
def _register(
self,
target: TorchOp,
onnx_decomposition: OnnxDecompMeta,
) -> None:
"""Registers a OnnxDecompMeta to an operator.
Args:
target: The PyTorch node callable target.
onnx_decomposition: The OnnxDecompMeta to register.
"""
target_or_name: str | TorchOp
if isinstance(target, torch._ops.OpOverload):
# Get the qualified name of the aten op because torch._ops.OpOverload lookup in
# a dictionary is unreliable for some reason.
target_or_name = target.name()
else:
target_or_name = target
if onnx_decomposition.is_custom:
self.functions.setdefault(target_or_name, []).insert(0, onnx_decomposition)
else:
self.functions.setdefault(target_or_name, []).append(onnx_decomposition)
def register_op(
self,
target: TorchOp,
function: Callable,
is_complex: bool = False,
) -> None:
"""Registers a custom operator: torch.ops.<namespace>.<op_name>.<overload>.
Args:
target: The PyTorch node callable target.
function: The onnx-script function to register.
is_complex: Whether the function is a function that handles complex valued inputs.
"""
if not hasattr(function, "signature"):
try:
# TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI
if isinstance(function, onnxscript.OnnxFunction):
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function,
function.function_ir.domain,
function.name,
opset_version=function.opset.version,
)
else:
function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined]
function, "__custom", function.__name__
)
except Exception:
logger.exception(
"Failed to infer the signature for function '%s'", function
)
onnx_decomposition = OnnxDecompMeta(
onnx_function=function,
fx_target=target,
is_custom=True,
is_complex=is_complex,
)
self._register(target, onnx_decomposition)
def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
"""Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
The list is ordered by the time of registration. The custom operators should come
first in the list.
Args:
target: The PyTorch node callable target.
Returns:
A list of OnnxDecompMeta corresponding to the given name, or None if
the name is not in the registry.
"""
target_or_name: str | TorchOp
if isinstance(target, torch._ops.OpOverload):
# Get the qualified name of the aten op because torch._ops.OpOverload lookup in
# a dictionary is unreliable for some reason.
target_or_name = target.name()
else:
target_or_name = target
decomps = self.functions.get(target_or_name, [])
return sorted(decomps, key=lambda x: x.is_custom, reverse=True)
def is_registered(self, target: TorchOp) -> bool:
"""Returns whether the given op is registered: torch.ops.<namespace>.<op_name>.<overload>.
Args:
target: The PyTorch node callable target.
Returns:
True if the given op is registered, otherwise False.
"""
return bool(self.get_decomps(target))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(functions={self.functions})"
|