1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
|
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
funcol = torch.ops.c10d_functional
class RedistributeTest(DTensorTestBase):
@property
def world_size(self):
return 4
@with_comms
def test_shard_to_replicate_forward_backward(self):
# 1) test shard -> replicate forward
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
replica_spec = [Replicate()]
input_sizes_and_shard_dim = [
((self.world_size * 3, 3), 0),
((self.world_size * 3 + 1, 3), 0),
((self.world_size * 3 + 2, 3), 0),
((3, self.world_size * 3), 1),
((3, self.world_size * 3 + 1), 1),
((3, self.world_size * 3 + 2), 1),
]
comm_mode = CommDebugMode()
for input_size, shard_dim in input_sizes_and_shard_dim:
shard_spec = [Shard(shard_dim)]
expected_tensor = torch.randn(
input_size, device=self.device_type, requires_grad=True
)
dtensor = distribute_tensor(expected_tensor, device_mesh, shard_spec)
with comm_mode:
reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec)
self.assertEqual(reshard_dtensor.size(), torch.Size(input_size))
self.assertEqual(expected_tensor, reshard_dtensor.to_local())
self.assertEqual(
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1
)
# 2) test shard -> replicate backward:
# should give gradient as shard
grad_output = torch.ones_like(reshard_dtensor)
with comm_mode:
reshard_dtensor.backward(grad_output)
grad_input = dtensor.grad
self.assertEqual(grad_input.placements, shard_spec)
self.assertEqual(
grad_input.to_local(), torch.ones(dtensor.to_local().size())
)
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
def test_replicate_to_replicate_forward_backward(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
replica_spec = [Replicate()]
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
comm_mode = CommDebugMode()
# 1) test replicate -> replicate forward
replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec)
with comm_mode:
reshard_replica_tensor = replica_tensor.redistribute(
device_mesh, replica_spec
)
self.assertEqual(replica_tensor.size(), local_tensor.size())
self.assertEqual(replica_tensor, reshard_replica_tensor)
self.assertEqual(comm_mode.get_total_counts(), 0)
# 2) test replicate -> replicate backward:
# should give gradient as replicate
grad_output = torch.ones_like(reshard_replica_tensor)
with comm_mode:
reshard_replica_tensor.backward(grad_output)
grad_input = replica_tensor.grad
self.assertEqual(grad_input.placements, replica_spec)
self.assertEqual(grad_input.to_local(), torch.ones(12, 3))
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
def test_replicate_to_local_partial_grad(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
replica_spec = [Replicate()]
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec)
comm_mode = CommDebugMode()
with comm_mode:
out = replica_tensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[Partial()]
)
out.backward(torch.ones_like(out))
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
@with_comms
def test_replicate_to_shard_forward_backward(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
replica_spec = [Replicate()]
input_sizes_and_shard_dim = [
((self.world_size * 3, 3), 0),
((self.world_size * 3 + 1, 3), 0),
((self.world_size * 3 + 2, 3), 0),
((3, self.world_size * 3), 1),
((3, self.world_size * 3 + 1), 1),
((3, self.world_size * 3 + 2), 1),
]
comm_mode = CommDebugMode()
for input_size, shard_dim in input_sizes_and_shard_dim:
shard_spec = [Shard(shard_dim)]
# 1) test replicate -> shard forward
local_replica = torch.randn(
input_size, device=self.device_type, requires_grad=True
)
splitted_list = list(
torch.chunk(local_replica, self.world_size, dim=shard_dim)
)
# make local tensor as the element of the corresponding chunked list
local_tensor = splitted_list[self.rank]
replica_tensor = distribute_tensor(local_replica, device_mesh, replica_spec)
with comm_mode:
reshard_tensor = replica_tensor.redistribute(device_mesh, shard_spec)
self.assertEqual(reshard_tensor.size(), replica_tensor.size())
self.assertEqual(reshard_tensor.placements, shard_spec)
self.assertEqual(reshard_tensor.to_local(), local_tensor)
self.assertEqual(comm_mode.get_total_counts(), 0)
# 2) test replicate -> shard backward:
# should give gradient as replicate
grad_output = torch.ones_like(reshard_tensor)
with comm_mode:
reshard_tensor.backward(grad_output)
grad_input = replica_tensor.grad
self.assertEqual(grad_input.placements, replica_spec)
self.assertEqual(grad_input.to_local(), torch.ones(input_size))
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1
)
@with_comms
def test_partial_to_replicate_forward_backward(self):
# Although we don't allow user to reshard to produce a partial
# placement (i.e. user can't reshard to partial), we do allow
# replicate to partial internally, and also partial to replicate
# backward should work as expected
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True)
partial_spec = [Partial()]
replica_spec = [Replicate()]
comm_mode = CommDebugMode()
# test partial -> replicate, which trigger all_reduce
partial_tensor = DTensor.from_local(partial_local, device_mesh, partial_spec)
with comm_mode:
global_partial_tensor = partial_tensor.redistribute(
device_mesh, replica_spec
)
self.assertEqual(partial_tensor.size(), partial_local.size())
self.assertEqual(
partial_local * self.world_size, global_partial_tensor.to_local()
)
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
# test backward to have replicate grad on partial
# for from_local backward, we want the replicate() -> partial() to be
# pass through.
with comm_mode:
global_partial_tensor.backward(torch.ones_like(global_partial_tensor))
self.assertIsNotNone(partial_local.grad)
self.assertEqual(partial_local.grad.size(), partial_local.size())
self.assertEqual(partial_local.grad, torch.ones_like(partial_local))
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
def test_replicate_to_partial(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
partial_spec = Partial()
replica_spec = Replicate()
# 1) test replicate -> partial forward
replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
from torch.distributed.tensor._redistribute import Redistribute
comm_mode = CommDebugMode()
with comm_mode:
partial_tensor = Redistribute.apply(
replica_tensor, device_mesh, [partial_spec]
)
self.assertEqual(partial_tensor.size(), local_tensor.size())
# test it successfully zero out the contents on other ranks
self.assertEqual(
replica_tensor.to_local() / self.world_size, partial_tensor.to_local()
)
self.assertEqual(comm_mode.get_total_counts(), 0)
# replicate to partial on sub groups
local_tensor = torch.randn(12, 3, device=self.device_type)
device_mesh = DeviceMesh(
self.device_type,
torch.arange(self.world_size).reshape(self.world_size // 2, 2),
)
# 1) test replicate -> partial on 2d-mesh subgroups
replica_tensor = distribute_tensor(
local_tensor, device_mesh, [replica_spec, replica_spec]
)
with comm_mode:
partial_tensor = Redistribute.apply(
replica_tensor, device_mesh, [partial_spec, partial_spec]
)
self.assertEqual(partial_tensor.size(), local_tensor.size())
self.assertEqual(
replica_tensor.to_local() / self.world_size,
partial_tensor.to_local(),
)
self.assertEqual(comm_mode.get_total_counts(), 0)
@with_comms
def test_partial_to_shard(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
partial_spec = [Partial()]
my_rank = device_mesh.get_rank()
input_sizes_and_shard_dim = [
((self.world_size * 3, 3), 0),
((self.world_size * 3 + 1, 3), 0),
((self.world_size * 3 + 2, 3), 0),
((3, self.world_size * 3), 1),
((3, self.world_size * 3 + 1), 1),
((3, self.world_size * 3 + 2), 1),
]
comm_mode = CommDebugMode()
for input_size, shard_dim in input_sizes_and_shard_dim:
shard_spec = [Shard(shard_dim)]
partial_local = torch.ones(input_size, device=self.device_type)
partial_tensor = DTensor.from_local(
partial_local, device_mesh, partial_spec, run_check=False
)
full_chunk_size = (
input_size[shard_dim] + self.world_size - 1
) // self.world_size
chunk_sizes = [
max(
min(input_size[shard_dim], full_chunk_size * (idx + 1))
- full_chunk_size * idx,
0,
)
for idx in range(self.world_size)
]
local_shape = list(input_size)
local_shape[shard_dim] = chunk_sizes[my_rank]
# test partial to shard, trigger reduce_scatter
with comm_mode:
scatter_shard_tensor = partial_tensor.redistribute(
device_mesh, shard_spec
)
self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size())
self.assertEqual(scatter_shard_tensor.placements, shard_spec)
self.assertEqual(
scatter_shard_tensor.to_local(),
torch.ones(local_shape) * self.world_size,
)
self.assertEqual(
comm_mode.get_comm_counts()[funcol.reduce_scatter_tensor], 1
)
@with_comms
def test_redistribute_negative_shard_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
shard_spec = [Shard(1)]
shard_minus_spec = [Shard(-1)]
shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
self.assertEqual(shard_tensor.placements[0].dim, 1)
reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
self.assertEqual(shard_tensor.placements[0].dim, 1)
@with_comms
def test_redistribute_uneven_sharding(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
data_to_test = [
# uneven on last mesh dim
torch.randn((10, 5), device=self.device_type),
# uneven on both mesh dims
torch.randn((9, 5), device=self.device_type),
# smaller than mesh dim shape
torch.randn((3, 5), device=self.device_type),
torch.randn((1, 3), device=self.device_type),
]
sharding_to_tests = [
[Shard(0), Shard(0)],
[Shard(0), Shard(1)],
]
for input_tensor in data_to_test:
for placements in sharding_to_tests:
dt = distribute_tensor(input_tensor, mesh, placements)
dt_full_tensor = dt.full_tensor()
self.assertEqual(dt_full_tensor, input_tensor)
@with_comms
def test_redistribute_shard_dim_change(self):
# test 1d device mesh
mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size))
data_to_test = [
# evenly sharded case
torch.randn((8, 8), device=self.device_type),
# 3d or more dims
torch.randn((8, 8, 8), device=self.device_type),
# uneven case 1
torch.randn((8, 5), device=self.device_type),
# uneven case 2
torch.randn((5, 8), device=self.device_type),
# uneven case 3
torch.randn((5, 5), device=self.device_type),
]
sharding_src_dst_pairs = [([Shard(0)], [Shard(1)]), ([Shard(1)], [Shard(0)])]
comm_mode = CommDebugMode()
for input_data in data_to_test:
for src, dst in sharding_src_dst_pairs:
expected_dt = distribute_tensor(input_data.clone(), mesh_1d, dst)
sharded_dt = distribute_tensor(input_data, mesh_1d, src)
with comm_mode:
out_dt = sharded_dt.redistribute(mesh_1d, dst)
self.assertEqual(out_dt.placements, expected_dt.placements)
local_out_dt = out_dt.to_local()
local_expected_dt = expected_dt.to_local()
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
if self.device_type == "cuda":
self.assertEqual(
comm_mode.get_comm_counts()[
torch.ops._dtensor.shard_dim_alltoall
],
1,
)
else:
self.assertEqual(
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
1,
)
# test 2d device mesh
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 2)
)
data_to_test_2d = [
# evenly sharded case
torch.randn((8, 8), device=self.device_type),
# 3d or more dims
torch.randn((8, 8, 8), device=self.device_type),
# uneven case 1
torch.randn((8, 5), device=self.device_type),
# uneven case 2
torch.randn((5, 8), device=self.device_type),
# uneven case 3
torch.randn((5, 5), device=self.device_type),
]
sharding_src_dst_pairs_2d = [
([Shard(0), Shard(1)], [Shard(0), Shard(0)]),
([Shard(0), Shard(1)], [Shard(1), Shard(0)]),
([Shard(0), Shard(0)], [Shard(1), Shard(1)]),
]
comm_counts_2d = [
1, # 1: S1 -> S0
2, # 1: S1 -> R, 0: S0 -> S1, 1: R -> S0
2, # 1: S0 -> R, 0: S0 -> S1, 1: R -> S1
]
for input_data in data_to_test_2d:
if input_data.ndim > 2:
sharding_spec_combs = sharding_src_dst_pairs_2d + [
([Shard(0), Shard(2)], [Shard(1), Shard(0)]),
([Shard(1), Shard(1)], [Shard(1), Shard(2)]),
]
comm_counts_2d = comm_counts_2d + [
2, # 1. S2 -> R, 0: S0 -> S1, 1: R -> S0
1, # 1: S1 -> S2
]
else:
sharding_spec_combs = sharding_src_dst_pairs_2d
for idx, (src, dst) in enumerate(sharding_spec_combs):
expected_dt = distribute_tensor(input_data.clone(), mesh_2d, dst)
sharded_dt = distribute_tensor(input_data, mesh_2d, src)
with comm_mode:
out_dt = sharded_dt.redistribute(mesh_2d, dst)
self.assertEqual(out_dt.placements, expected_dt.placements)
self.assertEqual(comm_mode.get_total_counts(), comm_counts_2d[idx])
local_out_dt = out_dt.to_local()
local_expected_dt = expected_dt.to_local()
self.assertEqual(local_out_dt, local_expected_dt)
@with_comms
def test_shard_dim_alltoall(self):
# init 2d mesh here so we can test when group_rank != global_rank
mesh = init_device_mesh(self.device_type, (2, 2))
tensor = torch.randn(12, self.world_size, device=self.device_type)
new_tensor = shard_dim_alltoall(tensor, 0, 1, mesh, 0)
meta_tensor = torch.randn(12, self.world_size, device="meta")
new_meta_tensor = shard_dim_alltoall(meta_tensor, 0, 1, mesh, 0)
self.assertEqual(new_tensor.shape, new_meta_tensor.shape)
self.assertEqual(new_tensor.stride(), new_meta_tensor.stride())
class MultiDimRedistributeTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 8
@with_comms
def test_multi_dim_mesh(self):
devices = torch.arange(self.world_size)
for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]:
mesh_shape = torch.arange(self.world_size).view(-1, 2)
device_mesh = DeviceMesh(self.device_type, mesh_shape)
tensor_shape = (16, 24)
if torch.distributed.get_rank() == 0:
full_tensor = torch.randn(*tensor_shape)
else:
# these should be entirely ignored
# because distribute_tensor is expected to override shards in ranks != 0
full_tensor = torch.ones(*tensor_shape)
possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)]
all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities])))
all_inputs = list(
itertools.product(*(mesh_shape.ndim * [possibilities + [Partial()]]))
)
for inputs in all_inputs:
# if partial, temporarily make it Replicated, then replace replicated with partial afterwards
repl_inputs = [Replicate() if s.is_partial() else s for s in inputs]
dt = distribute_tensor(full_tensor, device_mesh, repl_inputs)
if repl_inputs != inputs:
# create a new DTensor reinterpreting some of the replicated entires as "Partial"
dt = DTensor.from_local(
dt.to_local(), device_mesh, inputs, run_check=False
)
for outputs in all_outputs:
# redistribute on target outputs
dt2 = dt.redistribute(device_mesh, outputs)
# replicate and then get first shard
local_full = dt2.full_tensor()
if torch.distributed.get_rank() == 0:
self.assertEqual(local_full.shape, full_tensor.shape)
num_sums = 1
for idx, input in enumerate(inputs):
if input.is_partial():
num_sums *= mesh_shape.size(idx)
expected = num_sums * full_tensor
self.assertEqual(local_full, expected)
@with_comms
def test_redistribute_shard_dim_multi_dim_mesh(self):
mesh = init_device_mesh(self.device_type, (2, 2, 2))
input_data = torch.randn((8, 8, 8), device=self.device_type)
sharding_src_dst_pairs_3d = [
([Shard(0), Shard(0), Shard(0)], [Shard(1), Shard(1), Shard(1)]),
([Shard(0), Shard(1), Shard(0)], [Shard(1), Shard(0), Shard(0)]),
([Shard(0), Shard(1), Shard(2)], [Shard(2), Shard(1), Shard(0)]),
([Shard(1), Shard(0), Shard(0)], [Replicate(), Shard(0), Shard(0)]),
([Shard(1), Replicate(), Shard(0)], [Replicate(), Shard(0), Shard(0)]),
([Shard(0), Shard(0), Shard(1)], [Shard(0), Shard(1), Shard(2)]),
]
comm_counts_3d = [
3, # 2: S0 - R, 1: S1 -> R, 0: S0 -> S1
3, # 2: S0 -> R, 1: S1 -> R, 0: S0 -> S1, 1: R -> S0, 2: R -> S0
2, # 2: S2 -> R, 0: S1 -> S2
1, # 0: S1 -> R
2, # 2: S0 -> R, 1: R -> S0, 2: R -> S0, 0: S1 -> R
2, # 2: S1 -> S2, 1: S0 -> S1
]
comm_mode = CommDebugMode()
for idx, (src_placement, dst_placement) in enumerate(sharding_src_dst_pairs_3d):
expected_dt = distribute_tensor(input_data.clone(), mesh, dst_placement)
sharded_dt = distribute_tensor(input_data, mesh, src_placement)
with comm_mode:
out_dt = sharded_dt.redistribute(mesh, dst_placement)
self.assertEqual(out_dt.placements, expected_dt.placements)
self.assertEqual(comm_mode.get_total_counts(), comm_counts_3d[idx])
local_out_dt = out_dt.to_local()
local_expected_dt = expected_dt.to_local()
self.assertEqual(local_out_dt, local_expected_dt)
if __name__ == "__main__":
run_tests()
|