1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
|
# Owner(s): ["oncall: distributed"]
import itertools
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
_explicit_order_placements,
compute_global_tensor_shape,
compute_local_shape_and_global_offset,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
class LocalTest(TestCase):
def test_explicit_order_placements(self):
# mesh_shape: ShapeType, placements: Sequence[Placement]
test_cases = [
{
"mesh_shape": [2, 4],
"placements": [Replicate(), Replicate()],
"ordered": [(0, Replicate()), (1, Replicate())],
},
{
"mesh_shape": [3, 2],
"placements": [Shard(0), Replicate()],
"ordered": [(0, Shard(0)), (1, Replicate())],
},
{
"mesh_shape": [2, 4],
"placements": [_StridedShard(0, split_factor=4), Shard(0)],
"ordered": [(1, Shard(0)), (0, Shard(0))],
},
{
"mesh_shape": [2, 3, 4],
"placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)],
"ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))],
},
{
"mesh_shape": [2, 3, 4],
"placements": [
_StridedShard(0, split_factor=12),
_StridedShard(0, split_factor=4),
Shard(0),
],
"ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))],
},
]
for test_case in test_cases:
actual = _explicit_order_placements(
test_case["mesh_shape"], test_case["placements"]
)
expected = test_case["ordered"]
self.assertEqual(
actual,
expected,
f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}",
)
error_cases = [
{
"mesh_shape": [2, 3, 4],
"placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)],
"exception_type": RuntimeError,
"exception_text": "Can only convert _StridedShard to ordered Shard if split_factor",
},
{
"mesh_shape": [2, 3, 4],
"placements": [
_StridedShard(0, split_factor=3),
Shard(0),
Shard(0),
],
"exception_type": NotImplementedError,
"exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended",
},
{
"mesh_shape": [2, 3],
"placements": [
Shard(0),
],
"exception_type": RuntimeError,
"exception_text": "Expected one placement per mesh dim",
},
]
for test_case in error_cases:
with self.assertRaisesRegex(
test_case["exception_type"], test_case["exception_text"]
):
_explicit_order_placements(
test_case["mesh_shape"], test_case["placements"]
)
def test_compute_local_shape_and_global_offset_uneven(self):
# This case is not only 'uneven' bug also has an empty shard
# (e.g. most DP ranks have local shape 18,4096, one has 8,4096, one has 0,4096
global_shape = (4096, 4096)
DP = 30
TP = 8
mesh_shape = (DP, TP)
placements = [_StridedShard(0, split_factor=8), Shard(0)]
TP_shard_size = global_shape[0] / TP
for my_coordinate in itertools.product(range(DP), range(TP)):
local_shape, global_offset = _compute_local_shape_and_global_offset(
global_shape, mesh_shape, list(my_coordinate), placements
)
dp_rank, tp_rank = my_coordinate
expected_shard_size = 18
expected_shard_offset = tp_rank * TP_shard_size + 18 * dp_rank
if dp_rank == 28:
expected_shard_size = 8
elif dp_rank == 29:
expected_shard_size = 0
# we define the offset value of a zero-sized shard as the dim size
# this actually matters, because DCP uses offset to deduplicate shards when saving
expected_shard_offset = 4096
self.assertEqual(local_shape, (expected_shard_size, 4096))
self.assertEqual(global_offset, (expected_shard_offset, 0))
class UtilTest(DTensorTestBase):
@property
def world_size(self):
return 8
def _compute_start_end_offsets(self, global_offset, local_size, n_dim):
offset = []
for i in range(n_dim):
offset.append(((global_offset[i]), (global_offset[i] + local_size[i])))
return offset
@with_comms
def test_compute_global_tensor_shape_1D(self):
one_d_placements = [[Shard(1)], [Shard(0)], [Replicate()]]
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
for placements in one_d_placements:
if isinstance(placements[0], Shard):
uneven_dim = list(range(self.world_size))
local_shape = (
torch.Size([5, uneven_dim[self.rank]])
if placements[0].dim == 1
else torch.Size([uneven_dim[self.rank], 5])
)
expected_global_shape = (
torch.Size([5, sum(uneven_dim)])
if placements[0].dim == 1
else torch.Size([sum(uneven_dim), 5])
)
else:
expected_global_shape = torch.Size([5, 5])
local_shape = torch.Size([5, 5])
global_shape = compute_global_tensor_shape(
local_shape, device_mesh, placements
)
self.assertEqual(global_shape, expected_global_shape)
@with_comms
def test_compute_global_tensor_shape_1D_invalid_shape(self):
one_d_placement = [Shard(1)]
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
uneven_dim = list(range(self.world_size))
local_shape = (
torch.Size([5, uneven_dim[self.rank]])
if self.rank % 2 == 0
else torch.Size([6, uneven_dim[self.rank]])
)
with self.assertRaisesRegex(
RuntimeError,
"Non-sharded dimensions should have identical size across ranks.",
):
_ = compute_global_tensor_shape(
local_shape,
device_mesh,
one_d_placement,
)
@with_comms
def test_compute_global_tensor_shape_failure_2D(self):
placement_2D = [Shard(0), Shard(1)]
device_mesh_2D = init_device_mesh(self.device_type, (2, 2))
with self.assertRaisesRegex(
NotImplementedError,
"compute_global_tensor_shape only supports 1 placement for now.",
):
_ = compute_global_tensor_shape(
torch.Size([2, 2]),
device_mesh_2D,
placement_2D,
)
placement_1D = [Shard(0)]
with self.assertRaisesRegex(
RuntimeError,
"Expected one placement per mesh dim",
):
_ = compute_global_tensor_shape(
torch.Size([2, 2]),
device_mesh_2D,
placement_1D,
)
@with_comms
def test_compute_local_shape_and_global_offset_1D(self):
one_d_placements = [[Shard(0)], [Replicate()]]
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
for placements in one_d_placements:
# When the placements is [Shard(0)], we test for three different scenarios:
# 1) sharding resulting in empty shards on all or some of the ranks
# 2) sharding resulting in shards of different size across different ranks
# 3) sharding resulting in non-empty shards of same size across all ranks
for size in range(self.world_size * 2 + 1):
global_tensor = torch.arange(size)
global_shape = global_tensor.size()
dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_size, global_offset = compute_local_shape_and_global_offset(
global_shape, device_mesh, placements
)
dim = self._compute_start_end_offsets(global_offset, local_size, 1)
dim0_start, dim0_end = dim[0][0], dim[0][1]
# Check the local tensor of dtensor is exactly the same
# if we slice the global_tensor with local_size and global_offset
self.assertEqual(
dtensor.to_local(),
global_tensor[dim0_start:dim0_end],
)
@with_comms
def test_compute_local_shape_and_global_offset_2D(self):
two_d_placements_options = [Shard(0), Shard(1), Replicate()]
# Generating 6 two-d placements combinations
two_d_placements = list(
itertools.combinations_with_replacement(two_d_placements_options, 2)
)
# mesh: 2 * 4
device_mesh = init_device_mesh(self.device_type, (2, 4))
for placements in two_d_placements:
for dim_0_size in range(1, 9):
nelem = 64 // dim_0_size * dim_0_size
global_tensor = torch.arange(nelem).view(dim_0_size, -1)
global_shape = global_tensor.size()
dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_size, global_offset = compute_local_shape_and_global_offset(
global_shape, device_mesh, placements
)
dim = self._compute_start_end_offsets(global_offset, local_size, 2)
dim0_start, dim0_end = dim[0][0], dim[0][1]
dim1_start, dim1_end = dim[1][0], dim[1][1]
# Check the local tensor of dtensor is exactly the same
# if we slice the global_tensor with local_size and global_offset
self.assertEqual(
dtensor.to_local(),
global_tensor[dim0_start:dim0_end, dim1_start:dim1_end],
)
@with_comms
def test_fsdp_tp_meta_compute(self):
# FSDP + TP sharding
tp_size = 2
dp_size = self.world_size // tp_size
global_mesh = init_device_mesh(
self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
# local shard shape is [2, 2]
global_tensor_shape = torch.Size([2 * self.world_size, 2])
placements = [_StridedShard(0, split_factor=tp_size), Shard(0)]
local_shape, global_offset = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
assert global_mesh.get_coordinate is not None
dp_rank = global_mesh.get_local_rank("dp")
tp_rank = global_mesh.get_local_rank("tp")
shard_idx_on_dim_0 = tp_rank * dp_size + dp_rank
expected_local_shape = (2, 2)
expected_global_offset = (shard_idx_on_dim_0 * 2, 0)
self.assertEqual(local_shape, expected_local_shape)
self.assertEqual(global_offset, expected_global_offset)
@with_comms
def test_uneven_fsdp_tp_meta_compute(self):
# FSDP + TP uneven sharding
tp_size = 2
dp_size = self.world_size // tp_size
global_mesh = init_device_mesh(
self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
global_tensor_shape = torch.Size([15, 5])
placements = [_StridedShard(0, split_factor=tp_size), Shard(0)]
local_shape, global_offset = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
rank = global_mesh.get_rank()
expected_shapes = [2, 2, 2, 2, 2, 2, 2, 1]
expected_offsets = [0, 8, 2, 10, 4, 12, 6, 14]
self.assertEqual(local_shape[0], expected_shapes[rank])
self.assertEqual(global_offset[0], expected_offsets[rank])
@with_comms
def test_hsdp_tp_meta_compute(self):
# HSDP + TP sharding
tp_size = 2
dp_shard_size = 2
dp_replic_size = self.world_size // (dp_shard_size * tp_size)
global_mesh = init_device_mesh(
self.device_type,
(dp_replic_size, dp_shard_size, tp_size),
mesh_dim_names=("dp_replic", "dp_shard", "tp"),
)
# local shard shape is [2, 2]
global_tensor_shape = torch.Size([2 * dp_shard_size * tp_size, 2])
placements = [Replicate(), _StridedShard(0, split_factor=tp_size), Shard(0)]
local_shape, global_offset = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
assert global_mesh.get_coordinate is not None
dp_shard_rank = global_mesh.get_local_rank("dp_shard")
tp_rank = global_mesh.get_local_rank("tp")
shard_idx_on_dim_0 = tp_rank * dp_shard_size + dp_shard_rank
expected_local_shape = (2, 2)
expected_global_offset = (shard_idx_on_dim_0 * 2, 0)
self.assertEqual(local_shape, expected_local_shape)
self.assertEqual(global_offset, expected_global_offset)
# TODO: remove this test once we support general meta compute on strided sharding
@with_comms
def test_strided_sharding_assumption_in_meta_compute(self):
# current ``compute_local_shape_and_global_offset`` does not allow Shard(i)
# placement to appear after the strided sharding part has ended. This test
# check that ``compute_local_shape_and_global_offset`` does not allow placements
# that violate the assumption and does not forbid the allowed ones.
# Test 0: 2-D mesh
mesh_size_0 = 2
mesh_size_1 = self.world_size // mesh_size_0
global_mesh = init_device_mesh(
self.device_type,
(mesh_size_0, mesh_size_1),
mesh_dim_names=("mesh-0", "mesh-1"),
)
global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size])
for shard_dim in [0, 1]:
placements = [
_StridedShard(shard_dim, split_factor=mesh_size_1),
Shard(shard_dim),
]
_, _ = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
# Test 1: 3-D mesh
mesh_size_0 = 2
mesh_size_1 = 2
mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1)
global_mesh = init_device_mesh(
self.device_type,
(mesh_size_0, mesh_size_1, mesh_size_2),
mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"),
)
# legal placements: Shard() appear after the strided part but it's on another
# tensor dimension.
placements = [
_StridedShard(0, split_factor=mesh_size_1),
Shard(0),
Shard(1),
]
_, _ = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
# illegal placements: Shard() appear after the strided part and it's on the
# same tensor dimension.
placements = [
_StridedShard(0, split_factor=mesh_size_1),
Shard(0),
Shard(0),
]
with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"):
_, _ = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
# Test 2: 4-D mesh
mesh_size_0 = 1
mesh_size_1 = 2
mesh_size_2 = 2
mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2)
global_mesh = init_device_mesh(
self.device_type,
(mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3),
mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"),
)
# legal placements: Shard() appear after the strided part but it's on another
# tensor dimension.
placements = [
_StridedShard(0, split_factor=mesh_size_1),
_StridedShard(1, split_factor=mesh_size_3),
Shard(0),
Shard(1),
]
local_shape, _ = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
expected_local_shape = (
2 * mesh_size_1 * mesh_size_3,
2 * mesh_size_0 * mesh_size_2,
)
self.assertEqual(local_shape, expected_local_shape)
# illegal placements: Shard() appear after the strided part and it's on the
# same tensor dimension.
placements = [
_StridedShard(0, split_factor=mesh_size_1),
_StridedShard(1, split_factor=mesh_size_3),
Shard(0),
Shard(0),
]
with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"):
_, _ = compute_local_shape_and_global_offset(
global_tensor_shape, global_mesh, placements
)
class TestStridedSharding(DTensorTestBase):
@property
def world_size(self):
return 4
@with_comms
def test_1d_mesh_strided_sharding(self):
mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
# Test 1: 1-d tensor over 1-d mesh
x = torch.arange(2 * self.world_size, device=self.device_type)
"""
contiguous sharding: [0, 1 | 2, 3 | 4, 5 | 6, 7]
"""
shard_placement = _StridedShard(0, split_factor=1) # same as Shard(0)
tensor_list, _ = shard_placement._split_tensor(x, self.world_size)
shard_x = tensor_list[self.rank]
self.assertEqual(shard_x, x.view(self.world_size, -1)[self.rank])
# shard_to_replicate
full_tensor = shard_placement._to_replicate_tensor(
shard_x,
mesh_1d,
mesh_dim=0,
current_logical_shape=list(x.shape),
)
self.assertEqual(full_tensor, x)
"""
strided sharding: [0, 4 | 1, 5 | 2, 6 | 3, 7]
"""
shard_placement = _StridedShard(0, split_factor=2)
tensor_list, _ = shard_placement._split_tensor(x, self.world_size)
shard_x = tensor_list[self.rank]
self.assertEqual(
shard_x, x.view(-1, self.world_size).swapdims(-1, 0)[self.rank]
)
# shard_to_replicate
full_tensor = shard_placement._to_replicate_tensor(
shard_x,
mesh_1d,
mesh_dim=0,
current_logical_shape=list(x.shape),
)
self.assertEqual(full_tensor, x)
@with_comms
def test_2d_mesh_strided_sharding(self):
# Test 2: 1-d tensor over 2-d mesh
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dim0", "dim1")
)
mesh_dim0_size = mesh_2d["dim0"].size()
mesh_dim1_size = mesh_2d["dim1"].size()
mesh_dim0_local_rank = mesh_2d["dim0"].get_local_rank(mesh_dim=0)
mesh_dim1_local_rank = mesh_2d["dim1"].get_local_rank(mesh_dim=0)
x = torch.arange(2 * self.world_size, device=self.device_type)
"""
contiguous sharding: [
[ 0, 1 | 2, 3 ],
[ 4, 5 | 6, 7 ],
]
"""
# shard on mesh dim-0
shard_placement_dim0 = _StridedShard(0, split_factor=1) # same as Shard(0)
tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size)
expected_shard_dim0 = x.view(mesh_dim0_size, -1)[mesh_dim0_local_rank]
shard_x = tensor_list[mesh_dim0_local_rank]
self.assertEqual(shard_x, expected_shard_dim0)
# shard on mesh dim-1
shard_placement_dim1 = _StridedShard(0, split_factor=1) # same as Shard(0)
tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size)
expected_shard_dim1 = shard_x.view(mesh_dim1_size, -1)[mesh_dim1_local_rank]
shard_x = tensor_list[mesh_dim1_local_rank]
self.assertEqual(shard_x, expected_shard_dim1)
# shard_to_replicate on mesh dim-1
full_tensor = shard_placement_dim1._to_replicate_tensor(
shard_x,
mesh_2d,
mesh_dim=1,
current_logical_shape=list(expected_shard_dim0.shape),
)
self.assertEqual(full_tensor, expected_shard_dim0)
# shard_to_replicate on mesh dim-0
full_tensor = shard_placement_dim0._to_replicate_tensor(
full_tensor,
mesh_2d,
mesh_dim=0,
current_logical_shape=list(x.shape),
)
self.assertEqual(full_tensor, x)
"""
strided sharding: [
[ 0, 1 | 4, 5 ],
[ 2, 3 | 6, 7 ],
]
"""
split_factor = 2
# shard on mesh dim-0
shard_placement_dim0 = _StridedShard(0, split_factor=split_factor)
tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size)
shard_x = tensor_list[mesh_dim0_local_rank]
expected_shard_dim0 = (
torch.tensor([0, 1, 4, 5], device=self.device_type)
if mesh_dim0_local_rank == 0
else torch.tensor([2, 3, 6, 7], device=self.device_type)
)
self.assertEqual(shard_x, expected_shard_dim0)
# shard on mesh dim-1
shard_placement_dim1 = _StridedShard(0, split_factor=1) # same as Shard(0)
tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size)
shard_x = tensor_list[mesh_dim1_local_rank]
expected_shard_dim1 = expected_shard_dim0.view(mesh_dim1_size, -1)[
mesh_dim1_local_rank
]
self.assertEqual(shard_x, expected_shard_dim1)
# shard_to_replicate on mesh dim-1
full_tensor = shard_placement_dim1._to_replicate_tensor(
shard_x,
mesh_2d,
mesh_dim=1,
current_logical_shape=list(expected_shard_dim0.shape),
)
self.assertEqual(full_tensor, expected_shard_dim0)
# shard_to_replicate on mesh dim-0
full_tensor = shard_placement_dim0._to_replicate_tensor(
full_tensor,
mesh_2d,
mesh_dim=0,
current_logical_shape=list(x.shape),
)
self.assertEqual(full_tensor, x)
@with_comms
def test_2d_mesh_2d_tensor_strided_sharding(self):
# Test 2: 1-d tensor over 2-d mesh
mesh_2d = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dim0", "dim1")
)
mesh_dim0_size = mesh_2d["dim0"].size()
mesh_dim1_size = mesh_2d["dim1"].size()
mesh_dim0_local_rank = mesh_2d["dim0"].get_local_rank(mesh_dim=0)
mesh_dim1_local_rank = mesh_2d["dim1"].get_local_rank(mesh_dim=0)
x = torch.arange(2 * self.world_size, device=self.device_type).reshape(2, -1)
"""
strided sharding:
rank 0: [[0], [4]]
rank 1: [[2], [6]]
rank 2: [[1], [5]]
rank 3: [[3], [7]]
"""
split_factor = 2
# shard on mesh dim-0
shard_placement_dim0 = _StridedShard(1, split_factor=split_factor)
tensor_list, _ = shard_placement_dim0._split_tensor(x, mesh_dim0_size)
shard_x = tensor_list[mesh_dim0_local_rank]
expected_shard_dim0 = (
torch.tensor([[0, 2], [4, 6]], device=self.device_type)
if mesh_dim0_local_rank == 0
else torch.tensor([[1, 3], [5, 7]], device=self.device_type)
)
self.assertEqual(shard_x, expected_shard_dim0)
# shard on mesh dim-1
shard_placement_dim1 = _StridedShard(1, split_factor=1) # same as Shard(1)
tensor_list, _ = shard_placement_dim1._split_tensor(shard_x, mesh_dim1_size)
shard_x = tensor_list[mesh_dim1_local_rank]
expected_shard_dim1 = [
torch.tensor(value, device=self.device_type)
for value in [[[0], [4]], [[2], [6]], [[1], [5]], [[3], [7]]]
][self.rank]
self.assertEqual(shard_x, expected_shard_dim1)
# shard_to_replicate on mesh dim-1
full_tensor = shard_placement_dim1._to_replicate_tensor(
shard_x,
mesh_2d,
mesh_dim=1,
current_logical_shape=list(expected_shard_dim0.shape),
)
self.assertEqual(full_tensor, expected_shard_dim0)
# shard_to_replicate on mesh dim-0
full_tensor = shard_placement_dim0._to_replicate_tensor(
full_tensor,
mesh_2d,
mesh_dim=0,
current_logical_shape=list(x.shape),
)
self.assertEqual(full_tensor, x)
class Test2DStridedLocalShard(DTensorTestBase):
@property
def world_size(self):
return 4
@with_comms
def test_fsdp1_tp_2d_dtensor_local_shards_and_offsets(self):
# We are mimicking the behavior of FSDP1 + TP.
# Currently, the 2D DTensor's local shard is correct, since from_local + redistribute incurs a all_gather behind the scene.
# When we have a global_tensor of [0, 1, 2, 3, 4, 5, 6, 7], the local shard of 2D DTensor would be:
# rank0: [0, 1], rank1: [2, 3], rank2: [4, 5], rank3: [6, 7]
with CommDebugMode() as comm_mode:
global_tensor = torch.arange(8).view(4, 2)
mesh_2d = init_device_mesh(
self.device_type, (2, 2), mesh_dim_names=("DP", "TP")
)
tp_mesh = mesh_2d["TP"]
dtensor_tp = distribute_tensor(
global_tensor, tp_mesh, placements=[Shard(0)]
)
dtensor_2d = DTensor.from_local(
dtensor_tp.to_local(), mesh_2d, [Replicate(), Shard(0)], run_check=False
).redistribute(mesh_2d, [Shard(0), Shard(0)])
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1
)
self.assertEqual(
dtensor_2d.to_local(), global_tensor[self.rank : self.rank + 1]
)
# compute_local_shape_and_global_offset currently does take into consideration of strided sharding,
# which should after strided sharding is added.
local_size, global_offset = compute_local_shape_and_global_offset(
global_tensor.shape, mesh_2d, [Shard(0), Shard(0)]
)
self.assertEqual(local_size, torch.Size([1, 2]))
self.assertEqual(global_offset, torch.Size([self.rank, 0]))
@with_comms
def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self):
# We are mimicking the behavior of FSDP2 + TP.
# Currently, the 2D DTensor's local shard is incorrect for resharding, since we want to avoid extra communication.
# It's incorrect for resharding, since `compute_local_shape_and_global_offset`
# doesn't know the correct offsets for resharding.
# When we have a global_tensor of [0, 1, 2, 3, 4, 5, 6, 7], the local shard of 2D DTensor would be:
# local tensor -- rank0: [0, 1], rank1: [4, 5], rank2: [2, 3], rank3: [6, 7]
# current offsets -- rank0: [0, 0], rank1: [1, 0], rank2: [2, 0], rank3: [3, 0]
# Ideally, with strided sharding, the offsets should be rank0: [0, 0], rank1: [2, 0], rank2: [1, 0], rank3: [3, 0]
# TODO: to make the local shard of FSDP2 + TP correct for resharding, it would require strided_sharding
# as well as let compute_local_shape_and_global_offset takes into consideration of strided_sharding.
global_tensor = torch.arange(8).view(4, 2)
with CommDebugMode() as comm_mode:
mesh_2d = init_device_mesh(
self.device_type, (2, 2), mesh_dim_names=("DP", "TP")
)
tp_mesh = mesh_2d["TP"]
dtensor_tp = distribute_tensor(
global_tensor, tp_mesh, placements=[Shard(0)]
)
chunks = list(torch.chunk(dtensor_tp.to_local(), 2, dim=0))
shard_rank = 0 if self.rank // 2 == 0 else 1
sharded_param = chunks[shard_rank]
spec_2d = DTensorSpec(
mesh=mesh_2d,
placements=(_StridedShard(0, split_factor=2), Shard(0)),
tensor_meta=TensorMeta(
global_tensor.size(),
global_tensor.stride(),
global_tensor.dtype,
),
)
dtensor_2d = DTensor(
sharded_param,
spec_2d,
requires_grad=False,
)
self.assertEqual(
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 0
)
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
if __name__ == "__main__":
run_tests()
|