1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
import torch
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
from torch._ops import PyOperator
from torch.utils._pytree import tree_flatten
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode
from torch.fx.experimental.proxy_tensor import track_tensor_tree
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
"""
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
from contextlib import contextmanager
cond = PyOperator('cond')
# TODO(voz): Move out somewhere else once other py dispatched ops need it
@contextmanager
def suspend_mode(mode):
assert(mode is not None), "Cannot suspend None mode"
assert(isinstance(mode, TorchDispatchMode)), f"Unexpected mode type {mode.__class__}"
torch._C._set_torch_dispatch_mode(None)
try:
yield
finally:
torch._C._set_torch_dispatch_mode(mode)
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
def _unwrap_proxy(e):
return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy)
assert isinstance(operands, list), "Cond operands must be a list of tensors"
assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
true_graph = get_isolated_graphmodule(true_fn, operands, {})
false_graph = get_isolated_graphmodule(false_fn, operands, {})
true_outs = []
false_outs = []
for node in true_graph.graph.nodes:
if node.op == 'output':
true_outs.extend(node.args)
for node in false_graph.graph.nodes:
if node.op == 'output':
false_outs.extend(node.args)
flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
assert(len(flat_true_outs) == len(flat_false_outs))
for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
assert true_out.meta == false_out.meta
# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
# I was not sure how to do that. This kinda simulates it.
next_name = None
i = 0
while not next_name:
candidate = f"true_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
true_name = next_name
false_name = f"false_graph_{i}"
assert(not hasattr(proxy_mode.tracer.root, false_name))
proxy_mode.tracer.root.register_module(true_name, true_graph)
proxy_mode.tracer.root.register_module(false_name, false_graph)
args = (pred, true_graph, false_graph, [operands])
proxy_args = pytree.tree_map(_unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
name="conditional")
if pred:
out = true_fn(*operands)
else:
out = false_fn(*operands)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@cond.py_impl(DispatchKey.CPU)
def cond_dense(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
@cond.py_impl(DispatchKey.AutogradCPU)
def cond_autograd(pred, true_fn, false_fn, *operands):
# TODO: support autograd
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
assert all([not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor)])
guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
return cond(pred, true_fn, false_fn, *operands)
@cond.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with suspend_mode(mode):
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
return res
# TODO(voz): Make this automatic for keys, this is very ugly atm
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)
cond.fallthrough(DispatchKey.BackendSelect)
|