File: test_fsdp_optim_state.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (2015 lines) | stat: -rw-r--r-- 79,984 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
# Owner(s): ["oncall: distributed"]

import bisect
import sys
from copy import deepcopy
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    _CHECKPOINT_WRAPPED_MODULE,
    apply_activation_checkpointing,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    FullOptimStateDictConfig,
    FullStateDictConfig,
    OptimStateKeyType,
    ShardedOptimStateDictConfig,
    ShardedStateDictConfig,
    StateDictSettings,
    StateDictType,
)
from torch.distributed.optim import _NamedOptimizer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    DEVICEInitMode,
    FSDPInitMode,
    FSDPTest,
    TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
)


STATE_DICT_TYPES = [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT]

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class _OSDCommMethod(Enum):
    """Method for communicating the optimizer state dict for internal tests."""

    BROADCAST_OBJECT_LIST = auto()
    SCATTER_FULL_OSD = auto()
    FLATTEN_SHARDED_OSD = auto()
    OPTIM_STATE_DICT = auto()


class _ModelClass(Enum):
    """Different model type to test."""

    NESTED = auto()
    TRANSFORMER = auto()


class Bias(torch.nn.Module):
    """This module applies a 1D additive bias with dimension ``dim``."""

    def __init__(self, dim: int) -> None:
        super().__init__()
        assert dim > 0
        torch.manual_seed(0)
        self.bias = torch.nn.Parameter(torch.randn((dim,)))

    def forward(self, x):
        return x + self.bias


class BlockA(torch.nn.Module):
    """
    Used to define interesting nested structure for FSDP wrapping.
    BlockA
        Bias0
            bias
        weight
        Bias1
            bias
    """

    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()
        assert all(v > 0 for v in (in_dim, out_dim))
        torch.manual_seed(0)
        self.bias_module0 = Bias(out_dim)
        self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
        self.bias_module1 = Bias(out_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x @ self.weight
        x = self.bias_module0(x)
        x = self.relu(x)  # ensure biases have different gradients
        x = self.bias_module1(x)
        return x


class BlockB(torch.nn.Module):
    """
    Used to define interesting nested structure for FSDP wrapping.
    BlockB
        weight
        Bias
            bias
        Bias
            bias
    """

    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()
        assert all(v > 0 for v in (in_dim, out_dim))
        torch.manual_seed(0)
        self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))
        self.bias_module0 = Bias(out_dim)
        self.bias_module1 = Bias(out_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x @ self.weight
        x = self.bias_module0(x)
        x = self.relu(x)  # ensure biases have different gradients
        x = self.bias_module1(x)
        return x


class NestedModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.block0 = BlockB(5, 3)
        self.block1 = BlockB(3, 7)
        self.bias = torch.nn.Parameter(torch.randn((5,)))
        self.block2 = torch.nn.Sequential(
            BlockA(7, 9),
            BlockA(9, 9),
            BlockB(9, 5),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x) -> torch.Tensor:
        x = self.relu(self.block0(x))
        x = self.relu(self.block1(x))
        x = self.relu(self.block2(x))
        x = x + self.bias
        return x

    def get_input(self, device):
        BATCH_SIZE = 8
        return (torch.randn((BATCH_SIZE, 5)).to(device),)

    def get_loss(self, inp, output):
        return output.sum()

    def run_backward(self, loss):
        loss.backward()

    @staticmethod
    def wrap(
        model: torch.nn.Module,
        group: Optional[dist.ProcessGroup] = None,
        ignore_modules: bool = False,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        if fsdp_kwargs is None:
            fsdp_kwargs = {}
        # Flatten Bias0; then flatten weight and Bias1 together into `block1`
        model.block1.bias_module0 = FSDP(
            model.block1.bias_module0,
            process_group=group,
            **fsdp_kwargs,
        )
        model.block1 = FSDP(model.block1, process_group=group, **fsdp_kwargs)
        # Flatten Bias0; flatten Bias1; then flatten weight into `block2[1]`
        model.block2[1].bias_module0 = FSDP(
            model.block2[1].bias_module0,
            process_group=group,
            **fsdp_kwargs,
        )
        model.block2[1].bias_module1 = FSDP(
            model.block2[1].bias_module1,
            process_group=group,
            **fsdp_kwargs,
        )
        model.block2[1] = FSDP(model.block2[1], process_group=group, **fsdp_kwargs)
        # Flatten weight, Bias, bias into `block2[2]`
        ignored_modules = [model.block2[2].bias_module0] if ignore_modules else None
        model.block2[2] = FSDP(
            model.block2[2],
            process_group=group,
            ignored_modules=ignored_modules,
            **fsdp_kwargs,
        )
        return model

    @staticmethod
    def wrap_alt(
        model: torch.nn.Module,
        group: Optional[dist.ProcessGroup] = None,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
    ) -> torch.nn.Module:
        if fsdp_kwargs is None:
            fsdp_kwargs = {}
        model.block0.bias_module0 = FSDP(
            model.block0.bias_module0,
            process_group=group,
            **fsdp_kwargs,
        )
        model.block0 = FSDP(model.block0, process_group=group, **fsdp_kwargs)
        return model

    @staticmethod
    def wrap_with_unmanaged_params(
        model,
        add_to_fsdp_module: bool,
        group=None,
    ) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]:
        """Registers unmanaged parameters before wrapping with :meth:`wrap`."""
        device = next(model.parameters()).device
        unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))
        # Either register the parameter to a module to be wrapped with FSDP
        # (`model.block2[2]`) or a module not to be wrapped with FSDP (`model`)
        register_module = model.block2[2] if add_to_fsdp_module else model
        register_module.register_parameter(
            "unmanaged_param",
            unmanaged_param,
        )
        # For simplicity, we only add a single unmanaged parameter, but should
        # be easy to generalize if needed
        return NestedModel.wrap(model, group), [unmanaged_param]

    @staticmethod
    def add_unmanaged_param_entry(osd, unmanaged_param, step) -> None:
        """Adds an entry for the unmanaged parameter ``unmanaged_param``
        assuming Adam optimizer and a single parameter group."""
        # The unmanaged parameters should be passed to this method in
        # `model.parameters()` order since their parameter IDs will be assigned
        # in order of the skipped IDs
        # Assign a parameter ID to the unmanaged parameter
        unmanaged_param_id = -1
        param_ids = osd["param_groups"][0]["params"]
        for i in range(1, len(param_ids)):
            diff = param_ids[i] - param_ids[i - 1]
            if diff != 1:
                assert diff > 1, f"Invalid IDs: {param_ids[i - 1]} {param_ids[i]}"
                unmanaged_param_id = param_ids[i - 1] + 1
                break
        if unmanaged_param_id == -1:
            unmanaged_param_id = len(param_ids)  # last ID skipped
        assert unmanaged_param_id >= 0, "One parameter ID should be skipped"
        # Add a state entry for the unmanaged parameter
        state_device = next(iter(next(iter(osd["state"].values())).values())).device
        osd["state"][unmanaged_param_id] = {
            "step": torch.tensor(float(step), device=state_device),
            "exp_avg": torch.randn(unmanaged_param.shape, device=state_device),
            "exp_avg_sq": torch.randn(unmanaged_param.shape, device=state_device),
        }
        # Insert the ID into the parameter group in order
        bisect.insort(osd["param_groups"][0]["params"], unmanaged_param_id)

    # NOTE: We exclude `self.bias` from either parameter group to test the
    # case where the optimizer input does not include all model parameters
    def param_group0(self) -> List[torch.nn.Parameter]:
        # Use `block1`'s parameters for the first parameter group to deviate
        # from the `model.parameters()` order
        return list(self.block1.parameters())

    def param_group1(self) -> List[torch.nn.Parameter]:
        # Deviate from the `model.parameters()` order further by rearranging
        # `block2`'s parameters to be before `block0`'s parameters
        return list(self.block2.parameters()) + list(self.block0.parameters())


