1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797
|
# mypy: allow-untyped-defs
import contextlib
import functools
import logging
import os
import warnings
from enum import auto, Enum
from itertools import accumulate, chain
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
Iterator,
List,
NamedTuple,
no_type_check,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed.fsdp._common_utils import (
_FSDPDeviceHandle,
_named_parameters_with_duplicates,
_no_dispatch_record_stream,
_set_fsdp_flattened,
HandleTrainingState,
)
from torch.distributed.utils import (
_alloc_storage,
_data_ptr_allocated,
_free_storage,
_p_assert,
)
from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined]
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
from ._fsdp_extensions import (
_ext_post_unflatten_transform,
_ext_pre_flatten_transform,
FSDPExtensions,
)
__all__ = [
"FlatParameter",
"FlatParamHandle",
"FlatParamShardMetadata",
"ParamInfo",
"SharedParamInfo",
"HandleShardingStrategy",
]
logger = logging.getLogger(__name__)
"""
[Note: Fully Sharded Module]
We define the "fully sharded module" to be the original ``nn.Module`` that owns
a ``FlatParamHandle``. It is the *single* module logically responsible for the
*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
forward or backward pass. The fully sharded module should be passed to the
``FlatParamHandle`` constructor.
For the wrapper code path:
- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
runs the unshard/reshard on behalf of the fully sharded module by overriding
``nn.Module.forward``.
- The fully sharded module is exactly the module passed to the
``FullyShardedDataParallel`` constructor's ``module`` argument.
For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to ``fully_shard``
or a submodule chosen by the provided wrapping policy.
"""
# Environment variable toggling whether to use unsafe `setattr()` for view
# setting in `_use_sharded_views()` and `_use_unsharded_views()`
# We should use 'safe' by default since it respects method overrides, but for
# special cases such as for high CPU overhead or for intentionally bypassing
# checks in the overrides, we may use 'unsafe'.
_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
# Environment variable toggling whether to check for parameter/gradient
# writeback in case their storages change after FSDP initialization
# We should check by default since it prevents silent correctness errors, but
# since such changes are atypical, we may want to skip the check to save CPU
# overhead, especially since the check happens in the pre-forward and
# pre-backward each iteration.
_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
# Env var toggling whether when model is in .eval() mode, should we run in fp32
# or the reduced precision.
_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
# Some value to set padding in tensors to for debuggability
_FLAT_PARAM_PADDING_VALUE = 42
# Environment variables for disabling the all-gather and reduce-scatter
# communication ops for ablation studies. Note that without these communication
# ops the training won't converge, and you probably need to disable correctness
# checks in your model.
_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
# TODO: Define this for now to avoid circular imports. See if we can remove.
class HandleShardingStrategy(Enum):
FULL_SHARD = auto()
SHARD_GRAD_OP = auto()
NO_SHARD = auto()
HYBRID_SHARD = auto()
_HYBRID_SHARD_ZERO2 = auto()
RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
HandleShardingStrategy.FULL_SHARD,
HandleShardingStrategy.HYBRID_SHARD,
)
NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
HandleShardingStrategy.SHARD_GRAD_OP,
HandleShardingStrategy._HYBRID_SHARD_ZERO2,
)
class ParamInfo(NamedTuple):
"""Information for an original parameter."""
param_name: str # unprefixed
module: nn.Module
module_name: str
class SharedParamInfo(NamedTuple):
"""
Additional information for a shared parameter.
For each shared parameter, we designate one module and its parameter
variable to be the primary owner, determined as the first one encountered
in the parameter walk. These are prefixed with "prim". The primary module
and parameter do not have their own :class:`SharedParamInfo` instance.
"""
param_name: str # unprefixed
module: nn.Module
module_name: str
prim_param_name: str # unprefixed
prim_module: nn.Module
prim_module_name: str
class _ShardParamInfo(NamedTuple):
"""Shard-related information for an original parameter."""
in_shard: bool
# Use to index into the sharded flat parameter, e.g.
# `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
offset_in_shard: Optional[int]
numel_in_shard: Optional[int]
# Use to get part of the parameter in the local shard from a flattened
# version of the unsharded parameter, e.g. either
# `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` or
# `param.as_strided((param.numel(),), (1,))[intra_param_start_idx : intra_param_end_idx + 1]`
intra_param_start_idx: Optional[int]
intra_param_end_idx: Optional[int] # inclusive
class FlatParamShardMetadata(NamedTuple):
"""
This holds metadata specific to this rank's shard of the flat parameter.
Attributes:
param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
shard of the parameters; see :class:`FlatParameter`.
param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
shard of the parameters; see :class:`FlatParameter`.
param_strides (Tuple[torch.Size, ...]): Parameter strides of this rank's
shard of the parameters; see :class:`FlatParameter`.
param_contiguities (Tuple[bool, ...]): Parameter `.contiguous` call results
of this rank's shard of the parameters; see :class:`FlatParameter`.
param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
of the parameters; see :class:`FlatParameter`.
param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
units of numels) giving this rank's part of each flattened
original parameter.
"""
param_names: Tuple[str, ...]
param_shapes: Tuple[torch.Size, ...]
param_strides: Tuple[Tuple[int, ...], ...]
param_contiguities: Tuple[bool, ...]
param_numels: Tuple[int, ...]
param_offsets: Tuple[Tuple[int, int], ...]
class _FlatParameterMeta(_ParameterMeta):
# Make `isinstance(t, FlatParameter)` return True for custom tensor
# instances that have the _is_flat_param flag for BC
def __instancecheck__(self, instance):
# NB: do NOT test the super implementation
return isinstance(instance, torch.Tensor) and getattr(
instance, "_is_flat_param", False
)
class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
"""
This is the flat parameter used by :class:`FullyShardedDataParallel`.
It is comprised of one or more original parameters, which are flattened and
concatenated to construct the flat parameter.
Under the current design, this parameter logically represents both the
unsharded and sharded flat parameter, and its data changes storages
dynamically.
- In the :class:`FullyShardedDataParallel` constructor, the parameter
is initialized as unsharded and then sharded in-place.
- At runtime, the parameter is lazily (re)-initialized. The sharded
parameter data is saved in ``self._local_shard``, and a new ``Tensor``
``self._full_param_padded`` is created, which is the all-gather
destination and owns the unsharded parameter storage thereafter. (See
:meth:`FlatParamHandle.init_flat_param_attributes`.)
- Throughout runtime, the parameter data changes storages as needed,
e.g. to the sharded flat parameter, low precision sharded flat
parameter, or the unsharded flat parameter.
NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
padding, we have two versions of the per-parameter numels, one that
includes the padding (``_numels_with_padding``) and one that does not
(``_numels``). The former may have length longer than the other data
structures, while the latter has the same length as the number of actual
original parameters like the other per-parameter data structures.
NOTE: This is not a real class; instead, you will always get a Parameter
back out if you try to create one of these. This is similar to the trick
we implemented for Parameter to get it to work with subclasses; this
is primarily so that FlatParameter supports combination with FakeTensor.
Attributes:
_unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
without right-hand-side padding for divisibility by the world size.
For ``use_orig_params=True``, this includes alignment padding.
_padded_unsharded_size (torch.Size): Unsharded flat parameter's size
with right-hand-side padding for divisibility by the world size.
For ``use_orig_params=True``, this includes alignment padding. This
is only set for sharded strategies since they require padding for
the all-gather.
_sharded_size (torch.Size): Sharded flat parameter's size with padding.
This is also set for ``NO_SHARD``, in which case it is the same as
the unsharded sizes. (We omit "padded" because there is no
analogous unpadded one.)
_num_params (int): Number of original parameters flattened into this
flat parameter. This is the length of the per-parameter data
structures.
_param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
entry; see :class:`ParamInfo` for details.
_shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
_strides (Tuple[torch.Size, ...]): Each parameter's original stride.
_contiguities (Tuple[bool, ...]): Each parameter's ``contiguous()``
call result.
_fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
prefixed from the ``_fully_sharded_module``. The names are
guaranteed to be unique in the subtree rooted at that module.
_param_extensions (Tuple[Optional[Any], ...]): Each parameter's
extension (i.e. some per-parameter state) used to customize
pre-flatten and post-unflatten behavior or ``None``. This is
experimental, and users should not depend on its existence in the
future.
_numels_with_padding (Tuple[int, ...]): Each parameter's numel
including entries for the padding. This is used to construct views
into the flat parameter via ``torch.split()``. This may have length
longer than ``_num_params``.
_numels (Tuple[int, ...]): Each parameter's numel excluding entries for
padding. This has length equal to ``_num_params``.
_shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
shard parameter info; see :class:`_ShardParamInfo` for details.
_shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
info entries; see :class:`SharedParamInfo` for details.
_modules (Set[nn.Module]): Modules that contain some original parameter
that is flattened into the flat parameter.
_shard_numel_padded (int): Numel padded for this rank's sharded flat
parameter.
_local_shard (Tensor): Sharded flat parameter with padding if using a
sharded strategy. If using ``NO_SHARD``, then this is the unpadded
unsharded flat parameter, and there is no notion of a sharded flat
parameter or padded unsharded flat parameter.
_full_param_padded (Tensor): Unsharded flat parameter with padding.
This is not defined for ``NO_SHARD``. When using mixed precision
for parameters, this has the low precision.
_full_prec_full_param_padded (Tensor): Full precision unsharded flat
parameter with padding. This is used for unsharding outside of
computation when using mixed precision for parameters. This is
never defined for ``NO_SHARD``.
_post_backward_hook_handle (RemovableHandle):
Flat parameter's post-backward hook handle. (Compile only)
_post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
Flat parameter's :class:`AccumulateGrad` object and post-backward
hook handle. (Eager only)
_mp_shard (Tensor): Low precision sharded flat parameter with padding.
This is only defined when parameter mixed precision is enabled. For
``NO_SHARD``, this is used for computation.
_cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
This is only defined when offloading parameters is enabled.
_saved_grad_shard (Tensor): Sharded gradient with padding from previous
iterations for gradient accumulation without :meth:`no_sync`.
_params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
then each original parameter variable; otherwise, ``None``. This
does not include any padding tensors.
_shared_params (Optional[List[nn.Parameter]]): The original shared
parameter variables if ``use_orig_params=True`` and ``None``
otherwise.
_tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
views created in the forward and tracked by autograd when
``use_orig_params=True`` and is ``None`` otherwise. This is to
preserve those ``Tensor`` variables for the backward to ensure that
the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
in which case the post-backward hook does not run. This is relevant
for cases like reentrant activation checkpointing.
_is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
a mask over the original parameters' gradients indicating if it is
logically ``None`` or not; otherwise, ``None``. This does not
include entries for padding. This mask is needed because only some
of the parameters may have ``None`` gradient, in which case the
flat gradient must be non-``None`` and must use zeros to
approximate those original ``None`` gradients. This mask informs
FSDP to set the original parameter gradients to ``None`` (instead
of zeros) as needed.
"""
_unpadded_unsharded_size: torch.Size
_padded_unsharded_size: torch.Size
_sharded_size: torch.Size
_num_params: int
_param_infos: Tuple[ParamInfo, ...]
_shapes: Tuple[torch.Size, ...]
_strides: Tuple[Tuple[int, ...], ...]
_contiguities: Tuple[bool, ...]
_fqns: Tuple[str, ...]
_param_extensions: Tuple[Optional[Any], ...]
_numels_with_padding: Tuple[int, ...]
_numels: Tuple[int, ...]
_shard_param_infos: Tuple[_ShardParamInfo, ...]
_shared_param_infos: Tuple[SharedParamInfo, ...]
_modules: Set[nn.Module]
_shard_numel_padded: int
_local_shard: Tensor
_full_param_padded: Tensor
_full_prec_full_param_padded: Tensor
# Eager only
_post_backward_hook_state: Tuple[Any, Any]
# Compile only
_post_backward_hook_handle: Any
_mp_shard: Tensor
_cpu_grad: Tensor
_saved_grad_shard: Tensor
_params: Optional[List[nn.Parameter]]
_shared_params: Optional[List[nn.Parameter]]
_tensors: Optional[List[Optional[Tensor]]]
_is_grad_none_mask: Optional[List[bool]]
_is_padding_mask: List[bool]
def __new__(cls, data=None, requires_grad=True):
assert cls is FlatParameter, "subclasses FlatParameter not supported"
r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
r._is_flat_param = True # type: ignore[attr-defined]
return r
# NB: This is not a regular method, because FlatParameters are not actually
# instances of this class (see __new__ above). So you must indirectly
# call this directly through the classmethod.
@classmethod
def _init_metadata(
cls,
self,
param_infos: List[ParamInfo],
numels: List[int],
shapes: List[torch.Size],
strides: List[Tuple[int, ...]],
contiguities: List[bool],
fqns: List[str],
shared_param_infos: List[SharedParamInfo],
param_extensions: List[Optional[Any]],
params: Optional[List[nn.Parameter]],
shared_params: Optional[List[nn.Parameter]],
is_padding_mask: List[bool],
) -> None:
"""
Initialize attributes holding metadata about the original parameters comprising the flat parameter.
We expose this method separate from the constructor to keep the
constructor only responsible for the flat parameter's tensor data. This
method should only be called once per model, while the constructor may
be called multiple times, e.g. when reloading from a checkpoint, in
which case only the tensor data needs to be passed to the constructor.
Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
metadata is correctly assumed to be unchanged.
Args:
See the Attributes in the class docstring.
"""
assert len(param_infos) == len(shapes)
assert len(param_infos) == len(strides)
assert len(param_infos) == len(contiguities)
assert len(param_infos) == len(fqns)
assert len(param_infos) == len(param_extensions)
self._num_params = len(param_infos)
self._param_infos = param_infos
self._shapes = shapes
self._strides = strides
self._contiguities = contiguities
self._fqns = fqns
self._param_extensions = param_extensions
self._is_padding_mask = is_padding_mask
numels_without_padding: List[int] = []
for numel, is_padding in zip(numels, is_padding_mask):
if not is_padding:
numels_without_padding.append(numel)
self._numels = tuple(numels_without_padding)
self._numels_with_padding = tuple(numels)
assert len(self._numels) == self._num_params
self._shared_param_infos = tuple(shared_param_infos)
self._modules = {pi.module for pi in self._param_infos}.union(
{spi.module for spi in self._shared_param_infos}
)
assert (params is None) == (shared_params is None)
if params is not None:
assert shared_params is not None and len(shared_params) == len(
shared_param_infos
)
self._params = []
for param, is_padding in zip(params, is_padding_mask):
if not is_padding:
self._params.append(param)
self._shared_params = shared_params
# Mark the original parameters to avoid flattening them into
# another `FlatParameter` during recursive construction
for param in chain(self._params, self._shared_params):
_set_fsdp_flattened(param)
self._is_grad_none_mask = [False for _ in range(self._num_params)]
self._tensors = [None for _ in range(self._num_params)]
else:
self._params = None
self._shared_params = None
self._is_grad_none_mask = None
self._tensors = None
self._unpadded_unsharded_size = self.size()
_set_fsdp_flattened(self)
# Tracks whether the `FlatParameter`'s post-backward hook has been
# called to modify the behavior of the post-backward callback
self._post_backward_called = False
class FlatParamHandle:
"""
A handle that manages a flat parameter (:class:`FlatParameter`).
This includes sharding and view management.
Args:
params (Sequence[nn.Parameter]): The parameters to flatten into the
flat parameter.
fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
device (torch.device): The compute and communication device, which
should be a non-CPU device. We refer to it as the compute device.
sharding_strategy (ShardingStrategy): Sharding strategy to apply to
this handle's ``FlatParameter``.
offload_params (bool): Whether to offload the handle's
``FlatParameter`` to CPU.
mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
setting passed to the FSDP constructor.
mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
precision setting passed to the FSDP constructor.
keep_low_precision_grads (bool): Whether to keep gradients in low
precision.
use_orig_params (bool): If ``True``, then FSDP preserves the original
parameter variables and returns them from ``named_parameters()``
(e.g. to support different optimizer hyperparameters within one
:class:`FlatParameter`). If ``False``, then FSDP reconstructs the
parameters every iteration and returns the :class:`FlatParameter` s
from ``named_parameters()``.
"""
##################
# INITIALIZATION #
##################
def __init__(
self,
params: Sequence[Union[nn.Parameter, Tensor]],
fully_sharded_module: nn.Module,
device: torch.device,
sharding_strategy: HandleShardingStrategy,
offload_params: bool,
mp_param_dtype: Optional[torch.dtype],
mp_reduce_dtype: Optional[torch.dtype],
keep_low_precision_grads: bool,
process_group: dist.ProcessGroup,
use_orig_params: bool,
*,
fsdp_extension: Optional[FSDPExtensions] = None,
):
super().__init__()
params = list(params)
if len(params) == 0:
raise ValueError(
f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
)
self._init_setattr_fns()
self._skip_writeback_check = (
os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
)
self._use_full_prec_in_eval = (
os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
)
self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
if self._skip_writeback_check:
_warn_skip_writeback_check(
logger,
f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
"for parameter or gradient writeback. Changing parameter or "
"gradient storages may lead to silent correctness errors.",
)
if self._use_fake_all_gather:
_warn_use_fake_all_gather(
logger,
f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
"all-gather ops. Your training will be incorrect, but "
"can reveal how much time spent on all-gather ops.",
)
if self._use_fake_reduce:
_warn_use_fake_reduce(
logger,
f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
"reduce-scatter ops. Your training will be incorrect, but "
"can reveal how much time spent on reduce-scatter ops.",
)
# Only align addresses for `use_orig_params=True` (for now)
align_addresses = use_orig_params
self._init_get_unflat_views_fn(align_addresses)
self.device = device
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
self.process_group = process_group
if self._use_fake_all_gather or self._use_fake_reduce:
self._fake_process_group = FakeProcessGroup(
rank=process_group.rank(), world_size=process_group.size()
)
self.rank = process_group.rank()
self.world_size = process_group.size()
self._sharding_strategy = sharding_strategy
self._offload_params = offload_params
self._use_orig_params = use_orig_params
self._keep_low_precision_grads = keep_low_precision_grads
self._training_state = HandleTrainingState.IDLE
self._debug_level = dist.get_debug_level()
self._fully_sharded_module = fully_sharded_module
# For strategies that do not free after forward, we skip using sharded
# views after forward since the unsharded data exists. We still switch
# `self.flat_param` to point to the sharded flat parameter since what
# it points to parameterizes behavior. We use the following attribute
# to track which tensor data the parameters are unsharded views into.
self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
# The index in the state's `all_handles`, which must be the
# same across ranks for the execution order validation to work
self._handle_index: Optional[int] = None
# Index in handles_to_pre_forward_order
self._pre_forward_order_index: Optional[int] = None
# Index in `handles_post_forward_order`
self._post_forward_index: Optional[int] = None
# Used for guarding against mistargeted forward prefetches
self._needs_pre_forward_unshard = False
# Used for guarding against mistargeted backward prefetches
self._needs_pre_backward_unshard = False
# Was the handle prefetched? Set on successful _prefetch_handle and unshard
self._prefetched = False
# Optimistically assume a valid input `params` and set dtype attributes
# before `_init_flat_param()`, which performs the actual validation
self._orig_param_dtype = params[0].dtype
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
assert self._fwd_bwd_param_dtype is not None # mypy
self._aligned_numel = (
_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
if align_addresses
else 0
)
self._fsdp_extension = fsdp_extension
self._init_flat_param_and_metadata(
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
)
self._use_unsharded_views(as_params=False)
def _init_setattr_fns(self):
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
if use_unsafe_setattr:
self._setattr_tensor = _unsafe_setattr_tensor
self._setattr_param = _unsafe_setattr_param
else:
self._setattr_tensor = _safe_setattr_tensor_or_param
self._setattr_param = _safe_setattr_tensor_or_param
def _init_get_unflat_views_fn(self, align_addresses: bool):
self._get_unflat_views = (
self._get_unflat_views_aligned
if align_addresses
else self._get_unflat_views_unaligned
)
def _init_flat_param_and_metadata(
self,
params: List[Union[Tensor, nn.Parameter]],
module: nn.Module,
aligned_numel: int,
use_orig_params: bool,
) -> None:
"""
Initialize the ``FlatParameter`` and its metadata.
NOTE: This should only be called once at construction time, after which
the ``FlatParameter`` metadata is assumed to be static.
NOTE: The elements of ``params`` should only be ``Tensor`` s when
composing with ``DTensor`` -based tensor parallelism, in which case the
elements may be ``DTensor`` local shards.
"""
if len(params) == 0:
raise ValueError("Expects non-empty `params`")
if aligned_numel < 0:
raise ValueError(
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
)
(
dtype,
flat_param_requires_grad,
device,
) = self._validate_tensors_to_flatten(params)
params_set = set(params)
# For alignment padding, only `numels` gets strictly non-`None`
# elements, and all other lists get `None` elements for padding.
param_infos: List[ParamInfo] = []
numels: List[int] = []
shapes: List[torch.Size] = []
strides: List[Tuple[int, ...]] = []
contiguities: List[bool] = []
fqns: List[str] = []
shared_param_infos: List[SharedParamInfo] = []
shared_param_memo: Dict[
Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
] = {}
params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
shared_params: List[Union[Tensor, nn.Parameter]] = []
param_extensions: List[Any] = []
is_padding_mask: List[bool] = []
total_numel = total_numel_without_padding = 0
for submodule_name, submodule in module.named_modules(remove_duplicate=False):
for param_name, param in _named_parameters_with_duplicates(
submodule, recurse=False
):
if param not in params_set:
continue
if param in shared_param_memo: # shared reference
prim_module, prim_module_name, prim_param_name = shared_param_memo[
param
]
shared_params.append(param)
shared_param_infos.append(
SharedParamInfo(
param_name,
submodule,
submodule_name,
prim_param_name,
prim_module,
prim_module_name,
)
)
else:
if aligned_numel > 0:
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
params_to_flatten.append(padding_tensor)
is_padding_mask.append(True)
numels.append(numel_to_pad)
total_numel += numel_to_pad
transform_t, extension = _ext_pre_flatten_transform(
param,
self._fsdp_extension,
)
param = cast(nn.Parameter, transform_t)
param_extensions.append(extension)
shared_param_memo[param] = (submodule, submodule_name, param_name)
params_to_flatten.append(param)
is_padding_mask.append(False)
param_infos.append(ParamInfo(param_name, submodule, submodule_name))
numels.append(param.numel())
shapes.append(param.shape)
strides.append(param.stride())
contiguities.append(_is_truly_contiguous(param))
fqn = (
submodule_name + "." + param_name
if submodule_name
else param_name
)
fqns.append(fqn)
total_numel += param.numel()
total_numel_without_padding += param.numel()
if len(params_to_flatten) == 0:
raise ValueError(
f"`params` were not found in `module`'s tree"
f"params: {params}\nmodule: {module}"
)
if (
self.rank == 0
and aligned_numel > 0
and total_numel != total_numel_without_padding
):
logger.debug(
"FSDP FlatParameter address alignment created "
"%s numel of padding (%s vs. %s)",
total_numel - total_numel_without_padding,
total_numel,
total_numel_without_padding,
)
if aligned_numel > 0:
# Pad to be divisible by world size to avoid a copy for the
# post-backward reduce-scatter
numel_to_pad = self.world_size - (total_numel % self.world_size)
if numel_to_pad > 0 and numel_to_pad < self.world_size:
if self.rank == 0:
logger.info(
"FSDP FlatParameter world size divisibility created "
"%s numel of padding",
numel_to_pad,
)
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
params_to_flatten.append(padding_tensor)
is_padding_mask.append(True)
numels.append(numel_to_pad)
total_numel += numel_to_pad
# Pass `aligned_numel=0` since we already included padding tensors
self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
params_to_flatten,
aligned_numel=0,
requires_grad=flat_param_requires_grad,
)
FlatParameter._init_metadata(
self.flat_param,
param_infos,
numels,
shapes,
strides,
contiguities,
fqns,
shared_param_infos,
param_extensions,
_convert_to_params(params_to_flatten) if use_orig_params else None,
_convert_to_params(shared_params) if use_orig_params else None,
is_padding_mask,
)
def _validate_tensors_to_flatten(
self, tensors: List[Union[Tensor, nn.Parameter]]
) -> Tuple:
"""Validate the tensors to flatten and returns any necessary metadata."""
dtype: Optional[torch.dtype] = None
# Return as the logical OR over each tensor's value
flat_param_requires_grad: Optional[bool] = None
device: Optional[torch.device] = None
# For `use_orig_params=True`, permit non-uniform `requires_grad`
for tensor in tensors:
if isinstance(tensor, FlatParameter):
raise ValueError("Cannot flatten a `FlatParameter`")
if dtype is None and not tensor.is_floating_point():
raise ValueError("Cannot flatten integer dtype tensors")
if dtype is not None and tensor.dtype != dtype:
raise ValueError(
f"Must flatten tensors with uniform dtype but got {dtype} "
f"and {tensor.dtype}"
)
if (
not self._use_orig_params
and flat_param_requires_grad is not None
and tensor.requires_grad != flat_param_requires_grad
):
raise ValueError(
"Must flatten tensors with uniform `requires_grad` when "
"`use_orig_params=False`"
)
if device is not None and tensor.device != device:
raise ValueError(
"Must flatten tensors on the same device but got both "
f"{device} and {tensor.device}"
)
dtype = tensor.dtype
flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
device = tensor.device
assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
return dtype, flat_param_requires_grad, device
def flatten_tensors(
self,
tensors: List[Tensor],
aligned_numel: int,
) -> Tensor:
"""
Flatten ``tensors`` into a single flat tensor.
The flattening optionally includes
padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
gives the numel required to have address alignment.
NOTE: The padding alignment algorithm must be kept in sync with
:meth:`_init_flat_param_metadata`. We separate the two methods because
the initialization happens once, whereas this method may be called
multiple times throughout training (e.g. for checkpointing).
"""
if len(tensors) == 0:
raise ValueError("Expects non-empty `tensors`")
if aligned_numel < 0:
raise ValueError(
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
)
dtype, _, device = self._validate_tensors_to_flatten(tensors)
flat_tensors: List[Tensor] = []
if aligned_numel > 0:
total_numel = 0
for tensor in tensors:
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad
flat_tensors.append(
torch.flatten(_detach_if_needed(tensor))
if _is_truly_contiguous(tensor)
else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,))
)
total_numel += tensor.numel()
numel_to_pad = self.world_size - (total_numel % self.world_size)
if numel_to_pad > 0 and numel_to_pad < self.world_size:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, False, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad
else:
flat_tensors = [
torch.flatten(_detach_if_needed(tensor))
if _is_truly_contiguous(tensor)
else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,))
for tensor in tensors
]
return torch.cat(flat_tensors, dim=0)
def flatten_tensors_into_flat_param(
self,
tensors: List[Tensor],
aligned_numel: int,
requires_grad: bool,
) -> FlatParameter:
flat_param_data = self.flatten_tensors(tensors, aligned_numel)
return FlatParameter(flat_param_data, requires_grad=requires_grad)
def _init_param_reduce_dtypes(
self,
mp_param_dtype: Optional[torch.dtype],
mp_reduce_dtype: Optional[torch.dtype],
) -> None:
"""
Initialize param and reduce dtypes.
Precondition: ``self.flat_param`` is set. This ensures that this
handle's parameters have a single dtype.
Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
is ``None``, then we assume the original parameter dtype. One special
case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
is ``None``, in which case we assume the gradient reduction dtype
matches the forward/backward parameter dtype.
"""
# Save whether these dtypes were specified so that we permit the
# parameter dtype to change up until the lazy initialization
self._low_prec_param_dtype_specified = mp_param_dtype is not None
self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
if (
self._low_prec_param_dtype_specified
and not self._low_prec_reduce_dtype_specified
):
# Special case: infer gradient reduction mixed precision
self._fwd_bwd_param_dtype = mp_param_dtype
self._reduce_dtype = self._fwd_bwd_param_dtype
else:
self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
assert self._fwd_bwd_param_dtype is not None
assert self._reduce_dtype is not None
###################################
# SHARD INITIALIZATION & METADATA #
###################################
@torch.no_grad()
def shard(self):
"""
Shard the handle's ``FlatParameter``.
This allocates new memory for
the sharded flat parameter and frees the unsharded flat parameter's
storage.
Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
metadata attributes are set for all sharding strategies.
"""
flat_param = self.flat_param
if not self.uses_sharded_strategy:
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
else:
_p_assert(
flat_param.storage_offset() == 0,
"The `FlatParameter` is not the sole occupant of its storage",
)
sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
flat_param, self.rank, self.world_size
)
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
allocated = flat_param._typed_storage()._size() > 0
if allocated:
flat_param._typed_storage()._resize_(0)
flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
start_idx = sharded_flat_param.numel() * self.rank
end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
self._init_shard_metadata(numel_padded, start_idx, end_idx)
if self._use_orig_params:
self._use_sharded_views()
def _init_shard_metadata(
self,
numel_padded: int,
unsharded_start_idx: int,
unsharded_end_idx: int,
) -> None:
"""
Initialize shard-related metadata for this rank's shard of the flat parameter.
This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
Args:
numel_padded (int): Numel padded for this rank's sharded flat
parameter.
unsharded_start_idx (int): Start index in the unsharded flat
parameter assigned to this rank.
unsharded_end_idx (int): End index (inclusive) in the unsharded
flat parameter assigned to this rank.
Precondition: ``self.flat_param`` 's data is the sharded flat
parameter.
"""
flat_param = self.flat_param
flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined]
sharded_flat_param_numel = flat_param.numel() # includes `numel_padded`
_p_assert(
unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
)
_p_assert(
numel_padded <= sharded_flat_param_numel,
f"numel_padded: {numel_padded} "
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
)
shard_param_infos = self._get_shard_metadata(
unsharded_start_idx, unsharded_end_idx
)
assert (
len(shard_param_infos) == flat_param._num_params
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
def _get_shard_metadata(
self,
unsharded_start_idx: int,
unsharded_end_idx: int,
) -> Tuple[_ShardParamInfo, ...]:
"""
Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
unsharded flat parameter specifying the shard.
"""
flat_param_offsets = self._get_flat_param_offsets()
assert len(flat_param_offsets) == len(
self.flat_param._numels_with_padding
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
shard_param_infos: List[_ShardParamInfo] = []
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
# into the unsharded flat parameter (inclusive) of the given parameter
for (
(unsharded_param_start_idx, unsharded_param_end_idx),
is_padding,
) in zip(flat_param_offsets, self.flat_param._is_padding_mask):
if is_padding:
continue
in_sharded_flat_param = (
unsharded_start_idx <= unsharded_param_end_idx
and unsharded_end_idx >= unsharded_param_start_idx
)
if not in_sharded_flat_param:
shard_param_info = _ShardParamInfo(False, None, None, None, None)
else:
if unsharded_start_idx <= unsharded_param_start_idx:
# This branch can only happen once since the rank's
# unsharded start index can only intersect one parameter
intra_param_start_idx = 0
offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
else:
intra_param_start_idx = (
unsharded_start_idx - unsharded_param_start_idx
)
offset_in_shard = 0
assert (
offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
), (
f"Invalid `offset_in_shard` of {offset_in_shard} for "
f"sharded flat parameter with {sharded_flat_param_numel} numel"
)
intra_param_end_idx = (
min(unsharded_param_end_idx, unsharded_end_idx)
- unsharded_param_start_idx
)
numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
shard_param_info = _ShardParamInfo(
True,
offset_in_shard,
numel_in_shard,
intra_param_start_idx,
intra_param_end_idx,
)
shard_param_infos.append(shard_param_info)
return tuple(shard_param_infos)
@staticmethod
def _get_unpadded_shard(
tensor: Tensor,
rank: int,
world_size: int,
) -> Tuple[Tensor, int]:
"""
Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
The returned value is a tuple of the shard of ``tensor`` without any
padding and the numel to pad for that shard.
If ``tensor`` is already flattened or may be viewed in the flattened
shape (which is true in the expected usage), then this method does not
allocate any new tensor memory.
"""
chunks = (
torch.flatten(tensor).chunk(world_size)
if _is_truly_contiguous(tensor)
else tensor.as_strided((tensor.numel(),), (1,)).chunk(world_size)
)
if len(chunks) < (rank + 1):
# This rank gets an empty chunk fully padded with zeros since there
# are not enough chunks across ranks
chunk = chunks[0].new_empty(0)
else:
chunk = chunks[rank]
numel_to_pad = chunks[0].numel() - chunk.numel()
assert (
numel_to_pad >= 0
), "Chunk's size should be at most the first chunk's size"
return chunk, numel_to_pad
@staticmethod
def _get_shard(
tensor: Tensor,
rank: int,
world_size: int,
) -> Tuple[Tensor, int]:
"""
Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
This method allocates new memory (via :meth:`clone`) since the
unsharded ``tensor`` may be deallocated after this method returns.
"""
chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
tensor, rank, world_size
)
shard = chunk.clone()
if numel_to_pad > 0:
shard = F.pad(shard, [0, numel_to_pad])
return shard, numel_to_pad
@staticmethod
def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
"""
Return the shape of ``tensor`` after sharding including padding.
This requires ``tensor`` to have 1D shape and ensures that the returned
shape is 1D.
"""
assert len(tensor.shape) == 1, f"{tensor.shape}"
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
tensor, rank, world_size
)
unpadded_sharded_size = unpadded_sharded_tensor.size()
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
"""
Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
NOTE: The returned list includes elements for alignment padding.
"""
cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
starts = [0] + cumulative_sum[:-1]
ends = [end - 1 for end in cumulative_sum] # inclusive
param_offsets = list(zip(starts, ends))
return param_offsets
@no_type_check
def shard_metadata(
self,
) -> FlatParamShardMetadata:
"""
Return the shard-related metadata specific to this rank's shard of the flat parameter.
NOTE: The returned tuple does not include elements for alignment
padding but does account for the padding.
"""
fqns_list = []
shapes_list = []
strides_list = []
contiguities_list = []
numels_list = []
shard_param_offsets = []
for fqn, shape, stride, contiguous, numel, shard_param_info in zip(
self.flat_param._fqns,
self.flat_param._shapes,
self.flat_param._strides,
self.flat_param._contiguities,
self.flat_param._numels,
self.flat_param._shard_param_infos,
):
if not shard_param_info.in_shard:
continue
fqns_list.append(fqn)
shapes_list.append(shape)
strides_list.append(stride)
contiguities_list.append(contiguous)
numels_list.append(numel)
shard_param_offsets.append(
(
shard_param_info.intra_param_start_idx,
shard_param_info.intra_param_end_idx,
)
)
return FlatParamShardMetadata(
tuple(fqns_list),
tuple(shapes_list),
tuple(strides_list),
tuple(contiguities_list),
tuple(numels_list),
tuple(shard_param_offsets),
)
@no_type_check
@torch.no_grad()
def init_flat_param_attributes(self) -> None:
"""
This initializes some attributes on the handle's ``FlatParameter``.
This should be called during lazy initialization since it requires the
parameter to be on the compute device if not offloading to CPU and we
want to give users the chance to move the parameter appropriately after
the FSDP constructor.
For each tensor attribute on the ``FlatParameter``, see the unshard and
reshard methods in this class for the allocation and free pattern.
"""
flat_param = self.flat_param
if flat_param.dtype != self._orig_param_dtype:
# Entering this branch means that the user changed the parameter
# dtype after FSDP initialization, in which case we may need to
# refresh some saved dtype attributes (dtypes specified as a part
# of mixed precision take precedence).
if not self._low_prec_param_dtype_specified:
self._fwd_bwd_param_dtype = flat_param.dtype
# For `reduce_dtype`, require `param_dtype` was not specified since
# then we infer the `reduce_dtype` from the specified `param_dtype`
if (
not self._low_prec_reduce_dtype_specified
and not self._low_prec_param_dtype_specified
):
self._reduce_dtype = flat_param.dtype
self._orig_param_dtype = flat_param.dtype
cpu_device = torch.device("cpu")
if self._offload_params:
_p_assert(
flat_param.device == cpu_device,
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
f"offloading is enabled, not {flat_param.device}",
)
else:
self._check_on_compute_device(self.flat_param)
flat_param._local_shard = flat_param.data
if self._offload_params:
# Pin the memory for faster H2D transfer
flat_param._local_shard = flat_param._local_shard.pin_memory(
device=self.device
)
# Pre-allocate the sharded gradient on CPU to enable non-blocking
# D2H transfer during the backward pass
flat_param._cpu_grad = torch.zeros_like(
flat_param._local_shard, device=cpu_device
).pin_memory(device=self.device)
if self._uses_param_mixed_precision:
# For parameter mixed precision, we maintain a low precision
# sharded tensor on the compute device to be all-gathered (for
# sharded strategies) or directly used (for `NO_SHARD`) for
# computation.
flat_param._mp_shard = torch.empty_like(
flat_param._local_shard,
device=self.device,
dtype=self._fwd_bwd_param_dtype,
)
_free_storage(flat_param._mp_shard)
if self.uses_sharded_strategy:
# We maintain a padded unsharded tensor that serves as the
# all-gather destination and owns the original parameter storages.
unsharded_param_dtype = (
self._fwd_bwd_param_dtype
if self._uses_param_mixed_precision
else flat_param.dtype
) # use low precision if parameter mixed precision is enabled
padded_unsharded_numel = flat_param.numel() * self.world_size
flat_param._full_param_padded = torch.empty(
padded_unsharded_numel,
device=self.device,
dtype=unsharded_param_dtype,
)
flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
_free_storage(flat_param._full_param_padded)
if self._uses_param_mixed_precision:
# For parameter mixed precision, we maintain a full precision
# padded unsharded tensor for when we force full precision.
flat_param._full_prec_full_param_padded = torch.empty(
padded_unsharded_numel,
device=self.device,
dtype=flat_param.dtype, # full precision
)
_free_storage(flat_param._full_prec_full_param_padded)
###################
# UNSHARD/RESHARD #
###################
def pre_unshard(self) -> bool:
"""
Return ``False`` if this is a no-op and ``True`` otherwise.
Postcondition: ``self.flat_param`` 's data is on the device for
communication and is what should be all-gathered. This means that it
matches the dtype of the expected unsharded parameter.
"""
if (
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
and self._skipped_use_sharded_views
):
# Since this path imposes special semantics for the unsharded flat
# parameter (e.g. forcing full precision), use sharded views to
# reuse the existing logic for that special handling
self._use_sharded_views()
ret = False
if self._use_orig_params and not self._skip_writeback_check:
ret = self._writeback_orig_params()
if (
self.uses_sharded_strategy
and not self._offload_params
and not self.needs_unshard()
):
pass # no-op
elif self._uses_param_mixed_precision and not self._force_full_precision:
self._use_low_precision_shard()
ret = True
elif self._offload_params and self.flat_param.device != self.device:
# NOTE: This creates a new tensor distinct from any attributes.
self.flat_param_to(self.device, non_blocking=True)
ret = True
self._check_on_compute_device(self.flat_param)
return ret
def _use_low_precision_shard(self):
"""Allocate on the compute device and switch to using the low precision sharded flat parameter."""
self._check_low_precision_shard()
flat_param = self.flat_param
_alloc_storage(
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
)
# `copy_()` implicitly casts to the low precision
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
flat_param._local_shard.to( # type: ignore[attr-defined]
self.device, non_blocking=True
)
)
# Invariant: `_mp_shard` is always on the compute device.
flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
def unshard(self):
"""
Run the unshard logic.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)
def needs_unshard(self) -> bool:
"""Return if the handle's flat parameter needs to be unsharded."""
if not self.uses_sharded_strategy:
return False
unsharded_flat_param = self._get_padded_unsharded_flat_param()
already_unsharded = _same_storage_size(
unsharded_flat_param, unsharded_flat_param.numel()
)
return not already_unsharded
def _alloc_padded_unsharded_flat_param(self):
"""
Allocate the *padded* unsharded flat parameter.
The unpadded unsharded
flat parameter is always a view into the padded one. This padded
parameter is saved to a different attribute on the ``FlatParameter``
depending on if we force full precision.
"""
self._check_sharded_strategy()
flat_param = self.flat_param
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_storage_freed(unsharded_flat_param)
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
return unsharded_flat_param
def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
"""
Return a reference to the padded unsharded flat parameter depending on the calling context.
This should only be called if using a sharded strategy.
"""
self._check_sharded_strategy()
flat_param = self.flat_param
if self._force_full_precision and self._uses_param_mixed_precision:
# When parameter mixed precision is enabled, we use a different
# tensor as the all-gather destination to preserve the invariant
# that `_full_param_padded` is in the low precision
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
_p_assert(
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
)
# For no-reshard-after-forward strategies, `_full_param_padded` may
# still be allocated from a previous forward. As we are forcing
# full precision here, the full-precision unsharded copy may be
# modified, invalidating the existing low-precision unsharded copy,
# so we should free it here to ensure a new all-gather for the next
# forward/backward computation to persist the modifications.
if flat_param._full_param_padded.untyped_storage().size() > 0:
_free_storage(flat_param._full_param_padded)
else:
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
return unsharded_flat_param
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
Then switch to use the all-gathered tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
pg = (
self._fake_process_group
if self._use_fake_all_gather
else self.process_group
)
# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(
torch.chunk(
padded_unsharded_flat_param,
dist.get_world_size(pg), # type: ignore[arg-type]
)
)
dist.all_gather(tensor_list, sharded_flat_param, group=pg)
else:
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
pg,
)
if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
# created on the pre-unshard stream. We need to hand it over to the
# unshard stream for all-gather
_no_dispatch_record_stream(
sharded_flat_param,
self._device_handle.current_stream(), # unshard_stream
)
return padded_unsharded_flat_param
def _use_unsharded_flat_param(
self,
padded_unsharded_flat_param: torch.Tensor,
) -> None:
"""
Switch to use the *unpadded* unsharded flat parameter.
This is a view into the *padded* unsharded flat parameter.
"""
unsharded_size = self.flat_param._unpadded_unsharded_size
flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
# slicing [:] is not visible to autograd because of .data
self.flat_param.data = flat_param_part
in_forward = self._training_state == HandleTrainingState.FORWARD
in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
if self._use_orig_params:
if self._skipped_use_sharded_views and in_pre_backward:
# This call corresponds to the complementary pre-backward
# `_use_unsharded_views()` to the skipped pre-forward
# `_use_sharded_views()`, so we should skip this one too.
return
# We use `Tensor` views in the forward so that they are tracked by
# autograd. We use them in the pre-backward as well to support
# reentrant activation checkpointing, which needs the views to be
# tracked by autograd in the backward pass's recomputed forward.
self._use_unsharded_views(
as_params=(not in_forward and not in_pre_backward)
)
elif in_forward:
self._use_unsharded_views(as_params=False)
def post_unshard(self):
"""
Run the post-unshard logic.
This includes freeing the low precision shard if needed.
"""
if self._uses_param_mixed_precision and self.uses_sharded_strategy:
self._free_low_precision_sharded_param()
self._check_on_compute_device(self.flat_param)
def _free_low_precision_sharded_param(self):
"""Frees the low precision sharded flat parameter."""
self._check_low_precision_shard()
# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
# unshard stream for sharded strategies, and consumed in both the
# unshard and default streams for `NO_SHARD`. For sharded strategies,
# the current stream here is the unshard stream, and for `NO_SHARD`,
# it is the default stream. For `NO_SHARD`, only recording for the
# default stream suffices since the default stream waits for the
# unshard stream.
_no_dispatch_record_stream(
self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined]
)
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
@torch.no_grad()
def unshard_grad(self):
"""
Unshard the handle's ``FlatParameter``'s gradient.
If all ranks have
``None`` gradient, then all original parameters will as well. This
method performs an all-reduce and an all-gather. The additional
all-reduce is tolerable since this method is not meant to be used on
the computation critical path.
Postcondition: ``_saved_grad_shard`` is defined and contains the value
to set ``flat_param.grad`` after gradients are resharded.
"""
if not self.uses_sharded_strategy:
self._use_unsharded_grad_views()
return
flat_param = self.flat_param
self._check_unsharded(flat_param)
# Check if all ranks have a `None` gradient
num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
num_grad_none[0] = flat_param.grad is None
dist.all_reduce(num_grad_none, group=self.process_group)
if num_grad_none[0] == self.world_size:
flat_param._saved_grad_shard = None # type: ignore[assignment]
self._use_unsharded_grad_views()
return
if flat_param.grad is None:
# In the case that only some ranks have `None` gradient, we use
# zeros to approximate as a best effort attempt
if self._debug_level == dist.DebugLevel.INFO:
warnings.warn(
f"[Rank {self.rank}] Only some but not all ranks have a "
"`None` `FlatParameter` gradient, so FSDP is using zeros to "
"approximate those ranks' sharded gradients being `None`"
)
flat_param._saved_grad_shard = None # type: ignore[assignment]
sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
else:
self._check_sharded(flat_param.grad)
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
padded_unsharded_grad = torch.empty(
flat_param._padded_unsharded_size, # type: ignore[attr-defined]
device=self.device,
dtype=sharded_grad.dtype,
)
dist.all_gather_into_tensor(
padded_unsharded_grad, sharded_grad, self.process_group
)
unsharded_size = self.flat_param._unpadded_unsharded_size
flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
unsharded_size
)
self._use_unsharded_grad_views()
def reshard_grad(self):
if self._use_orig_params:
self._use_sharded_grad_views()
if not self.uses_sharded_strategy:
return
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
delattr(self.flat_param, "_saved_grad_shard")
def prepare_gradient_for_backward(self):
"""
Prepare the gradient for the backward computation.
This is done by saving and clearing any existing sharded gradient
in ``.grad`` to enable computing a new unsharded gradient.
"""
_p_assert(
self._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
)
flat_param = self.flat_param
if flat_param.grad is not None and (
flat_param.grad.size() != flat_param._unpadded_unsharded_size
or flat_param.grad.device != flat_param.device # grad on CPU
):
self._check_on_compute_device(self.flat_param)
grad_offloaded = flat_param.grad.device != self.device
_p_assert(
not grad_offloaded or self._offload_params,
f"Expects the sharded gradient to be on {self.device} "
f"but got {flat_param.grad.device}",
)
prev_iter_synced_gradients = (
flat_param.grad.size()
== flat_param._local_shard.size() # type: ignore[attr-defined]
)
if prev_iter_synced_gradients:
# TODO (awgu): Gradient accumulation outside `no_sync()`
# does not work with CPU offloading. The issue should be
# that, in the post-backward hook, we cannot do an addition
# between a CPU tensor (the existing sharded gradient) and
# a GPU tensor (the new sharded gradient).
if not grad_offloaded:
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
else:
_p_assert(
hasattr(flat_param, "_cpu_grad"),
"`_cpu_grad` should be defined if the gradient is on CPU",
)
sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
# If user specified to keep the gradient in low precision, then
# the gradient may still be of the low precision dtype if the
# user did not set the gradient to `None` after the previous
# backward, in which case FSDP should cast back to the full
# precision dtype so that FSDP can accumulate in that dtype in
# the post-backward hook and assign to `.grad` in that dtype in
# the post-backward callback.
local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
if (
self._keep_low_precision_grads
and sharded_grad.dtype != local_shard_dtype
):
sharded_grad.data = sharded_grad.to(local_shard_dtype)
else:
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
_p_assert(
flat_param.grad.size() == padded_unsharded_size,
"Expects `.grad` to be the unsharded gradient in "
f"`no_sync()` with size {padded_unsharded_size} "
f"but got size {flat_param.grad.size()}",
)
flat_param.grad = None
def prepare_gradient_for_optim(self):
"""Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
def cast_grad_to_param_dtype_if_needed(flat_param):
# TODO (rohan-varma): test for full precision with keep_low_precision_grads
if not self._force_full_precision and self._keep_low_precision_grads:
_p_assert(flat_param.grad is not None, "Unexpected None grad!")
if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
if self._use_orig_params:
self._use_sharded_grad_views()
flat_param = self.flat_param
# TODO (awgu): We should replace these conditional checks to encode
# the logical intention more directly.
if hasattr(flat_param, "_cpu_grad"):
# NOTE: This branch includes `NO_SHARD`.
self._check_sharded(flat_param)
self._check_on_cpu(flat_param)
flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
cast_grad_to_param_dtype_if_needed(flat_param)
elif hasattr(flat_param, "_saved_grad_shard"):
self._check_sharded(flat_param)
self._check_on_compute_device(flat_param)
if flat_param._saved_grad_shard is not None:
self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
# If no sharded gradient was computed this iteration, then there is
# no need to forward `_saved_grad_shard` to `grad`
if flat_param._post_backward_called: # type: ignore[attr-defined]
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
if flat_param.grad is not None:
cast_grad_to_param_dtype_if_needed(flat_param)
else:
_p_assert(
not self.uses_sharded_strategy
or not flat_param._post_backward_called, # type: ignore[attr-defined]
"All sharded parameters that received a gradient in the "
"post-backward should use `_saved_grad_shard`",
)
# Delete `_saved_grad_shard` since its existence indicates a previous
# gradient to accumulate with in the post-backward hook
if hasattr(flat_param, "_saved_grad_shard"):
delattr(flat_param, "_saved_grad_shard")
@contextlib.contextmanager
def to_cpu(self):
"""
Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
since (1) there is no reason to include the padding in the copy and (2)
there is no use case for the sharded flat parameter.
Precondition: ``self.flat_param`` 's data is the unpadded unsharded
flat parameter on the compute device, and the handle uses a sharded
strategy.
Postcondition: Same as the precondition.
"""
self._check_sharded_strategy()
_p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
)
self._check_on_compute_device(self.flat_param)
# Check that the unpadded unsharded flat parameter is a view into the
# padded unsharded flat parameter as expected
# NOTE: This check is not strictly needed for correctness but is a
# useful sanity check since the tensor should only be used internally.
_p_assert(
_same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
"Expects the unpadded parameter to be a view into the padded parameter",
)
self.flat_param_to(torch.device("cpu"))
self._free_unsharded_flat_param()
try:
yield
finally:
_p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
)
padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
# Copy from CPU to the compute device
padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
self.flat_param
)
self._use_unsharded_flat_param(padded_unsharded_flat_param)
def reshard(self, free_unsharded_flat_param: bool):
"""
Run the reshard logic.
This includes freeing the unsharded flat
parameter if ``free_unsharded_flat_param`` and switching to using the
sharded flat parameter. Note that this also implicitly offloads
the sharded flat parameter (if CPU offload is enabled) by pointing
it to the ``_local_shard`` attribute which resides on CPU.
"""
# Switch to the sharded `FlatParameter` before freeing to prevent
# "use-after-free"-type bugs with external profiling tools, where for
# `use_orig_params=True`, the `param` does not point to valid memory
# when setting `param.data = ...` in `_use_sharded_views()`.
self._use_sharded_flat_param()
if free_unsharded_flat_param:
self._free_unsharded_flat_param()
def post_reshard(self):
"""
Run the post-reshard logic.
This includes freeing any memory that
can now be freed given that the ``FlatParameter`` points to the full
precision sharded flat parameter.
Precondition: ``self.flat_param`` 's data points to the full precision
sharded flat parameter.
"""
# For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
# is also the low precision *unsharded* flat parameter. Hence, we delay
# the free until the reshard.
if (
self._uses_param_mixed_precision
and not self.uses_sharded_strategy
and not self._force_full_precision # did not use the low precision shard
):
self._free_low_precision_sharded_param()
def _free_unsharded_flat_param(self):
"""
Free the padded unsharded flat parameter. We allow this
function to be called even when storage is not allocated
The tensor to free depends
on the calling context since the unshard may have forced full
precision, in which case a different tensor is used.
"""
self._check_sharded_strategy()
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_on_compute_device(unsharded_flat_param)
# Do not free the memory until all ops in the current stream finish
_no_dispatch_record_stream(
unsharded_flat_param, self._device_handle.current_stream()
)
_free_storage(unsharded_flat_param)
def _use_sharded_flat_param(self) -> None:
"""Switches to using the sharded flat parameter."""
flat_param = self.flat_param
if self._use_orig_params:
in_forward = self._training_state == HandleTrainingState.FORWARD
skip_use_sharded_views = (
torch.is_grad_enabled()
and in_forward
and self._sharding_strategy
in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
)
# Only incur the extra `.data` call if needed
if skip_use_sharded_views:
unsharded_flat_param = flat_param.data
if self._offload_params:
device = flat_param._local_shard.device # type: ignore[attr-defined]
_p_assert(
device == torch.device("cpu"),
f"Expects the local shard to be on CPU but got {device}",
)
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
if self._use_orig_params:
if skip_use_sharded_views: # type: ignore[possibly-undefined]
self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined]
else:
self._use_sharded_views()
# For the post-forward reshard, we may try to use sharded gradient
# views (or unsharded gradient views if a gradient was accumulated
# in `no_sync()`), but for the post-backward reshard, we delay the
# call to after the reduce-scatter.
if (
in_forward # type: ignore[possibly-undefined]
# Skip using gradient views if skipped using sharded views
# since exposing unsharded parameters with sharded gradients
# may be confusing to the user
and not self._skipped_use_sharded_views
):
# TODO: Change `_unpadded_unsharded_size` if we change the
# gradient to be computed directly with padding.
accumulated_grad_in_no_sync = (
flat_param.grad is not None
and self.uses_sharded_strategy
and flat_param.grad.shape == flat_param._unpadded_unsharded_size
)
if accumulated_grad_in_no_sync:
self._use_unsharded_grad_views()
else:
self._use_sharded_grad_views()
#########
# VIEWS #
#########
@no_type_check
def _get_unflat_views_unaligned(
self,
tensor: Optional[torch.Tensor] = None,
) -> Iterator[Tensor]:
"""
Return unflattened ``Tensor`` views into ``tensor``.
If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based
on ``flat_param`` 's metadata.
Examples for ``tensor`` include ``flat_param.grad`` or unsharded
tensor optimizer state.
"""
flat_param = self.flat_param
if tensor is None:
tensor = flat_param
views = (
_ext_post_unflatten_transform(
subtensor.view(shape)
if contiguous
else subtensor.as_strided(shape, stride),
param_extension,
self._fsdp_extension,
)
for (subtensor, shape, stride, contiguous, param_extension) in zip(
torch.split(tensor, flat_param._numels, dim=0),
flat_param._shapes,
flat_param._strides,
flat_param._contiguities,
flat_param._param_extensions,
)
)
return views
@no_type_check
def _get_unflat_views_aligned(
self,
tensor: Optional[Tensor] = None,
) -> List[Tensor]:
"""
Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
This method has the same contract as :meth:`_get_unflat_views_unaligned`
except it checks for ``None`` placeholders representing padding for
alignment, which may incur slightly more CPU overhead.
"""
flat_param = self.flat_param
if tensor is None:
tensor = flat_param
splits: List[Tensor] = torch.split(
tensor, flat_param._numels_with_padding, dim=0
)
idx = 0
views: List[Tensor] = []
for split, is_padding in zip(splits, flat_param._is_padding_mask):
if is_padding:
continue
views.append(
_ext_post_unflatten_transform(
split.view(flat_param._shapes[idx])
if flat_param._contiguities[idx]
else split.as_strided(
flat_param._shapes[idx], flat_param._strides[idx]
),
flat_param._param_extensions[idx],
self._fsdp_extension,
)
)
idx += 1
return views
@no_type_check
@torch.enable_grad()
def _use_unsharded_views(self, as_params: bool) -> None:
"""
Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
Args:
as_params (bool): If ``True``, then registers the original
parameters as ``nn.Parameter`` s; if ``False``, then registers
the original parameters only as ``Tensor`` s. ``False`` should
be used during forward/backward computation and when hiding the
original parameters from :meth:`nn.Module.named_parameters`.
Note:
when prefetching for next forward, current forward may be
annotated with `@torch.no_grad()`
`@torch.enable_grad()` ensures non-empty `view.grad_fn`
otherwise `_post_backward_hook` will not get called
"""
flat_param = self.flat_param
self._check_unsharded(flat_param)
views = self._get_unflat_views()
from torch.distributed.tensor import DTensor
for i, (view, (param_name, module, _)) in enumerate(
zip(views, flat_param._param_infos)
):
if self._use_orig_params and as_params:
if type(view) is DTensor:
# A `DTensor` `view` is not compatible with assigning
# `param.data = view`, so we cannot preserve the parameter
# variable.
self._setattr_param(
module,
param_name,
nn.Parameter(view, requires_grad=flat_param.requires_grad),
)
continue
param = self.flat_param._params[i]
self._setattr_param(module, param_name, param)
param.data = view
elif as_params:
self._setattr_param(
module,
param_name,
nn.Parameter(view, requires_grad=flat_param.requires_grad),
)
else: # `as_params=False`
param_var: Tensor = view
if self._use_orig_params:
if self._training_state == HandleTrainingState.FORWARD:
# Save the `Tensor` for the pre-backward
self.flat_param._tensors[i] = view # save for pre-backward
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
# Use the saved `Tensor` variable from the forward to
# preserve the autograd graph so that the post-backward
# hook fires (e.g. for reentrant AC)
tensor = self.flat_param._tensors[i]
tensor.data = view
param_var = tensor
self._setattr_tensor(module, param_name, param_var)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = param_var
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in enumerate(self.flat_param._shared_param_infos):
prim_param: Union[Tensor, nn.Parameter] = getattr(
prim_module, prim_param_name
)
_p_assert(
not as_params or isinstance(prim_param, nn.Parameter),
f"as_params={as_params} type(prim_param)={type(prim_param)}",
)
if self._use_orig_params and as_params:
shared_param = self.flat_param._shared_params[i]
self._setattr_param(module, param_name, shared_param)
shared_param.data = prim_param
elif as_params:
self._setattr_param(module, param_name, prim_param)
else:
self._setattr_tensor(module, param_name, prim_param)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = prim_param
@no_type_check
def _use_unsharded_grad_views(self) -> None:
"""
Unflatten the unsharded flat parameter's gradient.
The original parameter variables' gradients are set to be views into
the unsharded flat parameter's gradient.
"""
# Expects the gradient to be in `flat_param.grad`
if self.flat_param.grad is None:
for param in chain(self.flat_param._params, self.flat_param._shared_params):
param.grad = None
return
self._check_unsharded(self.flat_param.grad)
views = self._get_unflat_views(self.flat_param.grad)
for i, (view, (param_name, module, _)) in enumerate(
zip(views, self.flat_param._param_infos)
):
_p_assert(
hasattr(module, param_name),
f"{self.flat_param._fqns[i]} is missing",
)
param = getattr(module, param_name)
if (
param.shape != view.shape
or param.dtype != view.dtype
or param.device != view.device
):
# NOTE: This is a hack using `.data` to side step the check
# that parameter/gradient sizes/dtypes/devices match. From
# calling `reshard()`, `param` has the sharded size, has the
# full precision dtype, and if CPU offloading is enabled, is on
# CPU. Thus, one or more of the following cases can hold when
# in `no_sync()`, where `view` is the original parameter's
# gradient:
# 1. `view` can have the unsharded size.
# 2. `view` can have the parameter low precision dtype.
# 3. `view` can be on GPU.
if param.grad is None:
param.grad = torch.empty_like(param)
param.grad.data = view
else:
param.grad = view
for i, (
param_name,
module,
module_name,
prim_param_name,
prim_module,
_,
) in enumerate(self.flat_param._shared_param_infos):
_p_assert(
hasattr(module, param_name),
f"{module_name + '.' + param_name if module_name else param_name} is missing",
) # did not save FQN info in `_shared_param_infos`
param = getattr(module, param_name)
prim_param = getattr(prim_module, prim_param_name)
if (
param.shape != prim_param.grad.shape
or param.dtype != prim_param.grad.dtype
or param.device != prim_param.grad.device
):
# NOTE: This is the same hack to use `.data` to side step the
# size check.
if param.grad is None:
param.grad = torch.empty_like(param)
param.grad.data = prim_param.grad
else:
param.grad = prim_param.grad
@contextlib.contextmanager
def unflatten_as_params(self) -> Generator:
"""
Unflatten the original parameters.
The function assumes that the flat parameter is unsharded. When in the context,
unflattens the original parameters as ``nn.Parameter`` views into the
flat parameter, and after the context, restores the original parameters
as ``Tensor`` views into the flat parameter.
"""
self._use_unsharded_views(as_params=True)
try:
yield
finally:
self._use_unsharded_views(as_params=False)
@no_type_check
@torch.no_grad()
def _use_sharded_views(self) -> None:
"""
Set the original parameter variables' data to be flattened views into the sharded flat parameter.
The views are kept as flattened to simplify the case where a parameter
is sharded across ranks. Parameters whose data is not present in the
sharded flat parameter have their data set to a size-0 empty tensor. We
do not delete them to ensure to preserve expected behaviors like model
printability. Parameters whose data is present must preserve their
variables to be passable to an optimizer.
"""
self._unsharded_flat_param_for_skipped_views = None
if not self.uses_sharded_strategy:
# For `NO_SHARD`, use the *unflattened* unsharded views since we
# have the unsharded parameter
self._use_unsharded_views(as_params=True)
return
flat_param = self.flat_param
self._check_sharded(flat_param)
# Construct once and reuse for all parameters not in the local shard
size_0_empty_tensor = torch.empty(
0,
dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
device=self.flat_param.device,
requires_grad=False,
)
for param, shard_param_info, (param_name, module, _) in zip(
flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
):
self._setattr_param(module, param_name, param)
if not shard_param_info.in_shard:
# Allow the original data to be freed via garbage collection
param.data = size_0_empty_tensor
else:
offset = shard_param_info.offset_in_shard
numel_in_shard = shard_param_info.numel_in_shard
param.data = flat_param[offset : offset + numel_in_shard]
assert self.flat_param._shared_params is not None
for i, (
param,
(param_name, module, _, prim_param_name, prim_module, _),
) in enumerate(
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
):
self._setattr_param(module, param_name, param)
prim_param = getattr(prim_module, prim_param_name)
param.data = prim_param # could be both empty and non-empty
if self._training_state == HandleTrainingState.BACKWARD_POST:
# Clear the saved `Tensor`s since they are unneeded now
for i in range(len(self.flat_param._tensors)):
self.flat_param._tensors[i] = None
@no_type_check
@torch.no_grad()
def _use_sharded_grad_views(self) -> None:
"""
Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
This is a no-op if there is no gradient.
Parameters whose data is not present in the sharded flat parameter and
parameters with ``requires_grad=False`` have their gradients set to
``None``. Since the gradient variables do not need to be preserved,
this method does not manipulate existing ``Tensor`` data directly and
creates new ``Tensor`` variables instead.
"""
flat_param = self.flat_param
self._check_sharded(flat_param)
grad = self.sharded_grad
if grad is None:
for param in chain(flat_param._params, flat_param._shared_params):
param.grad = None
return
self._check_sharded(grad)
for param, shard_param_info, is_grad_none in zip(
flat_param._params,
flat_param._shard_param_infos,
flat_param._is_grad_none_mask,
):
if not shard_param_info.in_shard:
param.grad = None
else:
numel_in_shard = shard_param_info.numel_in_shard
if param.requires_grad and not is_grad_none:
offset = shard_param_info.offset_in_shard
if self._keep_low_precision_grads or param.dtype != grad.dtype:
# NOTE: This is a hack using `.data` to side step the
# check that parameter/gradient dtypes match. Here,
# `param` has full precision; `grad` has low precision.
if param.grad is None:
# `.grad` must have the same shape as `param`
param.grad = torch.empty_like(param)
param.grad.data = grad[
offset : offset + numel_in_shard
].reshape(param.shape)
else:
param.grad = grad[offset : offset + numel_in_shard].reshape(
param.shape
)
else:
param.grad = None
assert flat_param._shared_params is not None
for param, (_, _, _, prim_param_name, prim_module, _) in zip(
flat_param._shared_params, flat_param._shared_param_infos
):
in_sharded_flat_param = hasattr(prim_module, prim_param_name)
if in_sharded_flat_param and param.requires_grad:
prim_param = getattr(prim_module, prim_param_name)
param.grad = prim_param.grad # share the same reference
else:
param.grad = None
@no_type_check
@torch.no_grad()
def _writeback_orig_params(self) -> bool:
"""
Write back any parameters that changed storage to the handle's ``FlatParameter``.
Iterates over the original parameters and writes back any parameters
that changed storages (due to a non-inplace operator) to the handle's
``FlatParameter``. This method preserves the ``FlatParameter` 's
device even if an original parameter's device changes.
Raises:
RuntimeError: If an original parameter or gradient changes storages
but no longer has the expected flattened shape.
Returns: ``True`` if some writeback happened, and ``False`` otherwise.
"""
if (
self.uses_sharded_strategy
and not self.is_sharded(self.flat_param)
and not self._skipped_use_sharded_views
):
# For `NO_SHARD`, we may still need to writeback
return False
flat_param = self.flat_param
wroteback = False
if self._skipped_use_sharded_views and self.uses_sharded_strategy:
# NOTE: We must use the unsharded flat parameter from which the
# unsharded views were computed, not the one from the current
# calling context (`_get_padded_unsharded_flat_param()`) since that
# may be different (e.g. the model changed from train to eval).
flat_param_tensor = self._unsharded_flat_param_for_skipped_views
_p_assert(
_data_ptr_allocated(flat_param_tensor),
"If skipped using sharded views, the unsharded flat parameter "
"should be allocated",
)
else:
flat_param_tensor = flat_param
# NOTE: Since this method is called in the pre-unshard, which is only
# called during computation in the pre-forward or pre-backward, the
# sharded gradient should be guaranteed to be in `.grad`, not in
# `._saved_grad_shard`.
flat_param_grad = (
flat_param.grad
if self.uses_sharded_strategy or not self._offload_params
else flat_param._cpu_grad
)
for i, (
param,
(in_shard, offset_in_shard, numel_in_shard, _, _),
(param_name, module, _),
) in enumerate(
zip(
flat_param._params,
flat_param._shard_param_infos,
flat_param._param_infos,
)
):
if not in_shard:
continue
if not hasattr(module, param_name):
# Do not writeback if original parameters are deregistered
# (e.g. during model checkpointing)
continue
# Check for parameter writeback
if self._skipped_use_sharded_views:
param = flat_param._tensors[i]
_p_assert(
param is not None,
f"Expects to have saved tensor for {flat_param._fqns[i]}",
)
param_changed = getattr(module, param_name) is not param
needs_param_writeback = (
param_changed # changed parameter variable itself
or not _same_storage(param, flat_param_tensor)
)
if self._skipped_use_sharded_views and (
param_changed or needs_param_writeback
):
raise AssertionError(
"FSDP does not support changing the parameters between "
f"forward and backward for {self._sharding_strategy}"
)
if param_changed:
# NOTE: The gradient is not preserved after a parameter change.
param = getattr(module, param_name)
flat_param._params[i] = param
if needs_param_writeback:
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
param, flat_param, i, expected_shape, offset_in_shard, True
)
wroteback = True
# Check for gradient writeback
if self._skipped_use_sharded_views:
# Skip the writeback check because we do not expose gradients
# when we skipped using sharded views
continue
if param.grad is None and flat_param.grad is not None:
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
None, flat_param.grad, i, expected_shape, offset_in_shard, False
)
elif param.grad is not None:
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
# memory and owns the gradient storage, so it will never
# require gradient writeback.
if not self.uses_sharded_strategy and self._offload_params:
# Explicitly continue to handle the case of `no_sync()`,
# where `param.grad` is a view into the GPU gradient
# referenced by `flat_param.grad`, while `flat_param_grad`
# is `flat_param._cpu_grad`, which is on CPU
continue
needs_grad_writeback = flat_param_grad is None or not _same_storage(
param.grad, flat_param_grad
)
if needs_grad_writeback:
if flat_param_grad is None:
flat_param_grad = torch.zeros_like(flat_param)
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
param.grad,
flat_param_grad,
i,
expected_shape,
offset_in_shard,
False,
)
flat_param.grad = flat_param_grad
flat_param_grad = flat_param.grad
# TODO: If we want to handle shared parameters, we need to re-generate
# the shared parameter data structures in case sharedness changed.
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in enumerate(flat_param._shared_param_infos):
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
raise NotImplementedError(
"Changing shared parameters is not supported yet"
)
return wroteback
def _writeback_tensor(
self,
src_tensor: Optional[Tensor],
dst_tensor: Tensor,
tensor_index: int,
expected_shape: torch.Size,
offset: int,
is_param: bool, # else gradient
) -> None:
"""
Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
instead of copying. ``tensor_index`` gives the index of ``src_tensor``
in the metadata structures.
Raises:
RuntimeError: If the ``src_tensor`` does not have the expected
shape.
"""
_p_assert(
len(expected_shape) == 1,
f"Expects a 1D expected shape but got {expected_shape}",
)
if self._debug_level == dist.DebugLevel.INFO:
rank = self.rank if hasattr(self, "rank") else dist.get_rank()
src_shape = src_tensor.shape if src_tensor is not None else None
src_device = src_tensor.device if src_tensor is not None else None
warnings.warn(
f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
f"writeback in {self._training_state}\n"
f"expected shape={expected_shape} shape={src_shape} "
f"expected device={dst_tensor.device} device={src_device}"
)
if src_tensor is not None and src_tensor.shape != expected_shape:
# NOTE: Gradient shape mismatch is not possible in practice since
# the gradient shape is enforced to match that of the parameter and
# we already check for parameter shape mismatch.
raise RuntimeError(
f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
)
if src_tensor is not None:
dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
else:
dst_tensor[offset : offset + expected_shape.numel()].zero_()
assert self.flat_param._is_grad_none_mask is not None
self.flat_param._is_grad_none_mask[tensor_index] = True
def _reset_flat_param_grad_info_if_needed(self):
"""
Reset ``flat_param.grad`` if needed.
When ``use_orig_params=True``:
(1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
original parameters' ``.grad`` are ``None``, and
(2) sets ``flat_param.requires_grad=False`` if *none* of the original
parameters require gradient.
For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
which case we want to free the gradients as soon after the
``zero_grad()`` call as possible.
"""
if not self._use_orig_params:
return
flat_param = self.flat_param
assert flat_param._params is not None # mypy
all_grad_none = True
requires_grad = False
for param in flat_param._params:
all_grad_none &= param.grad is None
requires_grad |= param.requires_grad
if all_grad_none:
flat_param.grad = None
# As long as one parameter requires gradient, then the flat parameter
# must require gradient
flat_param.requires_grad = requires_grad
def _deregister_orig_params(self):
for param_info in self.flat_param._param_infos:
param_name, module, _ = param_info
if hasattr(module, param_name):
delattr(module, param_name)
for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
if hasattr(module, param_name):
delattr(module, param_name)
###########
# HELPERS #
###########
def flat_param_to(self, *args, **kwargs):
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
self.flat_param.data = self.flat_param.to(*args, **kwargs)
if self._use_orig_params:
# Refresh the views because their storage may have changed
if self.is_sharded(self.flat_param):
self._use_sharded_views()
else:
self._use_unsharded_views(as_params=True)
def _get_modules(self) -> Set[nn.Module]:
"""Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
return {pi.module for pi in self.flat_param._param_infos}.union(
{spi.module for spi in self.flat_param._shared_param_infos}
)
def is_sharded(self, tensor: Tensor) -> bool:
"""
Return whether ``tensor`` is *currently* sharded.
For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
"""
if (
not hasattr(self.flat_param, "_sharded_size")
or not self.uses_sharded_strategy
):
# `_sharded_size` is defined iff `handle.shard()` has been called
return False
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
return tensor.size() == sharded_size
def param_module_names(self) -> Iterator[Tuple[str, str]]:
shared_param_infos = [
ParamInfo(param_name, module, module_name)
for (
param_name,
module,
module_name,
_,
_,
_,
) in self.flat_param._shared_param_infos
]
for param_info in chain(self.flat_param._param_infos, shared_param_infos):
param_name, _, module_name = param_info # type: ignore[misc]
yield (param_name, module_name)
def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
for param_name, _, module_name in [
ParamInfo(param_name, module, module_name)
for (
param_name,
module,
module_name,
_,
_,
_,
) in self.flat_param._shared_param_infos
]:
yield (param_name, module_name)
@property
def _fqns_in_shard(self) -> List[str]:
"""Return the FQNs of the parameters present in this rank's shard."""
fqns_in_shard: List[str] = []
for fqn, shard_param_info in zip(
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
):
if shard_param_info.in_shard:
fqns_in_shard.append(fqn)
return fqns_in_shard
@property
def sharded_grad(self) -> Optional[Tensor]:
"""Return the handle's sharded gradient."""
flat_param = self.flat_param
# Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
# - CPU offloading: `_cpu_grad`
# - No CPU offloading + sharded strategies: `_saved_grad_shard`
# - No CPU offloading + `NO_SHARD`: `grad`
grad: Optional[Tensor]
if hasattr(flat_param, "_cpu_grad"):
grad = flat_param._cpu_grad # type: ignore[attr-defined]
elif hasattr(flat_param, "_saved_grad_shard"):
# In the post-backward hook, the sharded gradient is still in
# `_saved_grad_shard`.
grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
else:
# If in IDLE or in FORWARD states, then there may be an
# (accumulated) gradient. If accessed in IDLE, then this should
# be due to re-registering the original parameters (e.g. in state
# dict load).
_p_assert(
flat_param.grad is None
or not self.uses_sharded_strategy
or self._training_state
in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
"Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
"unless in IDLE or FORWARD",
)
grad = flat_param.grad
return grad
def _reset_is_grad_none(self) -> None:
"""
Reset ``_is_grad_none_mask`` as needed.
This method should only be
called in the post-backward after gradient computation, in which case
if a parameter requires gradient, then it will surely receive a
gradient and we may reset its mask entry to ``False``.
"""
if not self._use_orig_params:
return
_p_assert(
self._training_state == HandleTrainingState.BACKWARD_POST,
"Expects to only be called in the post-backward after gradient computation",
)
flat_param = self.flat_param
assert flat_param._params is not None # mypy
for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
# As long as the parameter requires gradient, it should receive a
# meaningful gradient (even if the gradient happens to be zeros)
if param.requires_grad:
assert flat_param._is_grad_none_mask is not None # mypy
flat_param._is_grad_none_mask[i] = False
#######################
# CHECKS & INVARIANTS #
#######################
def _check_sharded_strategy(self):
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
def _check_on_compute_device(self, tensor: Tensor):
_p_assert(
tensor.device == self.device,
f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
)
def _check_on_cpu(self, tensor: Tensor):
_p_assert(
tensor.device == torch.device("cpu"),
f"Expects tensor to be on CPU but got {tensor.device}",
)
@staticmethod
def _check_storage_freed(tensor: Tensor):
# Compile does not resize during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
_p_assert(
_same_storage_size(tensor, 0),
"Expects storage to be freed but got storage with size > 0",
)
@staticmethod
def _check_storage_allocated(tensor: Tensor):
_p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
def _check_low_precision_shard(self):
_p_assert(
self._uses_param_mixed_precision,
"Not using low precision for parameters",
)
_p_assert(
getattr(self.flat_param, "_mp_shard", None) is not None,
"Expects `_mp_shard` to exist",
)
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
_p_assert(
device == self.device,
f"Expects the low precision shard to be on {self.device} but got {device}",
)
def _check_unsharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be unsharded "
_p_assert(tensor is not None, msg_prefix + "but got `None`")
unsharded_size = self.flat_param._unpadded_unsharded_size
_p_assert(
tensor.size() == unsharded_size,
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
)
def _check_sharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be sharded "
_p_assert(tensor is not None, msg_prefix + "but got `None`")
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
_p_assert(
tensor.size() == sharded_size,
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
)
##############
# PROPERTIES #
##############
@property
def uses_sharded_strategy(self) -> bool:
return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
@property
def _uses_param_mixed_precision(self) -> bool:
return self._fwd_bwd_param_dtype != self._orig_param_dtype
@property
def _uses_reduce_mixed_precision(self) -> bool:
return self._reduce_dtype != self._orig_param_dtype
@property
def _force_full_precision(self) -> bool:
return (
self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
) and (
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
or
# Also disable mixed precision in model eval mode, if configured
(not self._fully_sharded_module.training and self._use_full_prec_in_eval)
)
@property
def _skipped_use_sharded_views(self) -> bool:
"""
This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
This returns if this handle is
currently in a state where it has skipped using sharded views, in which
case it can restore view invariants via ``_use_sharded_views()``.
"""
return self._unsharded_flat_param_for_skipped_views is not None
# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
def _unsafe_setattr_param(
module: nn.Module, param_name: str, param: nn.Parameter
) -> None:
module._parameters[param_name] = param
# This bypasses any overrides in case `module` is an instance of an
# `nn.Module` subclass
super(nn.Module, module).__setattr__(param_name, param)
def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
module._parameters.pop(param_name, None)
# This bypasses any overrides in case `module` is an instance of an
# `nn.Module` subclass
super(nn.Module, module).__setattr__(param_name, tensor)
def _safe_setattr_tensor_or_param(
module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
):
# Call `delattr()` and `setattr()` to go through `nn.Module` checks
if hasattr(module, param_name):
delattr(module, param_name)
setattr(module, param_name, tensor_or_param)
def _convert_to_params(
tensors: List[Union[torch.Tensor, nn.Parameter]]
) -> List[nn.Parameter]:
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
def _is_truly_contiguous(x: Tensor) -> bool:
# Special case: Pytorch thinks that 1x1 channels_last convolution weights are
# both contiguous and channels_last contiguous at the same time.
# CuDNN does not agree though and refuses to select faster kernels.
# It is the reason of having the extra check here.
return x.stride(-1) == 1 and x.is_contiguous()
def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
return (
param_or_tensor.detach()
if isinstance(param_or_tensor, nn.Parameter)
else param_or_tensor
)
def _get_aligned_numel(unsharded_dtype: torch.dtype):
# NOTE: This alignment constraint comes from TorchInductor.
ALIGNMENT = 16 # bytes
unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
aligned_numel = ALIGNMENT // unsharded_dtype_size
return aligned_numel
@functools.lru_cache(8)
def _get_dtype_size(dtype):
return torch.empty((), dtype=dtype).element_size()
def _construct_padding_tensor(
padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
):
# NOTE: Set the padding value as a magic number for debuggability. The
# value itself should never be used in any user-facing computation.
return (
torch.ones(
(padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
)
* _FLAT_PARAM_PADDING_VALUE
)
# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
# messasge is passed in)
@functools.lru_cache(1)
def _warn_skip_writeback_check(log: logging.Logger, warning: str):
logger.warning(warning)
# Use `lru_cache(1)` to only log the warning once
@functools.lru_cache(1)
def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
logger.warning(warning)
# Use `lru_cache(1)` to only log the warning once
@functools.lru_cache(1)
def _warn_use_fake_reduce(log: logging.Logger, warning: str):
logger.warning(warning)
def _same_storage(a, b):
# Params are DTensors in backward
# with SHARD_GRAD_OP + TP
from torch.distributed.tensor import DTensor
if isinstance(a, DTensor):
a = a._local_tensor
if isinstance(b, DTensor):
b = b._local_tensor
return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
def _same_storage_size(a: torch.Tensor, b: int):
return a.untyped_storage().size() // a.element_size() == b
def _storage_size_allocated(tensor: Tensor):
storage_size: int = tensor.untyped_storage().size()
return storage_size > 0
|