1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
|
from dataclasses import dataclass
import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._shard.sharded_tensor.utils import (
_parse_and_validate_remote_device
)
from torch.distributed._shard._utils import narrow_tensor
import torch.distributed as dist
import torch.distributed.distributed_c10d as distributed_c10d
from typing import List, Union, TYPE_CHECKING
from ._internals import (
get_chunked_dim_size,
get_split_size,
)
from .api import ShardingSpec
if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
# from run-time to resolve circular dependency.
from torch.distributed._shard.sharded_tensor import ShardedTensor
@dataclass
class ChunkShardingSpec(ShardingSpec):
"""
This is a type of PlacementSpec that defines the placement as being sharded
across multiple devices. In particular, it represents sharding a Tensor
along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
The semantics of how a tensor is partitioned is inline with
:meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the
specified ``dim`` and ``chunks`` in torch.chunk is the number of elements
in the placement specified.
Args:
dim (int or str):
The dimension to shard on, could be an integer representing the
dimension or a string in case of named tensors where dimensions are
named. Note that named tensor support is not added yet.
placement(List[Union[_remote_device, str]]):
Specifies the placement of each shard of the Tensor. The size of
the list represents the number of shards to be created. This could
be a list of
:class:`torch.distributed._remote_device`'s. This list
could also contain a string which represents remote
device as accepted by
:class:`torch.distributed._remote_device`
"""
ShardingDim = Union[int, str]
dim: ShardingDim
placements: List[Union[torch.distributed._remote_device, str]]
def __post_init__(self):
self._verify_dim(self.dim)
for i, remote_device in enumerate(self.placements):
if not isinstance(remote_device, torch.distributed._remote_device):
self.placements[i] = torch.distributed._remote_device(remote_device)
@staticmethod
def _verify_dim(dim):
# Validate the sharding spec.
# TODO: support named dimension
if isinstance(dim, str):
raise NotImplementedError(
"ChunkShardingSpec does not support named dimension yet!"
)
if not isinstance(dim, int):
raise ValueError(
f"Sharding dim needs to be an integer, found: {dim}"
)
def build_metadata(self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
tensor_num_dim = len(tensor_sizes)
self._verify_dim(self.dim)
if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator]
raise ValueError(f"Invalid sharding dim: {self.dim}")
shards_metadata = []
sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index]
chunks = len(self.placements)
split_size = get_split_size(sharding_dim_size, chunks)
for idx, placement in enumerate(self.placements):
# generate ShardMetadata for each placement device
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
if chunked_dim_size > 0:
shard_size = list(tensor_sizes)
current_offsets = [0] * tensor_num_dim
current_offsets[self.dim] = split_size * idx # type: ignore[index]
shard_size[self.dim] = chunked_dim_size # type: ignore[index]
shard_metadata = ShardMetadata(
shard_offsets=current_offsets,
shard_sizes=shard_size,
placement=placement,
)
shards_metadata.append(shard_metadata)
# current_offsets[self.dim] += chunked_dim_size # type: ignore[index]
return sharded_tensor_meta.ShardedTensorMetadata(
shards_metadata,
tensor_sizes,
tensor_properties
)
def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
"""
Args:
src_rank: group rank relative to ``process_group``
N.B. If ``process_group`` is None, ``src_rank`` is a global rank.
"""
# relative imports to avoid circular dependency
from torch.distributed._shard.sharded_tensor import (
ShardedTensor
)
tensor_properties = sharded_tensor_meta.TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned()
)
current_rank = dist.get_rank(process_group)
tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
local_shards = []
local_tensor = None
local_metadata = None
tensors_to_scatter = [None] * dist.get_world_size(process_group)
sharding_dim_size = tensor.size()[self.dim] # type: ignore[index]
chunks = len(self.placements)
split_size = get_split_size(sharding_dim_size, chunks)
scatter_shape = list(tensor.size())
scatter_shape[self.dim] = split_size # type: ignore[index]
for shard_meta in tensor_meta.shards_metadata:
rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
if current_rank == src_rank:
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
narrowed_tensor = narrow_tensor(tensor, shard_meta)
if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index]
# for the last shard that might be smaller to other shards
# resize the narrowed tensor to the same size and use it for
# the scatter collective as dist.scatter requires same size
# inputs on every rank
tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
else:
tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
tensors_to_scatter[rank] = tensor_to_scatter
if current_rank == rank:
local_tensor = torch.empty(
scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
local_metadata = shard_meta
# each rank should have local_tensor and local_metadata initialized if we build
# the metadata list in a correct way.
assert local_tensor is not None
assert local_metadata is not None
# Scatter the shards to all ranks in the pg
# scatter takes the global rank as ``src``
src_for_scatter = src_rank
if process_group is not None and process_group is not distributed_c10d._get_default_group():
src_for_scatter = distributed_c10d.get_global_rank(process_group, src_for_scatter)
dist.scatter(
local_tensor,
scatter_list=tensors_to_scatter if current_rank == src_rank else None,
src=src_for_scatter,
group=process_group
)
if list(local_tensor.size()) != local_metadata.shard_sizes:
# detach again after receiving to ensure local shards remain a leaf node
local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()
# Sync requires_grad to local_shard.
local_tensor.requires_grad = tensor.requires_grad
local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
st = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards,
tensor_meta,
process_group=process_group)
# Manually set sharding_spec
st._sharding_spec = self
return st
|