1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
|
# mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementList,
PlacementStrategy,
RuntimeSchemaInfo,
)
from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
generate_redistribute_costs,
infer_broadcast_dims_map,
is_tensor_shardable,
map_placements_after_broadcast,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
aten = torch.ops.aten
@register_op_strategy(aten.t.default)
def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
self_strategy = op_schema.args_schema[0]
assert isinstance(self_strategy, OpStrategy)
transpose_strategies = []
for input_strategy in self_strategy.strategies:
input_spec = input_strategy.output_spec
# follow the input spec but transpose the Shard placements
output_placements = [
Shard(1 - p.dim) if isinstance(p, Shard) else p
for p in input_spec.placements
]
transpose_strategy = PlacementStrategy(
output_specs=DTensorSpec(
mesh=input_strategy.output_spec.mesh,
placements=tuple(output_placements),
),
input_specs=(input_strategy.output_spec,),
)
transpose_strategies.append(transpose_strategy)
return OpStrategy(strategies=transpose_strategies)
def _mm_like_strategy(
mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
self_strategy, mat2_strategy = op_schema.args_schema
assert isinstance(self_strategy, OpStrategy)
assert isinstance(mat2_strategy, OpStrategy)
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
# filter out invalid strategies and associate costs
strategies = mm_strategy.strategies
filtered_strategies = []
for strtg in strategies:
assert strtg.input_specs is not None
self_spec = strtg.input_specs[0]
mat2_spec = strtg.input_specs[1]
if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable(
mat2_strategy.shape, mat2_spec
):
redistribute_cost = [
generate_redistribute_costs(self_strategy, self_spec),
generate_redistribute_costs(mat2_strategy, mat2_spec),
]
strtg.redistribute_cost = redistribute_cost
filtered_strategies.append(strtg)
mm_strategy.strategies = filtered_strategies
return mm_strategy
def _addmm_like_strategy(
mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema
assert isinstance(self_strategy, OpStrategy)
assert isinstance(mat1_strategy, OpStrategy)
assert isinstance(mat2_strategy, OpStrategy)
self_shape = self_strategy.shape
mm_out_shape = torch.Size(
[
mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size
for i, dim_size in enumerate(mat1_strategy.shape)
]
)
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
# filter out invalid strategies and associate costs
strategies = mm_strategy.strategies
filtered_strategies = []
for strtg in strategies:
# construct new strategy by consider the self arg
assert strtg.input_specs is not None
mat1_spec = strtg.input_specs[0]
mat2_spec = strtg.input_specs[1]
out_spec = strtg.output_spec
# self arg's spec should follow the output of mm, but need
# to consider broadcast for the self arg
broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape)
self_placements = map_placements_after_broadcast(
out_spec.placements, mm_out_shape, broadcast_dims_map
)
self_spec = DTensorSpec(mesh=mesh, placements=self_placements)
if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable(
mat2_strategy.shape, mat2_spec
):
# update input specs with new self spec
strtg.input_specs = (self_spec, mat1_spec, mat2_spec)
# associate costs
redistribute_cost = [
generate_redistribute_costs(self_strategy, self_spec),
generate_redistribute_costs(mat1_strategy, mat1_spec),
generate_redistribute_costs(mat2_strategy, mat2_spec),
]
strtg.redistribute_cost = redistribute_cost
filtered_strategies.append(strtg)
mm_strategy.strategies = filtered_strategies
return mm_strategy
@register_op_strategy(aten.mm.default)
def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
@register_op_strategy(aten.addmm.default)
def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _addmm_like_strategy("mk,kn->mn", mesh, op_schema)
@register_op_strategy(aten.bmm.default)
def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
@register_op_strategy(aten.baddbmm.default)
def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
@register_op_strategy(
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
)
def scaled_dot_product_flash_attention_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
# as it involves: matmul, pointwise, reduction ops together.
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
q_input_strategy = op_schema.args_schema[0]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 3 valid tensor outputs and 3 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [
Replicate(),
Replicate(),
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
Replicate(),
Replicate(),
Replicate(),
Replicate(),
]
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the num of head dim
qkv_sharding = Shard(1) # num head dim
output_sharding = Shard(1) # num head dim
logsumexp_sharding = Shard(1) # num head dim
if return_debug_mask:
debug_attn_mask_sharding: Placement = Shard(1) # num head dim
else:
# empty debug mask, replicated
debug_attn_mask_sharding = Replicate()
num_heads_dim_sharding: PlacementList = [
output_sharding,
logsumexp_sharding,
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
debug_attn_mask_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
]
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
Shard(2), # debugattn
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
def scaled_dot_product_flash_attention_backward_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
q_input_strategy = op_schema.args_schema[1]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
tensor_input_indices = [
i
for i, arg_spec in enumerate(op_schema.args_schema)
if isinstance(arg_spec, OpStrategy)
]
num_tensor_inputs = len(tensor_input_indices)
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs)
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the num of head dim
grad_output_sharding = Shard(1) # num head dim
qkv_sharding = Shard(1) # num head dim
output_sharding = Shard(1) # num head dim
logsumexp_sharding = Shard(1) # num head dim
grad_qkv_sharding = Shard(1) # num head dim
num_heads_dim_sharding: PlacementList = [
grad_qkv_sharding,
grad_qkv_sharding,
grad_qkv_sharding,
grad_output_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
output_sharding,
logsumexp_sharding,
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(seq_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)
@register_op_strategy(aten.constant_pad_nd.default)
def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
# TODO(d4l3k); implement a more correct strategy for constant_pad_nd
return OpStrategy(
[
PlacementStrategy(
output_specs=DTensorSpec(mesh, (Replicate(),)),
input_specs=(
DTensorSpec(mesh, (Replicate(),)),
DTensorSpec(mesh, (Replicate(),)),
),
redistribute_cost=[[1]],
)
]
)
@register_op_strategy(
aten._scaled_dot_product_efficient_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_efficient_attention_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
q_input_strategy = op_schema.args_schema[0]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
has_attn_bias = op_schema.args_schema[3] is not None
compute_log_sumexp = op_schema.args_schema[4]
single_mesh_dim_strategies: List[PlacementList] = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [
Replicate(),
Replicate(),
None,
None,
Replicate(),
Replicate(),
Replicate(),
]
if has_attn_bias:
all_replicate.append(Replicate()) # attn bias
# Context Parallelism: shards on the sequence dim
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # philox_seed
None, # philox_offset
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
)
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the heads dimension
qkv_sharding = Shard(1)
output_sharding = Shard(1)
if compute_log_sumexp:
logsumexp_sharding: Placement = Shard(1)
else:
# empty logsumexp, replicated
logsumexp_sharding = Replicate()
num_heads_dim_sharding = [
output_sharding,
logsumexp_sharding,
None,
None,
qkv_sharding,
qkv_sharding,
qkv_sharding,
]
if has_attn_bias:
num_heads_dim_sharding.append(Shard(1))
single_mesh_dim_strategies.append(num_heads_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
single_mesh_dim_strategies,
input_index=4,
)
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
def scaled_dot_product_efficient_attention_backward_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
q_input_strategy = op_schema.args_schema[1]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
has_attn_bias = op_schema.args_schema[4] is not None
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs
# NOTE: Output sharding of grad_bias on heads dim if attn_bias is present;
# otherwise grad_bias will be empty and its DTensorSpec will be removed.
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias)
if not has_attn_bias:
all_replicate[3] = None # grad bias is None if attn_bias is not present
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the heads dimension
grad_output_sharding = Shard(1)
qkv_sharding = Shard(1)
output_sharding = Shard(1)
logsumexp_sharding = Shard(1)
grad_qkv_sharding = Shard(1)
grad_bias_sharding = Shard(1) if has_attn_bias else None
num_heads_dim_sharding: PlacementList = [
grad_qkv_sharding,
grad_qkv_sharding,
grad_qkv_sharding,
grad_bias_sharding,
grad_output_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
# the place for optional input attn_bias,
output_sharding,
logsumexp_sharding,
]
# input sharding of attn_bias on heads dim if present
if has_attn_bias:
num_heads_dim_sharding.insert(8, Shard(1))
# accept replicate on the rest scalar tensor inputs
# namely philox_seed and philox_offset
num_heads_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(1) if has_attn_bias else None, # grad_bias
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
if has_attn_bias:
num_heads_dim_sharding.insert(8, Shard(1))
seq_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(seq_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
single_mesh_dim_strategies,
input_index=4,
)
|