# Simple and boring model to test interface and some corner cases that do not
# require complicated wrapping strategy.
class TestDummyModel(torch.nn.Module):
    def __init__(self, no_grad: bool = False):
        super().__init__()
        torch.manual_seed(0)
        self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
        self.net1[0].weight.requires_grad = not no_grad
        self.net1[0].bias.requires_grad = not no_grad
        self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
        self.net3 = nn.Linear(32, 64)
        self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))

    def forward(self, x):
        return self.net4(self.net3(self.net2(self.net1(x))))

    def get_input(self):
        return torch.rand(8, 8, device="cuda")


class TestFSDPOptimState(FSDPTest):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._model_class = {
            _ModelClass.NESTED: self._init_nested_model,
            _ModelClass.TRANSFORMER: self._init_transformer_model,
        }

    def _init_nested_model(
        self,
        wrap: bool,
        wrap_alt: bool = False,  # ignored if `wrap=False`
        device: torch.device = torch.device("cuda"),
        group=None,
        optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
        use_multiple_param_groups: bool = False,
        use_diff_optim_inputs: bool = False,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
    ):
        model = NestedModel().to(device)
        if wrap:
            model = (
                NestedModel.wrap_alt(model, group, fsdp_kwargs)
                if wrap_alt
                else NestedModel.wrap(model, group, fsdp_kwargs=fsdp_kwargs)
            )
        if not use_multiple_param_groups:
            optim_input = list(model.parameters())
        else:
            optim_input = [
                {"params": model.param_group0()},
                {"params": model.param_group1(), "weight_decay": 0.9},
            ]
        # Use a reversed parameter order for the optimizer input on odd ranks
        if use_diff_optim_inputs and self.rank % 2 == 1:
            if isinstance(optim_input[0], dict):
                for param_group in optim_input:
                    param_group["params"] = list(reversed(param_group["params"]))
            else:
                optim_input = list(reversed(optim_input))
        optim = optim_class(optim_input, lr=0.01)
        return model, optim, optim_input

    def _init_transformer_model(
        self,
        wrap: bool,
        device: torch.device = torch.device("cuda"),
        group=None,
        optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
        use_multiple_param_groups: bool = False,
        use_diff_optim_inputs: bool = False,
    ):
        if use_multiple_param_groups or use_diff_optim_inputs:
            # Keep these as arguments for parity with `_init_nested_model()`;
            # these settings are not implemented since the transformer is
            # wrapped with FSDP at the top-level, which means that there is
            # only a single flat parameter, making these booleans vacuous
            raise NotImplementedError
        if group is None:
            group = dist.distributed_c10d._get_default_group()
        model = TransformerWithSharedParams.init(
            group,
            FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP,
            DEVICEInitMode.DEVICE_BEFORE,
            deterministic=True,
        )
        optim = optim_class(model.parameters(), lr=0.01)
        return model, optim, None

    def _step_model(
        self,
        model: torch.nn.Module,
        optim: torch.optim.Optimizer,
        device: torch.device = torch.device("cuda"),
        num_iters: int = 1,
    ) -> List[float]:
        """Performs a forward pass, backward pass, and optimizer step
        ``num_iters``-many times, and returns the per-iteration losses."""
        torch.manual_seed(0)  # set seed for determinism
        losses = []
        module = getattr(model, "module", model)
        for _ in range(num_iters):
            optim.zero_grad()
            inp = module.get_input(device)
            output = model(*inp)
            loss = module.get_loss(inp, output).to(device)
            losses.append(loss.item())
            module.run_backward(loss)
            optim.step()
        return losses

    def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):
        """Broadcasts the full optimizer state dict in place of using
        ``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
        obj_list = [full_osd]
        dist.broadcast_object_list(
            obj_list,
            src=0,
            group=group,
        )
        full_osd = obj_list[0]
        return full_osd

    def _are_equal_states(
        self,
        state1: Dict[str, Any],
        state2: Dict[str, Any],
    ) -> bool:
        """Checks if ``state1`` and ``state2`` contain the same mappings."""
        if set(state1.keys()) != set(state2.keys()):
            return False
        for state_name, value1 in state1.items():
            value2 = state2[state_name]
            if type(value1) != type(value2):
                return False
            if torch.is_tensor(value1):  # tensor state
                assert torch.is_tensor(value2)
                # Check the values on CPU to be device-agnostic
                value1 = value1.cpu()
                value2 = value2.cpu()
                if value1.shape != value2.shape or not torch.all(
                    torch.isclose(value1, value2)
                ):
                    return False
            else:  # non-tensor state
                if value1 != value2:
                    return False
        return True

    def _check_same_state(
        self,
        fsdp_osd,
        ref_osd,
        check_same_param_keys: bool,
    ):
        """Checks that ``full_osd`` and ``ref_osd`` have the same "state" part.
        If ``check_same_param_keys=True``, then checks that the parameter keys
        match (e.g. when both should be parameter names), and does not check
        the parameter keys otherwise."""
        assert "state" in ref_osd
        self.assertTrue("state" in fsdp_osd)
        ref_osd_state = ref_osd["state"]
        fsdp_osd_state = {
            k: _gather_state_dict(v) for k, v in fsdp_osd["state"].items()
        }

        if check_same_param_keys:
            # Check parameter keys are the same first for earlier erroring
            ref_osd_param_ids = set(ref_osd_state.keys())
            fsdp_osd_param_ids = set(fsdp_osd_state.keys())
            self.assertTrue(
                ref_osd_param_ids == fsdp_osd_param_ids,
                f"Rank {self.rank}: {(ref_osd_param_ids, fsdp_osd_param_ids)}",
            )
            # Check state values are the same
            for param_id, param_state in fsdp_osd_state.items():
                for state_name, value in param_state.items():
                    ref_value = ref_osd_state[param_id][state_name]
                    self.assertEqual(value, ref_value)
            return
        # Otherwise, only require the parameter keys to be isomorphic (e.g.
        # between IDs and names)
        ref_osd_states = list(ref_osd_state.values())
        fsdp_osd_states = list(fsdp_osd_state.values())
        self.assertEqual(len(ref_osd_states), len(fsdp_osd_states))
        # Use brute-force quadratic-time comparison since it is hard to
        # hash a tensor by value instead of by object
        for fsdp_osd_state in fsdp_osd_states:
            # Check for at least one match (may be > 1 in toy edge cases, e.g.
            # multiple biases); nonetheless, each having >= 1 match and the two
            # lists having equal length imply that the list contents are equal
            self.assertTrue(
                any(
                    self._are_equal_states(fsdp_osd_state, ref_osd_state)
                    for ref_osd_state in ref_osd_states
                )
            )

    def _check_same_param_groups(
        self,
        full_osd,
        ref_osd,
        check_same_param_keys: bool,
    ):
        """Checks that ``full_osd`` and ``ref_osd`` have the same
        "param_groups" part. If ``check_same_param_keys=True`, then checks that
        the parameter keys match (e.g. when both should be parameter names),
        and does not check the parameter keys otherwise."""
        assert "param_groups" in ref_osd
        self.assertTrue("param_groups" in full_osd)
        ref_osd_param_groups = ref_osd["param_groups"]
        full_osd_param_groups = full_osd["param_groups"]
        self.assertTrue(len(full_osd_param_groups), len(ref_osd_param_groups))
        for full_osd_pg, ref_osd_pg in zip(
            full_osd_param_groups,
            ref_osd_param_groups,
        ):
            self.assertEqual(
                set(full_osd_pg.keys()),
                set(ref_osd_pg.keys()),
            )
            for name, full_osd_value in full_osd_pg.items():
                if name == "params" and not check_same_param_keys:
                    continue
                self.assertEqual(full_osd_value, ref_osd_pg[name])

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", STATE_DICT_TYPES)
    @parametrize("use_multiple_param_groups", [False, True])
    @parametrize("rank0_only", [False, True])
    @parametrize("use_diff_optim_inputs", [False, True])
    def test_optim_state_dict_nested(
        self,
        state_dict_type: StateDictType,
        use_multiple_param_groups: bool,
        rank0_only: bool,
        use_diff_optim_inputs: bool,
    ) -> None:
        """
        Tests :meth:`full_optim_state_dict` and meth:`sharded_optim_state_dict`
        by comparing the returned dict for an FSDP-wrapped model with that of
        an equivalent non-wrapped model.

        The test checks the equivalence excluding the parameter keys since the
        FSDP and normal optimizer state dicts key by names and IDs,
        respectively. This means that the test can pass even if parameter keys
        are incorrectly mapped to values. Their correct mapping is tested in
        other tests that exercise the save/load workflow.
        """
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_optim_state_dict_nested,
            state_dict_type=state_dict_type,
            use_multiple_param_groups=use_multiple_param_groups,
            rank0_only=rank0_only,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )

    def _test_optim_state_dict_nested(
        self,
        state_dict_type: StateDictType,
        use_multiple_param_groups: bool,
        rank0_only: bool,
        use_diff_optim_inputs: bool,
        use_optim_input: bool,
    ) -> None:
        if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT:
            return  # not supported
        NUM_ITERS = 3
        model1, optim1, optim_input = self._init_nested_model(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
        if state_dict_type == StateDictType.FULL_STATE_DICT:
            if use_optim_input:
                fsdp_osd = FSDP.full_optim_state_dict(
                    model1,
                    optim1,
                    optim_input,
                    rank0_only=rank0_only,
                )
            else:
                fsdp_osd = FSDP.full_optim_state_dict(
                    model1,
                    optim1,
                    rank0_only=rank0_only,
                )
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)
        # Non-target ranks get an empty state dict
        if rank0_only and self.rank != 0:
            self.assertEqual(len(fsdp_osd), 0)
            return
        model2, optim2, _ = self._init_nested_model(
            wrap=False,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
        ref_osd = optim2.state_dict()
        # Check the losses to eliminate model drift as a source of error
        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
        # Do not check the parameter keys since the full/sharded optimizer state
        # dict uses parameter names, while the non-wrapped equivalent uses
        # parameter IDs
        check_same_param_keys = False
        self._check_same_param_groups(
            fsdp_osd,
            ref_osd,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            fsdp_osd,
            ref_osd,
            check_same_param_keys=check_same_param_keys,
        )

    @skip_if_lt_x_gpu(2)
    def test_full_optim_state_dict_keys(self):
        """Tests that the parameter keys returned by
        :meth:`full_optim_state_dict` match those of :meth:`state_dict` with
        full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
        instances and ignored modules."""
        device = torch.device("cuda")
        model = NestedModel().to(device)
        wrapped_model = NestedModel.wrap(model, ignore_modules=True)
        # Add checkpointing to ensure optim_state_dict and state_dict strip out
        # checkpointing prefixes.
        apply_activation_checkpointing(
            model, check_fn=lambda module: isinstance(module, torch.nn.Sequential)
        )
        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
        self._step_model(model, optim, device)
        optim_state_dict = FSDP.full_optim_state_dict(
            wrapped_model, optim, rank0_only=False
        )
        with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT):
            state_dict = wrapped_model.state_dict()
        self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())
        # Check that checkpointing prefix was indeed stripped.
        for key in optim_state_dict["state"]:
            self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, key)

    @skip_if_lt_x_gpu(2)
    def test_full_optim_state_dict_nested_invalid(self):
        """Tests that :meth:`full_optim_state_dict` raises an error when
        nonzero ranks are missing the optimizer state for parameters on rank
        0."""
        device = torch.device("cuda")
        model = NestedModel.wrap(NestedModel().to(device), None)
        optim_input = list(model.parameters())
        if self.rank != 0:
            # Exclude a parameter so that nonzero ranks are missing state
            optim_input = optim_input[:-1]
        optim = torch.optim.Adam(optim_input, lr=1e-3)
        self._step_model(model, optim, num_iters=3)
        error_regex = (
            "FSDP currently requires each rank to have at least the "
            "optimizer states needed by rank 0's optimizer but some ranks "
            "are missing some of those states"
        )
        with self.assertRaisesRegex(RuntimeError, error_regex):
            FSDP.full_optim_state_dict(model, optim)

    @skip_if_lt_x_gpu(2)
    @parametrize("use_multiple_param_groups", [False, True])
    @parametrize("wrap_alt", [False, True])
    @parametrize("use_diff_optim_inputs", [False, True])
    def test_shard_full_optim_state_dict_nested(
        self,
        use_multiple_param_groups: bool,
        wrap_alt: bool,
        use_diff_optim_inputs: bool,
    ):
        """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model
        with nested FSDP instances."""
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_load_optim_state,
            model_class=_ModelClass.NESTED,
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=False,
            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(),
            ),
            use_multiple_param_groups=False,
            halve_world_size=False,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_shard_full_optim_state_dict_nested_halve_world_size(self):
        """Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model
        with nested FSDP instances when loading into a new process group with
        halved world size."""
        # To save CI costs, we test with the "harder" settings:
        use_multiple_param_groups = True
        use_diff_optim_inputs = True
        wrap_alt = True
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_load_optim_state,
            model_class=_ModelClass.NESTED,
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=True,
            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(),
            ),
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=True,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_shard_full_optim_state_dict_transformer(self) -> None:
        """Tests :meth:`shard_full_optim_state_dict` for an FSDP-root
        transformer model with shared parameters."""
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_load_optim_state,
            model_class=_ModelClass.TRANSFORMER,
            use_multiple_param_groups=False,
            halve_world_size=True,
            osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
            use_diff_optim_inputs=False,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.TRANSFORMER,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(),
            ),
            use_multiple_param_groups=False,
            halve_world_size=True,
            use_diff_optim_inputs=False,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    @parametrize("use_multiple_param_groups", [False, True])
    @parametrize("wrap_alt", [False, True])
    @parametrize("use_diff_optim_inputs", [False, True])
    def test_scatter_full_optim_state_dict_nested(
        self,
        use_multiple_param_groups: bool,
        wrap_alt: bool,
        use_diff_optim_inputs: bool,
    ):
        """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root
        model with nested FSDP instances."""
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_load_optim_state,
            model_class=_ModelClass.NESTED,
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=False,
            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(rank0_only=True),
            ),
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=False,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_scatter_full_optim_state_dict_nested_halve_world_size(self):
        """Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root
        model with nested FSDP instances when loading into a new process group
        with halved world size."""
        # To save CI costs, we test with the "harder" settings:
        use_multiple_param_groups = True
        use_diff_optim_inputs = True
        wrap_alt = True
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_load_optim_state,
            model_class=_ModelClass.NESTED,
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=True,
            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(rank0_only=True),
            ),
            use_multiple_param_groups=use_multiple_param_groups,
            halve_world_size=True,
            use_diff_optim_inputs=use_diff_optim_inputs,
            wrap_alt=wrap_alt,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_scatter_full_optim_state_dict_transformer(self) -> None:
        """Tests :meth:`scatter_full_optim_state_dict` for an FSDP-root
        transformer model with shared parameters."""
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_load_optim_state,
            model_class=_ModelClass.TRANSFORMER,
            use_multiple_param_groups=False,
            halve_world_size=True,
            osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,
            use_diff_optim_inputs=False,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.TRANSFORMER,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(rank0_only=True),
            ),
            use_multiple_param_groups=False,
            halve_world_size=True,
            use_diff_optim_inputs=False,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_flatten_sharded_optim_state_dict_nested(self) -> None:
        """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root
        nested model."""
        self._test_load_optim_state(
            _ModelClass.NESTED,
            use_multiple_param_groups=False,
            halve_world_size=False,
            osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,
            use_diff_optim_inputs=False,
            use_optim_input=False,
            wrap_alt=True,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.SHARDED_STATE_DICT,
                ShardedStateDictConfig(),
                ShardedOptimStateDictConfig(),
            ),
            use_multiple_param_groups=False,
            halve_world_size=False,
            use_diff_optim_inputs=False,
            wrap_alt=True,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
        """Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root
        transformer model."""
        self._test_load_optim_state(
            _ModelClass.TRANSFORMER,
            use_multiple_param_groups=False,
            halve_world_size=False,
            osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,
            use_diff_optim_inputs=False,
            use_optim_input=False,
            num_iters=3,
        )

        self._test_load_optim_state_with_optim_state_dict(
            _ModelClass.TRANSFORMER,
            state_dict_settings=StateDictSettings(
                StateDictType.SHARDED_STATE_DICT,
                ShardedStateDictConfig(),
                ShardedOptimStateDictConfig(),
            ),
            use_multiple_param_groups=False,
            halve_world_size=False,
            use_diff_optim_inputs=False,
            num_iters=3,
        )

    @skip_if_lt_x_gpu(2)
    def test_use_orig_params(self) -> None:
        """Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
        self.run_subtests(
            {
                "halve_world_size": [True, False],
                "wrap_alt": [True, False],
            },
            self._test_load_optim_state_with_optim_state_dict,
            model_class=_ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(),
            ),
            use_multiple_param_groups=False,
            use_diff_optim_inputs=False,
            num_iters=3,
            fsdp_kwargs={"use_orig_params": True},
        )

        self.run_subtests(
            {
                "halve_world_size": [True, False],
                "wrap_alt": [True, False],
            },
            self._test_load_optim_state_with_optim_state_dict,
            model_class=_ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.FULL_STATE_DICT,
                FullStateDictConfig(),
                FullOptimStateDictConfig(rank0_only=True),
            ),
            use_multiple_param_groups=False,
            use_diff_optim_inputs=False,
            num_iters=3,
            fsdp_kwargs={"use_orig_params": True},
        )

        self.run_subtests(
            {
                "wrap_alt": [True, False],
            },
            self._test_load_optim_state_with_optim_state_dict,
            model_class=_ModelClass.NESTED,
            state_dict_settings=StateDictSettings(
                StateDictType.SHARDED_STATE_DICT,
                ShardedStateDictConfig(),
                ShardedOptimStateDictConfig(),
            ),
            use_multiple_param_groups=False,
            # We cannot test halve_world_size with SHARDED_STATE_DICT.
            halve_world_size=False,
            use_diff_optim_inputs=False,
            num_iters=3,
            fsdp_kwargs={"use_orig_params": True},
        )

    def _test_load_optim_state(
        self,
        model_class: _ModelClass,
        use_multiple_param_groups: bool,
        halve_world_size: bool,
        osd_comm_method: _OSDCommMethod,
        use_diff_optim_inputs: bool,
        use_optim_input: bool,
        num_iters: int,
        **new_model_kwargs,
    ):
        """
        (1) Runs a model with full world size for K iterations to generate a
        full/sharded optimizer state dict;
        (2) initializes a model with halved world size and possibly different
        FSDP wrapping scheme (based on ``new_model_kwargs``);
        (3) loads the full/sharded optimizer state dict from (1) according to the
        halved-world-size model;
        (4) runs the halved-world-size model for K iterations; and
        (5) checks that the sharded optimizer state dict from (3) matches the
        halved-world-size model's local optimizer state dict, meaning that the
        former could have equivalently been loaded into the local optimizer.
        """
        initializer = self._model_class[model_class]
        if osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
            osd_method = FSDP.optim_state_dict
        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
            osd_method = FSDP.sharded_optim_state_dict
        else:
            osd_method = FSDP.full_optim_state_dict

        # First, run a wrapped model with full world size for a few iterations
        model1, optim1, optim_input1 = initializer(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        self._step_model(model1, optim1, num_iters=num_iters)
        fsdp_osd1 = (
            osd_method(model1, optim1, optim_input1)
            if use_optim_input
            else osd_method(model1, optim1)
        )
        if halve_world_size:
            # Create a new process group with halved world size
            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
            new_group = dist.new_group(ranks=new_group_ranks)
            if self.rank not in new_group_ranks:
                return
        else:
            # Continue using the same group and hence world size
            new_group = dist.distributed_c10d._get_default_group()
        # Second, run a wrapped model with (possibly) halved world size and
        # (possibly) differing `optim_input` across ranks
        model2, optim2, optim_input2 = initializer(
            wrap=True,
            group=new_group,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
        )
        self._step_model(model2, optim2, num_iters=num_iters)
        fsdp_osd2 = (
            osd_method(model2, optim2, optim_input2, group=new_group)
            if use_optim_input
            else osd_method(model2, optim2, group=new_group)
        )
        # Compute two sharded optim state dicts: (1) for the first model
        # according to the second model and (2) for the second model according
        # to the second model
        if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST:
            fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group)
            sharded_osd1 = (
                FSDP.shard_full_optim_state_dict(
                    fsdp_osd1, model2, optim_input=optim_input2
                )
                if use_optim_input
                else FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim=optim2)
            )
            fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group)
            sharded_osd2 = (
                FSDP.shard_full_optim_state_dict(
                    fsdp_osd2, model2, optim_input=optim_input2
                )
                if use_optim_input
                else FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim=optim2)
            )
        elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD:
            sharded_osd1 = (
                FSDP.scatter_full_optim_state_dict(
                    fsdp_osd1 if self.rank == 0 else None,
                    model2,
                    optim_input=optim_input2,
                    group=new_group,
                )
                if use_optim_input
                else FSDP.scatter_full_optim_state_dict(
                    fsdp_osd1 if self.rank == 0 else None,
                    model2,
                    optim=optim2,
                    group=new_group,
                )
            )
            sharded_osd2 = (
                FSDP.scatter_full_optim_state_dict(
                    fsdp_osd2 if self.rank == 0 else None,
                    model2,
                    optim_input=optim_input2,
                    group=new_group,
                )
                if use_optim_input
                else FSDP.scatter_full_optim_state_dict(
                    fsdp_osd2 if self.rank == 0 else None,
                    model2,
                    optim=optim2,
                    group=new_group,
                )
            )
        elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
            sharded_osd1 = FSDP.flatten_sharded_optim_state_dict(
                fsdp_osd1,
                model2,
                optim=optim2,
            )
            sharded_osd2 = FSDP.flatten_sharded_optim_state_dict(
                fsdp_osd2,
                model2,
                optim=optim2,
            )
        elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
            sharded_osd1 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd1)
            sharded_osd2 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd2)

        # As a sanity check, check that sharding the second model's full/sharded
        # optimizer state dict according to itself is equivalent to its local
        # optimizer's state dict
        local_osd2 = optim2.state_dict()
        check_same_param_keys = True  # should all have matching parameter IDs
        self._check_same_param_groups(
            sharded_osd2,
            local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            sharded_osd2,
            local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        # Check that sharding the first model's full/sharded optimizer state dict
        # according to the second model is equivalent to the second model's
        # local optimizer state dict
        self._check_same_param_groups(
            sharded_osd1,
            local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            sharded_osd1,
            local_osd2,
            check_same_param_keys=check_same_param_keys,
        )
        # As a sanity check, check that we can load and run a few iterations
        optim2.load_state_dict(sharded_osd2)
        self._step_model(model2, optim2, num_iters=num_iters)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", STATE_DICT_TYPES)
    @parametrize("add_to_fsdp_module", [False, True])
    def test_shard_full_optim_state_dict_unmanaged_params(
        self,
        state_dict_type: StateDictType,
        add_to_fsdp_module: bool,
    ):
        """
        Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
        parameters.
          - If ``add_to_fsdp_module=True``, then the unmanaged parameters are
          added to a module to be wrapped with FSDP, in which case there should
          be an error since we require that all unflattened parameter
          comprising a flat parameter have the same scalar state (e.g. Adam
          "step") but the added parameter is missing its entry.
          - If ``add_to_fsdp_module=False``, then the unmanaged parameters are
          added to a module not to be wrapped with FSDP, in which case there
          should be no error (emulating model parallel use cases where some
          parameters may be managed externally to FSDP).
        We do not separately test unmanaged parameters for
        :meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict`
        to save CI cost since it call into the same subroutine
        :meth:`_flatten_optim_state_dict`.
        """
        if state_dict_type == StateDictType.SHARDED_STATE_DICT:
            use_optim_input = [False]
        else:
            use_optim_input = [False, True]
        self.run_subtests(
            {"use_optim_input": use_optim_input},
            self._test_shard_full_optim_state_dict_unmanaged_params,
            state_dict_type=state_dict_type,
            add_to_fsdp_module=add_to_fsdp_module,
        )

    def _test_shard_full_optim_state_dict_unmanaged_params(
        self,
        state_dict_type: StateDictType,
        add_to_fsdp_module: bool,
        use_optim_input: bool,
    ):
        NUM_ITERS = 1
        # Create a normal wrapped model
        model, optim, optim_input = self._init_nested_model(wrap=True)
        self._step_model(model, optim, num_iters=NUM_ITERS)

        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = (
                FSDP.full_optim_state_dict(model, optim, optim_input, rank0_only=False)
                if use_optim_input
                else FSDP.full_optim_state_dict(model, optim, rank0_only=False)
            )  # save on all ranks to avoid having to broadcast from rank 0
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(model, optim)
        # Create a new model with the same structure but additional unmanaged
        # parameters, representing the model for which we want to load
        device = torch.device("cuda")
        model = NestedModel().to(device)
        model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(
            model,
            add_to_fsdp_module,
        )
        optim_input = list(model.parameters())
        optim = torch.optim.Adam(optim_input, lr=1e-3)
        if add_to_fsdp_module:
            # If we add the unmanaged parameters to a module wrapped with FSDP,
            # then the flat parameter will be comprised of some unflattened
            # parameters with zero-dimensional tensor state (i.e. Adam "step")
            # and others without (i.e. the unmanaged parameters), which
            # triggers an error that we have to ensure correctness
            error_prefix = (
                "^(All unflattened parameters comprising a "
                "single flat parameter must have scalar state with the "
                "same value and dtype)"
            )
            with self.assertRaisesRegex(ValueError, error_prefix):
                if state_dict_type == StateDictType.FULL_STATE_DICT:
                    (
                        FSDP.shard_full_optim_state_dict(
                            fsdp_osd, model, optim_input=optim_input
                        )
                        if use_optim_input
                        else FSDP.shard_full_optim_state_dict(
                            fsdp_osd, model, optim=optim
                        )
                    )
                else:
                    FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim)
        else:
            # If we add the unmanaged parameters to a module not wrapped with
            # FSDP, then we simply ignore them without erroring to enable
            # model parallelism use cases, where some parameters are managed
            # externally to FSDP
            if state_dict_type == StateDictType.FULL_STATE_DICT:
                flattened_osd = (
                    FSDP.shard_full_optim_state_dict(
                        fsdp_osd, model, optim_input=optim_input
                    )
                    if use_optim_input
                    else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim)
                )
            else:
                flattened_osd = FSDP.flatten_sharded_optim_state_dict(
                    fsdp_osd, model, optim=optim
                )
            # Add entries for the unmanaged parameters to be able to load
            for unmanaged_param in unmanaged_params:
                NestedModel.add_unmanaged_param_entry(
                    flattened_osd,
                    unmanaged_param,
                    NUM_ITERS,
                )
            # Check that we can load the optimizer state dict
            optim.load_state_dict(flattened_osd)

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", STATE_DICT_TYPES)
    @parametrize("use_multiple_param_groups", [False, True])
    def test_rekey_optim_state_dict_to_ids(
        self,
        state_dict_type: StateDictType,
        use_multiple_param_groups: bool,
    ):
        """Tests :meth:`rekey_optim_state_dict` with the new keys being
        parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
        can rekey its optimizer state dict to match that of an equivalent
        non-wrapped model (i.e. without FSDP modules)."""
        if state_dict_type == StateDictType.SHARDED_STATE_DICT:
            use_optim_input = [False]
        else:
            use_optim_input = [False, True]
        self.run_subtests(
            {"use_optim_input": use_optim_input},
            self._test_rekey_optim_state_dict_to_ids,
            state_dict_type=state_dict_type,
            use_multiple_param_groups=use_multiple_param_groups,
        )

    @skip_if_lt_x_gpu(2)
    def _test_rekey_optim_state_dict_to_ids(
        self,
        state_dict_type: StateDictType,
        use_multiple_param_groups: bool,
        use_optim_input: bool,
    ):
        NUM_ITERS = 3
        # Run a wrapped model for a few iterations
        model1, optim1, optim_input1 = self._init_nested_model(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        self._step_model(model1, optim1, num_iters=NUM_ITERS)
        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = (
                FSDP.full_optim_state_dict(model1, optim1, optim_input1)
                if use_optim_input
                else FSDP.full_optim_state_dict(model1, optim1)
            )
            # Broadcast instead of `torch.save()`/`torch.load()` so that all ranks
            # have the full state dict
            fsdp_osd = self._broadcast_full_osd(fsdp_osd)
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)
        # Run a non-wrapped model for a few iterations
        model2, optim2, optim_input2 = self._init_nested_model(
            wrap=False,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        self._step_model(model2, optim2, num_iters=NUM_ITERS)
        # Re-key the wrapped model's optimizer state dict using parameter IDs
        # according to the non-wrapped model
        rekeyed_osd = (
            FSDP.rekey_optim_state_dict(
                fsdp_osd,
                OptimStateKeyType.PARAM_ID,
                model2,
                optim_input=optim_input2,
            )
            if use_optim_input
            else FSDP.rekey_optim_state_dict(
                fsdp_osd,
                OptimStateKeyType.PARAM_ID,
                model2,
                optim=optim2,
            )
        )
        # Check that the re-keyed dict and actual dict are the same
        osd = optim2.state_dict()
        check_same_param_keys = True
        self._check_same_param_groups(
            rekeyed_osd,
            osd,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            rekeyed_osd,
            osd,
            check_same_param_keys=check_same_param_keys,
        )
        # As a sanity check, check that we can load and run a few iterations
        if state_dict_type != StateDictType.SHARDED_STATE_DICT:
            optim2.load_state_dict(rekeyed_osd)
            self._step_model(model2, optim2, num_iters=NUM_ITERS)

    @skip_if_lt_x_gpu(2)
    def test_rekey_optim_state_dict_to_names(self):
        """Tests :meth:`rekey_optim_state_dict` with the new keys being
        parameter names by checking that a non-wrapped model (i.e. without FSDP
        modules) can rekey its optimizer state dict to match the expected
        output of :meth:`full_optim_state_dict`, hence be sharded using
        :meth:`shard_full_optim_state_dict`, and finally match the per-rank
        optimizer state dict of a wrapped model (i.e. with FSDP modules)."""
        self.run_subtests(
            {"use_optim_input": [False, True]},
            self._test_rekey_optim_state_dict_to_names,
            use_multiple_param_groups=False,
        )

    def _test_rekey_optim_state_dict_to_names(
        self,
        use_multiple_param_groups: bool,
        use_optim_input: bool,
    ):
        NUM_ITERS = 3
        # Run a wrapped model for a few iterations
        model1, optim1, optim_input1 = self._init_nested_model(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        self._step_model(model1, optim1, num_iters=NUM_ITERS)
        # Run a non-wrapped model for a few iterations
        model2, optim2, optim_input2 = self._init_nested_model(
            wrap=False,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        self._step_model(model2, optim2, num_iters=NUM_ITERS)
        # Re-key the non-wrapped model's optimizer state dict using parameter
        # names (still according to itself)
        osd2 = optim2.state_dict()
        rekeyed_osd = (
            FSDP.rekey_optim_state_dict(
                osd2,
                OptimStateKeyType.PARAM_NAME,
                model2,
                optim_input=optim_input2,
            )
            if use_optim_input
            else FSDP.rekey_optim_state_dict(
                osd2,
                OptimStateKeyType.PARAM_NAME,
                model2,
                optim=optim2,
            )
        )
        # Shard the non-wrapped model's re-keyed optimizer state dict, which
        # maps back to (flattened) parameter IDs
        sharded_osd = (
            FSDP.shard_full_optim_state_dict(
                rekeyed_osd,
                model1,
                optim_input=optim_input1,
            )
            if use_optim_input
            else FSDP.shard_full_optim_state_dict(
                rekeyed_osd,
                model1,
                optim=optim1,
            )
        )
        # Check that this sharded optimizer state dict matches the wrapped
        # model's per-rank optimizer state dict
        osd1 = optim1.state_dict()
        check_same_param_keys = True
        self._check_same_param_groups(
            sharded_osd,
            osd1,
            check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            sharded_osd,
            osd1,
            check_same_param_keys=check_same_param_keys,
        )
        # As a sanity check, check that we can load and run a few iterations
        optim1.load_state_dict(sharded_osd)
        self._step_model(model1, optim1, num_iters=NUM_ITERS)

    @skip_if_lt_x_gpu(2)
    def test_optim_input_warning(self):
        """Tests that passing the ``optim_input`` argument into optimizer state
        checkpointing APIs issues a warning."""

        def should_check_method(method_name: str):
            # Check every method since they all accept `optim_input`
            return method_name not in (
                "sharded_optim_state_dict",
                "flatten_sharded_optim_state_dict",
            )

        def get_warning_context():
            warning_regex = "`optim_input` argument is deprecated"
            return self.assertWarnsRegex(
                expected_warning=FutureWarning, expected_regex=warning_regex
            )

        self._run_on_all_optim_state_apis(
            should_check_method, get_warning_context, fsdp_kwargs=None
        )

    def _run_on_all_optim_state_apis(
        self,
        should_check_method_fn: Callable[[str], bool],
        context_fn: Callable,
        fsdp_kwargs: Optional[Dict[str, Any]],
    ):
        """
        Runs through all optimizer state checkpointing APIs with a context
        manager instantiated by ``context_fn``. Certain APIs can be skipped
        via ``should_check_method_fn``, which gets passed the string name of
        the method.
        """
        wrapped_model, wrapped_optim, wrapped_optim_input = self._init_nested_model(
            wrap=True,
            use_multiple_param_groups=False,
            fsdp_kwargs=fsdp_kwargs,
        )
        self._step_model(wrapped_model, wrapped_optim, num_iters=2)

        # Sharded optim state dict
        if should_check_method_fn("sharded_optim_state_dict"):
            with context_fn():
                fsdp_osd = FSDP.sharded_optim_state_dict(
                    wrapped_model,
                    wrapped_optim,
                )
        if "fsdp_osd" not in locals():
            fsdp_osd = {}  # may not be defined due to previous method erroring
        if should_check_method_fn("flatten_sharded_optim_state_dict"):
            with context_fn():
                FSDP.flatten_sharded_optim_state_dict(
                    fsdp_osd,
                    wrapped_model,
                    wrapped_optim,
                )
        # Full optim state dict
        if should_check_method_fn("full_optim_state_dict"):
            with context_fn():
                fsdp_osd = FSDP.full_optim_state_dict(
                    wrapped_model,
                    wrapped_optim,
                    optim_input=wrapped_optim_input,
                    rank0_only=False,
                )
        if should_check_method_fn("shard_full_optim_state_dict"):
            with context_fn():
                FSDP.shard_full_optim_state_dict(
                    fsdp_osd,
                    wrapped_model,
                    optim_input=wrapped_optim_input,
                )
        if should_check_method_fn("scatter_full_optim_state_dict"):
            with context_fn():
                FSDP.scatter_full_optim_state_dict(
                    fsdp_osd,
                    wrapped_model,
                    optim_input=wrapped_optim_input,
                )
        # Rekey optim state dict
        (
            nonwrapped_model,
            nonwrapped_optim,
            nonwrapped_optim_input,
        ) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)
        if should_check_method_fn("rekey_optim_state_dict"):
            with context_fn():
                rekeyed_osd = FSDP.rekey_optim_state_dict(
                    fsdp_osd,  # from `full_optim_state_dict()`
                    OptimStateKeyType.PARAM_ID,
                    nonwrapped_model,
                    optim_input=nonwrapped_optim_input,
                )
        self._step_model(nonwrapped_model, nonwrapped_optim, num_iters=2)
        osd = nonwrapped_optim.state_dict()
        if should_check_method_fn("rekey_optim_state_dict"):
            with context_fn():
                FSDP.rekey_optim_state_dict(
                    osd,
                    OptimStateKeyType.PARAM_NAME,
                    nonwrapped_model,
                    optim_input=nonwrapped_optim_input,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize("state_dict_type", STATE_DICT_TYPES)
    def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType):
        """
        Tests saving and loading an optim state dict for Adam optimizer (i.e.
        any optimizer with a "step" key in its state) when the first parameter
        does not have optimizer state (e.g. unused or frozen).
        """

        class Model(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin1 = nn.Linear(5, 5)
                self.lin2 = nn.Linear(5, 5)
                self.relu = nn.ReLU()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                # Do not use `lin1`, which is the parameter passed to the
                # optimizer and the one checked for "step" state to see if it
                # is tensor or float
                return self.relu(self.lin2(x))

        model = Model().cuda()
        model.lin1 = FSDP(model.lin1)
        model.lin2 = FSDP(model.lin2)
        fsdp_model = FSDP(model)
        optim = torch.optim.Adam(
            fsdp_model.parameters(), lr=1e-2
        )  # or any optimizer with "step"

        # Run an iteration to construct optimizer state
        device = torch.device("cuda")
        inp = torch.randn((2, 5), device=device)
        loss = fsdp_model(inp).sum()
        loss.backward()
        optim.step()

        # Check that save and load does not error
        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = FSDP.full_optim_state_dict(fsdp_model, optim, rank0_only=False)
            flattened_osd = FSDP.shard_full_optim_state_dict(fsdp_osd, fsdp_model)
        elif state_dict_type == StateDictType.SHARDED_STATE_DICT:
            fsdp_osd = FSDP.sharded_optim_state_dict(fsdp_model, optim)
            flattened_osd = FSDP.flatten_sharded_optim_state_dict(
                fsdp_osd, fsdp_model, optim
            )
        optim.load_state_dict(flattened_osd)
        # `__setstate__()` will check the 0th parameter to see if "step" is
        # represented as a tensor or float, so it is imperative that its state
        # is non-empty.

        # Run an iteration as a sanity check
        inp = torch.randn((2, 5), device=device)
        loss = fsdp_model(inp).sum()
        loss.backward()
        optim.step()

    @skip_if_lt_x_gpu(2)
    def test_compatible_with_trec(self):
        class DenseModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
                self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
                self.net3 = nn.Linear(32, 64)
                self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))

            def forward(self, x):
                return self.net4(self.net3(self.net2(self.net1(x))))

        class FakeMPModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                torch.manual_seed(0)
                self.dense = FSDP(DenseModel().cuda(), use_orig_params=True)
                if dist.get_rank() == 0:
                    self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
                else:
                    self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())

            def forward(self, x):
                if dist.get_rank() == 0:
                    sparse = self.sparse0(x)
                else:
                    sparse = self.sparse1(x)
                dist.all_reduce(sparse)
                return self.dense(sparse)

        models = [FakeMPModel().cuda(), FakeMPModel().cuda()]
        optims = [
            torch.optim.Adam(models[0].parameters(), lr=1e-2),
            _NamedOptimizer(
                models[1].named_parameters(),
                torch.optim.Adam,
                [{"params": models[1].parameters()}],
                models[1],
                lr=1e-2,
            ),
        ]
        state_dicts = []

        # Train one batch and see if optim_state_dict are the same.
        batch = torch.rand(5, 8, device=torch.device("cuda"))
        for model, optim in zip(models, optims):
            # Eagerly initialize the states
            for param in model.parameters():
                if param.requires_grad:
                    t = torch.zeros_like(param)
                    param.grad = torch.autograd.Variable(t)
            optim.step()
            loss = model(batch).sum()
            loss.backward()
            optim.step()
            state_dicts.append(deepcopy(FSDP.optim_state_dict(model, optim)))

        self._check_same_param_groups(
            state_dicts[0], state_dicts[1], check_same_param_keys=False
        )
        self._check_same_state(
            state_dicts[0], state_dicts[1], check_same_param_keys=True
        )

        # Make optim1 has a different state.
        for i in range(5):
            batch = torch.rand(5, 8).cuda()
            loss = models[1](batch).sum()
            loss.backward()
            optims[1].step()

        # Load the state back to see if load_optim_state_dict works.
        state_dict_to_load = FSDP.optim_state_dict_to_load(
            models[1], optims[1], state_dicts[1], is_named_optimizer=True
        )
        optims[1].load_state_dict(state_dict_to_load)
        state_dicts[1] = FSDP.optim_state_dict(models[1], optims[1])

        self._check_same_param_groups(
            state_dicts[0], state_dicts[1], check_same_param_keys=False
        )
        self._check_same_state(
            state_dicts[0], state_dicts[1], check_same_param_keys=True
        )

    @skip_if_lt_x_gpu(2)
    def test_optim_state_without_param_groups(self):
        class SimpleModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                torch.manual_seed(0)
                self.net1 = nn.Sequential(nn.Linear(2, 4), nn.ReLU())

            def forward(self, x):
                return self.net1(x)

        model = FSDP(SimpleModel().cuda())
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)

        # Train one step to save original optimizer state dict and original optimizer param groups.
        batch = torch.rand(3, 2, device=torch.device("cuda"))
        for param in model.parameters():
            if param.requires_grad:
                t = torch.zeros_like(param)
                param.grad = torch.autograd.Variable(t)
        optim.step()
        loss = model(batch).sum()
        loss.backward()

        original_osd = deepcopy(optim.state_dict())
        original_osd_no_param_groups = deepcopy(original_osd)
        # manually remove param_groups from optimizer state dict
        original_param_groups = deepcopy(
            original_osd_no_param_groups.pop("param_groups")
        )
        # passing the osd without param_groups to FSDP
        original_fsdp_optim_state_dict = deepcopy(
            FSDP.optim_state_dict(
                model, optim, optim_state_dict=original_osd_no_param_groups
            )
        )
        # check the state_dict sharded by FSDP does not contain param_groups.
        self.assertEqual(None, original_fsdp_optim_state_dict.get("param_groups"))

        # train another step to make optim a different state.
        for param in model.parameters():
            if param.requires_grad:
                t = torch.zeros_like(param)
                param.grad = torch.autograd.Variable(t)
        optim.step()
        loss = model(batch).sum()
        loss.backward()

        state_dict_to_load = FSDP.optim_state_dict_to_load(
            model, optim, original_fsdp_optim_state_dict
        )
        # manually add param_groups to state_dict_to_load before loading the optimizer state
        state_dict_to_load["param_groups"] = original_param_groups
        optim.load_state_dict(state_dict_to_load)
        self.assertEqual(original_osd, optim.state_dict())

        fsdp_optim_state = FSDP.optim_state_dict(model, optim)
        self._check_same_state(
            original_fsdp_optim_state_dict, fsdp_optim_state, check_same_param_keys=True
        )
        self.assertEqual(original_param_groups, optim.state_dict()["param_groups"])

    @skip_if_lt_x_gpu(2)
    def test_with_empty_optimizer_state(self):
        model = FSDP(TestDummyModel().cuda())
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
        state_dict = optim.state_dict()
        gathered_state_dict = FSDP.optim_state_dict(model, optim)
        self.assertEqual(gathered_state_dict["state"], state_dict["state"])

    def _test_load_optim_state_with_optim_state_dict(
        self,
        model_class: _ModelClass,
        state_dict_settings: StateDictSettings,
        use_multiple_param_groups: bool,
        halve_world_size: bool,
        use_diff_optim_inputs: bool,
        num_iters: int,
        **new_model_kwargs,
    ):
        """
        (1) Runs a model with full world size for K iterations to generate a
        full/sharded optimizer state dict;
        (2) initializes a model with halved world size and possibly different
        FSDP wrapping scheme (based on ``new_model_kwargs``);
        (3) loads the full/sharded optimizer state dict from (1) according to the
        halved-world-size model;
        (4) runs the halved-world-size model for K iterations; and
        (5) checks that the sharded optimizer state dict from (3) matches the
        halved-world-size model's local optimizer state dict, meaning that the
        former could have equivalently been loaded into the local optimizer.
        """
        initializer = self._model_class[model_class]

        # First, run a wrapped model with full world size for a few iterations
        model1, optim1, optim_input1 = initializer(
            wrap=True,
            use_multiple_param_groups=use_multiple_param_groups,
        )
        FSDP.set_state_dict_type(
            model1,
            state_dict_settings.state_dict_type,
            state_dict_settings.state_dict_config,
            state_dict_settings.optim_state_dict_config,
        )
        self._step_model(model1, optim1, num_iters=num_iters)
        fsdp_osd1 = FSDP.optim_state_dict(model1, optim1)
        if halve_world_size:
            # Create a new process group with halved world size
            new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]
            new_group = dist.new_group(ranks=new_group_ranks)
            if self.rank not in new_group_ranks:
                return
        else:
            # Continue using the same group and hence world size
            new_group = dist.distributed_c10d._get_default_group()
        # Second, run a wrapped model with (possibly) halved world size and
        # (possibly) differing `optim_input` across ranks
        model2, optim2, optim_input2 = initializer(
            wrap=True,
            group=new_group,
            use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
            **new_model_kwargs,  # specify `wrap_alt` to change wrapping
        )
        FSDP.set_state_dict_type(
            model2,
            state_dict_settings.state_dict_type,
            state_dict_settings.state_dict_config,
            state_dict_settings.optim_state_dict_config,
        )
        self._step_model(model2, optim2, num_iters=num_iters)
        fsdp_osd2 = FSDP.optim_state_dict(model2, optim2, group=new_group)
        # Compute two sharded optim state dicts: (1) for the first model
        # according to the second model and (2) for the second model according
        # to the second model
        sharded_osd2 = FSDP.optim_state_dict_to_load(
            model2, optim2, fsdp_osd2, group=new_group
        )

        # As a sanity check, check that sharding the second model's full/sharded
        # optimizer state dict according to itself is equivalent to its local
        # optimizer's state dict
        local_osd2 = optim2.state_dict()
        self._check_same_param_groups(
            sharded_osd2,
            local_osd2,
            check_same_param_keys=True,
        )
        self._check_same_state(
            sharded_osd2,
            local_osd2,
            check_same_param_keys=True,
        )
        # Check that sharding the first model's full/sharded optimizer state dict
        # according to the second model is equivalent to the second model's
        # local optimizer state dict
        sharded_osd1 = FSDP.optim_state_dict_to_load(
            model2, optim2, fsdp_osd1, group=new_group
        )
        self._check_same_param_groups(
            sharded_osd1,
            local_osd2,
            check_same_param_keys=True,
        )
        self._check_same_state(
            sharded_osd1,
            local_osd2,
            check_same_param_keys=True,
        )
        # As a sanity check, check that we can load and run a few iterations
        optim2.load_state_dict(sharded_osd2)
        self._step_model(model2, optim2, num_iters=num_iters)

    @skip_if_lt_x_gpu(2)
    def test_interface_arguments(self):
        model = FSDP(TestDummyModel().cuda())
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)

        def step():
            loss = model(model.get_input())
            loss.backward(loss)
            optim.step()

        step()
        original_osd = deepcopy(optim.state_dict())
        osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
        self._check_same_state(
            FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
        )
        step()
        osd_to_load = FSDP.optim_state_dict_to_load(
            model, optim, osd, load_directly=True
        )
        self._check_same_state(
            optim.state_dict(), original_osd, check_same_param_keys=True
        )

        # Test the default setting.
        osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
        for state in osd["state"].values():
            for s in state.values():
                self.assertFalse(isinstance(s, ShardedTensor))
                self.assertFalse(s.is_cuda)

        # Test sharded state_dict without offload_to_cpu
        with FSDP.state_dict_type(
            model,
            StateDictType.SHARDED_STATE_DICT,
            ShardedStateDictConfig(),
            ShardedOptimStateDictConfig(offload_to_cpu=False),
        ):
            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
            for state in osd["state"].values():
                for s in state.values():
                    if s.dim() == 0:
                        continue
                    self.assertTrue(isinstance(s, ShardedTensor))
                    if s._local_shards[0]:
                        self.assertTrue(s._local_shards[0].tensor.is_cuda)

        # Test full state_dict with rank0_only
        with FSDP.state_dict_type(
            model,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(),
            FullOptimStateDictConfig(
                offload_to_cpu=True,
                rank0_only=True,
            ),
        ):
            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
            if dist.get_rank() > 0:
                self.assertEqual(osd, {})
            else:
                for state in osd["state"].values():
                    for s in state.values():
                        if s.dim() == 0:
                            continue
                        self.assertFalse(s.is_cuda)
                        self.assertFalse(isinstance(s, ShardedTensor))

    @skip_if_lt_x_gpu(2)
    def test_state_dict_with_none_tensor_state(self):
        def _run_test(use_orig_params, optimizer_has_tensor_state):
            model = FSDP(TestDummyModel().cuda(), use_orig_params=use_orig_params)
            optimizer_cls = (
                torch.optim.Adam if optimizer_has_tensor_state else torch.optim.SGD
            )
            optim = optimizer_cls(model.parameters(), lr=1e-2)

            def step():
                loss = model(model.get_input())
                loss.backward(loss)
                optim.step()

            step()
            original_osd = deepcopy(optim.state_dict())
            for state in original_osd["state"].values():
                # Add customized value
                state["value1"] = 2.74
                state["value2"] = None

            osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
            osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
            for state in osd_to_load["state"].values():
                self.assertEqual(state["value1"], 2.74)
                self.assertEqual(state["value2"], None)

        self.run_subtests(
            {
                "use_orig_params": [False, True],
                "optimizer_has_tensor_state": [False, True],
            },
            _run_test,
        )

    @skip_if_lt_x_gpu(2)
    def test_with_no_shard(self):
        def _run_test(use_orig_params: bool) -> None:
            model = FSDP(
                TestDummyModel().cuda(),
                sharding_strategy=ShardingStrategy.NO_SHARD,
                use_orig_params=use_orig_params,
            )
            optim = torch.optim.Adam(model.parameters(), lr=1e-2)

            def step():
                loss = model(model.get_input())
                loss.backward(loss)
                optim.step()

            step()

            original_osd = deepcopy(optim.state_dict())

            osd = FSDP.optim_state_dict(model, optim)
            osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)
            optim.load_state_dict(osd_to_load)

            new_osd = optim.state_dict()

            self.assertEqual(original_osd, new_osd)

        self.run_subtests({"use_orig_params": [False, True]}, _run_test)

    @skip_if_lt_x_gpu(2)
    def test_no_grad(self):
        model = TestDummyModel(no_grad=True).cuda()
        fsdp_model = FSDP(deepcopy(model), use_orig_params=True)
        fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)

        for i in range(5):
            if i % 2 == 1:
                fsdp_model.net1[0].weight.requires_grad = True
                fsdp_model.net1[0].bias.requires_grad = True
            else:
                fsdp_model.net1[0].weight.requires_grad = False
                fsdp_model.net1[0].bias.requires_grad = False
            batch = fsdp_model.get_input()
            loss = fsdp_model(batch).sum()
            loss.backward()
            fsdp_optim.step()
            orig_state_dict = deepcopy(fsdp_optim.state_dict())
            optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)
            FSDP.optim_state_dict_to_load(
                fsdp_model,
                fsdp_optim,
                FSDP.optim_state_dict(fsdp_model, fsdp_optim),
                load_directly=True,
            )

            self._check_same_state(
                fsdp_optim.state_dict(),
                orig_state_dict,
                check_same_param_keys=True,
            )


instantiate_parametrized_tests(TestFSDPOptimState)

if __name__ == "__main__":
    run_tests()