1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
|
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch._ops import OpOverload
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.placement_types import Placement
try:
from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
except ImportError:
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
tree_leaves,
tree_map_only,
TreeSpec,
)
# Common type aliases
ArgsType = Tuple[object, ...]
KwargsType = Dict[str, object]
PlacementList = List[Optional[Placement]]
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
# be the same set of possibilities.
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
"""
This is used to propagate tensor metadata, must be under fake mode
"""
assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
return torch.empty_strided(
arg.tensor_meta.shape,
arg.tensor_meta.stride,
dtype=arg.tensor_meta.dtype,
)
def _is_inplace_op(op: OpOverload):
# simple analysis of function schema to determine
# if this is an inplace variant, it might not
# be entirely correct, but it's good enough for now.
return op._schema.name[-1] == "_"
def _is_out_variant_op(op: OpOverload):
# simple analysis of function schema to determine
# if this is an out variant, it might not
# be entirely correct, but it's good enough for now.
return "out" in op._schema.overload_name
def _pretty_print_spec(spec: object) -> str:
if spec is None:
return "None"
elif isinstance(spec, DTensorSpec):
return "".join([str(p) for p in spec.placements])
elif isinstance(spec, Sequence):
return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
else:
raise RuntimeError(f"Unknown spec type to print: spec={spec}")
@dataclass
class PlacementStrategy:
"""
A placement strategy describes acceptable sharding placements of the output
and the tensor arguments of an operation.
note: when the op return value is a single DTensor object, output_specs is
DTensorSpec; when the return value is a tuple of Optional[DTensor],
output_specs is a tuple of Optional[DTensorSpec].
"""
output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]]
input_specs: Optional[Sequence[DTensorSpec]] = None
# redistribute costs for this op placement strategy
# we need a nested list to record the cost for each
# operand of this operator, and for each operand of
# this operator it might have multiple placement strategies
redistribute_cost: Optional[List[List[float]]] = None
@cached_property
def output_spec(self) -> DTensorSpec:
"""
This function requires that the strategy have exactly one DTensorSpec as the
output spec. If the output_specs is a tuple, we throw an exception.
"""
if isinstance(self.output_specs, DTensorSpec):
return self.output_specs
else:
raise ValueError(
f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
)
def input_spec(self, index: int = 0) -> DTensorSpec:
assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
assert len(self.input_specs) > index, (
f"Invalid index {index} for input_specs of length "
f"{len(self.input_specs)}: {self.input_specs}"
)
return self.input_specs[index]
def __str__(self) -> str:
if self.input_specs is not None:
input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
else:
input_specs_str = ""
output_spec_str = _pretty_print_spec(self.output_specs)
return f"{input_specs_str}{output_spec_str}"
class StrategyType:
"""
Base class type for op strategy, We have two StrategyType:
OpStrategy and TupleStrategy
"""
class OpStrategy(StrategyType):
"""
OpStrategy that consists of a list of placement strategies associated with the op
"""
def __init__(self, strategies: List[PlacementStrategy]) -> None:
super().__init__()
self.strategies: List[PlacementStrategy] = strategies
def __str__(self) -> str:
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
mesh_shape = self.mesh_shape
return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
def max_num_shards(self) -> int:
"""
Returns the max number of shards across all placement strategies
"""
return max(strategy.output_spec.num_shards for strategy in self.strategies)
@property
def mesh_shape(self):
output_spec = self.strategies[0].output_specs
if isinstance(output_spec, DTensorSpec):
return output_spec.mesh.shape
else:
assert isinstance(
output_spec, tuple
), "found no DTensorSpec in the OpStrategy!"
assert output_spec[0] is not None
return output_spec[0].mesh.shape
@property
def ndim(self):
return self.strategies[0].output_spec.ndim
@property
def shape(self):
return self.strategies[0].output_spec.shape
class TupleStrategy(StrategyType):
"""
TupleStrategy represents the output strategy of this op is a tuple
of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
with possibly different placement strategies, we should return a TupleStrategy that
contains a tuple of OpStrategy, where each child represents the sharding strategy
of "each element" of the tuple/list of tensors the op returns.
NOTE: if the output of the op is a List[Tensor] and they share the same placement
strategy, then we should return a single OpStrategy instead of a TupleStrategy
"""
def __init__(self, childs: Sequence[StrategyType]) -> None:
super().__init__()
self.childs: Sequence[StrategyType] = childs
def __str__(self) -> str:
child_strategies_str = ", ".join(
[f"{str(strat)}" for idx, strat in enumerate(self.childs)]
)
return f"TupleStrategy({child_strategies_str})"
@dataclass
class RuntimeSchemaInfo:
"""
RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
execution. This is mainly used for two ways: 1. to generate hash for args to determine
whether to re-run sharding prop or not 2. to determine if we need pytree
"""
# This static_argnum records static arg "starting index" for ops that have non-tensor
# args/kwargs which would affect sharding propagation results. All args starting from
# this index would be hashed to our sharding cache.
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
static_argnum: int = 100
# This static_kwargkey records static kwarg names which would affect sharding prop
static_kwargkey: Optional[List[str]] = None
# each op can decide if it wants to use pytree flatten/unflatten during operator
# eager execution, by default we don't need to do flatten/unflatten, only if the
# op indicate it needs to, this is to accelerate eager performance.
needs_pytree: bool = False
@dataclass
class OpSchema:
"""
OpSchema is a data class that describes an operator input schemas, it includes
DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
preserved). It is mainly used by the DTensor's dispatching logic to perform various
actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
NOTE: this should be used as a read only data class
TODO: make this a frozen dataclass
Args:
op: the operator overload we are intercepting
args_schema: contains args except that the DTensor args have been replaced
with its DTensorSpec or OpStrategy
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
with its DTensorSpec or OpStrategy
"""
op: OpOverload
args_schema: ArgsType
kwargs_schema: KwargsType
schema_info: Optional[RuntimeSchemaInfo] = None
@property
def args_spec(self) -> Tuple[DTensorSpec, ...]:
"""
args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
mainly used by sharding propagation to propagate the output spec
"""
args = (
tree_leaves(self.args_schema)
if self.schema_info is not None and self.schema_info.needs_pytree
else self.args_schema
)
return tuple(item for item in args if isinstance(item, DTensorSpec))
@property
def args_strategy(self) -> Tuple[OpStrategy, ...]:
# filter out non-relevant values from args schema to get a clean OpStrategy list
# separate with args_spec for the ease of type annotation
# TODO: see if we should merge this with args_spec
args = (
tree_leaves(self.args_schema)
if self.schema_info is not None and self.schema_info.needs_pytree
else self.args_schema
)
return tuple(item for item in args if isinstance(item, OpStrategy))
def __repr__(self) -> str:
args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
return (
f"OpSchema(op={self.op},"
f" args_schema=({args_schema}),"
f" kwargs_schema={self.kwargs_schema})"
)
def __str__(self) -> str:
args_schema: List[str] = []
mesh_shape = None
for arg in self.args_schema:
if isinstance(arg, DTensorSpec):
args_schema.append(str(arg))
mesh_shape = arg.mesh.shape
elif isinstance(arg, OpStrategy):
assert len(arg.strategies) == 1
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
mesh_shape = arg.mesh_shape
elif isinstance(arg, TupleStrategy):
first_op_strtgy = arg.childs[0]
assert isinstance(first_op_strtgy, OpStrategy)
mesh_shape = first_op_strtgy.mesh_shape
args_schema.append(str(arg))
else:
args_schema.append(str(arg))
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
def __post_init__(self) -> None:
has_symints = False
for a in self.args_schema:
if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
has_symints = True
break
self.has_symints = has_symints
def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
arg = self.args_schema[arg_idx]
is_tensor = isinstance(arg, DTensorSpec)
if is_tensor:
return True
if not isinstance(arg, list):
return False
return all(isinstance(e, DTensorSpec) or e is None for e in arg)
def return_type_tuple_tensor_like(self) -> bool:
# all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
# in the tuple, but the first element must be a Tensor, so this check is enough
return_types = self.op._schema.returns
return len(return_types) > 1 and isinstance(
return_types[0].type, torch.TensorType
)
def return_type_tensor(self) -> bool:
return_types = self.op._schema.returns
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
# return types, so this check is enough for tensor like types
return isinstance(return_types[0].type, torch.TensorType)
def __hash__(self) -> int:
# Only hash args and kwargs that op indicates to hash
if not self.schema_info:
static_argnum = len(self.args_schema)
static_kwargkey = None
else:
static_argnum = self.schema_info.static_argnum
static_kwargkey = self.schema_info.static_kwargkey
args_to_hash = tuple(
tuple(e) if isinstance(e, list) else e
for i, e in enumerate(self.args_schema)
if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
)
if static_kwargkey is not None:
kwargs_to_hash = tuple(
self.kwargs_schema.get(k, None) for k in static_kwargkey
)
return hash((self.op, args_to_hash, kwargs_to_hash))
else:
return hash((self.op, args_to_hash))
def __eq__(self, other: object) -> bool:
# early return checks
if not isinstance(other, OpSchema):
return False
if self.op != other.op:
return False
if len(self.args_schema) != len(other.args_schema):
return False
# compare each element and early return if any of them is different
if not self.schema_info:
static_argnum = len(self.args_schema)
static_kwargkey = None
else:
static_argnum = self.schema_info.static_argnum
static_kwargkey = self.schema_info.static_kwargkey
for i, (self_arg, other_arg) in enumerate(
zip(self.args_schema, other.args_schema)
):
if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
return False
elif i >= static_argnum and self_arg != other_arg:
return False
# check kwarg equality when there's a static kwarg key
if static_kwargkey:
for key in static_kwargkey:
if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
key, None
):
return False
return True
def gen_fake_args(self) -> ArgsType:
"""
gen_fake_args: generate fake args for the operator, this is mainly used
by sharding propagation rules to generate fake args for the operator
to run the local tensor operator and get the output spec.
"""
return tree_map_only(
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
)
def gen_fake_kwargs(self) -> KwargsType:
"""
gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
by sharding propagation rules to generate fake kwargs for the operator
to run the local tensor operator and get the output spec.
"""
return tree_map_only(
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
)
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
suggestion_args_spec = self.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
if (
origin_schema.schema_info is not None
and origin_schema.schema_info.needs_pytree
):
args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
else:
args_schema = origin_schema.args_schema
for arg in args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
else:
new_arg_schema.append(arg)
self.args_schema = tuple(new_arg_schema)
self.kwargs_schema = origin_schema.kwargs_schema
@dataclass
class OutputSharding:
"""
OutputSharding is a data class that is used by the sharding propagation,
it could set the output_spec upon successful propagation. If needs_redistribute
is set to True, a redistribute_schema would be returned together to indicate
the input arguments needs to be redistributed before the op execution.
NOTE: the redistribute_schema generated by sharding propagation should be
exactly the same as the operator OpSchema, except the DTensorSpecs
"""
output_spec: OutputSpecType
redistribute_schema: Optional[OpSchema] = None
needs_redistribute: bool = False
@dataclass
class OpInfo:
"""
All Runtime Op execution info are packed here
"""
mesh: DeviceMesh
schema: OpSchema
flat_args_schema: List[object]
local_args: Sequence[object]
local_kwargs: Dict[str, object]
args_tree_spec: Optional[TreeSpec] = None
# the output sharding info
output_sharding: Optional[OutputSharding] = None
|