1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
|
from dataclasses import dataclass
from functools import partial
from typing import Any, List, Optional, Tuple
import torch
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import (
_ProxyTensor,
fetch_tensor_proxy,
get_proxy,
get_proxy_slots,
set_proxy_slot,
track_tensor_tree,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import (
tree_flatten,
tree_map,
tree_map_only,
)
@dataclass
class _CommResult:
# a custom type wrapping both inplace output tensor and work handle
_tensor: torch.Tensor
_work: torch.distributed._Work
def _wait_comm(comm_result: _CommResult):
# This function is only used by tracing mode as a call_function node right
# before consuming a collective result tensor.
comm_result._work.wait()
return comm_result._tensor
def _wrap_comm_result(result: Tuple[Any, Any]) -> Tuple[Any, Any]:
def wrap(work, e):
assert isinstance(e, torch.Tensor), (
"Excepting collection of tensors as the first element in the "
"return value of communication operations."
)
return _CommResult(e, work)
# E.g.,
# allreduce_ returns ([tensor], work)
# allgather_ returns ([[tensor1, tensor2]], work)
work = result[1]
return (tree_map(partial(wrap, work), result[0]), work)
def _get_tracer(obj: Any) -> Optional[torch.fx.Tracer]:
slots = get_proxy_slots(obj)
if slots is None:
return None
keys = tuple(slots.keys())
assert len(keys) == 1
return keys[0]
class CommTensor(torch.Tensor):
r"""
A Tensor subclass to wrap input tensors for collective communications. This
Tensor subclass works for both eager and tracing mode.
In eager mode, it will record whether the inplace collective communication
has been launched using this Tensor and remember the corresponding work
handle. If yes, it will expliclty call wait() in the ``__torch_dispatch__``
function before subsequent operations consuming the value of the Tensor.
In tracing mode, ``CommTensor`` inserts two node into the graph using the
``__torch_dispatch__`` function.
1. The first node is inserted right after the
communication, wrapping both the inplace output tensor and the returned
work handle into a custom ``_CommResult`` type. We have to do this because
``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``,
and ``torch.nn.Parameter`` objects and will treat the work handle
as a constant and embed that into the graph. As a result, during execution,
it will use the work handle created during tracing and will lead to wrong
result. The solution in this test is to manually create a proxy on the
return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that
to ``[(_CommResult(tensor, work)), work]``. In this way, subsequent nodes can
directly consume ``_CommResult``.
2. The second node is inserted right before any subsequent node reads from
``_CommResult``. It will call ``wait()`` on the stashed work handle to ensure
that computation waits for communication.
"""
_supported_comms: List[str] = [
"allreduce_",
"allgather_",
"broadcast_",
"reduce_scatter_",
"scatter_",
]
_tensor: torch.Tensor
_work: Optional[torch.distributed._Work]
@staticmethod
def __new__(cls, tensor: torch.Tensor):
t = tensor._tensor if isinstance(tensor, CommTensor) else tensor
if _get_tracer(t) is None:
# noop for eager mode
return tensor
# Use non-CommTensor to avoid nested CommTensor Wrapping
r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
# The tensor object wrapped by this CommTensor
r._tensor = tensor # type: ignore[attr-defined]
# Record the LAST `work` object returned by collective communication
# operations. If this is None, it means no collectives have called
# since last time a tensor is wrapped by CommTensor
r._work = None # type: ignore[attr-defined]
return r
def __repr__(self):
return f"CommTensor({self._tensor}, work={self._work})"
# disable __torch_function__ so that CommTensor can recursively dispatch
# with ProxyTorchDispatchMode in make_fx
__torch_function__ = _disabled_torch_function_impl
@classmethod
def _is_supported(cls, op_name):
return any([comm in op_name for comm in cls._supported_comms])
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# shared states when unwrapping args
tracer: Optional[torch.fx.Tracer] = None
work: Optional[torch.distributed._Work] = None
# wrapped ._tensor if this is a CommTensor, and insert/call wait()
# if communication has been launched on this tensor.
def unwrap(e: Any):
if isinstance(e, CommTensor):
nonlocal tracer, work
work = e._work
tracer = _get_tracer(e._tensor)
if work is not None:
if tracer is not None:
# insert a node to the traced graph.
proxy_res = tracer.create_proxy( # type: ignore[union-attr]
'call_function',
_wait_comm,
(get_proxy(e._tensor).proxy,),
{},
name="wait_comm"
)
# HACK: update the proxy for the inplace output
set_proxy_slot(e._tensor, tracer, proxy_res)
# For eager mode, simply wait.
# During tracing, still need to wait here, to make sure the
# execution during tracing is correct.
work.wait()
# communication has been waited, stop propagating CommTensor
return e._tensor
else:
return e
def wrap(e: Any):
return CommTensor(e) if isinstance(e, torch.Tensor) else e
def set_work(work: torch.distributed._Work, e: Any):
if isinstance(e, CommTensor):
e._work = work # type: ignore[attr-defined]
elif isinstance(e, torch.Tensor):
raise RuntimeError(
"Type of output tensors from collective communication during "
"tracing should always be CommTensor instead of torch.Tensor"
)
return e
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)
if cls._is_supported(func.__name__):
if tracer is not None:
# in tracing mode, get proxies for args
proxy_args, proxy_kwargs = tree_map_only(
_ProxyTensor,
lambda e: e.proxy,
tree_map_only(
torch.Tensor,
fetch_tensor_proxy(tracer),
(unwrapped_args, unwrapped_kwargs)
),
)
# get proxy for output tuple
proxy_res = func(*proxy_args, **proxy_kwargs)
# insert a node that wraps the output tuple into
# _CommResult(tensor, work)
comm_result_proxy = tracer.create_proxy( # type: ignore[union-attr]
'call_function',
_wrap_comm_result,
(proxy_res, ),
{},
name="comm_result"
)
with no_dispatch():
# disable dispatch to avoid trigger ProxyTorchDispatchMode logic
out = func(*unwrapped_args, **unwrapped_kwargs)
# wrap output with the proxy of _CommResult, so that subsequent
# ops and link to it.
track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer)
# N.B.: we still need to remember the work handle here, and wait
# for it later to make sure the execution during tracing is
# correct. Also, remember comm is already launched
# args[0] is always the collection of output tensors
tree_map(partial(set_work, out[1]), args[0])
# HACK: update the proxy on the input argument as this is an
# inplace collective communication.
flat_args, args_spec = tree_flatten(unwrapped_args[0])
flat_out, out_spec = tree_flatten(out[0])
for a, o in zip(flat_args, flat_out):
set_proxy_slot(a, tracer, get_proxy(o))
return out
else:
# in eager mode, simply remember work handle as an attribute
out = func(*unwrapped_args, **unwrapped_kwargs)
tree_map(partial(set_work, out[1]), args[0])
return out
else:
if work is not None:
return func(*unwrapped_args, **unwrapped_kwargs)
else:
# we need to propagate CommTensor wrapping until the first
# subsequent operation has waited for it.
return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))
|