1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
import contextlib
import warnings
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, _set_torch_dispatch_mode
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
class TorchDispatchMode:
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
``__torch_dispatch__`` overrideable functions within a dynamic scope,
without having to actually create a tensor subclass or manually
monkey-patch functions in the PyTorch API. Some common situations
where you should use a mode:
* You want to override the meaning of factory functions, or other
functions that do not otherwise take a tensor as an argument
(these cannot be overridden with tensor subclasses).
* You want to override the behavior of all functions without needing
to wrap your inputs in tensor subclasses; e.g., if you are just
interested in logging intermediate computations.
* You want to control the order of execution of various tensor
subclasses explicitly, rather than implicitly via the return of
``NotImplemented``.
Independent subclasses of :class:`TorchDispatchMode` are compositional:
modes can be pushed onto a stack using ``with MyMode():``.
When you call functions in the PyTorch API inside your
``__torch_dispatch__`` implementation, by default, they will forward on to
the next mode on the mode stack. If you want recursively call back into
your current ``__torch_dispatch__`` implementation, either explicitly
invoke ``self.__torch_dispatch__(...)``, or use the context manager
``__torch_dispatch__(self)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
def __enter__(self):
_push_mode(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_pop_mode()
@classmethod
def push(cls, *args, **kwargs):
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
instance = cls(*args, **kwargs)
return instance
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
def _get_current_dispatch_mode_stack():
stack_len = _len_torch_dispatch_stack()
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
def _push_mode(mode):
if _len_torch_dispatch_stack() == 0:
_set_torch_dispatch_mode(_TorchDispatchStackMode())
_push_on_torch_dispatch_stack(mode)
def _pop_mode():
old = _pop_torch_dispatch_stack()
if _len_torch_dispatch_stack() == 0:
_set_torch_dispatch_mode(None)
return old
@contextlib.contextmanager
def _pop_mode_temporarily():
old = _pop_mode()
try:
yield old
finally:
_push_mode(old)
# a helper "mode" used by the torch dispatch push helper method. This is the only mode that will ever
# be active at the C++ level and it will run the current mode
class _TorchDispatchStackMode:
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with _pop_mode_temporarily() as old:
if _len_torch_dispatch_stack() > 0:
_set_torch_dispatch_mode(self)
# we can't check the type of __torch_dispatch__ here but this is sufficient for checking it's a classmethod
if old.__torch_dispatch__.__self__ is type(old):
raise RuntimeError(f"{type(old)}'s torch_dispatch function " +
"should be a normal method not a class method")
return old.__torch_dispatch__(func, types, args, kwargs)
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
|