1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
|
import functools
from typing import Callable, Dict, TYPE_CHECKING
import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed import distributed_c10d
from torch.distributed.nn.functional import (
reduce_scatter,
)
from torch.distributed._shard.common_op_utils import _register_default_op
from torch.distributed._shard.op_registry_utils import _decorator_func
from torch.utils._pytree import tree_map
if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
# from run-time to resolve circular dependency.
from torch.distributed._shard.sharded_tensor import ShardedTensor
# Custom PartialTensor ops
_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}
def _custom_partial_tensor_op(func):
"""
Decorate for custom partial tensor op
Args:
func(Callable): Torch function for which we want to provide a PartialTensor
implementation (ex: torch.nn.functional.linear)
"""
return functools.partial(
_decorator_func,
op=func,
op_table=_PARTIAL_TENSOR_OPS
)
class _PartialTensor(torch.Tensor):
"""
PartialTensor is an abstraction to represent Tensors that need
aggregation across multiple devices and multiple processes.
PartialTensor is initialized in an SPMD like fashion where each rank
initializes the PartialTensor. The PartialTensor object on each rank
then only stores the local partial shard, process group and the
aggregation way to get a full tensor.
PartialTensor doesn't provide any Tensor like operations but is a
wrapper providing the Tensor representing the local partial shard.
We assume the size of each local tensor to be exactly the same.
Users can apply custom distributed sharded computations on top of
this primitive.
Args:
local_partial_shard (Tensor): Partial result stored across ranks.
process_group (ProcessGroup): The process group to aggregate on.
reduce_op (distributed_c10d.ReduceOp): Way to aggregate the partial result.
Default: ``distributed_c10d.ReduceOp.SUM``
Examples:
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> # xdoctest: +SKIP
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor = torch.cat([tensor, tensor + 2])
>>> tensor
tensor([1, 2, 3, 4]) # Rank 0
tensor([3, 4, 5, 6]) # Rank 1
>>> partial_tensor = _PartialTensor(tensor, distributed_c10d.ReduceOp.MAX)
>>> sharding_dim = 0
>>> collect_spec = shard_spec.ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
>>> complete_tensor = partial_tensor.reshard(collect_spec)
>>> complete_tensor
ShardedTensor(
ShardedTensorMetadata(
shards_metadata=[
ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0),
ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)],
size=torch.Size([4])
)
>>> complete_tensor.local_tensor()
tensor([3, 4]) # Rank 0
tensor([5, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1, 2]) + 2 * rank
>>> tensor = torch.cat([tensor, tensor + 2])
>>> tensor
tensor([1, 2, 3, 4]) # Rank 0
tensor([3, 4, 5, 6]) # Rank 1
>>> partial_tensor = _PartialTensor(tensor)
>>> complete_tensor = partial_tensor.reshard(collect_spec)
>>> complete_tensor
ShardedTensor(
ShardedTensorMetadata(
shards_metadata=[
ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0),
ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)],
size=torch.Size([4])
)
>>> complete_tensor.local_tensor()
tensor([4, 6]) # Rank 0
tensor([8, 10]) # Rank 1
"""
_process_group: distributed_c10d.ProcessGroup
_local_shard: torch.Tensor
_reduce_op: distributed_c10d.ReduceOp
__slots__ = ["_process_group", "_local_shard", "_reduce_op"]
def __new__(cls, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
local_shard.size(),
dtype=local_shard.dtype,
layout=local_shard.layout,
pin_memory=local_shard.is_pinned(),
requires_grad=local_shard.requires_grad) # type: ignore[arg-type]
r._process_group = ( # type: ignore[attr-defined]
process_group
if process_group is not None
else distributed_c10d._get_default_group()
)
r._reduce_op = reduce_op
r._local_shard = local_shard
return r
def __post_init__(self):
if not isinstance(self._reduce_op, distributed_c10d.ReduceOp):
raise ValueError(
"reduce_op needs to be a member of distributed_c10d.ReduceOp."
)
def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> "ShardedTensor":
"""
The reshard happens in two steps logically:
1. Aggregate all the shards of the partial tensor.
2. Shard this tensor according to the provided spec.
In reality, for the sake of performance, we consolidate all partial tensors
across multiple ranks and covert to a sharded tensor in one step.
Args:
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
The specification describing how we reshard the aggregated local result.
Returns:
A :class:`ShardedTensor` filled with local aggregated result.
"""
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
if self._local_shard.is_complex():
raise NotImplementedError("Only real partial tensor supported for reshard.")
sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined]
chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size()
local_shard = self._local_shard
# Add padding when the size is not divisible by the world size.
if chunk_mode_res != 0:
padding = [0] * (local_shard.dim() * 2)
padding[-1] = self._process_group.size() - chunk_mode_res
local_shard = torch.nn.functional.pad(
local_shard,
tuple(padding),
"constant",
0,
)
current_rank = dist.get_rank(self._process_group) # type: ignore[attr-defined]
rank_idx = None
rearrange_local_shards = False
indices = [0] * self._process_group.size()
for idx, placement in enumerate(resharding_spec.placements): # type: ignore[attr-defined]
if placement.rank() == current_rank: # type: ignore[index, union-attr]
rank_idx = idx # type: ignore[attr-defined]
if placement.rank() != idx: # type: ignore[index, union-attr]
rearrange_local_shards = True
indices[placement.rank()] = idx # type: ignore[index, union-attr]
local_shards = local_shard.chunk(self._process_group.size(), dim=sharding_dim)
if rearrange_local_shards:
# Need to re-arrange original shard_dim of output_tensor_list.
local_shards = [local_shards[idx] for idx in indices] # type: ignore[call-overload]
local_result = reduce_scatter(
torch.empty_like(local_shards[0]),
list(local_shards),
op=self._reduce_op,
group=self._process_group,
)
sharded_tensor_size = self._local_shard.size()
# Remove padding when the size is not divisible by the world size.
if chunk_mode_res != 0:
uneven_local_shards = self._local_shard.chunk(
self._process_group.size(), dim=sharding_dim
)
expected_size = uneven_local_shards[rank_idx].size() # type: ignore[index]
if local_result.size() != expected_size:
local_result = local_result.narrow(
sharding_dim,
0,
expected_size[sharding_dim],
)
return ShardedTensor._init_from_local_tensor(
local_result,
resharding_spec,
sharded_tensor_size,
process_group=self._process_group,
)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# Find process_group
process_group = None
def find_process_group(e):
nonlocal process_group
if process_group is None and isinstance(e, _PartialTensor):
process_group = e._process_group
tree_map(find_process_group, args)
tree_map(find_process_group, kwargs)
if func in _PARTIAL_TENSOR_OPS:
return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group)
# Need to disable all dispatch to print args and kwargs appropriately.
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
with torch._C.DisableTorchFunction():
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for PartialTensor!")
finally:
del guard
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise RuntimeError(
f"A {cls.__name__} object is being used from c++ "
f"while calling {func.__module__}.{func.__name__} "
"but the there is no custom __torch_dispatch__ implementation for it."
)
def __repr__(self):
return f"PartialTensor({super(_PartialTensor, self).__repr__()})"
def _transpose_impl(types, args=(), kwargs=None, process_group=None):
partial_tensor = args[0]
input = partial_tensor._local_shard
dim0 = args[1]
dim1 = args[2]
return _PartialTensor(
torch.transpose(input, dim0, dim1),
process_group,
partial_tensor._reduce_op
)
@_custom_partial_tensor_op(torch.Tensor.transpose)
def partial_transpose(types, args=(), kwargs=None, process_group=None):
return _transpose_impl(types, args, kwargs, process_group)
@_custom_partial_tensor_op(torch.transpose)
def partial_torch_transpose(types, args=(), kwargs=None, process_group=None):
return _transpose_impl(types, args, kwargs, process_group)
@_custom_partial_tensor_op(torch.cat)
def partial_cat(types, args=(), kwargs=None, process_group=None):
input_list = args[0]
if len(input_list) == 0:
raise RuntimeError('Empty list of tensors to torch.cat!')
local_shards = []
for idx, input in enumerate(input_list):
if not isinstance(input, _PartialTensor):
raise RuntimeError('All inputs need to be an instance of _PartialTensor')
if idx == 0:
reduce_op = input._reduce_op
elif reduce_op != input._reduce_op:
raise RuntimeError(
'All _PartialTensor reduce_ops need to be the same, found: '
'{reduce_op} and {input._reduce_op}'
)
local_shards.append(input._local_shard)
if kwargs is None:
dim = 0
else:
if 'out' in kwargs:
raise RuntimeError('"out" kwarg is not supported!')
dim = kwargs['dim'] if 'dim' in kwargs else 0
return _PartialTensor(torch.cat(local_shards, dim), process_group, input._reduce_op)
# Tensor properties access
_register_default_op(torch.Tensor.requires_grad.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, _custom_partial_tensor_op)
_register_default_op(torch.Tensor.dim, _custom_partial_tensor_op)
_register_default_op(torch.Tensor.ndim.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, _custom_partial_tensor_op)
_register_default_op(torch.Tensor.contiguous, _custom_partial_tensor_op)
|