1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
|
from dataclasses import dataclass
from typing import Any, cast, List, NamedTuple, Optional, Tuple
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
class TensorMeta(NamedTuple):
# simple named tuple to represent tensor metadata
# intentionally to stay simple only for sharding
# propagation purposes.
shape: torch.Size
stride: Tuple[int, ...]
dtype: torch.dtype
# used internally to propagate the placements
@dataclass
class DTensorSpec:
mesh: DeviceMesh
placements: Tuple[Placement, ...]
# tensor meta will only be set during sharding propagation
tensor_meta: Optional[TensorMeta] = None
def __post_init__(self) -> None:
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
self._hash: Optional[int] = None
def __setattr__(self, attr: str, value: Any) -> None:
super().__setattr__(attr, value)
# Make sure to recompute the hash in case any of the hashed attributes
# change (though we do not expect `mesh` or `placements` to change)
if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
self._hash = None
def _hash_impl(self) -> int:
# hashing and equality check for DTensorSpec are used to cache the sharding
# propagation results. We only need to consider the mesh, placements, shape
# dtype and stride.
# Caveat: we need to keep this in mind and sync hash and eq if we add more
# fields to them.
if self.tensor_meta is not None:
return hash(
(
self.mesh,
self.placements,
self.tensor_meta.shape,
self.tensor_meta.stride,
self.tensor_meta.dtype,
)
)
return hash((self.mesh, self.placements))
def __hash__(self) -> int:
# We lazily cache the spec to avoid recomputing the hash upon each
# use, where we make sure to update the hash when the `tensor_meta`
# changes by overriding `__setattr__`. This must be lazy so that Dynamo
# does not try to hash non-singleton `SymInt`s for the stride.
if self._hash is None:
self._hash = self._hash_impl()
return self._hash
def __eq__(self, __o: object) -> bool:
if not (
isinstance(__o, DTensorSpec)
and self.mesh == __o.mesh
and self.placements == __o.placements
):
return False
if self.tensor_meta is None or __o.tensor_meta is None:
return self.tensor_meta == __o.tensor_meta
return (
self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr]
and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr]
and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr]
)
def __str__(self) -> str:
"""
human readable representation of the DTensorSpec
"""
if len(self.placements) == 1:
placement_str = str(self.placements[0])
else:
placement_str = str(self.placements)
if self.tensor_meta is not None:
tensor_shape = str(tuple(self.tensor_meta.shape))
else:
tensor_shape = "unknown shape"
return f"Spec({placement_str} on {tensor_shape})"
@property
def shape(self) -> torch.Size:
if self.tensor_meta is None:
raise ValueError("tensor_meta is not set")
return self.tensor_meta.shape
@property
def stride(self) -> Tuple[int, ...]:
if self.tensor_meta is None:
raise ValueError("tensor_meta is not set")
return self.tensor_meta.stride
@property
def ndim(self) -> int:
if self.tensor_meta is None:
raise ValueError("tensor_meta is not set")
return len(self.tensor_meta.shape)
@property
def num_shards(self) -> int:
num_shards = 1
for i, placement in enumerate(self.placements):
if placement.is_shard():
num_shards *= self.mesh.size(i)
return num_shards
@property
def device_mesh(self) -> DeviceMesh:
# simple aliasing for the mesh field, make some
# checks that mixes DTensor/DTensorSpec easier
return self.mesh
@property
def dim_map(self) -> List[int]:
"""
dim_map is a property we derive from `placements` of
the distributed tensor. It simply return a list of ints
where dim_map[i] denotes the sharding mapping to the mesh
dimension, and len(dim_map) == dist_tensor.ndim
dim_map[i] = -1: means tensor dim i replicate on mesh
dim_map[i] = j: means tensor dim i shard on mesh dim j
For example, we have a dist tensor that have the shape of
[18, 20, 30], and device_mesh([0, 1, 2, 3]), placements:
[Shard(1)], the dim_map of this placement would be:
[-1, 0, -1]. This representation is pretty helpful during
sharding propagation where we could know exactly each
tensor dimension is sharded or not.
Note that if placements contains `_Partial`, we have to
explicitly deal with it, so that when we create a DTensorSpec
with dim_map, we could properly record the pending sums.
"""
# dims mapping of dist tensor sharding
# return size of tensor ndim, -1 represent replicate
# and int >=0 represent shard on that device mesh dim
r = [-1] * self.ndim
for i, placement in enumerate(self.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
if r[shard_dim] > -1:
raise ValueError(
f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]},"
" DTensor operator implementation does not support things like hybrid"
" sharding strategies yet (i.e. [Shard(0), Shard(0)])"
)
r[shard_dim] = i
return r
@property
def num_shards_map(self) -> List[int]:
"""
dim_map is a property we derive from `placements` of
the distributed tensor. Unlike `dim_map`, `num_shards_map`
denotes how many shards each tensor dim has. Like `dim_map`:
len(num_shards_map) == dist_tensor.ndim
num_shards_map[i] = 1: means tensor dim i is not sharded
num_shards_map[i] = j: means tensor dim i has j shards in total
For example, we have a dist tensor of shape [18, 20, 30],
a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
would be: [4, 2, 1].
"""
r = [1] * self.ndim
for i, placement in enumerate(self.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
r[shard_dim] *= self.mesh.size(i)
return r
@property
def sums(self) -> List[int]:
"""
sums is a property we derive from `placements` of the
distributed tensor. It simply return a list of ints where
sums[i] denotes the pending sum (partial) on mesh dim i
"""
return [
idx
for idx, placement in enumerate(self.placements)
if placement.is_partial()
]
@classmethod
def from_dim_map(
cls,
mesh: DeviceMesh,
dim_map: List[int],
sums: List[int],
tensor_meta: Optional[TensorMeta] = None,
) -> "DTensorSpec":
"""
Construct a DTensorSpec from dim_map list and pending sum.
Args:
mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec
dim_map (List[int]): a list of integer that represents sharding on each
tensor dimension, see `dim_map` property doc for details
sums (List[int]): a list of integer that represents the dist tensor have
pending sum on which device mesh dimension.
tensor meta (TensorMeta): DTensor metadata
Return:
a class:`DTensorSpec` object
"""
# by default replicate on device mesh dims
placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
# find all mesh dims that need pending reductions
for s in sums:
placements[s] = Partial()
for i, m in enumerate(dim_map):
if m >= 0:
placement = placements[m]
if placement.is_shard():
placement = cast(Shard, placement)
raise RuntimeError(
f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}"
)
elif placement.is_partial():
raise RuntimeError(
f"DeviceMesh dimension {m} cannot be both shard and partial!"
)
placements[m] = Shard(i)
return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
def is_replicated(self) -> bool:
"""
return True if the current DTensorSpec replicates on all mesh dims (devices)
"""
return all(placement.is_replicate() for placement in self.placements)
def is_sharded(self) -> bool:
"""
return True if the current DTensorSpec is sharded on any mesh dims (devices)
"""
return any(placement.is_shard() for placement in self.placements)
def shallow_copy_with_tensor_meta(
self, tensor_meta: Optional[TensorMeta]
) -> "DTensorSpec":
"""
Shallow copy the DTensorSpec with a new tensor_meta.
"""
assert tensor_meta is not None, "shallow copy with no tensor_meta!"
return DTensorSpec(
self.mesh,
self.placements,
tensor_meta=tensor_meta,
)
|