File: _func_map.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (228 lines) | stat: -rw-r--r-- 10,943 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, Optional, Sequence, Tuple, Union

import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed.tensor.placement_types import Placement


try:
    from torch.utils import _cxx_pytree as pytree
except ImportError:
    from torch.utils import _pytree as pytree  # type: ignore[no-redef]


__all__ = ["local_map"]

PlacementType = Optional[Sequence[Placement]]
InputPlacements = Optional[Tuple[PlacementType, ...]]
OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]]


def local_map(
    func: Callable,
    out_placements: OutputPlacements,
    in_placements: Optional[InputPlacements] = None,
    device_mesh: Optional[DeviceMesh] = None,
    *,
    redistribute_inputs: bool = False,
):
    """
    :meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s
    to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting
    the local components of :class:`DTensor`, call the function, and wrap the outputs to
    :class:`DTensor` according to the ``out_placements``.

    Args:
        func (Callable): the function to be applied on each local shard of
            :class:`DTensor` s.
        out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]):
            the desired placements of the :class:`DTensor` s in ``func``'s flattened output.
            If the flattened ``output`` is a single value, the ``out_placements`` should be
            of type `PlacementType`. Otherwise if the flattened ``output`` has multiple
            values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1
            mapping to the flattened ``output``.
            Besides, for :class:`Tensor` output, we use `PlacementType` as its
            placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType`
            should be `None`.
            Note that the only exception is when no :class:`DTensor` argument is passed
            in. In this case, even if `out_placements` is not `None`, the result function
            should ignore the desired placements because the function is not running with
            :class:`DTensor` s.
        in_placements (Tuple[`PlacementType`, ...], optional):
            the required placements of the :class:`DTensor` s in the flattened inputs of ``func``.
            If ``in_placements`` is specified, :meth:`local_map` would examine whether the
            placements of each :class:`DTensor` argument is the same as the required
            placements or not. If the placements are not the same and
            ``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if
            ``redistribute_inputs`` is ``True``, the argument will be first redistributed to
            the required sharding placements before passing its local tensor to ``func``.
            The only exception is when required placements are not ``None`` and the
            argument is a :class:`torch.Tensor`. In this case, the placements examination
            will be skipped and the argument will be directly passed to ``func``.
            If ``in_placements`` is ``None``, no placements examination will be performed.
            Default: None
        device_mesh (:class:`DeviceMesh`, optional):
            the device mesh that all the :class:`DTensor` s are placed on. If not
            specified, this will be inferred from the input :class:`DTensor` s' device
            mesh. `local_map` requires every :class:`DTensor` s to be placed on the same
            device mesh. Default: None.
        redistribute_inputs (bool, optional):
            the bool value indicating whether to reshard the input :class:`DTensor` s when
            their placements are different from the required input placements. If this
            value is ``False`` and some :class:`DTensor` input has a different placement,
            an exception will be raised. Default: False.

    Returns:
        A ``Callable`` that applies ``func`` to each local shard of the input :class:`DTensor`
        and returns a :class:`DTensor` constructed from the return value of ``func``.

    Raises:
        AssertionError: If the input :class:`DTensor` is not placed on the same device
            mesh, or if they are placed on a different device mesh than the ``device_mesh``
            argument passed in.

        AssertionError: For any non-DTensor output, we require its corresponding
            output placement in ``out_placements`` be None. An AssertionError will be raised
            if this is not the case.

        ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs
            a redistribution according to ``in_placements``.

    Example:
        >>> # xdoctest: +SKIP("distributed")
        >>> def mm_allreduce_forward(device_mesh, W, X):
        >>>     partial_sum_tensor = torch.mm(W, X)
        >>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
        >>>     return reduced_tensor
        >>>
        >>> W = torch.randn(12, 8, requires_grad=False)
        >>> X = torch.randn(8, 16, requires_grad=False)
        >>> Y = torch.mm(W, X)
        >>> row_wise = [Shard(0)]  # row-wise sharding placements on 1-d mesh
        >>> col_wise = [Shard(1)]  # col-wise sharding placements on 1-d mesh
        >>>
        >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion
        >>> local_mm_allreduce_forward = local_map(
        >>>     mm_allreduce_forward,
        >>>     out_placements=[Replicate()],
        >>>     in_placements=[col_wise, row_wise],
        >>>     device_mesh=device_mesh,
        >>> )
        >>>
        >>> W_dt = distribute_tensor(W, device_mesh, (col_wise))  # col-wisely sharded W tensor
        >>> X_dt = distribute_tensor(X, device_mesh, (row_wise))  # row-wisely sharded X tensor
        >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)  # apply local_mm_allreduce_forward to DTensors

    .. note:: This API is currently experimental and subject to change
    """

    def wrapped(*args, **kwargs):
        # process input args
        flat_args, args_spec = pytree.tree_flatten(args)
        if in_placements is not None:
            assert len(in_placements) == len(flat_args), (
                f"in_placements length {len(in_placements)} does not match the number "
                f"of input args {len(flat_args)}!"
            )

        # we assume every DTensor object is placed on the same device mesh
        flat_local_args = []
        nonlocal device_mesh  # access var device_mesh from the outer scope
        seen_dtensor_arg = False
        for idx, arg in enumerate(flat_args):
            if isinstance(arg, DTensor):
                # TODO: the current code doesn't consider the uneven sharding case
                # Need to think about what the consequence is when the input DTensor
                # is uneven sharded.
                if device_mesh is None:  # infer device mesh from the DTensor arg
                    device_mesh = arg.device_mesh

                # this function is applied to at least one DTensor argument
                seen_dtensor_arg = True

                assert arg.device_mesh == device_mesh, (
                    f"arg {arg} in local_map has a mismatched device mesh: "
                    f"{arg} has device mesh {arg.device_mesh} while "
                    f"the expected device mesh is {device_mesh}!"
                )
                if in_placements is not None:
                    spec = in_placements[idx]
                    assert (
                        spec is not None
                    ), f"DTensor input {arg} expects placements but received {spec}!"

                    if not isinstance(spec, tuple):
                        spec = tuple(spec)

                    if arg.placements != spec:
                        if redistribute_inputs:
                            # redistribute to input placements
                            arg = arg.redistribute(device_mesh, spec)
                        else:
                            raise ValueError(
                                f"arg {arg} in local_map has a mismatched placements: "
                                f"arg placements is {arg.placements} but the input "
                                f"placements is {spec}! "
                                "If redistribute_inputs is wanted, set "
                                "redistribute_inputs=True to local_map."
                            )

                local_arg = arg.to_local()
                if isinstance(local_arg, AsyncCollectiveTensor):
                    local_arg = local_arg.wait()

                flat_local_args.append(local_arg)
            else:
                # Non-Tensor input must have None in `in_placements`
                if in_placements is not None and not isinstance(arg, torch.Tensor):
                    spec = in_placements[idx]
                    assert spec is None, (
                        f"Non-Tensor input {arg} expects None placements "
                        f"but received {spec}!"
                    )

                flat_local_args.append(arg)

        local_args = pytree.tree_unflatten(flat_local_args, args_spec)

        out = func(*local_args, **kwargs)

        if seen_dtensor_arg:
            # process output
            flat_out, out_spec = pytree.tree_flatten(out)

            flat_dist_out = []
            out_placements_tuple = (
                out_placements
                if isinstance(out_placements, tuple)
                else (out_placements,)
            )
            assert len(flat_out) == len(out_placements_tuple), (
                "local_map requires one PlacementType be provided for each output value,"
                f" received {len(out_placements_tuple)} out_placements but"
                f" {len(flat_out)} is expected!"
            )
            for out, spec in zip(flat_out, out_placements_tuple):
                if isinstance(out, torch.Tensor):
                    assert not isinstance(
                        out, DTensor
                    ), f"torch.Tensor output expected but received {type(out)}: {out}"

                    flat_dist_out.append(
                        DTensor.from_local(out, device_mesh, spec, run_check=False)
                    )
                else:
                    assert (
                        spec is None
                    ), f"Non-tensor output {out} expects None placements but received {spec}!"

                    flat_dist_out.append(out)

            return pytree.tree_unflatten(flat_dist_out, out_spec)
        else:
            return out

    return wrapped