1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, Optional, Sequence, Tuple, Union
import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed.tensor.placement_types import Placement
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree # type: ignore[no-redef]
__all__ = ["local_map"]
PlacementType = Optional[Sequence[Placement]]
InputPlacements = Optional[Tuple[PlacementType, ...]]
OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]]
def local_map(
func: Callable,
out_placements: OutputPlacements,
in_placements: Optional[InputPlacements] = None,
device_mesh: Optional[DeviceMesh] = None,
*,
redistribute_inputs: bool = False,
):
"""
:meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s
to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting
the local components of :class:`DTensor`, call the function, and wrap the outputs to
:class:`DTensor` according to the ``out_placements``.
Args:
func (Callable): the function to be applied on each local shard of
:class:`DTensor` s.
out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]):
the desired placements of the :class:`DTensor` s in ``func``'s flattened output.
If the flattened ``output`` is a single value, the ``out_placements`` should be
of type `PlacementType`. Otherwise if the flattened ``output`` has multiple
values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1
mapping to the flattened ``output``.
Besides, for :class:`Tensor` output, we use `PlacementType` as its
placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType`
should be `None`.
Note that the only exception is when no :class:`DTensor` argument is passed
in. In this case, even if `out_placements` is not `None`, the result function
should ignore the desired placements because the function is not running with
:class:`DTensor` s.
in_placements (Tuple[`PlacementType`, ...], optional):
the required placements of the :class:`DTensor` s in the flattened inputs of ``func``.
If ``in_placements`` is specified, :meth:`local_map` would examine whether the
placements of each :class:`DTensor` argument is the same as the required
placements or not. If the placements are not the same and
``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if
``redistribute_inputs`` is ``True``, the argument will be first redistributed to
the required sharding placements before passing its local tensor to ``func``.
The only exception is when required placements are not ``None`` and the
argument is a :class:`torch.Tensor`. In this case, the placements examination
will be skipped and the argument will be directly passed to ``func``.
If ``in_placements`` is ``None``, no placements examination will be performed.
Default: None
device_mesh (:class:`DeviceMesh`, optional):
the device mesh that all the :class:`DTensor` s are placed on. If not
specified, this will be inferred from the input :class:`DTensor` s' device
mesh. `local_map` requires every :class:`DTensor` s to be placed on the same
device mesh. Default: None.
redistribute_inputs (bool, optional):
the bool value indicating whether to reshard the input :class:`DTensor` s when
their placements are different from the required input placements. If this
value is ``False`` and some :class:`DTensor` input has a different placement,
an exception will be raised. Default: False.
Returns:
A ``Callable`` that applies ``func`` to each local shard of the input :class:`DTensor`
and returns a :class:`DTensor` constructed from the return value of ``func``.
Raises:
AssertionError: If the input :class:`DTensor` is not placed on the same device
mesh, or if they are placed on a different device mesh than the ``device_mesh``
argument passed in.
AssertionError: For any non-DTensor output, we require its corresponding
output placement in ``out_placements`` be None. An AssertionError will be raised
if this is not the case.
ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs
a redistribution according to ``in_placements``.
Example:
>>> # xdoctest: +SKIP("distributed")
>>> def mm_allreduce_forward(device_mesh, W, X):
>>> partial_sum_tensor = torch.mm(W, X)
>>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>> return reduced_tensor
>>>
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
>>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
>>>
>>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion
>>> local_mm_allreduce_forward = local_map(
>>> mm_allreduce_forward,
>>> out_placements=[Replicate()],
>>> in_placements=[col_wise, row_wise],
>>> device_mesh=device_mesh,
>>> )
>>>
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
.. note:: This API is currently experimental and subject to change
"""
def wrapped(*args, **kwargs):
# process input args
flat_args, args_spec = pytree.tree_flatten(args)
if in_placements is not None:
assert len(in_placements) == len(flat_args), (
f"in_placements length {len(in_placements)} does not match the number "
f"of input args {len(flat_args)}!"
)
# we assume every DTensor object is placed on the same device mesh
flat_local_args = []
nonlocal device_mesh # access var device_mesh from the outer scope
seen_dtensor_arg = False
for idx, arg in enumerate(flat_args):
if isinstance(arg, DTensor):
# TODO: the current code doesn't consider the uneven sharding case
# Need to think about what the consequence is when the input DTensor
# is uneven sharded.
if device_mesh is None: # infer device mesh from the DTensor arg
device_mesh = arg.device_mesh
# this function is applied to at least one DTensor argument
seen_dtensor_arg = True
assert arg.device_mesh == device_mesh, (
f"arg {arg} in local_map has a mismatched device mesh: "
f"{arg} has device mesh {arg.device_mesh} while "
f"the expected device mesh is {device_mesh}!"
)
if in_placements is not None:
spec = in_placements[idx]
assert (
spec is not None
), f"DTensor input {arg} expects placements but received {spec}!"
if not isinstance(spec, tuple):
spec = tuple(spec)
if arg.placements != spec:
if redistribute_inputs:
# redistribute to input placements
arg = arg.redistribute(device_mesh, spec)
else:
raise ValueError(
f"arg {arg} in local_map has a mismatched placements: "
f"arg placements is {arg.placements} but the input "
f"placements is {spec}! "
"If redistribute_inputs is wanted, set "
"redistribute_inputs=True to local_map."
)
local_arg = arg.to_local()
if isinstance(local_arg, AsyncCollectiveTensor):
local_arg = local_arg.wait()
flat_local_args.append(local_arg)
else:
# Non-Tensor input must have None in `in_placements`
if in_placements is not None and not isinstance(arg, torch.Tensor):
spec = in_placements[idx]
assert spec is None, (
f"Non-Tensor input {arg} expects None placements "
f"but received {spec}!"
)
flat_local_args.append(arg)
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
out = func(*local_args, **kwargs)
if seen_dtensor_arg:
# process output
flat_out, out_spec = pytree.tree_flatten(out)
flat_dist_out = []
out_placements_tuple = (
out_placements
if isinstance(out_placements, tuple)
else (out_placements,)
)
assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
f" received {len(out_placements_tuple)} out_placements but"
f" {len(flat_out)} is expected!"
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor):
assert not isinstance(
out, DTensor
), f"torch.Tensor output expected but received {type(out)}: {out}"
flat_dist_out.append(
DTensor.from_local(out, device_mesh, spec, run_check=False)
)
else:
assert (
spec is None
), f"Non-tensor output {out} expects None placements but received {spec}!"
flat_dist_out.append(out)
return pytree.tree_unflatten(flat_dist_out, out_spec)
else:
return out
return wrapped
|