1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994
|
# Owner(s): ["module: dynamo"]
import contextlib
import copy
import functools
import random
import unittest
from contextlib import contextmanager
from datetime import timedelta
from io import StringIO
from typing import List
from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
import torch.distributed as dist
import torch.optim as optim
from torch import nn
from torch._C import FileCheck
from torch._dynamo import config
from torch._dynamo.backends.distributed import DDPOptimizer
from torch._dynamo.comptime import comptime
from torch._dynamo.testing import collect_results
from torch._dynamo.utils import same
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.distributed._functional_collectives import _maybe_wrap_tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
DynamoDistributedMultiProcTestCase,
DynamoDistributedSingleProcTestCase,
import_transformers_or_skip,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import requires_cuda
from torch.testing._internal.inductor_utils import HAS_GPU
def reset_rng_state():
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
super().__init__()
self.ctx_manager = ctx_manager
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
def forward(self, inputs):
if self.ctx_manager is not None:
with self.ctx_manager():
return self.net(inputs)
else:
return self.net(inputs)
def get_model(
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = ToyModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
class MutatingModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
super().__init__()
self.ctx_manager = ctx_manager
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
self.state = 1
def forward(self, inputs):
self.state = 2
return self.net(inputs) * self.state
def get_mutating_model(
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = MutatingModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
class ForcedGetAttrMod(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
self.__dict__["forced_linear"] = torch.nn.Linear(1, 1).to(device=device)
self.counter = 0
def forward(self, x):
self.counter += 1
return x * self.linear(x) * self.forced_linear.weight
def get_forced_getattr_module(device):
mod = ForcedGetAttrMod(device).to(device=device)
x = torch.randn(1, 1, device=device)
return mod, x, mod(x)
class ToyInnerModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = [nn.Linear(100, 100), nn.Linear(100, 100)]
self.layers = nn.Sequential(*self.layers)
def forward(self, inputs):
return self.layers(inputs)
class ToyOuterModel(nn.Module):
def __init__(self, device):
super().__init__()
self.layers = [ToyInnerModel().to(device) for _ in range(2)]
self.layers = nn.Sequential(
self.layers[0], nn.ReLU(), self.layers[1], nn.ReLU()
)
def forward(self, inputs):
return self.layers(inputs)
def get_toy_model_for_activation_checkpointing(device):
m = ToyOuterModel(device).to(device)
m.apply(init_weights)
inputs = torch.rand(100, 100).to(device)
return m, inputs
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def apply_fsdp_with_checkpointing(
model, wrap_policy, checkpoint_policy, use_activation_checkpointing=True
):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
model = FSDP(
copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
)
if use_activation_checkpointing:
checkpoint_wrapper_fn = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
check_fn=checkpoint_policy,
)
return model
def get_custom_model(device):
class MyCustomLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
tmp = torch.mm(x, self.weight.t())
# test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
# and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
return tmp + torch.where(tmp < 0.5, 0.3, 0.6)
class MyLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return self.linear(x)
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwich the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
def forward(self, x, y):
# test special case where the 0th bucket (layers close to graph input) is at capacity, which would
# trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket.
# optimize this case by fusing that 'empty bucket' back together with the previous full one
return self.seq(x + y)
m = MyModule().to(device)
m.apply(init_weights)
inputs = torch.rand((512, 512)).to(device)
# test duplicated inputs
inputs = (inputs, inputs)
correct_outputs = m(*inputs)
return m, inputs, correct_outputs
def get_hf_bert(rank):
# Note: use @import_transformers_or_skip on your test case if you use this
# in a multiprocessing test
try:
from transformers import AutoModelForMaskedLM, BertConfig
except ImportError as e:
raise unittest.SkipTest("Unable to import transformers") from e
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
model = AutoModelForMaskedLM.from_config(config).to(device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(
device
)
inputs = {"input_ids": input_ids, "labels": decoder_ids}
model.train()
return model, inputs
class CheckSplitsCompiler:
def __init__(self) -> None:
self.compiler_called = 0
def compile_fn(self, gm, example_inputs):
self.compiler_called += 1
return gm
# This simulates DDP, but it doesn't actually do any process communication;
# it just has enough properties so that the dynamo distributed optimization is
# able to optimize. Feel free to simulate more properties as necessary. The
# other important thing is patching _active_ddp_module, which is what actually
# triggers DDP optimization
class FakeDDP(nn.Module):
def __init__(self, module, bucket_cap_mb=25):
super().__init__()
self.module = module
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
@contextmanager
def _inside_ddp_forward(self):
DDP._active_ddp_module = self
try:
yield
finally:
DDP._active_ddp_module = None
def forward(self, *inputs, **kwargs):
if not DDP._active_ddp_module:
with self._inside_ddp_forward():
return self.module.forward(*inputs, **kwargs)
else:
return self.module.forward(*inputs, **kwargs)
def run_hf_bert_ddp(self, model, inputs, backend):
reset_rng_state()
correct_outputs = model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = torch.compile(model, backend=backend)
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(
model, correct_outputs.logits, correct_loss, inputs_flat
)
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
self.assertTrue(same(correct_results, opt_results))
class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
model, inputs = get_hf_bert(0)
model = FakeDDP(model)
run_hf_bert_ddp(self, model, inputs, "inductor")
@patch.object(config, "optimize_ddp", True)
def test_hf_bert_ddp_aot_eager(self):
model, inputs = get_hf_bert(0)
model = FakeDDP(model)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
@patch.object(config, "optimize_ddp", True)
def test_issue90375(self):
class Model(nn.Module):
def forward(self):
return torch.randn(3) * torch.randn(3)
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(model, backend="aot_eager")
opt_model()
@patch.object(config, "optimize_ddp", True)
def test_symbol_splitting(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
x = torch.cat([x, x])
y = x @ self.weight1
z = x + y @ self.weight2
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512))
@patch.object(config, "optimize_ddp", True)
def test_ddp_optimizer_inductor_strides_dont_specialize(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc_0 = nn.Linear(768, 768)
self.fc_1 = nn.Linear(768, 768)
def forward(self, x):
x = self.fc_0(x)
x = self.fc_1(x)
return x
model = Model()
model = FakeDDP(model)
inp = torch.randn((16, 18, 768))
inp2 = torch.randn((16, 20, 768))
torch._dynamo.mark_dynamic(inp, 1)
torch._dynamo.mark_dynamic(inp2, 1)
torch._dynamo.utils.clear_compilation_metrics()
torch._dynamo.reset()
try:
DDP._active_ddp_module = model
opt_model = torch.compile(model)
self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
opt_model(inp)
compile_count_before = len(torch._dynamo.utils.get_compilation_metrics())
opt_model(inp2)
compile_count_after = len(torch._dynamo.utils.get_compilation_metrics())
# no recompiles
self.assertEqual(compile_count_before, compile_count_after)
finally:
DDP._active_ddp_module = None
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_direct(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
u0, u1 = y.tolist()
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * u0
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_indirect(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
u0, u1 = y.tolist()
a = torch.ones(u0)
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * a.sum()
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_torture_multi(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
self.weight3 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
# partition one (contains the u0 def)
u0, u1 = y.tolist()
x = torch.cat([x, x])
y1 = x @ self.weight1
# partition two (contains the variable)
y2 = y1 @ self.weight2
a = torch.ones(u0)
# partition three
z = (x + y2 @ self.weight3) * a.sum()
return z
model = Model()
model = FakeDDP(model, bucket_cap_mb=1)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
def test_unbacked_symbol_splitting_no_binding(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
nz = y.nonzero()
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * (nz + 1).sum()
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
@patch.object(config, "optimize_ddp", True)
def test_call_method_forward(self):
class Model(nn.Module):
def __init__(
self,
):
super().__init__()
layers = []
for l in range(2):
layer = nn.ModuleList(
[
nn.LayerNorm(96),
nn.MultiheadAttention(
embed_dim=96, num_heads=4, batch_first=True
),
]
)
layers.append(layer)
self.layers = nn.ModuleList(layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [Batch, Freq, Time, Feature]
B, F, T, H = x.shape
for m in self.layers:
x = x.reshape(B * F, T, H)
x = m[0](x)
x, attn = m[1].forward(x, x, x)
x = x.reshape(B, F, T, H)
return x
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(model)
opt_model(torch.randn(2, 129, 100, 96))
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
# single process version; if it's just a problem in the Dynamo distributed
# optimizer, you should be able to repro it single process!
@requires_nccl()
class TestMultiProc(DynamoDistributedMultiProcTestCase):
"""
Note: MultiProcTestCase spawns processes per test and is slow.
Prefer MultiThreadedTestCase for most tests. Perhaps use this one
sparingly for integration tests.
"""
@skip_if_lt_x_gpu(2)
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
def test_ddp_baseline_aot_eager_multiprocess(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m = DDP(m, device_ids=[self.rank])
m = torch.compile(m, backend="aot_eager")
outputs = m(inputs)
self.assertTrue(same(correct_outputs, outputs))
def _test_hf_bert_ddp_inductor(self, static_graph):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model, static_graph=static_graph)
run_hf_bert_ddp(self, model, inputs, "inductor")
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
self._test_hf_bert_ddp_inductor(static_graph=False)
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor_static_graph(self):
self._test_hf_bert_ddp_inductor(static_graph=True)
def _test_hf_bert_aot_eager(self, static_graph):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model, static_graph=static_graph)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
def test_hf_bert_ddp_aot_eager(self):
self._test_hf_bert_aot_eager(static_graph=False)
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
def test_hf_bert_ddp_aot_eager_static_graph(self):
self._test_hf_bert_aot_eager(static_graph=True)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
def test_ddp_activation_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(64, 32)
self.fc2 = torch.nn.Linear(32, 16)
self.fc3 = torch.nn.Linear(16, 8)
def forward(self, inp):
return self.fc3(self.fc2(self.fc1(inp)))
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
model = MyModel().to(device="cuda")
# Activation checkpointing for Linear layers.
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance( # noqa: E731
submodule, torch.nn.Linear
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
model = DDP(model)
x = torch.randn(10, 64).cuda()
correct_outputs = model(x)
opt_model = torch.compile(model)
outputs = opt_model(x)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_aot_eager(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
),
use_orig_params=True,
)
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_setattr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015
self.assertTrue("setattr" not in first_graph_break)
@config.patch(inline_inbuilt_nn_modules=False)
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_unspecialized_forced_getattr_no_inline(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_unspecialized_forced_getattr_inline(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_inductor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="inductor")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
),
use_orig_params=True,
)
fsdp_m = torch.compile(fsdp_m, backend="inductor")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_activation_checkpointing(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_toy_model_for_activation_checkpointing(
f"cuda:{self.rank}"
)
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner)
correct_outputs = model(inputs)
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
opt_model = torch.compile(model, backend=cnt)
outputs = opt_model(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Each FSDP module is a separate graph
self.assertEqual(cnt.frame_count, 2)
self.assertTrue(
find_first_node(cnt.graphs[0], tag_activation_checkpoint) is not None
)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
@config.patch(enable_compiler_collectives=True)
@unittest.skipIf(
PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
"Inaccurate results with fused SDPA kernels",
)
def test_hf_bert_fsdp(self):
def apply_fsdp(model, wrap_policy):
model = FSDP(
copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
)
return model
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
for wrap_policy, test_instance in (
(None, "FSDP without recursive wrapping"),
):
print(f"Running hf_bert test for {test_instance}")
model, inputs = get_hf_bert(self.rank)
reset_rng_state()
eager_model = apply_fsdp(model, wrap_policy)
correct_outputs = eager_model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = apply_fsdp(model, wrap_policy)
opt_model = torch.compile(opt_model, backend="inductor")
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(
eager_model, correct_outputs.logits, correct_loss, inputs_flat
)
opt_results = collect_results(
opt_model, opt_outputs.logits, opt_loss, inputs_flat
)
self.assertTrue(same(correct_results, opt_results))
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
@config.patch(guard_nn_modules=True, enable_compiler_collectives=True)
def test_hf_bert_fsdp_activation_checkpointing(self):
from transformers.models.bert.modeling_bert import BertLayer
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
for wrap_policy, test_instance in (
(
functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer,)
),
"FSDP with recursive wrapping BertLayer instances",
),
):
print(
f"Running hf_bert_activation_checkpointing test for {test_instance}"
)
model, inputs = get_hf_bert(self.rank)
check_fn = lambda submodule: isinstance( # noqa: E731
submodule, BertLayer
)
reset_rng_state()
eager_model = apply_fsdp_with_checkpointing(
model, wrap_policy, check_fn
)
correct_outputs = eager_model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = apply_fsdp_with_checkpointing(model, wrap_policy, check_fn)
opt_model = torch.compile(opt_model, backend="inductor")
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(
eager_model, correct_outputs.logits, correct_loss, inputs_flat
)
opt_results = collect_results(
opt_model, opt_outputs.logits, opt_loss, inputs_flat
)
self.assertTrue(same(correct_results, opt_results))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_tensor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
class SimpleModel(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
torch._dynamo.utils.clear_compilation_metrics()
model = SimpleModel(10, 2).to(self.rank)
model.forward = torch.compile(model.forward)
ddp_model = DDP(model, device_ids=[self.rank])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
def B(s):
return [torch.randn(s, 10), torch.randint(0, 2, (s,))]
if self.rank == 0:
dataloader = [B(5), B(8), B(6)]
else:
dataloader = [B(6), B(6), B(3)]
for data, labels in dataloader:
data, labels = data.to(self.rank), labels.to(self.rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_scalar(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
# TODO: This should be possible to do inside the function, but
device = f"cuda:{self.rank}"
@torch.compile()
def f(x, y):
return x + torch.ones(y, device=device).sum()
if self.rank == 0:
dataloader = [3, 3, 7]
else:
dataloader = [3, 4, 9]
for data in dataloader:
f(torch.randn(5, device=self.rank), data)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_speculation_divergence(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x, y):
zx = x.shape
zy = y.shape
return x.sum() + y.sum()
if self.rank == 0:
dataloader = [4, 4]
else:
dataloader = [3, 4]
for data in dataloader:
f(
torch.randn(data, device=self.rank),
torch.randn(data, device=self.rank),
)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_graph_break_empty_graph_still_collective(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x, y):
z = y
print("woof")
zx = x.shape
zy = y.shape
return x.sum() + y.sum()
if self.rank == 0:
dataloader = [5, 5, 6]
else:
dataloader = [3, 4, 5]
for data in dataloader:
f(
torch.randn(data, device=self.rank),
torch.randn(data, device=self.rank),
)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_dim_mismatch(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x, y):
zx = x.shape
zy = y.shape
return x.sum() + y.sum()
if self.rank == 0:
dataloader = [[4, 2]]
else:
dataloader = [[3]]
for data in dataloader:
f(
torch.randn(data, device=self.rank),
torch.randn(data, device=self.rank),
)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_missing_source(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(rank, xs):
return xs[rank].sum()
xs = []
for _ in range(self.world_size):
xs.append(torch.randn(10, device=self.rank))
f(self.rank, xs)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_scalar_missing_source(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(rank, xs):
return torch.tensor(xs[rank], device=self.rank)
xs = []
for i in range(self.world_size):
xs.append(10 + i)
f(self.rank, xs)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_type_mismatch(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x):
if isinstance(x, int):
return torch.tensor(x, device=self.rank)
else:
return x.sum()
if self.rank == 0:
x = torch.randn(10, device=self.rank)
else:
x = 12
f(x)
# This deadlocks, I guess we don't support this
"""
if self.rank == 0:
x = torch.randn(12, device=self.rank)
else:
x = 10
f(x)
"""
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_get_pg_attr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
pg = dist.distributed_c10d._get_default_group()
device = f"cuda:{self.rank}"
@torch.compile(fullgraph=True)
def f(x):
if dist.distributed_c10d._rank_not_in_group(pg):
return x + 1
else:
return x - 1
x = torch.ones(4, device=device)
self.assertEqual(f(x), x - 1)
pg = dist.distributed_c10d.GroupMember.NON_GROUP_MEMBER
self.assertEqual(f(x), x + 1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", False)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
def test_asymmetric_compilation(self):
from torch._dynamo.comptime import comptime
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
device = f"cuda:{self.rank}"
pg = dist.distributed_c10d._get_default_group()
cnt = torch._dynamo.testing.CompileCounter()
sleep_time = 5
@torch.compile(backend=cnt)
def f(x):
if self.rank == 0:
comptime.sleep(sleep_time)
y = 2 * x
return y.sum()
backend = pg._get_backend(torch.device(device))
backend._set_default_timeout(timedelta(seconds=sleep_time - 2))
x = torch.ones(4, device=device)
# NCCL startup is lazy
w = pg.allreduce(x)
w.wait()
f(x)
if self.rank != 0:
# test fails with NCCL timeout without this line
dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
timedelta(seconds=sleep_time)
)
w = pg.allreduce(x)
w.wait()
torch.cuda.synchronize(device)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", True)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
def test_asymmetric_compilation_with_fx_cache(self):
from torch._dynamo.utils import counters
from torch._inductor.utils import fresh_inductor_cache
with fresh_inductor_cache(), _dynamo_dist_per_rank_init(
self.rank, self.world_size
):
torch._dynamo.utils.clear_compilation_metrics()
device = f"cuda:{self.rank}"
pg = dist.distributed_c10d._get_default_group()
@torch.compile
def f(x):
y = 2 * x
return y.sum()
backend = pg._get_backend(torch.device(device))
backend._set_default_timeout(timedelta(seconds=5))
counters.clear()
x = torch.ones(4, device=device)
f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
w = pg.allreduce(x)
w.wait()
torch.cuda.synchronize(device)
torch._dynamo.reset()
if self.rank == 0:
with fresh_inductor_cache():
f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
else:
f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
w = pg.allreduce(x)
w.wait()
torch.cuda.synchronize(device)
@requires_nccl()
@requires_cuda
class TestSingleProc(DynamoDistributedSingleProcTestCase):
"""
Test harness initializes dist process group.
Test simple things here since they are simpler to debug.
Use TestMultiProc for things that really need to run on multiple nodes
"""
def get_model(
self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = ToyModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(self.device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch.compile(ddp_m, backend="aot_eager")
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch.compile(ddp_m, backend="inductor")
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@patch.object(config, "optimize_ddp", True)
def test_graph_split(self):
assert config.optimize_ddp
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
check_splits_compiler = CheckSplitsCompiler()
@torch.compile(backend=check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
# ensure compatibility with dynamo explain
explain_out = torch._dynamo.explain(ddp_m)(inputs)
break_reasons = explain_out.break_reasons
self.assertEqual(len(break_reasons), 3)
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@patch.object(config, "optimize_ddp", True)
def test_graph_split_ctx_manager(self):
"""
Ensures that we get the right number of splits and that the respective
context managers' effects are applied to the computation.
"""
for get_compiler in [
lambda: CheckSplitsCompiler(),
lambda: None,
]:
for ctx_manager, output_test in [
(
lambda: torch.autocast(
torch.device(self.device).type, torch.float16
),
lambda out: self.assertEqual(out.dtype, torch.float16),
),
(torch.enable_grad, lambda out: self.assertTrue(out.requires_grad)),
(torch.no_grad, lambda out: self.assertTrue(not out.requires_grad)),
]:
m, inputs, correct_outputs = self.get_model(
out_feat=1000,
hidden_feat=1000,
in_feat=1000,
ctx_manager=ctx_manager,
)
# inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
# hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
bucket_cap_mb = 3.5 # 4MB
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
compiler = get_compiler()
@torch.compile(backend=compiler.compile_fn if compiler else "aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
if compiler:
self.assertEqual(compiler.compiler_called, 4)
output_test(opt_outputs)
# ensure compatibility with dynamo explain
explain_out = torch._dynamo.explain(ddp_m)(inputs)
break_reasons = explain_out.break_reasons
self.assertEqual(len(break_reasons), 4)
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@patch.object(config, "optimize_ddp", True)
def test_compiled_flex_attention_full_model_ddp(self):
class Model(torch.nn.Module):
def __init__(self, S, H, D):
super().__init__()
self.S = S
self.H = H
self.D = D
alibi_bias = self.generate_alibi_bias(H)
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
self.attention = flex_attention
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
self.project_v = torch.nn.Linear(H * D, H * D)
def forward(self, hidden_states):
batch_size, _, _ = hidden_states.size()
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
query = query.view(self.S, batch_size, self.H, self.D)
query = query.permute(1, 2, 0, 3)
key = key.view(self.S, batch_size, self.H, self.D)
key = key.permute(1, 2, 0, 3)
value = self.project_v(hidden_states)
value = value.view(self.S, batch_size, self.H, self.D)
value = value.permute(1, 2, 0, 3)
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
def generate_alibi_bias(self, num_heads):
alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)]
return torch.tensor(alibi_bias)
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
bias = (q_idx - kv_idx) * self.alibi_bias[h]
return score + bias
B = 16
H = 12
S = 512
D = 64
device = "cuda"
model = Model(S, H, D)
model.to(device)
model = torch.compile(model)
model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(device)
attention_scores = model(hidden_states)
torch.cuda.synchronize()
@patch.object(config, "optimize_ddp", True)
def test_compiled_flex_attention_local_ddp(self):
class Model(torch.nn.Module):
def __init__(self, S, H, D):
super().__init__()
self.S = S
self.H = H
self.D = D
alibi_bias = self.generate_alibi_bias(H)
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
self.attention = torch.compile(flex_attention)
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
self.project_v = torch.nn.Linear(H * D, H * D)
def forward(self, hidden_states):
batch_size, _, _ = hidden_states.size()
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
query = query.view(self.S, batch_size, self.H, self.D)
query = query.permute(1, 2, 0, 3)
key = key.view(self.S, batch_size, self.H, self.D)
key = key.permute(1, 2, 0, 3)
value = self.project_v(hidden_states)
value = value.view(self.S, batch_size, self.H, self.D)
value = value.permute(1, 2, 0, 3)
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
def generate_alibi_bias(self, num_heads):
alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)]
return torch.tensor(alibi_bias)
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
bias = (q_idx - kv_idx) * self.alibi_bias[h]
return score + bias
B = 16
H = 12
S = 512
D = 64
device = "cuda"
model = Model(S, H, D)
model.to(device)
model = torch.compile(model)
model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(device)
attention_scores = model(hidden_states)
torch.cuda.synchronize()
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor(self):
assert config.optimize_ddp
"""
Same as above, but using inductor backend.
We observed issues with inductor/fx interface in the past.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch.compile(backend="inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@torch._inductor.config.patch(
{"layout_optimization": True, "keep_output_stride": False}
)
@patch.object(config, "optimize_ddp", True)
def _test_graph_split_inductor_layout_optimizations_impl(self, context):
assert config.optimize_ddp
channel_dim = 512
# channel dim must be > 64 for inductor to do layout optimization and use NHWC
class ToyModelConv(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
*[
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
+ [
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
+ [
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
+ [
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
)
def forward(self, inputs):
return self.net(inputs)
def get_model():
m = ToyModelConv().to(self.device)
m.apply(init_weights)
inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
with context():
m, inputs, correct_outputs = get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch.compile(backend="inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_layout_optimizations_training(self):
self._test_graph_split_inductor_layout_optimizations_impl(
contextlib.nullcontext
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_layout_optimizations_inference(self):
self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad)
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_transpose(self):
assert config.optimize_ddp
B = 100
N = 30
D = 50
K = 70
class Foo(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = nn.Linear(N, K)
self.linear1 = torch.nn.Linear(D * K, 2048)
def forward(self, x):
xt = x.transpose(2, 1)
xt = self.linear0(xt).flatten(1)
return self.linear1(xt)
mod = Foo().to(self.device)
compiled_mod = torch.compile(mod, backend="inductor")
ddp_compiled_mod = DDP(compiled_mod, device_ids=self.device_ids)
x = torch.randn((B, N, D), dtype=torch.float32, device=self.device)
self.assertTrue(same(mod(x), ddp_compiled_mod(x)))
x_1 = torch.randn((B * 2, N, D), dtype=torch.float32, device=self.device)
self.assertTrue(same(mod(x_1), ddp_compiled_mod(x_1)))
x_2 = torch.randn((B * 3, N, D), dtype=torch.float32, device=self.device)
self.assertTrue(same(mod(x_2), ddp_compiled_mod(x_2)))
@patch.object(config, "optimize_ddp", True)
def test_no_split(self):
"""
Ensures the DDPOptimizer returns a correct, compiled module without
introducing graph splits. (Based on model parameters fitting in the bucket)
"""
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
m, inputs, correct_outputs = self.get_model(hidden_feat=5)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
check_splits_compiler = CheckSplitsCompiler()
@torch.compile(backend=check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 1)
@patch.object(config, "optimize_ddp", True)
def test_aot_autograd(self):
"""
Explicitly check AotAutograd family of compilers work,
since they require example inputs propagated between graph splits.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch.compile(backend="aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_custom_layer(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = get_custom_model(self.device)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
check_splits_compiler = CheckSplitsCompiler()
@torch.compile(backend=check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(*inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_empty_graph_inductor(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()
return (get_world_size,)
opt_fn = torch.compile(fn, backend="inductor")
res = None
try:
res = opt_fn()[0]
except Exception:
pass
self.assertEqual(res, 1)
@patch.object(config, "optimize_ddp", False)
def test_ignored_parameters(self):
"""
Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
Hooks up graph-split optimizer manually so it can peek at internal state.
"""
m, inputs, correct_outputs = get_custom_model(self.device)
parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
parameter_ids_to_ignore = [
id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore
]
check_splits_compiler = CheckSplitsCompiler()
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_m.bucket_bytes_cap,
backend_compile_fn=check_splits_compiler.compile_fn,
)
@torch.compile(backend=ddp_optimizer.compile_fn)
def opt_fn(inputs):
return ddp_m(*inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 2)
for b in ddp_optimizer.buckets:
for p_id in b.param_ids:
self.assertFalse(p_id in parameter_ids_to_ignore)
@patch.object(config, "optimize_ddp", True)
def test_higher_order_op(self):
from torch.utils.checkpoint import checkpoint
N = 1000
class InnerModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(N, N)
self.linear2 = torch.nn.Linear(N, N)
def forward(self, x):
a = self.linear1(x)
a = self.linear2(a)
return a
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.inner_mod1 = InnerModule()
self.inner_mod2 = InnerModule()
def forward(self, x):
a = checkpoint(self.inner_mod1, x, use_reentrant=False)
a = torch.cos(a)
a = checkpoint(self.inner_mod2, a, use_reentrant=False)
a = torch.cos(a)
return a
mod = MockModule().cuda()
mod = DDP(mod, bucket_cap_mb=1)
x = torch.randn(N, N, device="cuda", requires_grad=True)
args = (x,)
backend = "aot_eager"
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
torch.compile(mod, backend=cnt)(*args)
def test_fsdp_orig_params_assert(self):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=False)
fsdp_m = torch.compile(fsdp_m)
self.assertRaisesRegex(
AssertionError,
"Dynamo only supports FSDP with use_orig_params=True",
fsdp_m,
inputs,
)
def test_fsdp_skip_guards(self):
"""
It's currently difficult to test dynamo guards. Most guards tests are indirect- modify something and
observe that the guard in question failed. In this case, since the FSDP guards were already deemed
useless and skipping them is expected to have no practical effect, it's pretty contrived to even try to
make those guards fail. Instead, we observe the 'guard source' printed by dynamo's comptime print_guards
function.
Note: comptime prints the guards before the time they get installed or not installed, so in both cases
(skip or no skip) the same guards get printed. The difference is that in the skip case, they show up
with a special 'guard source' which will cuase them to not be installed. So all we check for is the expected
guard source 'local_fsdp_module'.
"""
global GUARDS_FILE
GUARDS_FILE = StringIO()
for skip_guards, expected_guard_source in (
(True, "local_fsdp_module"),
(False, "local_unspecialized_nn_module"),
):
torch._dynamo.reset()
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
def forward(self, inputs):
out = self.net(inputs)
@comptime
def _(ctx):
ctx.print_guards(file=GUARDS_FILE)
return out
device = f"cuda:{self.rank}"
m = ToyModel(
in_feat=10,
hidden_feat=5000,
out_feat=5,
).to(device)
inputs = torch.rand(20, 10).to(device)
m.apply(init_weights)
correct_outputs = m(inputs)
fsdp_m = FSDP(m, use_orig_params=True)
with torch._dynamo.config.patch(skip_fsdp_guards=skip_guards):
opt_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = opt_m(inputs)
# far from an exhaustive check of all the expected guards, just check a couple of them.
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).run(
GUARDS_FILE.getvalue()
)
self.assertTrue(same(correct_outputs, outputs))
def test_fsdp_skip_register_attr_or_module(self):
"""
ensure FSDP module is not registered as attrbutes
in the fx graph
see `not source.guard_source().is_fsdp_module()`
before calling `register_attr_or_module`
in variables/builder.py
"""
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
)
def forward(self, inputs):
out = self.net(inputs)
return out
torch._dynamo.reset()
device = f"cuda:{self.rank}"
m = ToyModel(
in_feat=10,
hidden_feat=5000,
out_feat=5,
).to(device)
inputs = torch.rand(20, 10).to(device)
m.apply(init_weights)
correct_outputs = m(inputs)
fsdp_m = FSDP(m, use_orig_params=True)
def debug_compiler(gm, _):
for node in gm.graph.nodes:
if node.op == "get_attr":
for name in [
"l__self___net_0_weight",
"l__self___net_0_bias",
"l__self___net_2_weight",
"l__self___net_2_bias",
]:
self.assertFalse(
name in node.name,
f"FSDP module {name} should not be registered as attributes",
)
return gm
opt_m = torch.compile(fsdp_m, backend=debug_compiler)
outputs = opt_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
def test_fsdp_dup_tensors_same_source(self):
"""
Tests that FSDP-managed modules' parameters and buffers with the same
source are de-duplicated, meaning that they are each only passed once
as a graph input.
"""
class DuplicateModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self._param = torch.randn((3,), device="cuda")
self._buf = torch.nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use `_param` and `_buf` each twice in this compiled forward
# to exercise if they are de-duplicated by TorchDynamo
z = x + self._buf + self._buf
z += self._param + self._param
return z
model = DuplicateModule()
fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True)
fsdp_model = torch.compile(fsdp_model, backend="aot_eager")
inp = torch.randn((2, 3), device="cuda")
local_out = model(inp)
fsdp_out = fsdp_model(inp)
self.assertEqual(local_out, fsdp_out)
@patch.object(config, "guard_nn_modules", True)
def test_fsdp_dup_tensors_diff_source(self):
"""
Tests that FSDP-managed modules' parameters and buffers with different
source do not result in incorrect AOTAutograd de-dup guards like
``a is b``, where ``a`` and ``b`` are certainly not the same. We check
this by checking for per-invocation recompiles.
"""
class BufModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self._buf = nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self._buf
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
self._buf_module = BufModule()
# Share the buffer, meaning same tensor but different source
self._buf = self._buf_module._buf
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use the same buffer tensor twice in the compiled forward,
# including a data mutation to trigger de-dup logic
self._buf.mul_(2)
z = x + self._buf
z = self._buf_module(z)
z += self._param
return z
fsdp_model = FSDP(Model(), use_orig_params=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fsdp_model = torch.compile(fsdp_model, backend=cnt)
inp = torch.randn((2, 3), device="cuda")
for _ in range(15):
fsdp_model(inp)
# Check for no recompiles (if there were incorrect de-dup guards, then
# the frame count would be equal to the number of forward calls)
self.assertEqual(cnt.frame_count, 1)
def test_fsdp_staticmethod(self):
"""
Tests that Dynamo compiles staticmethods for FSDP-managed modules
correctly both when the staticmethod is invoked from the class and from
the object itself.
"""
class ModuleWithStaticMethod(nn.Module):
def __init__(self, use_self: bool):
super().__init__()
self._use_self = use_self
torch.manual_seed(42) # force `_param` to be deterministic
self._param = nn.Parameter(torch.randn((3,), device="cuda"))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._use_self:
z = self._add(x, self._param)
else:
z = ModuleWithStaticMethod._add(x, self._param)
z *= 2
return z
@staticmethod
def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
model = ModuleWithStaticMethod(False)
x = torch.randn((2, 3), device="cuda")
ref_out = model(x)
test_outs: List[torch.Tensor] = []
for use_self in (False, True):
model = ModuleWithStaticMethod(use_self)
fsdp_model = FSDP(model, use_orig_params=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fsdp_model = torch.compile(fsdp_model, backend=cnt)
test_outs.append(fsdp_model(x))
# Check for no recompiles, which could happen if incorrectly
# passing args to the staticmethod (e.g. doubly passing `self`)
# 3 is expected here for 1 forward.
# Graph 1 should be add and imul
self.assertEqual(cnt.frame_count, 1)
for test_out in test_outs:
self.assertEqual(test_out, ref_out)
def test_async_subclass_no_specialize(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
@torch.compile(backend=cnt, fullgraph=True, dynamic=True)
def f(x):
return x + 1
f(_maybe_wrap_tensor(torch.randn(10)))
f(_maybe_wrap_tensor(torch.randn(12)))
self.assertEqual(cnt.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|