1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
from typing import List
from torch.distributed._shard.metadata import ShardMetadata
def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata):
"""
Checks if two shards overlap.
"""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
# shard, we would check if one shard is above or on the left of the
# other shard.
ndims = len(shard1.shard_offsets)
for i in range(ndims):
if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]:
return False
if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]:
return False
return True
def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
"""
Ensures none of the shards overlap with each other.
Args:
shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
each shard.
Raises:
``ValueError`` if there's overlap in any two shards.
"""
# TODO: evaluate optimizing this if needed.
for i in range(len(shards)):
for j in range(i + 1, len(shards)):
if _check_shard_metadata_pair_overlap(shards[i], shards[j]):
raise ValueError(f'Shards {shards[i]} and {shards[j]} overlap')
def check_tensor(shards_metadata, tensor_dims) -> None:
"""
Checks if the shards_metadata is compatible with the provided tensor dims.
Args:
shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata`
objects representing each shard of the tensor.
tensor_dims(Sequence of int): Dimensions of tensor to verify
Raises:
``ValueError`` if not compatible.
"""
# If the tensor's volume matches the total volume of all shards and
# all shard boundaries are within tensor dims, we have a compatible
# sharding spec for this tensor. Note that we have already verified
# we don't have overlapping shards.
tensor_rank = len(tensor_dims)
shards_rank = len(shards_metadata[0].shard_offsets)
if tensor_rank != shards_rank:
raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}')
total_shard_volume = 0
for shard in shards_metadata:
shard_volume = 1
for i, shard_length in enumerate(shard.shard_sizes):
shard_volume *= shard_length
if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]:
raise ValueError(
f'Shard offset {shard.shard_offsets[i]} and length '
f'{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}')
total_shard_volume += shard_volume
tensor_volume = 1
for size in tensor_dims:
tensor_volume *= size
if total_shard_volume != tensor_volume:
# TODO: Can we improve this error message to point out the gaps?
raise ValueError(
f'Total volume of shards: {total_shard_volume} '
f'does not match tensor volume: {tensor_volume}, in other words '
f'all the individual shards do not cover the entire tensor')
def get_split_size(dim_size, chunks):
"""
Computes the split size inline with ``torch.chunk``
Args:
dim_size(int): Size of the dimension being chunked.
chunks(int): Number of chunks to create for ``dim_size``.
Returns:
An int indicating the split size to use.
"""
return (dim_size + chunks - 1) // chunks
def get_chunked_dim_size(dim_size, split_size, idx):
"""
Computes the dim size of the chunk for provided ``idx`` given ``dim_size``
and ``split_size``.
Args:
dim_size(int): Size of the dimension being chunked.
split_size(int): The chunk size for each chunk of ``dim_size``.
idx(int): The index of chunk whose dim size is being requested.
Returns:
An int indicating the dim size of the chunk.
"""
return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0)
def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank):
"""
Generate the start pos and offset length for the current rank for
chunk sharding.
Args:
sharding_dim_size(int): The dimension length which we shard on.
world_size(int): number of ranks.
spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`):
sharding spec.
rank(int): # of cuda process.
Returns:
start_pos(int): start position of sharded tensor on the given rank.
chunk_size(int): chunk size of sharded tensor on the given rank.
"""
split_size = get_split_size(sharding_dim_size, world_size)
current_offsets = 0
start_pos = current_offsets
for idx, placement in enumerate(spec.placements):
chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
if rank == placement.rank():
start_pos = current_offsets
break
current_offsets += chunk_size
return start_pos, chunk_size
|