1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
|
import contextlib
import ctypes
import inspect
import sys
import types
from abc import ABC
from typing import Any, Dict
import torch._C
import torch.jit
from torch import _utils_internal
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@contextlib.contextmanager
def dl_open_guard():
"""
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
shared library to load custom operators.
"""
if _SET_GLOBAL_FLAGS:
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
yield
if _SET_GLOBAL_FLAGS:
sys.setdlopenflags(old_flags)
def has_key(op, k):
return (
torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
or k in op.py_kernels
)
# TODO(voz) We are missing an entire axis of registration - Modes for the python key
class PyOperatorABC(ABC):
def __call__(self, *args, **kwargs):
pass
def py_impl(self, dispatch_key, fn):
pass
def name(self):
pass
is_included_in_alias = torch._C._dispatch_is_included_in_alias
DispatchKey = torch._C.DispatchKey
# Equivalent to computeDispatchTableEntryWithDebug
def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
# 1. (Direct) operator registration
if has_key(op, k):
return k
# 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
cand = DispatchKey.CompositeExplicitAutogradNonFunctional
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
op, cand
):
return cand
# 2.2 Use CompositeExplicitAutograd kernel if available
cand = DispatchKey.CompositeExplicitAutograd
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
op, cand
):
return cand
has_backend_kernel = torch._C._dispatch_has_kernel_for_any_dispatch_key(
op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k)
) or has_key(op, DispatchKey.CompositeExplicitAutograd)
# 2.3. Use CompositeImplicitAutograd kernel if available
cand = DispatchKey.CompositeImplicitAutogradNestedTensor
if (
(k != DispatchKey.Undefined and is_included_in_alias(k, cand))
and has_key(op, cand)
and not has_backend_kernel
):
return cand
cand = DispatchKey.CompositeImplicitAutograd
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
op, cand
):
if (
k == DispatchKey.AutogradOther
and torch._C._dispatch_has_kernel_for_any_dispatch_key(
op.name(), torch._C._dispatch_autogradother_backends
)
):
raise RuntimeError("ambiguous autogradother kernel")
elif not has_backend_kernel:
return cand
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
cand = DispatchKey.Autograd
if is_included_in_alias(k, cand) and has_key(op, cand):
return cand
# Backend fallback
if torch._C._dispatch_has_backend_fallback(k):
# The dispatch key itself will implicitly route to backend fallback.
# This is probably not great for the pure Python implementation.
return k
raise RuntimeError("could not find kernel")
pyop_namespace = {}
class PyOperator(PyOperatorABC):
def __init__(self, name):
self._name = name
self.table = {}
self.python_key_mode_table = {}
# Make _OPNamespace not scream, this whole name based association needs a good hard look
self.__name__ = name
pyop_namespace[name] = self
def fallthrough(self, dispatch_key):
self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)
def py_impl(self, dispatch_key_or_mode):
def inner(fn):
if inspect.isclass(dispatch_key_or_mode) and issubclass(
dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
):
mode = dispatch_key_or_mode
assert mode not in self.python_key_mode_table
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
self.python_key_mode_table[mode] = fn
return fn
dispatch_key = dispatch_key_or_mode
assert (
dispatch_key != torch._C.DispatchKey.Python
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
assert isinstance(dispatch_key, torch._C.DispatchKey)
assert dispatch_key not in self.table
self.table[dispatch_key] = fn
return fn
return inner
def dispatch(self, dispatch_key, *args, **kwargs):
from torch.utils._python_dispatch import _get_current_dispatch_mode
if dispatch_key == torch._C.DispatchKey.Python:
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = type(_get_current_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
assert (
curr_mode in self.python_key_mode_table
), f"Current active mode {curr_mode} not registered"
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
return self.python_key_mode_table[curr_mode](*args, **kwargs)
assert dispatch_key in self.table
return self.table[dispatch_key](*args, **kwargs)
def __call__(self, *args, **kwargs):
flat_args = _to_flat_tuple(args, kwargs)
if torch.overrides.has_torch_function(flat_args):
return torch.overrides.handle_torch_function(
self, flat_args, *args, **kwargs
)
dispatch_key_set = _compute_keyset(args, kwargs)
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
def name(self):
return self.name
# TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
# as opposed to being this sort of explicit thing where ops are a little too key aware...
def _fallthrough_fn(self, operator, dispatch_key):
def inner(*args, **kwargs):
all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key)
all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
args, kwargs
)
return self.dispatch(
all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
)
return inner
def _to_flat_tuple(args, kwargs):
flat_args, _ = torch.utils._pytree.tree_flatten(args)
flat_kwargs, _ = torch.utils._pytree.tree_flatten(kwargs)
flat_all = flat_args + flat_kwargs
return flat_all
def _compute_keyset(args, kwargs):
tensors = _get_tensors(args, kwargs)
return key_extractor(tensors)
def _get_tensors(args, kwargs):
flat_all = _to_flat_tuple(args, kwargs)
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
return tuple(tensor_args)
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
# at ATen/core/dispatch/DispatchKeyExtractor.h
def key_extractor(tensors):
key_set = torch._C._dispatch_tls_local_include_set()
for tensor in tensors:
key_set = key_set | torch._C._dispatch_keys(tensor)
key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
return key_set
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload(PyOperatorABC):
def __init__(self, overloadpacket, op, op_dk, schema, tags):
self._op = op
self._op_dk = op_dk
self._schema = schema
self._overloadpacket = overloadpacket
self._tags = tags
self._overloadname = (
"default" if schema.overload_name == "" else schema.overload_name
)
self._name = self._schema.name
if schema.overload_name:
self._name += "." + schema.overload_name
self.py_kernels: Dict[torch._C.DispatchKey, Any] = {} # type: ignore[name-defined]
self.__name__ = "{}.{}".format(
self._schema.name.split("::")[1], self._overloadname
)
# TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
self.python_key_mode_table = {}
self.__module__ = overloadpacket.__module__
op.__module__ = overloadpacket.__module__
self.__qualname__ = self._name
self.__annotations__ = {}
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
def __deepcopy__(self, memo=None):
return self
def __repr__(self):
return "<OpOverload(op='{}.{}', overload='{}')>".format(
*self._schema.name.split("::"), self._overloadname
)
def __call__(self, *args, **kwargs):
return self._op(*args, **kwargs or {})
def __hash__(self):
return hash(self._op)
# `my_namespace.my_op_name.overload_name`
def __str__(self):
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
@property
def namespace(self):
return self._schema.name.split("::")[0]
def decompose(self, *args, **kwargs):
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if dk in self.py_kernels:
# NB: This branch is not too necessary anymore, because we can
# apply Python CompositeImplicitAutograd *before* tracing
# using Python dispatcher (also taking advantage of the autograd
# formula). But it's included for completeness
return self.py_kernels[dk](*args, **kwargs)
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
return self._op_dk(dk, *args, **kwargs)
else:
return NotImplemented
def py_impl(self, dispatch_key_or_mode):
def inner(fn):
if inspect.isclass(dispatch_key_or_mode) and issubclass(
dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
):
mode = dispatch_key_or_mode
assert mode not in self.python_key_mode_table
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
self.python_key_mode_table[mode] = fn
return fn
assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey)
assert (
dispatch_key_or_mode != torch._C.DispatchKey.Python
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
self.py_kernels[dispatch_key_or_mode] = fn
return fn
return inner
# This implements the pre-computation logic for the Python dispatcher.
def __getattr__(self, attr):
if len(attr) == 0 or not attr[0].isupper():
raise AttributeError()
try:
key = torch._C._dispatch_key_parse(attr)
except Exception as e:
raise AttributeError()
if key == torch._C.DispatchKey.Python:
if not self.python_key_mode_table:
setattr(self, attr, key)
return key
def handler(*args, **kwargs):
from torch.utils._python_dispatch import _get_current_dispatch_mode
# TODO: We also need to handle tensor subclasses here
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = type(_get_current_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
if curr_mode not in self.python_key_mode_table:
# TODO: This path is slow, should generally encourage this
# case to not happen
return self._op_dk(key, *args, **kwargs)
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
return self.python_key_mode_table[curr_mode](*args, **kwargs)
setattr(self, attr, handler)
return handler
key = resolve_key(self, key)
r = self.py_kernels.get(key, key)
setattr(self, attr, r)
return r
def name(self):
return self._name
@property
def overloadpacket(self):
return self._overloadpacket
@property
def op(self):
return self._op
@property
def tags(self):
return self._tags
# TODO: add more methods to expose information about input and output arguments
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
# You can obtain an OpOverload object through attribute query.
class OpOverloadPacket:
def __init__(self, qualified_op_name, op_name, op, overload_names):
# These attributes are accessible on the object through the properties
# defined below but are immutable
self._qualified_op_name = qualified_op_name
self.__name__ = op_name
self._op = op
self._overload_names = overload_names
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
def __deepcopy__(self, memo=None):
return self
def __repr__(self):
return "<OpOverloadPacket(op='{}.{}')>".format(
*self._qualified_op_name.split("::")
)
def __hash__(self):
return hash(self._op)
def __str__(self):
return "{}.{}".format(*self._qualified_op_name.split("::"))
@property
def op(self):
return self._op
def __getattr__(self, key):
# It is not a valid op_name when __file__ is passed in
if key == "__file__":
return "torch.ops"
# ensure that query for dunder attributes that does not exist on
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
# `_get_operation_overload` (which is an expensive operation).
# This is done to prevent any potential slowdown. This list can be extended
# if there exists other attributes like `__name__` that only exist on self._op and not on the
# opoverloadpacket.
# This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
try:
if key.startswith("__"):
return getattr(self._op, key)
except AttributeError:
# for consistency because it seems weird to
# throw an attribute error with a message containing
# an object name different from the one the attribute
# query was performed on.
raise AttributeError(
"'{}' can't have an overload name beginning with '__' and the "
"underlying op {} has no attribute {} either.".format(
str(self), str(self._op), key
)
) from None
try:
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
use_key = "" if key == "default" else key
# TODO: disallow access to overloads registered by JIT
op_, op_dk_, tags = torch._C._get_operation_overload(
self._qualified_op_name, use_key
)
schema = torch._C._get_schema(self._qualified_op_name, use_key)
overload = OpOverload(self, op_, op_dk_, schema, tags)
# cache the overload object
setattr(self, key, overload)
return overload
except RuntimeError:
raise AttributeError(
"The underlying op of '{}' has no overload name '{}'".format(
str(self), key
)
) from None
def __call__(self, *args, **kwargs):
# overloading __call__ to ensure torch.ops.foo.bar()
# is still callable from JIT
# We save the function ptr as the `op` attribute on
# OpOverloadPacket to access it here.
return self._op(*args, **kwargs or {})
# TODO: use this to make a __dir__
def overloads(self):
return [n if n else "default" for n in self._overload_names]
# Resolution of torch.fn is different from torch.ops.aten.fn
# torch.fn uses the Python argparser, matches with the
# appropriate schema, and calls into the unboxed version of the method
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
# JIT creates a stack of all the overloads and then tries to match the
# correct one at runtime and always calls into the boxed version of the method
# Autograd codegen creates VariableType, TracerType,
# inplace or view type and python bindings.
# Aten codegen generates tensor methods for the the tensor class.
# _OpNamespace is a subclass of ModuleType because the torch script
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
# to work from script, we need to ensure ops and foo are modules
class _OpNamespace(types.ModuleType):
"""
An op namespace to dynamically bind Operators into Python.
Say a user has created a custom Operator called "my_namespace::my_op". To
call this op, the user will write torch.ops.my_namespace.my_op(...).
At startup, this operation will not yet be bound into Python. Instead, the
following sequence of magic tricks will occur:
1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
on the `torch.ops` object, which will create a new `_OpNamespace`
object called `my_namespace` and set it as an attribute on the `ops`
object.
2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
the `my_namespace` object, which will retrieve the operation via
`torch.get_operation`, a function bound from C++, and then in a similar
fashion bind this new object onto the `my_namespace` object.
3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
and subsequent accesses will incur no further lookup (the namespace and
operation will already exist).
"""
def __init__(self, name):
super(_OpNamespace, self).__init__("torch.ops." + name)
self.name = name
def __getattr__(self, op_name):
# It is not a valid op_name when __file__ is passed in
if op_name == "__file__":
return "torch.ops"
elif op_name == "__origin__":
raise AttributeError()
# Get the op `my_namespace::my_op` if available. This will also check
# for overloads and raise an exception if there are more than one.
namespace_name = self.name
qualified_op_name = "{}::{}".format(namespace_name, op_name)
try:
op, overload_names = torch._C._jit_get_operation(qualified_op_name)
except RuntimeError as e:
# Turn this into AttributeError so getattr(obj, key, default)
# works (this is called by TorchScript with __origin__)
raise AttributeError(
f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
) from e
# let the script frontend know that op is identical to the builtin op
# with qualified_op_name
torch.jit._builtins._register_builtin(op, qualified_op_name)
op.__module__ = self.__module__ + "." + namespace_name
opoverloadpacket = OpOverloadPacket(
qualified_op_name, op_name, op, overload_names
)
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
# cache the opoverloadpacket to ensure that each op corresponds to
# a unique OpOverloadPacket object
setattr(self, op_name, opoverloadpacket)
return opoverloadpacket
class _PyOpNamespace(_OpNamespace):
def __init__(self):
super(_PyOpNamespace, self).__init__("torch.ops")
self.pyop_namespace = pyop_namespace
class _Ops(types.ModuleType):
__file__ = "_ops.py"
def __init__(self):
super(_Ops, self).__init__("torch.ops")
self.loaded_libraries = set()
self.pyops = _PyOpNamespace()
def __getattr__(self, name):
# Check if the name is a pyop
if name in self.pyops.pyop_namespace:
return self.pyops.pyop_namespace[name]
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
setattr(self, name, namespace)
return namespace
def load_library(self, path):
"""
Loads a shared library from the given path into the current process.
The library being loaded may run global initialization code to register
custom operators with the PyTorch JIT runtime. This allows dynamically
loading custom operators. For this, you should compile your operator
and the static registration code into a shared library object, and then
call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
shared object.
After the library is loaded, it is added to the
``torch.ops.loaded_libraries`` attribute, a set that may be inspected
for the paths of all libraries loaded using this function.
Args:
path (str): A path to a shared library to load.
"""
if sys.executable == "torch_deploy":
return
path = _utils_internal.resolve_library_path(path)
with dl_open_guard():
# Import the shared library into the process, thus running its
# static (global) initialization code in order to register custom
# operators with the JIT.
ctypes.CDLL(path)
self.loaded_libraries.add(path)
# The ops "namespace"
ops = _Ops()
|