1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
from weakref import WeakKeyDictionary
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
class _EffectType(Enum):
ORDERED = "Ordered"
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
{
torch.ops.aten._print.default: _EffectType.ORDERED,
call_torchbind: _EffectType.ORDERED,
}
)
def _register_effectful_op(op: OpType, effect: _EffectType):
assert isinstance(
op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
) and not has_aliasing(op)
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
raise RuntimeError(
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
f"trying to register a different effect type {effect}."
)
SIDE_EFFECTS[op] = effect
def _deregister_effectful_op(op: OpType):
if op not in SIDE_EFFECTS:
raise RuntimeError(f"Op {op} is not registered as effectful")
del SIDE_EFFECTS[op]
class WithEffects(HigherOrderOperator):
"""
with_effects(token, op, args, kwargs) -> (new_token, op_results)
This HOP helps ensure ordering between side effectful ops like prints or ops
using torchbind objects. This is needed to ensure a traced graph from
AOTAutograd is functional so that future optimization passes do not reorder
these operators. This is done through threading "effect tokens" through the
graph to enforce data dependence between side effectful ops.
The tokens are basically dummy values (torch.tensor([])). We create a token
per "effect type", which are enumerated in the _EffectType enum.
"""
def __init__(self) -> None:
super().__init__("with_effects")
def __call__(
self,
token,
op: OpType,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[Any, ...]:
assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
assert not has_aliasing(op), "Ops with aliasing is not supported"
assert has_effects(op, args, kwargs)
assert isinstance(kwargs, dict)
return super().__call__(token, op, *args, **kwargs)
with_effects = WithEffects()
def has_aliasing(op: OpType):
# NOT FOR PUBLIC USE
if isinstance(op, torch._ops.HigherOrderOperator):
return op not in SIDE_EFFECTS
for arg in op._schema.arguments:
if arg.alias_info is not None:
return True
for arg in op._schema.returns:
if arg.alias_info is not None:
return True
return False
def has_effects(op, args, kwargs) -> bool:
# Skip over the profiler's RecordFunction as they should not show up in the graph
_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
if op in _skip_ops:
return False
return (
isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
and not has_aliasing(op)
and get_effect_key(op, args, kwargs) is not None
)
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
if op in SIDE_EFFECTS:
return SIDE_EFFECTS[op]
for arg in args:
if isinstance(arg, torch.ScriptObject):
# Add it to the table so that next time we see the same op we don't
# have to parse through the args again
SIDE_EFFECTS[op] = _EffectType.ORDERED
return _EffectType.ORDERED
return None
def new_token_tensor() -> torch.Tensor:
return torch.tensor([])
@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
def with_effects_dense(
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
out = op(*args, **kwargs)
new_token = new_token_tensor()
if isinstance(out, tuple):
return (new_token, *out)
return (new_token, out)
@with_effects.py_impl(FakeTensorMode)
def with_effects_fake(
mode,
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
with mode:
result = with_effects_dense(token, op, *args, **kwargs)
return result
@with_effects.py_impl(ProxyTorchDispatchMode)
def with_effects_proxy(
mode,
token: torch.Tensor,
op: torch._ops.OpOverload,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
with disable_proxy_modes_tracing():
out = with_effects(token, op, *args, **kwargs)
proxy_token = mode.tracer.unwrap_proxy(token)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
from torch.fx.node import has_side_effect
# To avoid the being DCEed by graph.eliminate_dead_code if they.
# don't have output or their outputs are not used.
has_side_effect(op)
out_proxy = mode.tracer.create_proxy(
"call_function",
with_effects,
(proxy_token, op, *proxy_args),
proxy_kwargs,
)
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
return result
with_effects.fallthrough(DispatchKey.AutogradCPU)
with_effects.fallthrough(DispatchKey.AutogradCUDA)
def _get_schema(op, args) -> torch.FunctionSchema:
if isinstance(op, torch._ops.OpOverload):
return op._schema
elif op == call_torchbind:
return getattr(args[0], args[1]).schema
else:
raise RuntimeError(f"Unable to get schema for op {op}")
def handle_effects(
allow_token_discovery: bool,
tokens: Dict[_EffectType, torch.Tensor],
op: OpType,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
"""
Args:
allow_token_discovery: Whether or not we are discovering tokens. If this
is true, we will create a token for every side effect type seen that
does not have a token assigned yet. If this is false, the tokens
should've all been created ahead of time, so we will error if there is
no token mapping to every effect type.
tokens: Map of effect type to tokens. This is to chain operators of the
same effects together so that they do not get reordered in later
optimization passes.
"""
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
# this will create an empty tensor during proxy mode tracing if the token
# doesn't exist. But the tokens should always exist during proxy mode tracing.
key = get_effect_key(op, args, kwargs)
assert key is not None
if key not in tokens:
assert (
allow_token_discovery
), f"Could not find a token for effect {key} which came from the function {op}"
proxy_tensor_mode = torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
)
if proxy_tensor_mode is not None:
# If we discovered a new token during tracing, we are in backward.
# Then we patch the graph, adding additional tangents_token as input to the joint graph.
tracer = proxy_tensor_mode.tracer
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
track_tensor_tree,
)
with disable_proxy_modes_tracing():
token_tensor = new_token_tensor()
token_proxy = proxy_tensor_mode.tracer.create_proxy(
"placeholder", "tangents_token", (), {}, name="tangents_token"
)
track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer)
tokens[key] = token_tensor
else:
tokens[key] = new_token_tensor()
token = tokens[key]
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
ctx = PythonFunctionalizeAPI()
unwrapped_token = ctx.unwrap_tensors([token])[0]
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
with ctx.redispatch_to_next():
(new_token, *unwrapped_outs) = with_effects(
unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs
)
schema = _get_schema(op, unwrapped_args)
if len(schema.returns) == 0:
assert unwrapped_outs[0] is None
unwrapped_outs = None # type: ignore[assignment]
elif len(schema.returns) == 1:
assert len(unwrapped_outs) == 1
unwrapped_outs = unwrapped_outs[0]
else:
assert len(unwrapped_outs) == len(schema.returns)
# Add the newly created token into the tokens map for a following call to
# use this token.
wrapped_token = ctx.wrap_tensors(new_token)
assert isinstance(wrapped_token, torch.Tensor)
tokens[key] = wrapped_token
return ctx.wrap_tensors(unwrapped_outs)
|