1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
|
# mypy: allow-untyped-defs
import logging
import math
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._C._distributed_c10d import _resolve_process_group
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
broadcast,
get_global_rank,
get_group_rank,
get_rank,
GroupMember,
ProcessGroup,
scatter,
Work,
)
logger = logging.getLogger(__name__)
if not torch._running_with_deploy():
@torch.library.register_fake("_dtensor::shard_dim_alltoall")
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
group_size = _get_group_size_by_name(group_name)
stacked_list = [torch.empty_like(input) for _ in range(group_size)]
group = _resolve_process_group(group_name)
group_rank = get_group_rank(group, get_rank())
return (
torch.cat(stacked_list, dim=gather_dim)
.chunk(group_size, dim=shard_dim)[group_rank]
.contiguous()
)
else:
import warnings
warnings.warn(
"PyTorch Distributed functional collectives do not work with torch::deploy."
)
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
if mesh.device_type == "cpu":
# Gloo does not support alltoall, so falling back to allgather + chunk
# TODO: This logs way too much
logger.warning(
"CPU process group does not support alltoall yet, falling back with allgather + chunk!"
)
out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
if isinstance(out, funcol.AsyncCollectiveTensor):
# stick to the same behavior for the alltoall case, remove this once we enable alltoall async
out = out.wait()
out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[
mesh.get_local_rank(mesh_dim)
]
return out.contiguous()
group_name = funcol._resolve_group_name((mesh, mesh_dim))
# TODO: enable async op for shard_dim_alltoall
return torch.ops._dtensor.shard_dim_alltoall(
input, gather_dim, shard_dim, group_name
)
def mesh_scatter(
output: torch.Tensor,
scatter_list: List[torch.Tensor],
mesh: DeviceMesh,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
scatter a list of tensors to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
2 to rank 2/3.
Args:
output (torch.Tensor): the tensor to receive the scattered list.
scatter_list (List[torch.Tensor]): the tensor list to be scattered.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if output.is_meta:
return None
dim_group = mesh.get_group(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
if src_for_dim == get_rank():
fut = scatter(
output,
scatter_list=scatter_list,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
else:
fut = scatter(
output,
scatter_list=None,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
return fut
def mesh_broadcast(
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
broadcast the tensor to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
to rank 2/3.
Args:
tensor (torch.Tensor): tensor to broadcast.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if tensor.is_meta:
return None
dim_group = mesh.get_group(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
pad = [0, 0] * (tensor.ndim - pad_dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
return tensor.narrow(
pad_dim,
start=0,
length=tensor.size(pad_dim) - pad_size,
)
def fill_empty_tensor_to_shards(
shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
) -> List[torch.Tensor]:
if num_empty_tensors == 0:
return shards
tensor_size = list(shards[0].size())
tensor_size = [
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
]
tensor = shards[0].new_zeros(tensor_size)
shards.extend(tensor for _ in range(num_empty_tensors))
return shards
def check_tensor_meta(
local_tensor, check_shape_stride=False
) -> Optional["dtensor_spec.TensorMeta"]:
local_metadata = {
"dtype": local_tensor.dtype,
"requires_grad": local_tensor.requires_grad,
}
if check_shape_stride:
local_metadata.update(
{"shape": local_tensor.shape, "stride": local_tensor.stride()}
)
gathered_metadata = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_metadata, local_metadata)
# Check if metadata is consistent across ranks
if not all(meta == local_metadata for meta in gathered_metadata):
raise ValueError(
"Inconsistent tensor metadata (including shape and stride) across ranks."
)
return None
def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int:
assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
@dataclass
class MeshTopoInfo:
"""
Mesh information for collective cost estimation
"""
mesh: DeviceMesh
mesh_dim_devices: List[int]
mesh_dim_bandwidth: List[float]
mesh_dim_latency: List[float]
@staticmethod
@lru_cache(None)
def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
# Generate mesh topology info for intra-host/inter-host communication pattern
# Note that we made bunch of assumptions for simplicity:
# 1. we assume the mesh is homogeneous, and it's gpu/nccl model
# 2. we assume gpu arch is Ampere or Hopper
# 3. we assume collectives are all ring base algo for now
num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
# the base bw number (intra-node), GB/s
base_bw = 87.7
mesh_dim_bandwidth = [base_bw] * mesh.ndim
# the latency in terms of us (intra-node, nv-link)
mesh_dim_latency = [0.6] * mesh.ndim
mesh_dim_devices = [1] * mesh.ndim
total_num_devices = 1
for mesh_dim in reversed(range(mesh.ndim)):
num_devices = mesh.size(mesh_dim)
mesh_dim_devices[mesh_dim] = num_devices
total_num_devices *= num_devices
if total_num_devices > num_devices_per_host:
# magic number for inter-host communication bandwidth/latency factor
# This number assumes latest GPU arch, i.e. Ampere or Hopper
# TODO: see if we need to tweak this or offer a way for user
# to specify the bandwidths/latency
mesh_dim_bandwidth[mesh_dim] *= 0.22
# set to ethernet latency for inter-host
mesh_dim_latency[mesh_dim] = 2.7
return MeshTopoInfo(
mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
)
def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
num_hops = num_devices_on_mesh_dim - 1
# base latency + comm latency
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
return latency + bw * 1e6 # rescale to us
def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
# allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
num_hops = 2 * num_devices_on_mesh_dim - 1
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
return latency + bw * 1e6
def reduce_scatter_cost(
bytes_gb: float,
mesh_topo: MeshTopoInfo,
mesh_dim: int,
) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
num_hops = num_devices_on_mesh_dim - 1
# base latency + comm latency
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
return latency + bw * 1e6
def redistribute_cost(
current_spec: "dtensor_spec.DTensorSpec",
target_spec: "dtensor_spec.DTensorSpec",
) -> float:
"""
This function returns the cost of redistribute from current to target DTensorSpec.
NOTE:
1. Only consider communication cost here, since computation costs for redistribute
are quite trival (i.e. we only need to narrow or simple division)
2. Only consider redistribute cost on same mesh, cross mesh communication cost is
not quite needed for operator strategy estimation/selection.
"""
if current_spec.mesh != target_spec.mesh:
# make infinite cost if meshes are not same
# TODO: see if we want to support this once there's cross mesh communication
return float("inf")
if current_spec.is_replicated():
# short-cut:
# comm cost is 0 if current spec is already full replication
return 0.0
mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
cost = 0.0
comm_bytes_gb = (
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
)
# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
for i, (current, target) in enumerate(
zip(current_spec.placements, target_spec.placements)
):
if current == target:
continue
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
if current.is_shard() and target.is_replicate():
# allgather gives larger comm bytes
comm_bytes_gb *= num_devices_on_mesh_dim
# add up allgather comm cost
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
elif current.is_shard() and target.is_shard():
# should be alltoall comm, since we haven't implement it yet, add penalty
# to favor allgather instead
cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0
elif current.is_partial() and target.is_replicate():
# add up allreduce comm cost
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
elif current.is_partial() and target.is_shard():
# add up reduce_scatter comm cost
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
# after reduce_scatter the comm bytes for further collectives halved.
comm_bytes_gb /= num_devices_on_mesh_dim
elif current.is_shard() and target.is_partial():
# ban shard -> partial as it does not make sense to perform
# this redistribute
return float("inf")
return cost
|