1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757
|
"""
To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.py -e MLP_operation_tracing
"""
import argparse
import os
from typing import Callable, Dict, Union
import torch
import torch.nn as nn
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
MLPModule,
MLPStacked,
ModelArgs,
NUM_DEVICES,
Transformer,
)
from torch.utils.checkpoint import checkpoint
def get_device_type() -> str:
return (
"cuda"
if torch.cuda.is_available() and torch.cuda.device_count() >= 4
else "cpu"
)
c10d_functional = torch.ops.c10d_functional
aten = torch.ops.aten
supported_ops = [aten.view.default, aten._to_copy.default]
class CommDebugModeExample:
"""
Checks if the set of keys in ground truth dictionary and the set
produced in advanced_module_tracker are in the same order
"""
def __init__(self, world_size: int, rank: int) -> None:
self.world_size = world_size
self.rank = rank
self.device_type = get_device_type()
def _MLP_model_setup(
self, model_type: type, parallelize_plan: Union[None, dict] = None
) -> tuple[nn.Module, torch.Tensor]:
"""
Creates MLP or MLPStacked model for examples
"""
if parallelize_plan is None:
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
inp_size = [8, 10]
inp = torch.rand(*inp_size, device=self.device_type)
model = model_type(self.device_type)
model = parallelize_module(model, device_mesh, parallelize_plan)
return model, inp
def _transformer_model_setup(
self, is_seq_parallel: bool = False
) -> tuple[nn.Module, torch.Tensor]:
"""
Creates transformer model for examples
"""
device_mesh = DeviceMesh(
self.device_type,
torch.arange(0, NUM_DEVICES),
)
model_args = ModelArgs()
model = Transformer(model_args).to(device=self.device_type)
model = Transformer.parallelize(model, device_mesh, is_seq_parallel)
inp_size = [8, 8]
inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type)
return model, inp
def example_MLP_distributed_sharding_display(self) -> None:
"""
Example of obtaining all module's FQN and parameters for a given distributed model and printing the sharding info
Expected output:
MLPModule.net1.weight: (Shard(dim=0),)
MLPModule.net1.bias: (Shard(dim=0),)
MLPModule.net2.weight: (Shard(dim=1),)
MLPModule.net2.bias: (Replicate(),)
"""
torch.manual_seed(0)
model, inp = self._MLP_model_setup(model_type=MLPModule)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
print(comm_mode.get_sharding_info())
def example_MLPStacked_distributed_sharding_display(self) -> None:
"""
Example of obtaining all module's FQN and parameters for a given
distributed model with nested modules and printing the sharding info
Expected output:
MLPStacked.layers.0.net1.weight: (Shard(dim=0),)
MLPStacked.layers.0.net1.bias: (Shard(dim=0),)
MLPStacked.layers.0.net2.weight: (Shard(dim=1),)
MLPStacked.layers.0.net2.bias: (Replicate(),)
MLPStacked.layers.1.net1.weight: (Shard(dim=0),)
MLPStacked.layers.1.net1.bias: (Shard(dim=0),)
MLPStacked.layers.1.net2.weight: (Shard(dim=1),)
MLPStacked.layers.1.net2.bias: (Replicate(),)
"""
torch.manual_seed(0)
parallelize_plan = {
"MLPStacked.layers.0.net1": ColwiseParallel(),
"MLPStacked.layers.0.net2": RowwiseParallel(),
"MLPStacked.layers.1.net1": ColwiseParallel(),
"MLPStacked.layers.1.net2": RowwiseParallel(),
}
model, inp = self._MLP_model_setup(
model_type=MLPStacked, parallelize_plan=parallelize_plan
)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
print(comm_mode.get_sharding_info())
def example_MLP_module_tracing(self) -> None:
"""
Example code to demonstrate CommModeDebug's module level tracing using a MLP model.
Prints a table of module level collective tracing information and logs table to comm_mode_log.txt
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
"""
torch.manual_seed(0)
model, inp = self._MLP_model_setup(model_type=MLPModule)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
# print the module level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
comm_mode.log_comm_debug_tracing_table_to_file(noise_level=0)
def example_transformer_module_tracing(self) -> None:
"""
Example code to demonstrate CommModeDebug's module level tracing using a distributed Transformer model.
Prints a table of module level collective tracing information and logs table to comm_mode_log.txt
Expected output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 6
*c10d_functional.all_gather_into_tensor: 1
Transformer
FORWARD PASS
*c10d_functional.all_reduce: 6
*c10d_functional.all_gather_into_tensor: 1
Transformer.tok_embeddings
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.pos_embeddings
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.dropout
Transformer.layers.0
FORWARD PASS
*c10d_functional.all_reduce: 2
Transformer.layers.0.attention_norm
Transformer.layers.0.attention
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.0.attention.wq
Transformer.layers.0.attention.wk
Transformer.layers.0.attention.wv
Transformer.layers.0.attention.wo
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.0.attention.resid_dropout
Transformer.layers.0.ffn_norm
Transformer.layers.0.feed_forward
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.0.feed_forward.w1
Transformer.layers.0.feed_forward.gelu
Transformer.layers.0.feed_forward.w2
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.0.feed_forward.resid_dropout
Transformer.layers.1
FORWARD PASS
*c10d_functional.all_reduce: 2
Transformer.layers.1.attention_norm
Transformer.layers.1.attention
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.1.attention.wq
Transformer.layers.1.attention.wk
Transformer.layers.1.attention.wv
Transformer.layers.1.attention.wo
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.1.attention.resid_dropout
Transformer.layers.1.ffn_norm
Transformer.layers.1.feed_forward
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.1.feed_forward.w1
Transformer.layers.1.feed_forward.gelu
Transformer.layers.1.feed_forward.w2
FORWARD PASS
*c10d_functional.all_reduce: 1
Transformer.layers.1.feed_forward.resid_dropout
Transformer.norm
Transformer.output
FORWARD PASS
*c10d_functional.all_gather_into_tensor: 1
"""
torch.manual_seed(0)
model, inp = self._transformer_model_setup()
comm_mode = CommDebugMode()
with comm_mode:
model(inp)
# print the module level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
comm_mode.log_comm_debug_tracing_table_to_file(noise_level=0)
def example_MLP_operation_tracing(self) -> None:
"""
Example code to demonstrate CommModeDebug's module operation level tracing using a distributed MLP model.
Prints a table of module opoeration level collective tracing information and logs table to comm_mode_log.txt
Expected output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.view.default
**aten.sum.default
**aten.ones_like.default
BACKWARD PASS
**aten.expand.default
MLPModule
*module type: class 'torch.testing._internal.distributed._tensor.common_dtensor.MLPModule'
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.view.default
**aten.view.default
**aten.view.default
MLPModule.net1
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=0),)
*bias: (Shard(dim=0),)
FORWARD PASS
**aten.detach.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.view.default
**aten.t.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.addmm.default
shape: [torch.Size([16]), torch.Size([8, 10]), torch.Size([10, 16])]
sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.addmm.default
**aten.view.default
BACKWARD PASS
**aten.t.default
shape: [torch.Size([8, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.mm.default
shape: [torch.Size([16, 8]), torch.Size([8, 10])]
sharding: [(Shard(dim=0),), (Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.mm.default
**aten.t.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.sum.dim_IntList
shape: [torch.Size([8, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.sum.dim_IntList
**aten.view.default
shape: [torch.Size([1, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.view.default
**aten.detach.default
shape: [torch.Size([16])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.t.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.detach.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
MLPModule.relu
*module type: class 'torch.nn.modules.activation.ReLU'
FORWARD PASS
**aten.view.default
**aten.relu.default
**aten.detach.default
BACKWARD PASS
**aten.detach.default
**aten.threshold_backward.default
MLPModule.net2
*module type: class 'torch.nn.modules.linear.Linear'
*Parameter List
*weight: (Shard(dim=1),)
*bias: (Replicate(),)
FORWARD PASS
*c10d_functional.all_reduce: 1
**aten.detach.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.view.default
**aten.view.default
shape: [torch.Size([8, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.view.default
**aten.t.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.addmm.default
shape: [torch.Size([10]), torch.Size([8, 16]), torch.Size([16, 10])]
sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.div.Tensor
**aten.addmm.default
**_c10d_functional.all_reduce.default
**aten.view.default
BACKWARD PASS
**aten.t.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.mm.default
shape: [torch.Size([8, 10]), torch.Size([10, 16])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.mm.default
**aten.t.default
shape: [torch.Size([8, 10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.mm.default
shape: [torch.Size([10, 8]), torch.Size([8, 16])]
sharding: [(Replicate(),), (Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.mm.default
**aten.t.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.sum.dim_IntList
shape: [torch.Size([8, 10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.sum.dim_IntList
**aten.view.default
shape: [torch.Size([1, 10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.view.default
**aten.detach.default
shape: [torch.Size([10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10])]
sharding: [(Replicate(),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
**aten.t.default
shape: [torch.Size([16, 10])]
sharding: [(Shard(dim=0),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.t.default
**aten.detach.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
shape: [torch.Size([10, 16])]
sharding: [(Shard(dim=1),)]
device mesh: DeviceMesh([0, 1, 2, 3])
**aten.detach.default
**aten.detach.default
"""
torch.manual_seed(0)
model, inp = self._MLP_model_setup(model_type=MLPModule)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=3))
comm_mode.log_comm_debug_tracing_table_to_file(noise_level=3)
def example_transformer_operation_tracing(
self, is_seq_parallel: bool = False
) -> None:
"""
Example code to demonstrate CommModeDebug's module operation level tracing using a distributed transformer model.
Prints a table of module opoeration level collective tracing information, excluding trivial operations and logs
table to transformer_operation_log.txt
"""
torch.manual_seed(0)
model, inp = self._transformer_model_setup()
comm_mode = CommDebugMode()
with comm_mode:
model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=2))
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
def example_MLP_json_dump(self) -> None:
"""
Example code to demonstrate CommModeDebug's json dump using a MLP model. Sends the information to default
comm_mode_log.json file
"""
torch.manual_seed(0)
model, inp = self._MLP_model_setup(model_type=MLPModule)
comm_mode = CommDebugMode()
with comm_mode:
output_tp = model(inp)
output_tp.sum().backward()
comm_mode.generate_json_dump()
def example_transformer_json_dump(self, is_seq_parallel: bool = False) -> None:
"""
Example code to demonstrate CommModeDebug's json dump using a transformer model, excluding the trivial
operations. Sends the information to user-passed transformer_log.json file
"""
torch.manual_seed(0)
model, inp = self._transformer_model_setup()
comm_mode = CommDebugMode()
with comm_mode:
model(inp)
comm_mode.generate_json_dump(file_name="transformer_log.json", noise_level=1)
comm_mode.generate_json_dump(file_name="transformer_log_2.json", noise_level=2)
def example_activation_checkpointing(self) -> None:
"""
Example code showing that CommDebugMode is able to differentiate between backward passes
and activation checkpointing. Sends the information to default comm_mode_log.json file.
The output for the example output is shown below:
Global
FORWARD PASS
**aten.sum.default
**aten.ones_like.default
BACKWARD PASS
**aten.expand.default
Foo
*module type: class '__main__.CommDebugModeExample.example_activation_checkpointing.locals.Foo'
FORWARD PASS
**aten.relu.default
**aten.empty.memory_format
**aten.empty.memory_format
**aten.relu.default
BACKWARD PASS
**aten.threshold_backward.default
Foo.linears.0
*module type: class 'torch.nn.modules.linear.Linear'
FORWARD PASS
**aten.addmm.default
BACKWARD PASS
**aten.mm.default
**aten.sum.dim_IntList
Foo.linears.1
*module type: class 'torch.nn.modules.linear.Linear'
FORWARD PASS
**aten.addmm.default
ACTIVATION CHECKPOINTING
**aten.mm.default
**aten.mm.default
**aten.sum.dim_IntList
**aten.threshold_backward.default
"""
class Foo(torch.nn.Module):
def __init__(self, n_layers: int, dim: int, use_ac: bool = False):
super().__init__()
self.linears = torch.nn.ModuleList()
self.use_ac = use_ac
for _ in range(n_layers):
self.linears.append(torch.nn.Linear(dim, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, block in enumerate(self.linears):
if i >= 1 and self.use_ac:
x = checkpoint(
block, x, preserve_rng_state=True, use_reentrant=False
)
else:
x = block(x)
assert x is not None
x = torch.nn.functional.relu(x)
return x
bsz = 2
dim = 8
n_layers = 2
model = Foo(n_layers, dim, True)
x = torch.randn(bsz, dim)
comm_mode = CommDebugMode()
with comm_mode:
model(x).sum().backward()
print(comm_mode.generate_comm_debug_tracing_table(noise_level=2))
comm_mode.log_comm_debug_tracing_table_to_file(noise_level=2)
comm_mode.generate_json_dump(noise_level=2)
def run_example(world_size: int, rank: int, example_name: str) -> None:
# set manual seed
# intializing class with all of the functions
instantiated_example = CommDebugModeExample(world_size, rank)
# dict that stores example code function names
name_to_example_code: Dict[str, Callable[[], None]] = {
"MLP_distributed_sharding_display": instantiated_example.example_MLP_distributed_sharding_display,
"MLPStacked_distributed_sharding_display": instantiated_example.example_MLPStacked_distributed_sharding_display,
"MLP_module_tracing": instantiated_example.example_MLP_module_tracing,
"transformer_module_tracing": instantiated_example.example_transformer_module_tracing,
"MLP_operation_tracing": instantiated_example.example_MLP_operation_tracing,
"transformer_operation_tracing": instantiated_example.example_transformer_operation_tracing,
"MLP_json_dump": instantiated_example.example_MLP_json_dump,
"transformer_json_dump": instantiated_example.example_transformer_json_dump,
"activation_checkpointing": instantiated_example.example_activation_checkpointing,
}
name_to_example_code[example_name]()
if __name__ == "__main__":
# this script is launched via torchrun which automatically manages ProcessGroup
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 4 # our example uses 4 worker ranks
parser = argparse.ArgumentParser(
description="comm_mode_feature examples",
formatter_class=argparse.RawTextHelpFormatter,
)
example_prompt = (
"choose one comm_mode_feature example from below:\n"
"\t1. MLP_distributed_sharding_display\n"
"\t2. MLPStacked_distributed_sharding_display\n"
"\t3. MLP_module_tracing\n"
"\t4. transformer_module_tracing\n"
"\t5. MLP_operation_tracing\n"
"\t6. transformer_operation_tracing\n"
"\t7. MLP_json_dump\n"
"\t8. transformer_json_dump\n"
"\t9. activation_checkpointing\n"
"e.g. you want to try the MLPModule sharding display example, please input 'MLP_distributed_sharding_display'\n"
)
parser.add_argument("-e", "--example", help=example_prompt, required=True)
example = parser.parse_args().example
run_example(world_size, rank, example)
|