1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
# mypy: allow-untyped-defs
import copy
import itertools
import math
from typing import Optional
import torch
import torch.distributed as dist
from torch._utils import _get_device_module
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
def _get_remote_device_str(rank, device_type, num_devices_per_node):
if device_type.lower() == "cpu":
return f"rank:{rank}/{device_type}"
elif device_type.lower() == "hpu":
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
else:
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
def _create_chunk_sharded_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> ShardedTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local shard to create a ShardedTensor.
"""
chunks = tensor.chunk(world_size, dim=0)
if len(chunks) > rank:
local_shard = chunks[rank].clone()
offsets = [0 for _ in tensor.size()]
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
else:
local_shards = []
# Create a ShardedTensor without invoking communication.
chunk_sizes = [list(chunk.size()) for chunk in chunks]
dim0_offsets = [0] + list(
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
)[:-1]
offsets = [0] * (len(chunk_sizes[0]) - 1)
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
device_type = (
distributed_c10d._get_pg_default_device(pg).type
if device is None
else device.type
)
placements = [
_get_remote_device_str(
dist.get_global_rank(pg, r),
device_type,
num_devices_per_node,
)
for r in range(len(chunk_sizes))
]
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
shard_metadata = [
ShardMetadata(offset, size, placement)
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
]
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=tensor.size(),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
)
def _create_chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> DTensor:
"""
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local tensor to create a DTensor.
"""
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
tensor = tensor.detach().clone()
# FSDP placements: [Shard(0)]
# HSDP placements: [Replicate(), Shard(0)]
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
return DTensor.from_local(
tensor, device_mesh, replicate_placements, run_check=False
).redistribute(
placements=shard_placements,
)
def _all_gather_dtensor(
tensor: DTensor,
root_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""
All gather a DTensor in its sharded dimension and return the local tensor.
"""
assert (
root_mesh == tensor.device_mesh
), "The device mesh of a tensor should be a root mesh."
placements = list(copy.deepcopy(tensor.placements))
# FSDP placements: [Shard(0)] -> [Replicate()]
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
placements[-1] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)
return tensor.to_local()
|