1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
|
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import functools
import itertools
import logging
import unittest
from collections import defaultdict
from unittest import mock
import torch
import torch._dynamo.testing
import torch.nn.functional as F
from torch import nn
from torch._dynamo.utils import counters
from torch._inductor import comms
from torch._inductor.utils import is_fallback_op, run_and_get_code
from torch.distributed._tensor import init_device_mesh
from torch.distributed.fsdp import (
fully_shard,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
from torch.testing import FileCheck
from torch.testing._internal.common_distributed import (
at_least_x_gpu,
skip_if_lt_x_gpu,
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from torch.testing._internal.inductor_utils import HAS_GPU
log = logging.getLogger(__name__)
def _count_op_in_graph(graph, op):
return sum(1 for node in graph.nodes if node.target is op)
def _is_fallback_op_in_snodes(snodes, op):
return any(is_fallback_op(snode.node, op) for snode in snodes)
orig_F_scaled_dot_product_attention = F.scaled_dot_product_attention
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28 * 28, 1024, device="cuda"),
torch.nn.Linear(1024, 1024, device="cuda"),
torch.nn.Linear(1024, 4096, device="cuda"),
)
def forward(self, x):
return self.encoder(x)
class TestFullyShardCompileCompute(FSDPTest):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_disable_compiling_hooks(self):
self.run_subtests(
{
"skip_fsdp_hooks": [False, True],
},
self._test_disable_compiling_hooks,
)
def _test_disable_compiling_hooks(
self,
skip_fsdp_hooks: bool,
):
torch._dynamo.reset()
trace_rules_check_count = 0
HOOKS_FILE_NAME = "torch/distributed/fsdp/_fully_shard/_fsdp_state.py"
HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
def patched_trace_rules_check(*args, **kwargs):
nonlocal trace_rules_check_count
f_code = args[0]
if (
hasattr(f_code, "co_filename")
and f_code.co_filename.endswith(HOOKS_FILE_NAME)
and f_code.co_name != HOOK_WRAPPER_NAME
):
trace_rules_check_count += 1
return orig_trace_rules_check(*args, **kwargs)
original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
orig_trace_rules_check = torch._dynamo.trace_rules.check
torch.distributed.barrier()
torch._dynamo.config.skip_fsdp_hooks = skip_fsdp_hooks
torch._dynamo.trace_rules.check = patched_trace_rules_check
model = MLP(4)
fully_shard(model)
model.compile()
model(torch.randn((4, 4), device="cuda"))
torch.distributed.barrier()
torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
torch._dynamo.trace_rules.check = orig_trace_rules_check
if skip_fsdp_hooks:
self.assertEqual(trace_rules_check_count, 0)
else:
self.assertTrue(trace_rules_check_count > 0)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
class TestFullyShardCompile(FSDPTest):
fake_pg = not at_least_x_gpu(2)
# This method is an override of the base class.
# Tests in this class requires bf16 support, so SM arch must be 80 or
# higher.
def skipTestForOldSm(self):
# Assumption: This test class is only run on GPU. See `HAS_GPU` check at
# the top of the class.
device = torch.device("cuda", self.rank % torch.cuda.device_count())
if not sm_is_or_higher_than(device, 8, 0):
self.skipTest("bf16 requires sm >= 8.0")
@skipIfRocm
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
param_group = FSDPParamGroup(
[], # params: List[nn.Parameter],
(torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...],
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
torch.device("cuda"), # device: torch.device,
None, # shard_placement_fn: Optional[Callable],
None, # mp_policy: MixedPrecisionPolicy,
None, # offload_policy: OffloadPolicy,
)
def f(x):
param_group._training_state = TrainingState.IDLE
with param_group.use_training_state(TrainingState.FORWARD):
if param_group._training_state == TrainingState.FORWARD:
return x + 1
else:
return x
inp = torch.zeros(1)
self.assertEqual(param_group._training_state, TrainingState.IDLE)
eager_out = f(inp)
self.assertEqual(param_group._training_state, TrainingState.IDLE)
self.assertEqual(eager_out, inp + 1)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_out = torch.compile(f, backend=cnt, fullgraph=True)(inp)
self.assertEqual(param_group._training_state, TrainingState.IDLE)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)
@skipIfRocm
def test_trace_fsdp_copy_(self):
@torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
torch.add(x, 1, out=out)
def f(x):
buf = torch.zeros(2)
buf_view = buf.view(-1)
torch.ops.mylib.add_one_out(x, out=buf_view)
buf_view2 = buf.view(-1)
torch.ops.fsdp.copy_(x, buf_view2)
ref_x = torch.zeros(2)
x = copy.deepcopy(ref_x)
f(ref_x)
torch.compile(f, backend="aot_eager")(x)
self.assertEqual(x, ref_x)
def _get_resize_count_in_fx_graph(self, graph: torch.fx.Graph):
resize_count = 0
for node in graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
):
resize_count += 1
return resize_count
def _assert_no_aliased_unsharded_params_in_graph_inputs(
self, model, graph: torch.fx.Graph
) -> None:
# FSDP2 unsharded params are mutated in the graph without going through functionalization.
# Therefore, we want to make sure they don't have aliases in the graph inputs, to make it easier
# for us to do the replacement of unsharded params with the all-gathered temporary buffer directly
# in downstream users in the graph.
storage_id_to_graph_inputs = defaultdict(list)
unsharded_param_graph_inputs = set()
for node in graph.nodes:
if (
node.op == "call_function"
and node.target
in [
torch.ops.inductor.resize_storage_bytes_.default,
torch.ops.fsdp.copy_.default,
]
and node.args[0].op == "placeholder"
):
unsharded_param_graph_inputs.add(node.args[0])
assert len(unsharded_param_graph_inputs) > 0
assert len(unsharded_param_graph_inputs) == len(
list(model.parameters())
), """\
Expected all model parameters to be wrapped by FSDP2 and
have their unsharded version as graph input, but it's not true!
"""
no_aliased_unsharded_params_in_graph_inputs = True
err_msg = ""
for aliased_graph_inputs in storage_id_to_graph_inputs.values():
if len(aliased_graph_inputs) > 1 and any(
x in unsharded_param_graph_inputs for x in aliased_graph_inputs
):
no_aliased_unsharded_params_in_graph_inputs = False
err_msg += f"""\n
Found aliased unsharded param in graph inputs: {aliased_graph_inputs},
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
"""
self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg)
def _remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks(
self, model, *, bwd_resize_count_before_pass=None, fwd_fullgraph=False
):
def _run_with_checks(graph, orig_fn):
if (
self._is_bwd_fx_graph(graph)
and bwd_resize_count_before_pass is not None
):
self.assertEqual(
bwd_resize_count_before_pass,
self._get_resize_count_in_fx_graph(graph),
)
self._assert_no_aliased_unsharded_params_in_graph_inputs(model, graph)
orig_fn(graph)
if fwd_fullgraph:
return mock.patch.object(
comms,
"remove_fsdp2_unsharded_param_graph_input_usage",
functools.partial(
_run_with_checks,
orig_fn=comms.remove_fsdp2_unsharded_param_graph_input_usage,
),
)
else:
return contextlib.nullcontext()
def _check_fsdp_copy_and_resize_ops_count_in_graph(
self,
graph,
*,
fwd_copy_count,
fwd_resize_count,
bwd_copy_count,
bwd_resize_count,
):
def _check_count(copy_count, resize_count):
actual_copy_count = _count_op_in_graph(graph, torch.ops.fsdp.copy_.default)
self.assertEqual(
actual_copy_count,
copy_count,
f"Unexpected number of `fsdp.copy_` ops (expected {copy_count}, got {actual_copy_count}) in graph: {graph}",
)
actual_resize_count = _count_op_in_graph(
graph, torch.ops.inductor.resize_storage_bytes_.default
)
self.assertEqual(
actual_resize_count,
resize_count,
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
)
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
else:
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph
def _reinplace_all_gather_with_optional_checks(self, fwd_fullgraph):
def _run_with_checks(graph, orig_fn):
self.assertGreater(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
),
0,
)
orig_fn(graph)
self.assertEqual(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor.default
),
0,
)
self.assertGreater(
_count_op_in_graph(
graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default
),
0,
)
if fwd_fullgraph:
return mock.patch.object(
comms,
"reinplace_fsdp_all_gather",
functools.partial(
_run_with_checks,
orig_fn=comms.reinplace_fsdp_all_gather,
),
)
else:
return contextlib.nullcontext()
def _is_fwd_graph(self, snodes):
ag_copy_in_snode = None
for snode in snodes:
if is_fallback_op(snode.node, torch.ops.fsdp.all_gather_copy_in.default):
ag_copy_in_snode = snode
break
self.assertTrue(ag_copy_in_snode is not None)
if any(
dep.name.startswith("primals_")
for dep in ag_copy_in_snode.read_writes.reads
):
return True
else:
return False
def _is_bwd_fx_graph(self, graph):
for node in graph.nodes:
if (
node.op == "call_function"
and node.target
== torch.ops._c10d_functional.reduce_scatter_tensor.default
):
return True
return False
def _maybe_run_decide_global_ordering_of_comms_with_checks(self, fwd_fullgraph):
def _check_fsdp_ops_in_snodes(snodes, is_fwd_graph, expect=True):
assert_method = self.assertTrue if expect else self.assertFalse
common_ops = {
torch.ops.fsdp.all_gather_copy_in.default,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
torch.ops.fsdp.split_with_sizes_copy.default,
}
bwd_only_ops = {
torch.ops.fsdp.chunk_cat.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
}
for op in common_ops:
assert_method(
_is_fallback_op_in_snodes(
snodes,
op,
),
msg=f"{op}",
)
if not is_fwd_graph:
for op in bwd_only_ops:
assert_method(
_is_fallback_op_in_snodes(
snodes,
op,
),
msg=f"{op}",
)
def _decide_global_ordering_of_comms_with_checks(
snodes, name_to_buf, name_to_fused_node, orig_fn
):
is_fwd_graph = self._is_fwd_graph(snodes)
_check_fsdp_ops_in_snodes(snodes, is_fwd_graph, expect=True)
new_snodes = orig_fn(snodes, name_to_buf, name_to_fused_node)
_check_fsdp_ops_in_snodes(new_snodes, is_fwd_graph, expect=False)
return new_snodes
if fwd_fullgraph:
return mock.patch.object(
comms,
"decide_global_ordering_of_comms",
functools.partial(
_decide_global_ordering_of_comms_with_checks,
orig_fn=comms.decide_global_ordering_of_comms,
),
)
else:
return contextlib.nullcontext()
def inductor_code_check_no_compute_op(self, file_check):
return (
file_check.check_not(" = aten.")
.check_not(" = extern_kernels.")
.check_not(" = triton_")
.check_not(" = torch.ops.")
.check_not(" = inductor_ops.")
.check_not(" aten.")
.check_not(" extern_kernels.")
.check_not(" triton_")
.check_not(" torch.ops.")
.check_not(" inductor_ops.")
)
def inductor_code_check_fsdp_all_gather(
self,
file_check,
overlapped_compute_op_str,
last_all_gather=False,
):
file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.")
file_check = self.inductor_code_check_no_compute_op(file_check)
file_check = file_check.check(
"torch.ops._c10d_functional.all_gather_into_tensor_out."
)
# Checks that AGWait is delayed, making the AG overlap with some compute op.
if overlapped_compute_op_str is not None:
file_check = file_check.check(f"{overlapped_compute_op_str}")
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
file_check = self.inductor_code_check_no_compute_op(file_check)
file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.")
if not last_all_gather:
# Checks that there is no compute op between this AGWait and next AG.
file_check = self.inductor_code_check_no_compute_op(file_check)
return file_check
def inductor_code_check_fsdp_reduce_scatter(
self, file_check, overlapped_compute_op_str
):
file_check = file_check.check("torch.ops.fsdp.chunk_cat.")
file_check = self.inductor_code_check_no_compute_op(file_check)
file_check = file_check.check(
"torch.ops._c10d_functional.reduce_scatter_tensor."
)
# Checks that RSWait is delayed, making the RS overlap with some compute op.
if overlapped_compute_op_str is not None:
file_check = file_check.check(f"{overlapped_compute_op_str}")
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
return file_check
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_compiled_autograd_ctx(self):
self.skipTestForOldSm()
with torch._dynamo.config.patch(
skip_fsdp_hooks=False,
), torch._functorch.config.patch(
recompute_views=True,
):
inputs = torch.randn(8, 8)
model = torch.nn.Linear(8, 8)
fully_shard(model)
model_compiled = torch.compile(model, backend="inductor")
for i in range(10):
torch.compiler.set_stance(
"force_eager" if i < 1 else "default"
) # eager warmup for 1 iteration
with torch._dynamo.compiled_autograd._enable(
torch.compile(backend="inductor", fullgraph=True)
):
out = model_compiled(inputs)
out.sum().backward()
def _test_traceable_fsdp(
self,
model_init_fn,
input_creation_fn,
backend,
fwd_fullgraph,
*,
bwd_resize_count_before_inductor=None,
):
def fwd_bwd(model, inp):
out = model(inp)
loss = out.sum()
loss.backward()
return loss
def run_iters(
fwd_bwd_func,
optim,
n_iter=10,
compiled_autograd_backend=None,
):
torch.manual_seed(42)
losses = []
for i in range(n_iter):
# eager warmup for 1 iteration, so that all FSDP2 lazy-initialization is done in eager
torch.compiler.set_stance("force_eager" if i < 1 else "default")
inp = input_creation_fn()
loss = fwd_bwd_func(inp)
losses.append(loss.item())
optim.step()
optim.zero_grad(set_to_none=True)
return losses
def test_compiled():
model, optim = model_init_fn()
fwd_bwd_fn = functools.partial(fwd_bwd, model)
counters.clear()
with self._remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks(
model,
bwd_resize_count_before_pass=bwd_resize_count_before_inductor,
fwd_fullgraph=fwd_fullgraph,
):
fwd_bwd_fn_compiled = torch.compile(
fwd_bwd_fn,
backend=backend,
# NOTE: we can't set `fullgraph=True` here because we will always graph-break
# on `loss.backward()` call in `fwd_bwd()`. This is okay as long as
# it's the only graph-break in forward pass.
fullgraph=False,
)
res = run_iters(
fwd_bwd_fn_compiled,
optim,
compiled_autograd_backend=backend,
)
if fwd_fullgraph:
self.assertEqual(len(counters["graph_break"]), 1)
self.assertIn("Tensor.backward", counters["graph_break"])
else:
self.assertGreater(len(counters["graph_break"]), 1)
return res
def test_eager():
model, optim = model_init_fn()
fwd_bwd_fn = functools.partial(fwd_bwd, model)
res = run_iters(fwd_bwd_fn, optim)
return res
torch._dynamo.reset()
torch._dynamo.compiled_autograd.reset()
with torch._dynamo.config.patch(
compiled_autograd=True,
compiled_autograd_kwargs_override={
"fullgraph": True,
},
inline_inbuilt_nn_modules=True,
skip_fsdp_hooks=False,
), torch._functorch.config.patch(
enable_autograd_cache=False,
recompute_views=True,
), torch._inductor.config.patch(
force_disable_caches=True,
reorder_for_compute_comm_overlap=True,
reorder_for_compute_comm_overlap_passes=[
"sink_waits",
"raise_comms",
"reorder_compute_for_overlap",
],
):
losses_compiled = test_compiled()
losses_eager = test_eager()
if not self.fake_pg:
for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
self.assertTrue(
torch.allclose(
torch.tensor(loss_compiled),
torch.tensor(loss_eager),
rtol=1e-5,
atol=1e-8,
),
f"{loss_compiled} vs {loss_eager}",
)
def _create_simple_mlp_factory_fns(self):
hidden_dim = 16
def model_init_fn():
torch.manual_seed(self.rank)
fsdp_config = {}
model = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
)
fully_shard(model, reshard_after_forward=True, **fsdp_config)
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
return model, optim
def input_creation_fn():
torch.manual_seed(self.rank)
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
return inp
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "aot_eager", fwd_fullgraph=True
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(),
"aot_eager_decomp_partition",
fwd_fullgraph=True,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_simple_mlp_fullgraph_backend_inductor(self):
self.skipTestForOldSm()
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "inductor", fwd_fullgraph=True
)
def _create_nested_fully_shard_factory_fns(self, fwd_fullgraph):
hidden_dim = 16
class TestSubmodule(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.param1 = nn.Parameter(
torch.zeros(
hidden_dim, hidden_dim, dtype=torch.float, device="cuda"
)
)
self.param2 = nn.Parameter(
torch.zeros(hidden_dim, dtype=torch.float, device="cuda")
)
def forward(self, x):
ret = torch.matmul(x, self.param1)
if not fwd_fullgraph:
torch._dynamo.graph_break()
ret = ret * self.param2
ret = torch.relu(ret)
return ret
class TestModule(nn.Module):
def __init__(self, n_layers):
super().__init__()
self.layers = torch.nn.ModuleList()
for layer_id in range(n_layers):
self.layers.append(TestSubmodule(hidden_dim))
def forward(self, x):
# Intentionally reusing all layers a few times,
# to test "multiple all-gathers for the same parameter" case.
# Case 1: rerun the same layer twice
for layer_id in range(len(self.layers)):
for _ in range(2):
x = self.layers[layer_id](x)
# Case 2: iterate through all layers twice
for layer in self.layers:
x = layer(x)
for layer in self.layers:
x = layer(x)
return x
def model_init_fn():
torch.manual_seed(self.rank)
fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,))
model = TestModule(n_layers=3)
for layer_id, mod in enumerate(model.layers):
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model = fully_shard(
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
)
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
return model, optim
def input_creation_fn():
torch.manual_seed(self.rank)
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
return inp
return model_init_fn, input_creation_fn
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager(self):
# TODO: fix fwd_fullgraph=False case
for fwd_fullgraph in [True]:
self._test_traceable_fsdp(
*self._create_nested_fully_shard_factory_fns(
fwd_fullgraph=fwd_fullgraph
),
"aot_eager",
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
# TODO: fix fwd_fullgraph=False case
for fwd_fullgraph in [True]:
self._test_traceable_fsdp(
*self._create_nested_fully_shard_factory_fns(
fwd_fullgraph=fwd_fullgraph
),
"aot_eager_decomp_partition",
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
self.skipTestForOldSm()
for fwd_fullgraph in [True]:
with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph
), self._maybe_run_decide_global_ordering_of_comms_with_checks(
fwd_fullgraph
), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
fwd_copy_count=0,
fwd_resize_count=0,
bwd_copy_count=0,
bwd_resize_count=0,
)
if fwd_fullgraph
else None
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
*self._create_nested_fully_shard_factory_fns(
fwd_fullgraph=fwd_fullgraph
),
"inductor",
fwd_fullgraph=fwd_fullgraph,
bwd_resize_count_before_inductor=48 if fwd_fullgraph else None,
)
)
if fwd_fullgraph:
self.assertEqual(
len(triton_codes),
2,
"Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph",
)
fwd_code = triton_codes[0]
file_check = FileCheck().check("def call(args):")
for fwd_ag_block_info in [
dict(overlapped_compute_op_str=None),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
last_all_gather=True,
),
]:
file_check = self.inductor_code_check_fsdp_all_gather(
file_check, **fwd_ag_block_info
)
file_check.run(fwd_code)
bwd_code = triton_codes[1]
file_check = FileCheck().check("def call(args):")
for bwd_ag_block_info in [
dict(overlapped_compute_op_str=None),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="extern_kernels.mm(",
last_all_gather=True,
),
]:
file_check = self.inductor_code_check_fsdp_all_gather(
file_check, **bwd_ag_block_info
)
for bwd_rs_block_info in [
dict(overlapped_compute_op_str="extern_kernels.addmm("),
dict(
overlapped_compute_op_str=None
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
dict(overlapped_compute_op_str=None),
]:
file_check = self.inductor_code_check_fsdp_reduce_scatter(
file_check, **bwd_rs_block_info
)
file_check.run(bwd_code)
@unittest.skip("TODO: fix fwd_fullgraph=False case")
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
self.skipTestForOldSm()
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
*self._create_nested_fully_shard_factory_fns(fwd_fullgraph=False),
"inductor",
fwd_fullgraph=False,
)
)
# TODO: when fwd_fullgraph=False and there is graph break in FWD graph,
# there are several recompiles, need to figure out why.
self.assertGreater(
len(triton_codes),
2,
"Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph",
)
def _create_transformer_factory_fns(
self, all_requires_grad, *, activation_checkpoint=False
):
seq_len = 16
vocab_size = 8
n_layers = 3
def model_init_fn():
torch.manual_seed(self.rank)
fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,))
model_args = ModelArgs(
vocab_size=vocab_size,
n_layers=n_layers,
checkpoint_activations=activation_checkpoint,
)
model = Transformer(model_args)
if not all_requires_grad:
requires_grad_params = ["attention.wq", "attention.wv"]
requires_grad_param_count = 0
for k, v in model.named_parameters():
for substring in requires_grad_params:
if substring in k:
v.requires_grad_(True)
requires_grad_param_count += 1
else:
v.requires_grad_(False)
assert requires_grad_param_count == n_layers * len(requires_grad_params)
for layer_id, mod in enumerate(model.layers):
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model = fully_shard(
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
)
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
return model, optim
def input_creation_fn():
torch.manual_seed(self.rank)
inp = torch.randint(
0, vocab_size, (2, seq_len), device="cuda", requires_grad=False
)
return inp
return model_init_fn, input_creation_fn
def _maybe_add_graph_break_to_sdpa(self, fwd_fullgraph):
def _sdpa_with_graph_break(*args, **kwargs):
torch._dynamo.graph_break()
return orig_F_scaled_dot_product_attention(*args, **kwargs)
if not fwd_fullgraph:
return mock.patch.object(
F,
"scaled_dot_product_attention",
_sdpa_with_graph_break,
)
else:
return contextlib.nullcontext()
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_transformer_backend_aot_eager(self):
# TODO: fix fwd_fullgraph=False case
for fwd_fullgraph, all_requires_grad in itertools.product(
[True], [True, False]
):
with self._maybe_add_graph_break_to_sdpa(
fwd_fullgraph
), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad
),
"aot_eager",
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_aot_eager_decomp_partition(self):
# TODO: fix fwd_fullgraph=False case
for fwd_fullgraph, all_requires_grad in itertools.product(
[True], [True, False]
):
with self._maybe_add_graph_break_to_sdpa(fwd_fullgraph):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad
),
"aot_eager_decomp_partition",
fwd_fullgraph=fwd_fullgraph,
)
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_inductor_fullgraph_True(self):
self.skipTestForOldSm()
for (
fwd_fullgraph,
all_requires_grad,
activation_checkpoint,
) in itertools.product([True], [True, False], [True, False]):
log.warning(
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
)
with self._reinplace_all_gather_with_optional_checks(
fwd_fullgraph
), self._maybe_run_decide_global_ordering_of_comms_with_checks(
fwd_fullgraph
), torch._inductor.config.patch(
post_grad_custom_post_pass=functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph,
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
# the parameters would be freed and all-gathered immediately. Hence we still have
# their resize and copy ops in the graph.
fwd_copy_count=4,
fwd_resize_count=4,
bwd_copy_count=0,
bwd_resize_count=4,
)
if fwd_fullgraph
else None
):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad,
activation_checkpoint=activation_checkpoint,
),
"inductor",
fwd_fullgraph=fwd_fullgraph,
bwd_resize_count_before_inductor=76 if fwd_fullgraph else None,
)
)
if fwd_fullgraph:
self.assertEqual(
len(triton_codes),
2,
"Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph",
)
fwd_code = triton_codes[0]
file_check = FileCheck().check("def call(args):")
for fwd_ag_block_info in [
dict(
overlapped_compute_op_str="triton_"
if all_requires_grad
else None,
),
dict(
overlapped_compute_op_str="aten.native_dropout.",
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.",
last_all_gather=True,
),
]:
file_check = self.inductor_code_check_fsdp_all_gather(
file_check, **fwd_ag_block_info
)
file_check.run(fwd_code)
bwd_code = triton_codes[1]
file_check = FileCheck().check("def call(args):")
for bwd_ag_block_info in [
dict(
overlapped_compute_op_str="extern_kernels.mm(",
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
),
dict(
overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.",
last_all_gather=True,
),
]:
if bwd_ag_block_info is not None:
file_check = self.inductor_code_check_fsdp_all_gather(
file_check, **bwd_ag_block_info
)
for bwd_rs_block_info in [
dict(overlapped_compute_op_str="extern_kernels.mm(")
if all_requires_grad
else None,
dict(
overlapped_compute_op_str=None
), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None
dict(overlapped_compute_op_str=None),
dict(overlapped_compute_op_str=None) if all_requires_grad else None,
]:
if bwd_rs_block_info is not None:
file_check = self.inductor_code_check_fsdp_reduce_scatter(
file_check, **bwd_rs_block_info
)
file_check.run(bwd_code)
@unittest.skip("TODO: fix fwd_fullgraph=False case")
@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO: native_dropout causes CUDA IMA error, need to figure out why
@torch._inductor.config.patch(fallback_random=True)
def test_transformer_backend_inductor_fullgraph_False(self):
self.skipTestForOldSm()
fwd_fullgraph = False
# TODO: fix numerical issue in activation_checkpoint=True case
for all_requires_grad, activation_checkpoint in itertools.product(
[True, False], [False]
):
log.warning(
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
)
with self._maybe_add_graph_break_to_sdpa(fwd_fullgraph):
_, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp(
*self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad,
activation_checkpoint=activation_checkpoint,
),
"inductor",
fwd_fullgraph=fwd_fullgraph,
)
)
# TODO: when fwd_fullgraph=False and there is graph break in FWD graph,
# there are several recompiles, need to figure out why.
self.assertGreater(
len(triton_codes),
2,
"Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph",
)
def test_dynamo_recompiles_on_fsdp_layers(self):
m = Mod()
for name, child in m.encoder.named_children():
if isinstance(child, torch.nn.Linear):
new_child = torch.compile(child)
setattr(m.encoder, name, new_child)
m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)
inp = torch.randn(32, 784, device="cuda")
out = m(inp)
if __name__ == "__main__":
run_tests()
|