1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
import collections.abc
import copy
from typing import Optional, List, Sequence
import torch
from torch.distributed import distributed_c10d
from torch.distributed import rpc
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
validate_non_overlapping_shards_metadata,
)
from torch.distributed._shard.metadata import ShardMetadata
from .metadata import TensorProperties, ShardedTensorMetadata
from .shard import Shard
def _parse_and_validate_remote_device(pg, remote_device):
if remote_device is None:
raise ValueError("remote device is None")
worker_name = remote_device.worker_name()
rank = remote_device.rank()
device = remote_device.device()
# Validate rank, skip validation if rank is not part of process group.
if not distributed_c10d._rank_not_in_group(pg):
if rank is not None and (rank < 0 or rank >= distributed_c10d.get_world_size(pg)):
raise ValueError(f'Invalid rank: {rank}')
if worker_name is not None:
if not rpc._is_current_rpc_agent_set():
raise RuntimeError(f'RPC framework needs to be initialized for using worker names: {worker_name}')
workers = rpc._get_current_rpc_agent().get_worker_infos()
for worker in workers:
if worker.name == worker_name:
return worker.id, device
raise ValueError(f'Invalid worker name: {worker_name}')
return rank, device
def _validate_output_tensor_for_gather(
my_rank: int,
dst_rank: int,
size: torch.Size,
dst_tensor: Optional[torch.Tensor],
) -> None:
if dst_rank == my_rank:
if dst_tensor is None:
raise ValueError(
f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}"
)
if tuple(size) != (dst_tensor.size()):
raise ValueError(
f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())},"
f"but should be {tuple(size)}"
)
elif dst_tensor:
raise ValueError(
"Argument ``dst_tensor`` must NOT be specified "
"on non-destination ranks."
)
def _flatten_tensor_size(size) -> torch.Size:
"""
Checks if tensor size is valid, then flatten/return a torch.Size object.
"""
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
dims = list(*size)
else:
dims = list(size)
for dim in dims:
if not isinstance(dim, int):
raise TypeError(f'size has to be a sequence of ints, found: {dims}')
return torch.Size(dims)
def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
if is_local:
assert isinstance(ranks, int)
if expected != actual:
raise ValueError(f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! "
f"Found one local shard tensor {prop_name}={expected}, "
f"the other local shard tensor {prop_name}={actual}.")
else:
# compare failure check across ranks, ranks list should have two rank
assert len(ranks) == 2
if expected != actual:
raise ValueError(f"ShardedTensor {prop_name} property does not match from different ranks! "
f"Found {prop_name}={expected} on rank:{ranks[0]}, "
f"and {prop_name}={actual} on rank:{ranks[1]}.")
def build_metadata_from_local_shards(
local_shards: List[Shard],
global_size: torch.Size,
current_rank: int,
pg: distributed_c10d.ProcessGroup
) -> ShardedTensorMetadata:
assert len(local_shards) > 0, "must have local shards!"
local_shard_metadatas: List[ShardMetadata] = []
first_shard_dtype = local_shards[0].tensor.dtype
first_shard_layout = local_shards[0].tensor.layout
first_shard_requires_grad = local_shards[0].tensor.requires_grad
first_shard_is_pinned = local_shards[0].tensor.is_pinned()
# 1). Validate local tensors and associated metadatas
for i, local_shard in enumerate(local_shards):
local_shard_tensor = local_shard.tensor
local_shard_meta = local_shard.metadata
local_shard_metadatas.append(local_shard_meta)
rank, local_device = _parse_and_validate_remote_device(pg, local_shard_meta.placement)
if local_shard_tensor.layout != torch.strided or local_shard_tensor.layout != first_shard_layout:
raise ValueError(
f'Only torch.strided layout is currently supported, but found '
f'{local_shard_tensor.layout} on rank:{current_rank}!'
)
if not local_shard_tensor.is_contiguous():
raise ValueError('Only torch.contiguous_format memory_format is currently supported!')
if rank != current_rank:
raise ValueError(
f"Local shard metadata's rank does not match with the rank in its process group! "
f'Found current rank in the process group: {current_rank}, '
f"local ShardMetadata placement's rank: {rank}"
)
if local_shard_tensor.device != local_device:
raise ValueError(
f"Local shard tensor device does not match with local Shard's placement! "
f"Found local shard tensor device: {local_shard_tensor.device}, "
f"local shard metadata placement device: {local_device}"
)
_raise_if_mismatch(local_shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank)
_raise_if_mismatch(local_shard_tensor.is_pinned(), first_shard_is_pinned, "pin_memory", current_rank)
_raise_if_mismatch(local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank)
_raise_if_mismatch(local_shard_tensor.requires_grad, first_shard_requires_grad, "requires_grad", current_rank)
# 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then
# do all_gather to collect local_sharded_tensor_metadata from all ranks
local_tensor_properties = TensorProperties(
dtype=first_shard_dtype,
layout=first_shard_layout,
requires_grad=first_shard_requires_grad,
memory_format=torch.contiguous_format,
pin_memory=first_shard_is_pinned
)
local_sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=local_shard_metadatas,
size=global_size,
tensor_properties=local_tensor_properties)
return local_sharded_tensor_metadata
def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]):
global_sharded_tensor_metadata = None
global_metadata_rank = 0
for rank, rank_metadata in enumerate(gathered_metadatas):
if rank_metadata is None:
continue
if global_sharded_tensor_metadata is None:
global_sharded_tensor_metadata = copy.deepcopy(rank_metadata)
global_metadata_rank = rank
else:
_raise_if_mismatch(global_sharded_tensor_metadata.size,
rank_metadata.size,
"global_size",
[global_metadata_rank, rank],
is_local=False)
# don't need to check layout and memory format as we already checked in local shards validation stage
_raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.dtype,
rank_metadata.tensor_properties.dtype,
"dtype",
[global_metadata_rank, rank],
is_local=False)
_raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.requires_grad,
rank_metadata.tensor_properties.requires_grad,
"requires_grad",
[global_metadata_rank, rank],
is_local=False)
_raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.pin_memory,
rank_metadata.tensor_properties.pin_memory,
"pin_memory",
[global_metadata_rank, rank],
is_local=False)
# pass all validations, extend shards metadata
global_sharded_tensor_metadata.shards_metadata.extend(rank_metadata.shards_metadata)
if global_sharded_tensor_metadata is not None:
# check if shards_metadata have overlap shards
validate_non_overlapping_shards_metadata(global_sharded_tensor_metadata.shards_metadata)
# check if the shards_metadata is compatible with global size of the sharded tensor.
check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size)
else:
raise ValueError("ShardedTensor have no local shards on all ranks!")
return global_sharded_tensor_metadata
|