File: comm_tensor.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (241 lines) | stat: -rw-r--r-- 9,533 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from dataclasses import dataclass
from functools import partial
from typing import Any, List, Optional, Tuple


import torch
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import (
    _ProxyTensor,
    fetch_tensor_proxy,
    get_proxy,
    get_proxy_slots,
    set_proxy_slot,
    track_tensor_tree,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import (
    tree_flatten,
    tree_map,
    tree_map_only,
)


@dataclass
class _CommResult:
    # a custom type wrapping both inplace output tensor and work handle
    _tensor: torch.Tensor
    _work: torch.distributed._Work


def _wait_comm(comm_result: _CommResult):
    # This function is only used by tracing mode as a call_function node right
    # before consuming a collective result tensor.
    comm_result._work.wait()
    return comm_result._tensor


def _wrap_comm_result(result: Tuple[Any, Any]) -> Tuple[Any, Any]:
    def wrap(work, e):
        assert isinstance(e, torch.Tensor), (
            "Excepting collection of tensors as the first element in the "
            "return value of communication operations."
        )

        return _CommResult(e, work)

    # E.g.,
    # allreduce_ returns ([tensor], work)
    # allgather_ returns ([[tensor1, tensor2]], work)
    work = result[1]
    return (tree_map(partial(wrap, work), result[0]), work)


def _get_tracer(obj: Any) -> Optional[torch.fx.Tracer]:
    slots = get_proxy_slots(obj)
    if slots is None:
        return None
    keys = tuple(slots.keys())
    assert len(keys) == 1
    return keys[0]


class CommTensor(torch.Tensor):
    r"""
    A Tensor subclass to wrap input tensors for collective communications. This
    Tensor subclass works for both eager and tracing mode.

    In eager mode, it will record whether the inplace collective communication
    has been launched using this Tensor and remember the corresponding work
    handle. If yes, it will expliclty call wait() in the ``__torch_dispatch__``
    function before subsequent operations consuming the value of the Tensor.

    In tracing mode, ``CommTensor`` inserts two node into the graph using the
    ``__torch_dispatch__`` function.
    1. The first node is inserted right after the
    communication, wrapping both the inplace output tensor and the returned
    work handle into a custom ``_CommResult`` type. We have to do this because
    ``ProxyTorchDispatchMode`` only handles ``torch.Tensor``, ``_ProxyTensor``,
    and ``torch.nn.Parameter`` objects and will treat the work handle
    as a constant and embed that into the graph. As a result, during execution,
    it will use the work handle created during tracing and will lead to wrong
    result. The solution in this test is to manually create a proxy on the
    return value of ``allreduce_`` which is ``([tensor], work)``, and wrap that
    to ``[(_CommResult(tensor, work)), work]``. In this way, subsequent nodes can
    directly consume ``_CommResult``.
    2. The second node is inserted right before any subsequent node reads from
    ``_CommResult``. It will call ``wait()`` on the stashed work handle to ensure
    that computation waits for communication.
    """

    _supported_comms: List[str] = [
        "allreduce_",
        "allgather_",
        "broadcast_",
        "reduce_scatter_",
        "scatter_",
    ]

    _tensor: torch.Tensor
    _work: Optional[torch.distributed._Work]

    @staticmethod
    def __new__(cls, tensor: torch.Tensor):
        t = tensor._tensor if isinstance(tensor, CommTensor) else tensor
        if _get_tracer(t) is None:
            # noop for eager mode
            return tensor

        # Use non-CommTensor to avoid nested CommTensor Wrapping
        r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
        # The tensor object wrapped by this CommTensor
        r._tensor = tensor  # type: ignore[attr-defined]
        # Record the LAST `work` object returned by collective communication
        # operations. If this is None, it means no collectives have called
        # since last time a tensor is wrapped by CommTensor
        r._work = None  # type: ignore[attr-defined]
        return r

    def __repr__(self):
        return f"CommTensor({self._tensor}, work={self._work})"

    # disable __torch_function__ so that CommTensor can recursively dispatch
    # with ProxyTorchDispatchMode in make_fx
    __torch_function__ = _disabled_torch_function_impl

    @classmethod
    def _is_supported(cls, op_name):
        return any([comm in op_name for comm in cls._supported_comms])

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # shared states when unwrapping args
        tracer: Optional[torch.fx.Tracer] = None
        work: Optional[torch.distributed._Work] = None

        # wrapped ._tensor if this is a CommTensor, and insert/call wait()
        # if communication has been launched on this tensor.
        def unwrap(e: Any):
            if isinstance(e, CommTensor):
                nonlocal tracer, work

                work = e._work
                tracer = _get_tracer(e._tensor)

                if work is not None:
                    if tracer is not None:
                        # insert a node to the traced graph.
                        proxy_res = tracer.create_proxy(  # type: ignore[union-attr]
                            'call_function',
                            _wait_comm,
                            (get_proxy(e._tensor).proxy,),
                            {},
                            name="wait_comm"
                        )
                        # HACK: update the proxy for the inplace output
                        set_proxy_slot(e._tensor, tracer, proxy_res)
                    # For eager mode, simply wait.
                    # During tracing, still need to wait here, to make sure the
                    # execution during tracing is correct.
                    work.wait()

                # communication has been waited, stop propagating CommTensor
                return e._tensor
            else:
                return e

        def wrap(e: Any):
            return CommTensor(e) if isinstance(e, torch.Tensor) else e

        def set_work(work: torch.distributed._Work, e: Any):
            if isinstance(e, CommTensor):
                e._work = work  # type: ignore[attr-defined]
            elif isinstance(e, torch.Tensor):
                raise RuntimeError(
                    "Type of output tensors from collective communication during "
                    "tracing should always be CommTensor instead of torch.Tensor"
                )
            return e

        unwrapped_args = tree_map(unwrap, args)
        unwrapped_kwargs = tree_map(unwrap, kwargs)

        if cls._is_supported(func.__name__):
            if tracer is not None:
                # in tracing mode, get proxies for args
                proxy_args, proxy_kwargs = tree_map_only(
                    _ProxyTensor,
                    lambda e: e.proxy,
                    tree_map_only(
                        torch.Tensor,
                        fetch_tensor_proxy(tracer),
                        (unwrapped_args, unwrapped_kwargs)
                    ),
                )

                # get proxy for output tuple
                proxy_res = func(*proxy_args, **proxy_kwargs)
                # insert a node that wraps the output tuple into
                # _CommResult(tensor, work)
                comm_result_proxy = tracer.create_proxy(  # type: ignore[union-attr]
                    'call_function',
                    _wrap_comm_result,
                    (proxy_res, ),
                    {},
                    name="comm_result"
                )

                with no_dispatch():
                    # disable dispatch to avoid trigger ProxyTorchDispatchMode logic
                    out = func(*unwrapped_args, **unwrapped_kwargs)

                # wrap output with the proxy of _CommResult, so that subsequent
                # ops and link to it.
                track_tensor_tree(out, comm_result_proxy, constant=None, tracer=tracer)

                # N.B.: we still need to remember the work handle here, and wait
                # for it later to make sure the execution during tracing is
                # correct. Also, remember comm is already launched
                # args[0] is always the collection of output tensors
                tree_map(partial(set_work, out[1]), args[0])

                # HACK: update the proxy on the input argument as this is an
                # inplace collective communication.
                flat_args, args_spec = tree_flatten(unwrapped_args[0])
                flat_out, out_spec = tree_flatten(out[0])
                for a, o in zip(flat_args, flat_out):
                    set_proxy_slot(a, tracer, get_proxy(o))

                return out
            else:
                # in eager mode, simply remember work handle as an attribute
                out = func(*unwrapped_args, **unwrapped_kwargs)
                tree_map(partial(set_work, out[1]), args[0])
                return out
        else:
            if work is not None:
                return func(*unwrapped_args, **unwrapped_kwargs)
            else:
                # we need to propagate CommTensor wrapping until the first
                # subsequent operation has waited for it.
                return tree_map(wrap, func(*unwrapped_args, **unwrapped_kwargs))