1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
|
from abc import ABC, abstractmethod
from dataclasses import dataclass
import functools
from typing import Callable, Dict, List, TYPE_CHECKING
import torch
from ._internals import (
check_tensor,
get_chunked_dim_size,
get_split_size,
validate_non_overlapping_shards_metadata
)
from torch.distributed._shard.metadata import ShardMetadata
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
from torch.distributed._shard.op_registry_utils import _decorator_func
if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
# from run-time to resolve circular dependency.
from torch.distributed._shard.sharded_tensor import ShardedTensor
class PlacementSpec(ABC):
"""
Base class representing the placement of an entity. Subclasses of this
class can be used to specify customized placements which might not be
covered by existing APIs.
"""
pass
@dataclass
class DevicePlacementSpec(PlacementSpec):
"""
Associates placement of an entity with a single device.
Args:
device(:class:`torch.distributed._remote_device`): The device to place the entity on.
"""
device: torch.distributed._remote_device
def __post_init__(self):
if not isinstance(self.device, torch.distributed._remote_device):
self.device = torch.distributed._remote_device(self.device)
class ShardingSpec(ABC):
"""
Base class representing sharding specifications.
"""
@abstractmethod
def build_metadata(self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
"""
Given a global tensor size, define how to shard a tensor like this shape
across ranks, return ShardedTensorMetadata
Args:
tensor_sizes (:class:`torch.Size`):
The tensor shape to shard on, a `torch.Size` object that represents the
tensor shape to be sharded according to the ShardingSpec.
tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties):
Tensor properties used to create a ShardedTensor.
Returns:
A :class:`ShardedTensorMetadata` object that encodes the information about
the layout of the ShardedTensor and its properties.
"""
@abstractmethod
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
"""
Given a global tensor on src_rank, shard this tensor
across ranks within the process group, return a ShardedTensor.
Args:
tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
A :class:`ShardedTensor` sharded from the given tensor.
"""
# Ops customized for a particular ShardingSpec.
_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
def _has_custom_op(sharding_spec, op):
"""
Returns whether or not the ShardingSpec has a custom op implementation.
"""
class_name = type(sharding_spec).__qualname__
return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name]
def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group):
"""
Calls the custom op for this ShardingSpec if it exists.
"""
class_name = type(sharding_spec).__qualname__
if not _has_custom_op(sharding_spec, op):
raise RuntimeError(f'Custom op: {op} not registered for {class_name}')
func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op]
return func(types, args, kwargs, process_group)
def custom_sharding_spec_op(sharding_spec_class, func):
"""
Decorator to allow custom registration of ops.
Args:
sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
func(Callable): The op to override (ex: torch.bmm)
"""
class_name = sharding_spec_class.__qualname__
if class_name not in _CUSTOM_SHARDING_SPEC_OPS:
_CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
return functools.partial(
_decorator_func,
op=func,
op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name]
)
@dataclass
class EnumerableShardingSpec(ShardingSpec):
"""
This is a type of PlacementSpec that allows users to specify a generic
sharding scheme by enumerating exactly how each shard is laid out.
Args:
shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
each shard. Note that none of the shards should overlap.
"""
shards: List[ShardMetadata]
def __post_init__(self):
if len(self.shards) == 0:
raise ValueError(f'Empty shard list provided: {self.shards}')
# Validate each shard has same rank.
rank = -1
for shard in self.shards:
if rank != -1 and rank != len(shard.shard_offsets):
raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}')
rank = len(shard.shard_offsets)
validate_non_overlapping_shards_metadata(self.shards)
def build_metadata(self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
# check if shards form a valid tensor
check_tensor(self.shards, tensor_sizes)
return sharded_tensor_meta.ShardedTensorMetadata(
self.shards,
tensor_sizes,
tensor_properties
)
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
# TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec
raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!")
def _infer_sharding_spec_from_shards_metadata(shards_metadata):
"""
Infer the sharding spec from the metadata of each shard of a ShardedTensor.
If the tensor is sharded only on one dimension, we can then verify whether it's
a ChunkShardingSpec or not. The way to verify it is to first get the total length
and perform a chunk sharding with the given placements to see if we can have the
same chunk size as the given shards_metadata. If not, we assume it's enum sharded.
Args:
shards_metadata (List[ShardMetadata]): List of Metadata of local shards.
Returns:
A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding
spec for one sharded tensor.
"""
placements = []
chunk_sharding_dim = None
chunk_offset_list = []
shard_size_list = []
# collect local shard metadatas from the global sharded_tensor_metadata
for shard_metadata in shards_metadata: # type: ignore[attr-defined]
placements.append(shard_metadata.placement)
local_offsets = shard_metadata.shard_offsets
chunk_offset_list.append(sum(local_offsets))
shard_size_list.append(shard_metadata.shard_sizes)
shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0]
# If the offset is [0, 0, ..., 0] (all zeros),
# we cannot decide whether how the tensor is sharded.
if len(shard_dims) == 0:
continue
# If the offset is [0, N, .,0, M, 0, .., 0],
# we are sure it's sharded by more than one dimension.
if len(shard_dims) != 1:
chunk_sharding_dim = None
break
# If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just
# one dimension, we need to make sure all ranks share the same dimension.
if not chunk_sharding_dim:
chunk_sharding_dim = shard_dims[0]
elif chunk_sharding_dim != shard_dims[0]:
chunk_sharding_dim = None
break
if chunk_sharding_dim is not None:
# Ensure we infer the correct placement order from offsets
placements = [
x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
]
from .chunk_sharding_spec import ChunkShardingSpec
chunk_spec = ChunkShardingSpec(
dim=chunk_sharding_dim,
placements=placements,
)
shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list])
shard_total_length = sum(shard_sizes)
chunks = len(placements)
split_size = get_split_size(shard_total_length, chunks)
chunk_shard_sizes = sorted(
[
get_chunked_dim_size(shard_total_length, split_size, idx)
for idx in range(len(placements))
]
)
if shard_sizes == chunk_shard_sizes:
return chunk_spec
return EnumerableShardingSpec(shards_metadata)
|