1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
from typing import List, Tuple
from torch.distributed._shard.sharding_spec import (
ShardMetadata,
)
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ShardMetadata, current_shard: ShardMetadata
) -> List[Tuple[int, int, int, int]]:
"""
Return the overlapping region between saved_shard and current_shard.
There returned list has the same number of elements as the tensor's dimension.
For each element, we produce a tuple with the following contents:
(dimension, `saved_shard` offset, `current_shard` offset, length)
Offsets are relative to each shard.
"""
narrows = []
for dim, (
saved_shard_offset,
current_shard_offset,
saved_shard_size,
current_shard_size,
) in enumerate(
zip(
saved_shard.shard_offsets,
current_shard.shard_offsets,
saved_shard.shard_sizes,
current_shard.shard_sizes,
)
):
min_range_end = min(
saved_shard_offset + saved_shard_size,
current_shard_offset + current_shard_size,
)
length = min_range_end - max(current_shard_offset, saved_shard_offset)
if saved_shard_offset > current_shard_offset:
offset_for_saved_tensor = 0
offset_for_current_tensor = saved_shard_offset - current_shard_offset
else:
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
offset_for_current_tensor = 0
narrows.append(
(dim, offset_for_saved_tensor, offset_for_current_tensor, length)
)
return narrows
|