1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import pathlib
import tempfile
import unittest
from numpy.testing import assert_array_equal
import torch
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
init_device_mesh,
)
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
Replicate,
Shard,
TensorMeta,
)
from torch.distributed.tensor._api import _shard_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import IS_FBCODE, run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.logging_utils import LoggingTestCase
c10d_functional = torch.ops.c10d_functional
class DummyMLP(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.net1 = torch.nn.Linear(5, 1024, device=device)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(1024, 4, device=device)
def forward(self, x):
return self.net2(F.relu(self.net1(x)))
def reset_parameters(self, *args, **kwargs):
with torch.no_grad():
self.net1.weight.fill_(0.5)
self.net2.weight.fill_(1)
self.net1.bias.fill_(1.5)
self.net2.bias.fill_(1.2)
class DTensorTest(DTensorTestBase):
@with_comms
def test_dtensor_constructor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3, requires_grad=True)
spec = DTensorSpec(
device_mesh,
tuple(placements),
tensor_meta=TensorMeta(
torch.Size([self.world_size * 3, 3]),
local_tensor.stride(),
local_tensor.dtype,
),
)
dist_tensor = DTensor(
local_tensor,
spec,
requires_grad=True,
)
self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3)))
with self.assertWarnsRegex(UserWarning, "To construct"):
DTensor(
local_tensor,
spec,
requires_grad=False,
)
@with_comms
def test_meta_dtensor(self):
device_mesh = self.build_device_mesh()
dist_specs = [[Shard(0)], [Replicate()]]
meta_tensor = torch.randn(1024, 2048, device="meta")
for dist_spec in dist_specs:
# Test distribute_tensor on meta tensor
meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
self.assertTrue(meta_dtensor.is_meta)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.2)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2)
self.assertFalse(meta_dtensor.is_meta)
self.assertEqual(meta_dtensor.device.type, self.device_type)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
# Test from_local on meta tensor
meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.5)
self.assertEqual(meta_dtensor.device.type, self.device_type)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
@with_comms
def test_modules_w_meta_dtensor(self):
model = DummyMLP("meta")
device_mesh = self.build_device_mesh()
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
model_tp.to_empty(device=self.device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
model_regular = DummyMLP(self.device_type)
model_regular_tp = parallelize_module(
model_regular, device_mesh, parallelize_plan
)
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
model_regular_tp.reset_parameters()
torch.manual_seed(0)
inp = torch.randn(20, 5, device=self.device_type)
output = model_tp(inp)
output_regular = model_regular_tp(inp)
self.assertEqual(output, output_regular)
output.sum().backward()
output_regular.sum().backward()
optim.step()
optim_regular.step()
torch.manual_seed(1)
inp = torch.randn(20, 5, device=self.device_type)
self.assertEqual(model_tp(inp), model_regular_tp(inp))
@with_comms
def test_dtensor_stride(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard0_spec = [Shard(0)]
local_tensor = torch.randn(4, 8)
global_shape = torch.Size([self.world_size * 4, 8])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
# won't affect stride
self.assertEqual(dist_tensor.stride(), (8, 1))
shard1_spec = [Shard(1)]
local_tensor = torch.randn(8, 4)
global_shape = torch.Size([8, self.world_size * 4])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
# will affect stride after DT initialized
self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
# if initialized from a transposed mat
local_tensor = torch.randn(8, 4, 8)
local_tensor_t = local_tensor.permute(1, 2, 0)
global_shape = torch.Size([4, self.world_size * 8, 8])
self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
global_stride = (8 * self.world_size, 1, 32 * self.world_size)
self.assertEqual(dist_tensor.stride(), global_stride)
@with_comms
def test_from_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3]))
replica_spec = [Replicate()]
ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec)
self.assertEqual(ddp_tensor.size(), local_tensor.size())
partial_spec = [Partial()]
partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec)
self.assertEqual(partial_tensor.size(), local_tensor.size())
# test dist tensor works with torch.Tensor during backwards
local_tensor_with_grad = torch.randn(3, 3, requires_grad=True)
# do some operations on local tensor
local_tensor_temp = local_tensor_with_grad * 3
# create the dist tensor with non leaf local tensor, dist tensor created
# should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 3
self.assertIsInstance(output, DTensor)
# trigger .backward() on dist tensor directly
local_grad = torch.ones(3, 3)
grad_output = DTensor.from_local(local_grad, device_mesh, placements)
# run backward directly on dist tensor
output.backward(grad_output)
# check it gradients flow back to original torch.Tensor
self.assertIsNotNone(local_tensor_with_grad.grad)
expected_grad = torch.ones(3, 3) * 9
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
def test_from_local_uneven_sharding(self):
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
uneven_dim0_size = self.world_size + 1
global_tensor = torch.randn(uneven_dim0_size, 2)
shard_placement = Shard(0)
tensor_list, _ = shard_placement._split_tensor(
global_tensor,
device_mesh.size(mesh_dim=0),
with_padding=False,
contiguous=True,
)
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
shape=global_tensor.size(),
stride=global_tensor.stride(),
)
self.assertEqual(dtensor.size(), global_tensor.size())
self.assertEqual(dtensor.stride(), global_tensor.stride())
@with_comms
def test_from_local_uneven_sharding_raise_error(self):
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
uneven_dim0_size = self.world_size + 1
global_tensor = torch.randn(uneven_dim0_size, 2)
shard_placement = Shard(0)
tensor_list, _ = shard_placement._split_tensor(
global_tensor,
device_mesh.size(mesh_dim=0),
with_padding=False,
contiguous=True,
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
):
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
shape=global_tensor.size(),
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
):
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
stride=global_tensor.stride(),
)
@with_comms
def test_from_local_negative_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(-1)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.placements[0].dim, 1)
@with_comms
def test_to_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
spec = DTensorSpec(
mesh=device_mesh,
placements=placements,
tensor_meta=TensorMeta(
dist_tensor_shape,
local_tensor_with_grad.stride(),
local_tensor_with_grad.dtype,
),
)
sharded_tensor = DTensor(
local_tensor_with_grad,
spec,
requires_grad=True,
)
self.assertEqual(sharded_tensor.size(), dist_tensor_shape)
self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad)
# test dist tensor works with torch.Tensor during backwards
# dist tensor created is a leaf node, do some operation on dist tensor
temp_st = sharded_tensor * 3
# do some operation on local tensor of the dist tensor
new_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
res = temp_st.to_local() + new_tensor_with_grad
# call backward directly on torch.Tensor, and see if it works by
# propagating through dist tensor
res.sum().backward()
self.assertIsNotNone(sharded_tensor.grad)
self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3)
# test the case when grad stride is different from fwd input.
res = sharded_tensor.to_local()
model = torch.nn.ReLU()
res.register_hook(lambda grad: grad.t())
target = torch.randn(3, 3, device=self.device_type)
mae_loss = torch.nn.L1Loss()
output = mae_loss(model(res), target)
# The manual change to grad stride leads to the failure of the copy op afterwards.
# so that we need a try-catch here.
try:
output.backward()
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
# test the case under no-grad we directly return the local tensor
with torch.no_grad():
local_no_grad = sharded_tensor.to_local()
assert local_no_grad is sharded_tensor._local_tensor
@with_comms
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
comm_mode = CommDebugMode()
with comm_mode:
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[Partial()]
)
local_out.backward(torch.ones_like(local_out))
self.assertEqual(
comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1
)
self.assertEqual(
comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1
)
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
def test_full_tensor_sync(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
full_out = sharded_dtensor.full_tensor()
self.assertFalse(isinstance(full_out, AsyncCollectiveTensor))
self.assertEqual(full_out, global_tensor)
@with_comms
def test_full_tensor_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_out = sharded_dtensor.full_tensor(grad_placements=[Partial()])
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
def test_dtensor_new_empty_strided(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
new_strided_dtensor = my_dtensor.new_empty_strided(
(8, 8), (8, 1), requires_grad=True
)
# test the op produces new dtensor and autograd works
self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
new_strided_dtensor.sum().backward()
self.assertIsNotNone(new_strided_dtensor.grad)
self.assertIsInstance(new_strided_dtensor.grad, DTensor)
# test backward new_empty_strided with sharding works correctly
my_dtensor.to_local().sum().backward()
local_tensor.sum().backward()
self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
self.assertEqual(
my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
local_tensor.grad,
)
@with_comms
def test_dtensor_async_output(self):
# Tests that if the output of some dtensor operations isn't used in any compute,
# the output should be an AsyncCollectiveTensor (representing the fact that
# we haven't synced the collective yet).
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
def fn(dt):
dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True)
# Make sure we haven't synced yet
# TODO: figure out why this is returning None
# self.assertTrue(_tensor_needs_wait(dt_out_redistribute))
dt_out_redistribute_view = dt_out_redistribute.view(
dt_out_redistribute.shape
)
local_tensor = dt_out_redistribute_view.to_local()
return local_tensor
x = torch.ones((4, 2), device=self.device_type)
dt = distribute_tensor(x, mesh, [Shard(0)])
out = fn(dt)
# Make sure we haven't synced yet
self.assertEqual(type(out), AsyncCollectiveTensor)
self.assertFalse(out.completed)
out_view = out.view(-1)
# Assert that output is a `AsyncCollectiveTensor`
self.assertEqual(type(out_view), AsyncCollectiveTensor)
self.assertFalse(out.completed)
# Use the daa, requiring a sync
ref = torch.ones((4, 2), device=self.device_type) + 1
ref = ref.view(-1)
out_data = out_view + 1
self.assertEqual(type(out_data), torch.Tensor)
self.assertEqual(out_data, ref)
# test async_op = False default
sync_out = dt.redistribute(mesh, [Replicate()])
self.assertFalse(isinstance(sync_out, AsyncCollectiveTensor))
self.assertEqual(sync_out.to_local(), x)
@with_comms
def test_from_local_then_to_local(self):
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
# step 1. construct from construct local tensor
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
# do some operations on local tensor
local_tensor_temp = local_tensor_with_grad + 8
# step 2. create the dist tensor with non leaf local tensor, dist tensor
# created should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 6
self.assertIsInstance(output, DTensor)
# step 3. do some operation on local tensor of the dist tensor
new_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
res = output.to_local() + new_tensor_with_grad
# call backward directly on torch.Tensor, and see if it works by
# propagating all the way back to the original torch.Tensor
res.sum().backward()
self.assertIsNotNone(local_tensor_with_grad.grad)
expected_grad = torch.ones(3, 3) * 6
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
def test_dtensor_spec_read_only_after_set(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
# modify placements, and dist_tensor's spec should not be changed
placements[0] = Replicate()
self.assertTrue(sharded_tensor.placements is not placements)
self.assertNotEqual(sharded_tensor.placements, placements)
@with_comms
def test_dtensor_spec_hash(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
local_tensor2 = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements)
# note that DTensorSpec without real tensor data, so the hash would be the same
# as long as the mesh, placements and tensor properties are the same
self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec))
# change the placements would change the hash
local_tensor3 = torch.ones(3, 3)
replica_spec = [Replicate()]
replica_tensor = DTensor.from_local(
local_tensor3, device_mesh, replica_spec, run_check=False
)
self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec))
@with_comms
def test_dtensor_properties(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.device.type, self.device_type)
@with_comms
def test_dtensor_save_load(self):
import io
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertEqual(sharded_tensor, reloaded_st)
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
@with_comms
@unittest.skipIf(
IS_FBCODE,
"subprocess import torch fails with ModuleNotFoundError: No module named 'torch' in fbcode",
)
def test_dtensor_save_load_import(self):
for should_import in [True, False]:
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
with tempfile.NamedTemporaryFile() as f:
torch.save(sharded_tensor, f)
import_string = (
"import torch.distributed.tensor;" if should_import else ""
)
filename = pathlib.Path(f.name)
err_msg = (
(
"_pickle.UnpicklingError: Weights only load failed. "
"``torch.distributed.tensor`` must be imported to load DTensors"
)
if not should_import
else None
)
self._attempt_load_from_subprocess(filename, import_string, err_msg)
@with_comms
def test_shard_tensor(self):
ws = self.world_size
device_mesh = DeviceMesh(self.device_type, list(range(ws)))
full_tensor = torch.arange(ws * ws).reshape(ws, ws)
# Shard by row
placements = [Shard(0)]
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
# Shard by column
placements = [Shard(1)]
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
# assert full tensor is not changed
self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
@with_comms
def test_shard_tensor_2d(self):
ws = self.world_size
full_tensor = torch.arange(ws).reshape(2, ws // 2)
device_mesh = DeviceMesh(self.device_type, full_tensor)
# Shard by row and column
placements = [Shard(0), Shard(1)]
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
self.assertEqual(sharded_tensor.size(), torch.Size([2, ws // 2]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor.item(), self.rank)
class DTensorMeshTest(DTensorTestBase):
@property
def world_size(self):
return 8
def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
if self.rank in mesh:
self.assertEqual(tensor, exp_in_mesh)
else:
self.assertEqual(tensor, exp_out_of_mesh)
@with_comms
def test_dtensor_device_mesh_device_conversion(self):
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# construct from a cpu local tensor with cuda device mesh
# should automatically convert the dist tensor to cuda
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
def test_dtensor_api_device_mesh_context_manager(self):
with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=placements
)
with DeviceMesh(self.device_type, list(range(self.world_size))):
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, placements=placements)
replica_spec = [Replicate()]
replica_tensor = sharded_tensor.redistribute(placements=replica_spec)
self.assertEqual(
replica_tensor.size(), torch.Size([3 * self.world_size, 3])
)
with DeviceMesh(self.device_type, torch.arange(self.world_size)):
placements = [Shard(0)]
global_shape = torch.Size([3 * self.world_size, 3])
global_tensor = torch.randn(global_shape)
sharded_tensor = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
with mesh_2d:
shard_2d_spec = [Shard(0), Replicate()]
tensor_2d = distribute_tensor(global_tensor, placements=shard_2d_spec)
self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3]))
sharded_after_2d = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3]))
@with_comms
def test_dtensor_2d_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works
placements = [Shard(0), Shard(1)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(
dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# if shard on the same tensor dimension
# we should correctly construct the global tensor size
shard_same_dim_spec = [Shard(0), Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec)
self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
@with_comms
def test_device_mesh_nd(self):
# construct a cuda device mesh
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works
placements = [Shard(0), Shard(1), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# construct a dist tensor on 3d device mesh with some shards on same dim
placements = [Shard(0), Shard(0), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
def test_dtensor_spec_local_shard_offset(self):
device_mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
tensor_shape = (3 * self.world_size, 3 * self.world_size)
# sharding specs and its corresponding local shard offsets
shard_spec_and_offsets = [
(
[Shard(0), Replicate()],
(3 * (self.world_size // 2) * (self.rank // 4), 0),
),
(
[Shard(1), Replicate()],
(0, 3 * (self.world_size // 2) * (self.rank // 4)),
),
(
[Replicate(), Shard(0)],
(3 * (self.world_size // 4) * (self.rank % 4), 0),
),
(
[Replicate(), Shard(1)],
(0, 3 * (self.world_size // 4) * (self.rank % 4)),
),
]
from torch.distributed._tensor._utils import (
compute_local_shape_and_global_offset,
)
# loop through all sharding specs and check local shard offsets
logical_tensor = torch.randn(tensor_shape)
for placements, expected_shard_offsets in shard_spec_and_offsets:
dtensor = distribute_tensor(logical_tensor, device_mesh, placements)
_, offset = compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, dtensor.placements
)
self.assertEqual(expected_shard_offsets, offset)
@with_comms
def test_from_local_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
local_tensor = torch.ones(3, 4)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
self.assertEqual(dtensor.size(), torch.Size([6, 4]))
self.sub_mesh_assert_equal(
mesh.mesh,
torch.ones(3, 4),
torch.tensor([]),
dtensor.to_local(),
)
# test dtensor created in submesh, the operation should only
# be applied to the local shard inside the mesh, not the whole
# world, so only 0/2 really run the computation
dtensor = dtensor + 2
self.sub_mesh_assert_equal(
mesh.mesh,
torch.ones(3, 4) + 2,
torch.tensor([]),
dtensor.to_local(),
)
@with_comms
def test_default_value_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
# test scalar return value
local_tensor1 = torch.ones(4, 3)
local_tensor2 = torch.ones(4, 3)
dtensor1 = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
dtensor2 = DTensor.from_local(local_tensor2, mesh, [Shard(0)])
local_res = dtensor1.equal(dtensor2) # equal returns local result
self.sub_mesh_assert_equal(
mesh.mesh,
True,
True,
local_res,
)
# test 0-d tensor return value
local_tensor = torch.ones(4, 3)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]).sum()
self.sub_mesh_assert_equal(
mesh.mesh,
torch.tensor(12.0),
torch.tensor(0.0),
dtensor.to_local(),
)
# test List[torch.Tensor] return value
local_tensor = torch.ones(3, 4)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
dtensor_list = dtensor.split([2, 2], dim=1)
self.sub_mesh_assert_equal(
mesh.mesh,
[torch.ones(3, 2)] * 2,
[torch.tensor([])] * 2,
[dt.to_local() for dt in dtensor_list],
)
@with_comms
def test_redistribute_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
# test redistribute on a submesh
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
replicated_dtensor = sharded_dtensor.redistribute(placements=[Replicate()])
self.sub_mesh_assert_equal(
mesh.mesh, torch.ones(8, 3), torch.tensor([]), replicated_dtensor.to_local()
)
sharded_again = replicated_dtensor.redistribute(placements=[Shard(0)])
self.sub_mesh_assert_equal(
mesh.mesh, torch.ones(4, 3), torch.tensor([]), sharded_again.to_local()
)
@with_comms
def test_implicit_replication(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
with implicit_replication():
# We put the scalar tensor as the left operand so we can test out
# when a non-dtensor is a the arg in the args list.
out_dt = torch.ones(3, device=self.device_type) + sharded_dtensor
self.assertEqual(out_dt.placements, [Shard(0)])
self.assertEqual(out_dt.shape, (4 * self.world_size, 3))
local_shard = out_dt.to_local()
self.assertEqual(local_shard.shape, (4, 3))
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
@with_comms
def test_auto_implicit_replication(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
local_tensor = torch.ones(self.world_size, 3, device=self.device_type)
sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
# automatically turn tensor to DTensor replicate when ndim = 0 and numel = 1
ndim_0_tensor = torch.tensor(1, device=self.device_type)
def add_scalar_tensor_with_dtensor():
return ndim_0_tensor + sharded_dtensor
result = add_scalar_tensor_with_dtensor().to_local()
self.assertEqual(result, local_tensor + ndim_0_tensor)
self.assertNotWarn(
add_scalar_tensor_with_dtensor,
"Found a non-scalar tensor with numel=1 and ndim!=0",
)
# automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1
numel_1_tensor = torch.tensor([1], device=self.device_type)
self.assertEqual(
(numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
)
@with_comms
def test_implicit_replication_for_foreach_ops(self):
mesh = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
global_tensor1 = torch.randn(4, 2)
dtensor_2d = distribute_tensor(global_tensor1, mesh, [Shard(0), Shard(1)])
self.assertEqual(dtensor_2d.full_tensor(), global_tensor1)
global_tensor2 = torch.randn(4)
dtensor_1d = distribute_tensor(global_tensor2, mesh["dp"], [Shard(0)])
dtensor_list = [dtensor_2d, dtensor_1d]
# Check without implicit replication, cross mesh error raises.
with self.assertRaisesRegex(
RuntimeError, "DTensor does not support cross-mesh operation yet!"
):
torch._foreach_mul(dtensor_list, 2.0)
# Check dtensor result matches tensor result.
with implicit_replication():
torch._foreach_mul_(dtensor_list, 2.0)
self.assertEqual(dtensor_list[0].full_tensor(), global_tensor1 * 2.0)
self.assertEqual(dtensor_list[1].full_tensor(), global_tensor2 * 2.0)
mesh_1d = DeviceMesh.from_group(mesh["tp"].get_group(), self.device_type)
dtensor_1d = distribute_tensor(global_tensor2, mesh_1d, [Shard(0)])
dtensor_list = [dtensor_2d, dtensor_1d]
# Check even with implicit replication, cross mesh error raises if different device mesh don't
# belong to the same root mesh.
with self.assertRaisesRegex(
RuntimeError, "DTensor does not support cross-mesh operation yet!"
):
torch._foreach_mul_(dtensor_list, 2.0)
@with_comms
def test_metadata_consistency_check(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
# Create a local tensor with specific metadata and check dtype change
local_tensor = torch.randn(3, 3, requires_grad=True, dtype=torch.float32)
if self.rank == 0:
local_tensor = local_tensor.to(dtype=torch.float64)
with self.assertRaises(ValueError):
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
try:
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
# Create a local tensor with specific metadata and check requires_grad change
local_tensor = torch.randn(3, 3, requires_grad=True, dtype=torch.float32)
if self.rank == 0:
local_tensor.requires_grad = False
with self.assertRaises(ValueError):
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
try:
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
# Create a local tensor with specific metadata and check stride change
local_tensor = torch.randn(3, 4, requires_grad=True, dtype=torch.float32)
if self.rank == 0:
local_tensor = local_tensor.t() # transpose changes the stride
with self.assertRaises(ValueError):
DTensor.from_local(local_tensor, device_mesh, placements, run_check=True)
try:
DTensor.from_local(local_tensor, device_mesh, placements, run_check=False)
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
class TestDTensorPlacementTypes(DTensorTestBase):
@property
def world_size(self):
return 8
def _create_tensor(self, size):
# Keep everything deterministic.
torch.manual_seed(0)
tensor = torch.rand(size)
if self.device_type == "cuda":
return tensor.cuda()
else:
return tensor
@with_comms
def test_split_tensor_1D(self) -> None:
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
shard_placement = Shard(0)
for size in range(8):
tensor = self._create_tensor(size)
splitted_tensor_list, pad_sizes = shard_placement._split_tensor(
tensor,
mesh.size(),
with_padding=True,
contiguous=True,
)
if size == 0:
# when tensor size is 0, there is no padding needed for all the ranks.
expected_pad_sizes = []
assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [
False if splitted_tensor.numel() > 0 else True
for splitted_tensor in splitted_tensor_list
]
expected_is_tensor_empty = [True] * self.world_size
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
else:
expected_pad_sizes = [
0 if idx < size else 1
for idx, _ in enumerate(range(self.world_size))
]
assert_array_equal(expected_pad_sizes, pad_sizes)
from torch.distributed.tensor._collective_utils import unpad_tensor
unpadded_list = [
(
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
if pad_sizes[i] > 0
else tensor
)
for i, tensor in enumerate(splitted_tensor_list)
]
expected_is_tensor_empty = [
False if idx < size else True
for idx, _ in enumerate(range(self.world_size))
]
is_tensor_empty = [
False if unpadded_tensor.numel() > 0 else True
for unpadded_tensor in unpadded_list
]
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
class DTensorLogTest(LoggingTestCase):
def test_dtensor_log(self):
if not torch.distributed.is_available() or not torch.cuda.is_available():
return
env = dict(os.environ)
env["TORCH_LOGS"] = "+dtensor"
env["RANK"] = "0"
env["WORLD_SIZE"] = "1"
env["MASTER_PORT"] = "12345"
env["MASTER_ADDR"] = "localhost"
stdout, stderr = self.run_process_no_exception(
"""\
import logging
import torch
from torch.distributed._tensor import init_device_mesh, distribute_tensor, Shard
mesh = init_device_mesh("cuda", (1,), mesh_dim_names=("dp",))
placements = [Shard(0)]
tensor = torch.randn(12, 8, 8)
dtensor = distribute_tensor(tensor, mesh, placements)
dtensor.max()
""",
env=env,
)
self.assertIn("_dispatch.py", stderr.decode("utf-8"))
self.assertIn("redistribute=False", stderr.decode("utf-8"))
if __name__ == "__main__":
run_tests()
|