1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
|
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from functools import lru_cache
from typing import cast, List, NamedTuple, Tuple
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._api as dtensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
logger = logging.getLogger(__name__)
class _TransformInfo(NamedTuple):
mesh_dim: int
src_dst_placements: Tuple[Placement, Placement]
# logical_shape on this mesh dimension
logical_shape: List[int]
def _gen_transform_infos_non_cached(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
"""
Generate the transform infos from the source placements to the target placements.
To transform from source to target placement it might have multiple steps, i.e. it
might decompose Si -> Sj into Si -> R -> Sj.
This would detect if there're mis-aligned/nested shardings between src/dst placements.
E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)),
in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in
the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
the latter is the first sharding on tensor dimension 0.
"""
transform_infos: List[_TransformInfo] = []
device_mesh = src_spec.device_mesh
my_coordinate = device_mesh.get_coordinate()
assert my_coordinate is not None
# logical shape records the logic tensor shape on the mesh dimension
# this is useful to ensure uneven sharding gets correct output shape
initial_logical_shape = list(src_spec.shape)
mesh_dims_to_logical_shape = [initial_logical_shape]
if device_mesh.ndim == 1:
# if device_mesh is 1D, redistribute is a simple direct transformation
transform_infos.append(
_TransformInfo(
mesh_dim=0,
src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]),
logical_shape=initial_logical_shape,
)
)
return transform_infos
# Handle multi-dim device mesh placement redistribution
# First, we need to build the logical shape for each mesh dim
# for correct allgathering uneven shards on each mesh dim (with dynamic padding)
for i, src in enumerate(src_spec.placements):
current_logical_shape = mesh_dims_to_logical_shape[i]
if isinstance(src, Shard):
if i < device_mesh.ndim - 1:
# calculate and save the logical shape for this sharding
mesh_dim_size = device_mesh.size(mesh_dim=i)
local_shard_size, _ = src._local_shard_size_on_dim(
current_logical_shape[src.dim],
mesh_dim_size,
my_coordinate[i],
)
new_logical_shape = list(current_logical_shape)
new_logical_shape[src.dim] = local_shard_size
mesh_dims_to_logical_shape.append(new_logical_shape)
else:
mesh_dims_to_logical_shape.append(current_logical_shape)
# Next, we need to derive the transform infos from src to dst placements,
# here we use a greedy search with step by step state transformations
current_placements = list(src_spec.placements)
target_placements = list(dst_spec.placements)
if src_spec.num_shards > 1:
# If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec
# a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))).
# In those cases, we first traverse from inner placement to outer placement
# to detect misaligned shardings and properly replicate nested sharding first.
for mesh_dim in reversed(range(len(current_placements))):
current = current_placements[mesh_dim]
target = target_placements[mesh_dim]
# If target is not Shard, we can directly redistribute since we are traversing from innner
# to outer placements here
if isinstance(target, Shard):
# If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim
shard_dim = target.dim
current_mesh_sharding, target_mesh_sharding = [], []
for i, (s, p) in enumerate(zip(current_placements, target_placements)):
if i >= mesh_dim:
break
if s.is_shard(shard_dim):
current_mesh_sharding.append(i)
if p.is_shard(shard_dim):
target_mesh_sharding.append(i)
if current_mesh_sharding != target_mesh_sharding:
# if current/target_placements have misaligned sharding on the tensor dim BEFORE the current
# mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding
target = Replicate()
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
current_placements[mesh_dim] = target
# We always traverse from outer placement to inner placement to collect the remaining
# needed transform infos (i.e. the replication from nested sharding might need to further
# perform resharding to Shard again)
for mesh_dim, (current, target) in enumerate(
zip(current_placements, target_placements)
):
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
current_placements[mesh_dim] = target
return transform_infos
@lru_cache(maxsize=None)
def _gen_transform_infos(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
return _gen_transform_infos_non_cached(src_spec, dst_spec)
def redistribute_local_tensor(
local_tensor: torch.Tensor,
current_spec: DTensorSpec,
target_spec: DTensorSpec,
*,
async_op: bool = False,
is_backward: bool = False,
) -> torch.Tensor:
"""
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
the target DTensorSpec, which involves the necessary collective calls to transform
the local shard of the DTensor from its current spec to the target spec.
"""
if current_spec.mesh != target_spec.mesh:
# TODO: alltoall/permute reshuffling to change device_mesh if they are not the same
raise NotImplementedError("Cross device mesh comm not supported yet!")
new_local_tensor = None
device_mesh = current_spec.mesh
my_coordinate = device_mesh.get_coordinate()
if my_coordinate is None:
# if rank is not part of mesh, we skip redistribute and simply return local_tensor,
# which should be an empty tensor
return local_tensor
has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any(
isinstance(s, torch.SymInt) for s in target_spec.shape
)
if has_symints:
transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec)
else:
transform_infos = _gen_transform_infos(current_spec, target_spec)
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
device_mesh.size(mesh_dim=i)
if current == target:
# short cut, just use the original local tensor
new_local_tensor = local_tensor
continue
logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert (
current.is_shard()
), f"Current placement should be shard but found {current}"
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
# skip the replicate to partial transformation when we are in backward pass
# In this case we keep the grad as replicate, this is because we don't
# want to convert the replicated gradients back to partial, although
# that's logically conform with the same layout, converting the gradients
# back to partial is actually useless as you would have to do reduce later
# which would be more expensive than keeping it replicate! For this reason,
# we keep the replicate grad here.
new_local_tensor = (
partial_spec._partition_value(local_tensor, device_mesh, i)
if not is_backward
else local_tensor
)
elif current.is_shard():
if not is_backward:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
# for backward shard -> partial, we just need to convert the shard to replicate
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
assert new_local_tensor is not None
local_tensor = new_local_tensor
assert new_local_tensor is not None, "redistribute failed!"
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()
return new_local_tensor
class Redistribute(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
# pyre-fixme[2]: Parameter must be annotated.
ctx,
input: "dtensor.DTensor",
device_mesh: DeviceMesh,
placements: Tuple[Placement, ...],
async_op: bool = False,
):
current_spec = input._spec
ctx.current_spec = current_spec
ctx.async_op = async_op
if current_spec.placements != placements:
target_spec = DTensorSpec(
device_mesh, placements, tensor_meta=input._spec.tensor_meta
)
local_tensor = input._local_tensor
output = redistribute_local_tensor(
local_tensor, current_spec, target_spec, async_op=async_op
)
else:
# use the same local tensor if placements are the same.
output = input._local_tensor
target_spec = current_spec
return dtensor.DTensor(
output,
target_spec,
requires_grad=input.requires_grad,
)
@staticmethod
def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
previous_spec = ctx.current_spec
current_spec = grad_output._spec
async_op = ctx.async_op
local_tensor = grad_output._local_tensor
output = redistribute_local_tensor(
local_tensor,
current_spec,
previous_spec,
async_op=async_op,
is_backward=True,
)
# normalize the target placement to replicate if it is partial
normalized_placements: List[Placement] = []
for previous_placement in previous_spec.placements:
if previous_placement.is_partial():
# keep target placement to replicate instead of partial in this case
normalized_placements.append(Replicate())
else:
normalized_placements.append(previous_placement)
spec = DTensorSpec(
previous_spec.device_mesh,
tuple(normalized_placements),
tensor_meta=TensorMeta(
shape=grad_output.shape,
stride=grad_output.stride(),
dtype=grad_output.dtype,
),
)
output_dtensor = dtensor.DTensor(
output,
spec,
requires_grad=grad_output.requires_grad,
)
return (
output_dtensor,
None,
None,
None,
)
|