1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
|
# mypy: allow-untyped-defs
import collections.abc
import copy
from typing import List, Optional, Sequence, TYPE_CHECKING
import torch
from torch.distributed import distributed_c10d as c10d, rpc
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
validate_non_overlapping_shards_metadata,
)
from .metadata import ShardedTensorMetadata, TensorProperties
from .shard import Shard
if TYPE_CHECKING:
from torch.distributed._shard.metadata import ShardMetadata
def _parse_and_validate_remote_device(pg, remote_device):
if remote_device is None:
raise ValueError("remote device is None")
worker_name = remote_device.worker_name()
rank = remote_device.rank()
device = remote_device.device()
# Validate rank, skip validation if rank is not part of process group.
if rank is not None and not c10d._rank_not_in_group(pg):
pg_global_ranks = c10d.get_process_group_ranks(pg)
if rank not in pg_global_ranks:
raise ValueError(
f"Global rank {rank} does not exist in input process group: {pg_global_ranks}"
)
if worker_name is not None:
if not rpc._is_current_rpc_agent_set():
raise RuntimeError(
f"RPC framework needs to be initialized for using worker names: {worker_name}"
)
workers = rpc._get_current_rpc_agent().get_worker_infos()
for worker in workers:
if worker.name == worker_name:
return worker.id, device
raise ValueError(f"Invalid worker name: {worker_name}")
return rank, device
def _validate_output_tensor_for_gather(
my_rank: int,
dst_rank: int,
size: torch.Size,
dst_tensor: Optional[torch.Tensor],
) -> None:
if dst_rank == my_rank:
if dst_tensor is None:
raise ValueError(
f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}"
)
if tuple(size) != (dst_tensor.size()):
raise ValueError(
f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())},"
f"but should be {tuple(size)}"
)
elif dst_tensor:
raise ValueError(
"Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks."
)
def _flatten_tensor_size(size) -> torch.Size:
"""
Checks if tensor size is valid, then flatten/return a torch.Size object.
"""
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
dims = list(*size)
else:
dims = list(size)
for dim in dims:
if not isinstance(dim, int):
raise TypeError(f"size has to be a sequence of ints, found: {dims}")
return torch.Size(dims)
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
if is_local:
assert isinstance(ranks, int)
if expected != actual:
raise ValueError(
f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! "
f"Found one local shard tensor {prop_name}={expected}, "
f"the other local shard tensor {prop_name}={actual}."
)
else:
# compare failure check across ranks, ranks list should have two rank
assert len(ranks) == 2
if expected != actual:
raise ValueError(
f"ShardedTensor {prop_name} property does not match from different ranks! "
f"Found {prop_name}={expected} on rank:{ranks[0]}, "
f"and {prop_name}={actual} on rank:{ranks[1]}."
)
def build_metadata_from_local_shards(
local_shards: List[Shard],
global_size: torch.Size,
current_rank: int,
pg: c10d.ProcessGroup,
) -> ShardedTensorMetadata:
assert len(local_shards) > 0, "must have local shards!"
local_shard_metadatas: List[ShardMetadata] = []
first_shard_dtype = local_shards[0].tensor.dtype
first_shard_layout = local_shards[0].tensor.layout
first_shard_requires_grad = local_shards[0].tensor.requires_grad
first_shard_is_pinned = local_shards[0].tensor.is_pinned()
# 1). Validate local tensors and associated metadatas
for local_shard in local_shards:
local_shard_tensor = local_shard.tensor
local_shard_meta = local_shard.metadata
local_shard_metadatas.append(local_shard_meta)
rank, local_device = _parse_and_validate_remote_device(
pg, local_shard_meta.placement
)
if (
local_shard_tensor.layout != torch.strided
or local_shard_tensor.layout != first_shard_layout
):
raise ValueError(
f"Only torch.strided layout is currently supported, but found "
f"{local_shard_tensor.layout} on rank:{current_rank}!"
)
if not local_shard_tensor.is_contiguous():
raise ValueError(
"Only torch.contiguous_format memory_format is currently supported!"
)
if rank != current_rank:
raise ValueError(
f"Local shard metadata's rank does not match with the rank in its process group! "
f"Found current rank in the process group: {current_rank}, "
f"local ShardMetadata placement's rank: {rank}"
)
if local_shard_tensor.device != local_device:
raise ValueError(
f"Local shard tensor device does not match with local Shard's placement! "
f"Found local shard tensor device: {local_shard_tensor.device}, "
f"local shard metadata placement device: {local_device}"
)
_raise_if_mismatch(
local_shard_meta.shard_sizes,
list(local_shard_tensor.size()),
"size",
current_rank,
)
_raise_if_mismatch(
local_shard_tensor.is_pinned(),
first_shard_is_pinned,
"pin_memory",
current_rank,
)
_raise_if_mismatch(
local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank
)
_raise_if_mismatch(
local_shard_tensor.requires_grad,
first_shard_requires_grad,
"requires_grad",
current_rank,
)
# 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then
# do all_gather to collect local_sharded_tensor_metadata from all ranks
local_tensor_properties = TensorProperties(
dtype=first_shard_dtype,
layout=first_shard_layout,
requires_grad=first_shard_requires_grad,
memory_format=torch.contiguous_format,
pin_memory=first_shard_is_pinned,
)
local_sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=local_shard_metadatas,
size=global_size,
tensor_properties=local_tensor_properties,
)
return local_sharded_tensor_metadata
def build_global_metadata(
gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]],
):
global_sharded_tensor_metadata = None
global_metadata_rank = 0
for rank, rank_metadata in enumerate(gathered_metadatas):
if rank_metadata is None:
continue
if global_sharded_tensor_metadata is None:
global_sharded_tensor_metadata = copy.deepcopy(rank_metadata)
global_metadata_rank = rank
else:
_raise_if_mismatch(
global_sharded_tensor_metadata.size,
rank_metadata.size,
"global_size",
[global_metadata_rank, rank],
is_local=False,
)
# don't need to check layout and memory format as we already checked in local shards validation stage
_raise_if_mismatch(
global_sharded_tensor_metadata.tensor_properties.dtype,
rank_metadata.tensor_properties.dtype,
"dtype",
[global_metadata_rank, rank],
is_local=False,
)
_raise_if_mismatch(
global_sharded_tensor_metadata.tensor_properties.requires_grad,
rank_metadata.tensor_properties.requires_grad,
"requires_grad",
[global_metadata_rank, rank],
is_local=False,
)
_raise_if_mismatch(
global_sharded_tensor_metadata.tensor_properties.pin_memory,
rank_metadata.tensor_properties.pin_memory,
"pin_memory",
[global_metadata_rank, rank],
is_local=False,
)
# pass all validations, extend shards metadata
global_sharded_tensor_metadata.shards_metadata.extend(
rank_metadata.shards_metadata
)
if global_sharded_tensor_metadata is not None:
# check if shards_metadata have overlap shards
validate_non_overlapping_shards_metadata(
global_sharded_tensor_metadata.shards_metadata
)
# check if the shards_metadata is compatible with global size of the sharded tensor.
check_tensor(
global_sharded_tensor_metadata.shards_metadata,
global_sharded_tensor_metadata.size,
)
else:
raise ValueError("ShardedTensor have no local shards on all ranks!")
return global_sharded_tensor_metadata
|