File: test_zero_redundancy_optimizer.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (1381 lines) | stat: -rw-r--r-- 53,393 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
# Owner(s): ["oncall: distributed"]

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os
import sys
import unittest
from contextlib import suppress
from typing import Any, List, cast

import numpy as np

import torch
import torch.distributed as dist

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)
from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
    hook_with_zero_step,
    hook_with_zero_step_interleaved,
)
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
    allreduce_hook,
)
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, AdamW
from torch.testing._internal import common_distributed
from torch.testing._internal.common_utils import (
    IS_WINDOWS,
    TEST_WITH_ASAN,
    TEST_WITH_DEV_DBG_ASAN,
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)

try:
    import torchvision
    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False

# Use GLOO on GPU when running CUDA + Windows
def _get_backend_for_tests():
    return (
        dist.Backend.NCCL if not IS_WINDOWS and torch.cuda.is_available()
        # Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when
        # no GPUs are available.
        else dist.Backend.GLOO
    )

BACKEND = _get_backend_for_tests()

@unittest.skipIf(
    TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work."
)
class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase):
    def setUp(self):
        super(TestZeroRedundancyOptimizer, self).setUp()
        os.environ["WORLD_SIZE"] = str(self.world_size)
        self._spawn_processes()

    @property
    def device(self):
        return torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu")

    @property
    def world_size(self):
        return 1

    def tearDown(self):
        try:
            torch.distributed.destroy_process_group()
        except AssertionError:
            pass
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    def dist_init(self, rank, world_size=-1, backend=BACKEND):
        if (world_size < 1):
            world_size = self.world_size
        store = dist.FileStore(self.file_name, world_size)
        return dist.init_process_group(
            backend=backend, store=store, rank=rank, world_size=world_size,
        )


# TODO: sandcastle_skip_if does not work here.
@unittest.skipIf(
    TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work."
)
class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
    def test_state_dict(self):
        """Check that ZeroRedundancyOptimizer exposes the expected state dict
        interface, irrespective of the sharding."""
        self.dist_init(self.rank)
        LR1 = 0.1
        LR2 = 0.01
        MOMENTUM = 0.9
        RECIPIENT_RANK = 0  # rank 0 is the only rank since the world size is 1
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
        o = ZeroRedundancyOptimizer(
            [x], optimizer_class=SGD, lr=LR1, momentum=MOMENTUM,
        )
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=self.device))
        self.assertEqual(
            o.optim.state[x]["momentum_buffer"],
            torch.tensor([1.0], device=self.device),
        )

        o.zero_grad()
        o.consolidate_state_dict(to=RECIPIENT_RANK)
        state_dict = o.state_dict()

        # Check that the state dict has keys compliant with PyTorch
        self.assertIn("param_groups", state_dict.keys())
        self.assertIn("state", state_dict.keys())

        # Check that the state has the expected keys
        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)

        # Check that the state and the `param_groups` attribute are in sync
        for k in state_dict["param_groups"][0]:
            if k != "params":
                self.assertEqual(
                    state_dict["param_groups"][0][k],
                    o.param_groups[0][k],
                )

        # Check that the state is reloaded with the correct values and device
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR2)
        o.load_state_dict(state_dict)
        self.assertEqual(
            o.optim.state[x]["momentum_buffer"],
            torch.tensor([1.0], device=self.device),
        )

        # We should we using `LR1` and not `LR2` after reloading, both within
        # the optimizer and as exposed by the `param_groups` attribute
        self.assertEqual(o.param_groups[0]["lr"], LR1)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.71], device=self.device))
        self.assertEqual(
            o.optim.state[x]["momentum_buffer"],
            torch.tensor([1.9], device=self.device),
        )

        # Check that the exposed `param_groups`` are on the proper device
        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)

    def test_lr_scheduler(self):
        """Check that a normal PyTorch ``lr_scheduler`` is usable with
        ZeroRedundancyOptimizer."""
        self.dist_init(self.rank)
        NUM_ITERS = 5
        LR = 0.01
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
        x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=LR)
        o2 = torch.optim.SGD([x2], lr=LR)
        s = torch.optim.lr_scheduler.StepLR(o, 1)
        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
        for _ in range(NUM_ITERS):
            x.backward()
            o.zero_grad()
            o.step()
            s.step()
            x2.backward()
            o2.zero_grad()
            o2.step()
            s2.step()
            self.assertEqual(x, x2)

    def test_step_with_kwargs(self):
        """Check that the ``step(**kwargs)`` interface is properly exposed."""
        self.dist_init(self.rank)
        LR = 0.1

        class SGDWithStepKWArg(torch.optim.SGD):
            def step(self, closure=None, kwarg=None):
                super().step()
                kwarg.append(5)

        kwarg: List[Any] = []
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
        o = ZeroRedundancyOptimizer(
            [x], optimizer_class=SGDWithStepKWArg, lr=LR,
        )
        x.backward()
        o.step(0, kwarg=kwarg)
        self.assertEqual(kwarg, [5])
        self.assertEqual(x, torch.tensor([0.9], device=self.device))

    def test_step_with_extra_inner_key(self):
        """Check that ZeroRedundancyOptimizer wrapping an optimizer that adds
        extra keys to ``param_groups`` exposes those keys through ZeRO's own
        ``param_groups``."""
        self.dist_init(self.rank)
        LR = 0.1

        class SGDWithNewKey(torch.optim.SGD):
            # Dummy optimizer which adds a new key to the param groups
            def step(self, closure=None):
                super().step()
                self.param_groups[0]["new_key"] = 0.1

        x = torch.tensor([1.0], device=self.device, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithNewKey, lr=LR)
        x.backward()
        o.step()
        self.assertEqual(o.param_groups[0]["new_key"], 0.1)
        self.assertEqual(x, torch.tensor([0.9], device=self.device))

    def test_step_without_closure(self):
        """Check that the ``step()`` method (without closure) is handled as
        expected."""
        self.dist_init(self.rank)
        LR = 0.1

        class SGDWithoutClosure(torch.optim.SGD):
            def step(self):
                return super().step()

        x = torch.tensor([1.0], device=self.device, requires_grad=True)
        o = ZeroRedundancyOptimizer(
            [x], optimizer_class=SGDWithoutClosure, lr=LR,
        )
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=self.device))

    def test_zero_grad(self):
        """Check that the ``zero_grad`` method is properly handled."""
        self.dist_init(self.rank)
        LR = 0.01
        x = torch.rand(1)
        m = torch.nn.Linear(1, 1)
        o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=LR)
        y = m(x)
        y.backward(x)
        self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
        self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
        o.zero_grad()
        self.assertFalse(m.weight.grad)
        self.assertFalse(m.bias.grad)

    def test_constructor(self):
        """Check the robustness of the ZeroRedundancyOptimizer constructor by
        passing different values for the ``params`` argument."""
        self.dist_init(self.rank)
        LR = 0.01
        m = torch.nn.Sequential(
            torch.nn.Linear(5, 10),
            torch.nn.Linear(10, 10),
            torch.nn.Linear(10, 10),
        )
        # Test various constructor inputs in the form: (input, expected error)
        ctor_inputs = [
            ([], ValueError),                          # empty parameter list
            (torch.randn(1), TypeError),               # non-iterable: `torch.Tensor`
            (1.2, TypeError),                          # non-iterable: `float`
            ([
                {"params": [l.weight for l in m]},
                {"params": [l.bias for l in m]},
            ], None),                                  # iterable of dict
            (list(m.parameters()) + [42], TypeError),  # iterable containing invalid type
            (m.parameters(), None),                    # `params` as a generator
            (list(m.parameters()), None)               # `params` as a list
        ]
        for ctor_input, error in ctor_inputs:
            context = self.assertRaises(error) if error else suppress()
            with context:
                ZeroRedundancyOptimizer(
                    ctor_input, optimizer_class=SGD, lr=LR,
                )

        # Test constructing with multiple parameter groups more thoroughly
        WD = 0.01
        BETAS = (0.9, 0.999)
        EPS = 1e-8
        params = [
            {"params": [l.weight for l in m], "weight_decay": 0.},
            {"params": [l.bias for l in m], "weight_decay": WD},
        ]
        o = ZeroRedundancyOptimizer(
            params, optimizer_class=AdamW,
            lr=LR, betas=BETAS, eps=EPS,
        )
        assert len(o.param_groups) == 2, \
            f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
        assert len(o.optim.param_groups) == 2, \
            "Expected 2 local optimizer param groups, but got " \
            f"{len(o.optim.param_groups)}"

    def test_same_dense_param_type(self):
        """Check that ZeroRedundancyOptimizer raises an exception if the input
        parameters include sparse tensors or different dense types.

        NOTE: This test should be removed once support for sparse parameters
        and varying parameter types is added.
        """
        self.dist_init(self.rank)
        LR = 0.01
        inputs = [
            [torch.sparse_coo_tensor(size=(2, 3))],
            [torch.FloatTensor(1), torch.DoubleTensor(1)],
            [torch.FloatTensor(1), torch.FloatTensor(1),
                torch.sparse_coo_tensor(size=(2, 3))]
        ]
        for input in inputs:
            with self.assertRaises(ValueError):
                ZeroRedundancyOptimizer(input, optimizer_class=SGD, lr=LR)


class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
    @property
    def device(self):
        return torch.device(self.rank) if torch.cuda.is_available() \
            else torch.device("cpu")

    @property
    def world_size(self):
        return min(4, max(2, torch.cuda.device_count()))

    @property
    def context(self):
        return suppress() if not torch.cuda.is_available() \
            else torch.cuda.device(self.rank)

    def _check_same_model_params(
        self,
        model_a: torch.nn.Module,
        model_b: torch.nn.Module,
        message: str = "",
    ) -> None:
        # Check that model parameters match
        for p_a, p_b in zip(model_a.parameters(), model_b.parameters()):
            torch.testing.assert_close(
                p_a, p_b, atol=1e-3, rtol=1e-5,
                msg=f"Model parameters differ:\n{p_a} {p_b}\n" + message,
            )
        # Check that model buffers match
        for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
            torch.testing.assert_close(
                b_a, b_b,
                msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message,
            )

    @common_distributed.skip_if_no_gpu
    @common_distributed.skip_if_rocm
    def test_step(self):
        """Check that ZeroRedundancyOptimizer properly exposes the ``step()``
        interface."""
        self.dist_init(self.rank, world_size=self.world_size)
        LR = 0.01

        with self.context:
            x = torch.tensor([float(self.rank + 1)], device=self.device)
            m = torch.nn.Linear(1, 1)
            m.weight.data = torch.tensor([[1.0]])
            m.bias.data = torch.tensor([2.0])
            m = m.to(self.device)
            m_zero = copy.deepcopy(m).to(self.device)

            o = SGD(m.parameters(), lr=LR)
            o_zero = ZeroRedundancyOptimizer(
                m_zero.parameters(), optimizer_class=SGD, lr=LR,
            )

            y = m(x)
            y.backward(x)
            y_zero = m_zero(x)
            y_zero.backward(x)

            for p in m.parameters():
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                p.grad.data /= self.world_size
            o.step()
            for p in m_zero.parameters():
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                p.grad.data /= self.world_size
            o_zero.step()

            self.assertEqual(m.weight, m_zero.weight)
            self.assertEqual(m.bias, m_zero.bias)

    @common_distributed.skip_if_no_gpu
    @common_distributed.skip_if_rocm
    def test_step_with_closure(self):
        """Check that ZeroRedundancyOptimizer properly exposes the
        ``step(closure)`` interface."""
        self.dist_init(self.rank, world_size=self.world_size)

        with self.context:
            for bucket_view in [False, True]:
                x_val = self.rank + 1
                weight = 1.0
                bias = 2.0
                error = 1.0
                target = torch.tensor(
                    [x_val * weight + bias + error],
                    device=self.device,
                )
                loss_fn = torch.nn.L1Loss()

                x = torch.tensor([float(x_val)], device=self.device)
                m = torch.nn.Linear(1, 1)
                m.weight.data = torch.tensor([[weight]])
                m.bias.data = torch.tensor([bias])
                m.to(self.device)

                o = ZeroRedundancyOptimizer(
                    m.parameters(),
                    optimizer_class=SGD,
                    parameters_as_bucket_view=bucket_view,
                    lr=0.1,
                )

                y = m(x)
                y.backward(x)
                for p in m.parameters():
                    dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                    p.grad.data /= self.world_size

                def closure():
                    o.zero_grad()
                    output = m(x)
                    loss = loss_fn(output, target)
                    loss.backward()
                    return loss

                loss = o.step(closure=closure)

                self.assertEqual(loss, torch.tensor(error))
                self.assertEqual(m.weight, torch.tensor([[1.1]]))
                self.assertEqual(m.bias, torch.tensor([2.1]))

    @common_distributed.skip_if_no_gpu
    def test_lr_scheduler(self):
        """Check that a normal PyTorch ``lr_scheduler`` is usable with
        ZeroRedundancyOptimizer."""
        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=self.device, requires_grad=True)
        x2 = torch.tensor([1.0], device=self.device, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o2 = torch.optim.SGD([x2], lr=0.01)
        s = torch.optim.lr_scheduler.StepLR(o, 1)
        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
        for _ in range(5):
            x.backward()
            o.zero_grad()
            o.step()
            s.step()
            x2.backward()
            o2.zero_grad()
            o2.step()
            s2.step()
            self.assertEqual(x, x2)

    def test_sharding(self):
        """
        Check ZeroRedundancyOptimizer's parameter sharding at construction
        time.

        NOTE: The correctness of this test depends on the ZeRO implementation
        using the sorted-greedy partitioning algorithm. For details, see
        ``ZeroRedundancyOptimizer._partition_parameters()`` in
        zero_redundancy_optimizer.py.
        """
        self.dist_init(self.rank)
        LR = 0.01
        sizes = [9, 7, 5, 3]
        params = []
        for size in sizes * self.world_size:
            params.append(torch.rand(size, 1))
        o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
        self.assertEqual(
            sum([x.numel() for x in o.optim.param_groups[0]["params"]]),
            sum(sizes),
        )

    def test_add_param_group(self):
        """Check that ZeroRedundancyOptimizer properly handles adding a new
        parameter group a posteriori and that all ranks get a shard of the
        contained parameters.

        NOTE: The correctness of this test depends on the ZeRO implementation
        using the sorted-greedy partitioning algorithm. For details, see
        ``ZeroRedundancyOptimizer._partition_parameters()`` in
        zero_redundancy_optimizer.py.
        """
        self.dist_init(self.rank)
        LR = 0.01

        # Test with all parameters trainable to begin with
        def all_trainable():
            params = []
            sizes = [9, 7, 5, 3]
            sizes_world = sizes * self.world_size
            for size in sizes_world[:-1]:
                params.append(torch.rand(size, 1))

            # Make sure that the params are trainable so that they are factored
            # into the size-based parameter partitioning
            for p in params:
                p.requires_grad = True

            o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
            self.assertEqual(len(o.param_groups), 1)
            o.add_param_group({"params": [torch.rand(3, 1)]})
            # Verify that new group is added to the correct partition, making
            # all partitions have the same elements
            self.assertEqual(len(o.param_groups), 2)
            self.assertEqual(
                sum([
                    x.numel()
                    for g in o.optim.param_groups
                    for x in g["params"]
                ]),
                sum(sizes),
            )
            self.assertEqual(len(o.optim.param_groups), 2)

        # Test a pathological config with a first big non-trainable param
        def some_trainable():
            params = []
            for size in [100, 3, 5, 2, 6, 4]:
                params.append(torch.rand(size, 1))

            # Make sure that all but the first param are trainable so that they
            # are factored into the size-based parameter partitioning
            for p in params[1:]:
                p.requires_grad = True

            o = ZeroRedundancyOptimizer(params, optimizer_class=SGD, lr=LR)
            self.assertEqual(len(o.param_groups), 1)
            o.add_param_group({"params": [torch.rand(3, 1)]})
            self.assertEqual(len(o.param_groups), 2)
            self.assertEqual(len(o.optim.param_groups), 2)

        all_trainable()
        some_trainable()

    @common_distributed.skip_if_no_gpu
    def test_multiple_param_groups(self):
        """
        Check parity between constructing ZeRO with multiple parameter groups
        upfront versus adding parameter groups to ZeRO after construction
        versus a non-sharded optimizer.
        """
        self.dist_init(self.rank)
        BATCH_SIZE, NUM_ITERS = 8, 3
        INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 10, 5
        WD, LR = 0.01, 0.01
        model1 = torch.nn.Sequential(
            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
            torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
        )
        model2 = copy.deepcopy(model1)
        model3 = copy.deepcopy(model1)
        model1 = model1.to(self.device)
        model2 = model2.to(self.device)
        model3 = model3.to(self.device)
        inputs = [
            torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device)
            for _ in range(NUM_ITERS)
        ]
        # Construct `optim1` with both parameter groups upfront
        optim1 = ZeroRedundancyOptimizer(
            [
                {"params": [l.weight for l in model1], "weight_decay": 0.},
                {"params": [l.bias for l in model1], "weight_decay": WD},
            ],
            optimizer_class=AdamW, lr=LR,
        )
        # Construct `optim2` by adding the second parameter after
        optim2 = ZeroRedundancyOptimizer(
            [l.weight for l in model2],
            optimizer_class=AdamW, lr=LR, weight_decay=0.,
        )
        optim2.add_param_group(
            {"params": [l.bias for l in model2], "weight_decay": WD}
        )
        # Construct `optim3` as a non-sharded optimizer
        optim3 = AdamW(
            [
                {"params": [l.weight for l in model3], "weight_decay": 0.},
                {"params": [l.bias for l in model3], "weight_decay": WD},
            ], lr=LR,
        )
        # Check parity over a few iterations
        for input in inputs:
            for model, optim in (
                (model1, optim1), (model2, optim2), (model3, optim3),
            ):
                optim.zero_grad()
                out = model(input)
                loss = out.sum()
                loss.backward()
                optim.step()
            for layer1, layer2, layer3 in zip(model1, model2, model3):
                torch.testing.assert_close(layer1.weight, layer2.weight)
                torch.testing.assert_close(layer1.weight, layer3.weight)
                torch.testing.assert_close(layer1.bias, layer2.bias)
                torch.testing.assert_close(layer1.bias, layer3.bias)

    @common_distributed.skip_if_no_gpu
    @common_distributed.skip_if_rocm
    def test_collect_shards(self):
        """Check the state consolidation mechanism and the state dict exposed
        by ZeroRedundancyOptimizer."""
        self.dist_init(self.rank)
        LR = 1e-3
        MOMENTUM = 0.99
        BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 3, 20, 10, 5
        REFERENCE_RANK = 0
        target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=self.device)
        inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=self.device)
        model = torch.nn.Sequential(
            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
        ).to(self.device)
        loss_fn = torch.nn.L1Loss()
        loss_fn.to(self.device)
        optimizer = ZeroRedundancyOptimizer(
            model.parameters(),
            optimizer_class=SGD,
            lr=LR,
            momentum=MOMENTUM,  # ensure there exists state to shard
        )

        def closure():
            optimizer.zero_grad()
            output = model(inputs)
            loss = loss_fn(output, target)
            loss.backward()
            return loss

        # Run a dummy step so that the optimizer state dict exists
        _ = optimizer.step(closure=closure)

        # Get the optimizer state on the reference rank
        optimizer.consolidate_state_dict(to=REFERENCE_RANK)
        if self.rank == REFERENCE_RANK:
            # Check that the state has the correct size
            optimizer_state_dict = optimizer.state_dict()
            self.assertEqual(
                len(optimizer_state_dict["state"]),
                len(list(model.parameters())),
            )
        else:
            optimizer_state_dict = {}

        # Load the optimizer state on all ranks without any exceptions
        optimizer_state_dict = _broadcast_object(
            optimizer_state_dict,
            src_rank=REFERENCE_RANK,
            group=dist.group.WORLD,
            device=self.device,
        )
        optimizer.load_state_dict(optimizer_state_dict)

    def test_nondefault_process_group(self):
        """Check that ZeroRedundancyOptimizer works with a non-default process
        group consisting only of even ranks."""
        # Skip the test if below the minimum world size since then the test is
        # trivial
        MIN_WORLD_SIZE = 4
        if self.world_size < MIN_WORLD_SIZE:
            common_distributed.logger.info(
                "Skipping `test_nondefault_process_group()` since world size "
                f"of {self.world_size} is less than {MIN_WORLD_SIZE}"
            )
            return
        BACKEND = dist.Backend.GLOO
        self.dist_init(self.rank, self.world_size, BACKEND)
        # Use GPU if enough are available, or fall back to CPU otherwise, which
        # is fine since Gloo backend supports both
        if torch.cuda.is_available() and \
                torch.cuda.device_count() >= self.world_size:
            device = torch.device(self.rank)
        else:
            device = torch.device("cpu")
        # Create a new process group consisting of the even ranks to exercise
        # the case where the global and local ranks do not necessarily match
        subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0]
        process_group = dist.new_group(
            ranks=subgroup_ranks, backend=BACKEND,
        )
        # Ranks not participating in the new process group are no longer needed
        if self.rank not in subgroup_ranks:
            return

        # Set different seeds across ranks so that each rank gets different
        # training data and hence the model sync check is meaningful
        torch.manual_seed(self.rank)
        np.random.seed(self.rank)

        EPOCHS, BATCH_SIZE, INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM = 5, 3, 20, 10, 5
        LR = 1e-3
        MOMENTUM = 0.99
        REFERENCE_RANK = 0
        assert REFERENCE_RANK in subgroup_ranks, \
            "Reference rank must be in the new process group"
        loss_fn = torch.nn.L1Loss().to(device)

        def check(optimizer):
            for _ in range(EPOCHS):
                target = torch.rand((BATCH_SIZE, OUTPUT_DIM), device=device)
                inputs = torch.rand((BATCH_SIZE, INPUT_DIM), device=device)

                def closure():
                    optimizer.zero_grad()
                    output = model(inputs)
                    loss = loss_fn(output, target)
                    loss /= self.world_size
                    loss.backward()
                    dist.all_reduce(loss, group=process_group)
                    return loss

                _ = optimizer.step(closure=closure)

                # Check that the parameters match across ranks after a step
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        receptacle = [
                            p.clone() for _ in subgroup_ranks
                        ] if self.rank == REFERENCE_RANK else []
                        dist.gather(
                            p, receptacle, dst=REFERENCE_RANK,
                            group=process_group,
                        )
                        if self.rank == REFERENCE_RANK:
                            reference_param = receptacle[0]
                            for param in receptacle[1:]:
                                torch.testing.assert_close(
                                    reference_param,
                                    param,
                                    msg="Models differ between ranks",
                                )

        model = torch.nn.Sequential(
            torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
            torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
        ).to(device)
        optimizer = ZeroRedundancyOptimizer(
            model.parameters(),
            optimizer_class=SGD,
            lr=LR,
            momentum=MOMENTUM,  # ensure there exists state to shard
            process_group=process_group,
        )
        check(optimizer)

    @common_distributed.skip_if_no_gpu
    @parametrize(
        "optimizer_class_str",
        ["Adam", "AdamW", "SGD"],
        # Use string to appease the internal test name parser
    )
    @parametrize(
        "maximize",
        [False, True],
    )
    def test_local_optimizer_parity(
        self,
        optimizer_class_str: str,
        maximize: bool,
    ):
        """When combined with DDP, check that a local optimizer gives the same
        results as wrapping that optimizer with ZeroRedundancyOptimizer."""
        self.dist_init(self.rank)
        BATCHES = 20
        BATCH_SIZE = 64
        LR = 1e-3
        INPUT_DIM = 2
        HIDDEN_DIM = 3
        OUTPUT_DIM = 3
        torch.manual_seed(self.rank)
        np.random.seed(self.rank)
        if optimizer_class_str == "Adam":
            optimizer_class = torch.optim.Adam
        elif optimizer_class_str == "AdamW":
            optimizer_class = torch.optim.AdamW
        elif optimizer_class_str == "SGD":
            optimizer_class = torch.optim.SGD
        else:
            assert 0, f"Unsupported optimizer class: {optimizer_class_str}"

        with self.context:
            # Define a base model with a different buffer for each rank
            model = torch.nn.Sequential(
                torch.nn.Linear(INPUT_DIM, HIDDEN_DIM),
                torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
                torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
            ).to(self.device)
            model.register_buffer(
                "test_buffer", torch.ones((1), device=self.device) * self.rank,
            )
            # Define models/optimizers for DDP with ZeRO and DDP with local
            # optimizer
            defaults = {"maximize": True} if maximize else {}
            sharded_optimizer = ZeroRedundancyOptimizer(
                params=model.parameters(), optimizer_class=optimizer_class,
                lr=LR, **defaults,
            )
            sharded_ddp_model = DDP(
                module=model, device_ids=[self.rank],
                broadcast_buffers=True, find_unused_parameters=True,
            )
            local_model = copy.deepcopy(model).to(self.device)
            ddp_optimizer = optimizer_class(
                local_model.parameters(), lr=LR, **defaults,
            )
            ddp_model = DDP(
                local_model, device_ids=[self.rank],
                broadcast_buffers=True, find_unused_parameters=True,
            )
            # Check that the model is properly synchronized between ranks
            # at construction time
            self._check_same_model_params(
                sharded_ddp_model, ddp_model,
                "Models differ from the start",
            )

            def check_step():
                input_tensor = torch.rand((BATCH_SIZE, INPUT_DIM))

                def closure_ddp(input_tensor=input_tensor):
                    ddp_optimizer.zero_grad()
                    ddp_loss = ddp_model(input_tensor).abs().sum()
                    ddp_loss.backward()
                    return ddp_loss

                def closure_sharded(input_tensor=input_tensor):
                    sharded_optimizer.zero_grad()
                    sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                    sharded_loss.backward()
                    return sharded_loss

                loss_ddp = cast(
                    torch.Tensor, ddp_optimizer.step(closure=closure_ddp),
                )
                loss_sharded_optim = cast(
                    torch.Tensor,
                    sharded_optimizer.step(closure=closure_sharded),
                )
                torch.testing.assert_close(
                    loss_ddp, loss_sharded_optim,
                    msg="Losses differ between local optimizer and ZeRO",
                )
                self._check_same_model_params(
                    sharded_ddp_model, ddp_model,
                    "Models differ after a step",
                )

            # Check that parity is maintained
            for i in range(BATCHES):
                check_step()
                # For the second half of batches, change the parameter
                # trainability to further test parity
                if i > BATCHES // 2:
                    next(ddp_model.parameters()).requires_grad = bool(i % 2)
                    next(sharded_ddp_model.parameters()).requires_grad = bool(i % 2)

            # Check that the `state_dict` checkpoints are compatible between
            # the local optimizer and ZeRO
            REFERENCE_RANK = 0
            # - Get states
            ddp_state_dict = ddp_optimizer.state_dict()
            sharded_optimizer.consolidate_state_dict(to=REFERENCE_RANK)
            sharded_optim_state_dict = [
                sharded_optimizer.state_dict()
                if self.rank == REFERENCE_RANK else {}
            ]
            dist.broadcast_object_list(
                sharded_optim_state_dict, src=REFERENCE_RANK,
                group=dist.group.WORLD,
            )
            sharded_optim_state_dict = sharded_optim_state_dict[0]

            # - Cross-load the states
            # Run one step and check that the models are still the same
            ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)
            ddp_optimizer.load_state_dict(sharded_optim_state_dict)
            sharded_optimizer.load_state_dict(ddp_state_dict)
            check_step()

            # - Reload their respective states
            # Run one step and check that the models are still the same
            ddp_optimizer.load_state_dict(ddp_state_dict_ref)
            sharded_optimizer.load_state_dict(sharded_optim_state_dict)
            check_step()

    def _test_zero_join(self, device):
        """Check that the ZeRO join hook allows training with uneven inputs
        when using the given device."""
        NUM_INPUTS = 3
        NUM_EPOCHS = 2
        LR = 0.01
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        rank = self.rank
        world_size = self.world_size
        is_gpu = device.type == "cuda"
        backend = _get_backend_for_tests() if is_gpu else dist.Backend.GLOO
        self.dist_init(rank, world_size, backend)

        model = torch.nn.Sequential(
            torch.nn.Linear(2, 3),
            torch.nn.Linear(3, 3),
            torch.nn.Linear(3, 3),
        )
        model.to(device)

        # DDP ensures correct gradients in data parallel training, so DDP with
        # local optimizers on uneven inputs should be equivalent to ZeRO on
        # uneven inputs with gradients being manually set
        ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
        local_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
        zero_model = copy.deepcopy(model)
        zero_model.to(device)
        zero_optim = ZeroRedundancyOptimizer(
            zero_model.parameters(), torch.optim.Adam, lr=LR,
        )
        loss_fn = torch.nn.MSELoss()

        # Use uneven inputs: rank i has i extra inputs
        inputs = [
            torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)
        ]
        labels = torch.randn(20, 3).to(device)

        # Save the gradients and parameters from DDP as the ground truth; do
        # so on the last-joining rank (in this case, the largest rank)
        grads_at_each_iter = []
        params_at_each_iter = []
        with ddp_model.join():
            for _ in range(NUM_EPOCHS):
                for input in inputs:
                    output = ddp_model(input)
                    loss_fn(output, labels).backward()
                    if rank == world_size - 1:
                        grads = []
                        for p in ddp_model.parameters():
                            grads.append(p.grad.detach().clone().to(device))
                    local_optim.step()
                    if rank == world_size - 1:
                        params = []
                        for p in ddp_model.parameters():
                            params.append(p.detach().clone().to(device))
                        grads_at_each_iter.append(grads)
                        params_at_each_iter.append(params)

        # Broadcast the saved gradients and parameters to all of the other
        # ranks (which joined early)
        grads_and_params = [grads_at_each_iter, params_at_each_iter]
        grads_and_params = _broadcast_object(
            grads_and_params, src_rank=world_size - 1, group=dist.group.WORLD,
            device=device,
        )
        grads_at_each_iter = grads_and_params[0]
        params_at_each_iter = grads_and_params[1]
        # TODO: Replace this `_broadcast_object` with `broadcast_object_list`
        # once the latter supports loading to the destination device instead
        # of the source device

        # A process must still set the remaining gradients after joining, so we
        # define a join hook to do this before the ZeRO join hook
        class _JoinGradInfo():
            def __init__(self, grads):
                self.grads = grads  # remaining gradients to set (in order)
                self.index = 0

        class _SetGradsJoinHook(JoinHook):
            def __init__(self, zero_optim, grads):
                zero_optim._join_grad_info = _JoinGradInfo(grads)
                self.zero = zero_optim
                super().__init__()

            def main_hook(self):
                join_grad_info = self.zero._join_grad_info
                grads = self.zero._join_grad_info.grads[join_grad_info.index]
                join_grad_info.index += 1
                for p, grad in zip(self.zero._all_params, grads):
                    p.grad = grad.detach().clone().to(device)

        class _GradientSetter(Joinable):
            def __init__(self):
                super().__init__()

            def join_hook(self, **kwargs):
                assert "zero_optim" in kwargs
                assert "grads" in kwargs
                zero_optim = kwargs["zero_optim"]
                grads = kwargs["grads"]
                return _SetGradsJoinHook(zero_optim, grads)

            @property
            def join_device(self):
                return device

            @property
            def join_process_group(self):
                return dist.group.WORLD

        num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1)
        grads = grads_at_each_iter[-num_grads_after_joining:]
        gradient_setter = _GradientSetter()
        iter = 0
        with Join(
            [gradient_setter, zero_optim], zero_optim=zero_optim, grads=grads,
        ):
            for _ in range(NUM_EPOCHS):
                for input in inputs:
                    # Notify join context that this process has not joined
                    Join.notify_join_context(gradient_setter)
                    # Set gradients manually
                    for p, grad in zip(
                        zero_model.parameters(), grads_at_each_iter[iter],
                    ):
                        p.grad = grad.detach().clone().to(device)
                    # Perform optimizer step and check parity
                    zero_optim.step()
                    for p, ddp_p in zip(
                        zero_model.parameters(), params_at_each_iter[iter],
                    ):
                        torch.testing.assert_close(
                            p, ddp_p,
                            msg="Parameters differ between using ZeRO and "
                            "local optimizer",
                        )
                    iter += 1

    @common_distributed.requires_nccl()
    @common_distributed.skip_if_no_gpu
    def test_zero_join_gpu(self):
        """Check that the ZeRO join hook allows training with uneven inputs
        on GPU."""
        self._test_zero_join(self.device)

    @common_distributed.requires_gloo()
    def test_zero_join_cpu(self):
        """Check that the ZeRO join hook allows training with uneven inputs
        on CPU."""
        self._test_zero_join(torch.device("cpu"))

    def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
        # Use two processes each with two GPUs
        assert self.rank < 2
        NUM_EPOCHS = 2
        NUM_INPUTS = 4
        LR = 0.01
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        class ModelParallelModel(torch.nn.Module):
            def __init__(self, dev0, dev1):
                super().__init__()
                self.dev0 = dev0
                self.dev1 = dev1
                self.net0 = torch.nn.Linear(10, 10).to(dev0)
                self.relu = torch.nn.ReLU()
                self.net1 = torch.nn.Linear(10, 5).to(dev1)

            def forward(self, x):
                x = x.to(self.dev0)
                x = self.relu(self.net0(x))
                x = x.to(self.dev1)
                return self.net1(x)

        class LocalModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.net0 = torch.nn.Linear(10, 10)
                self.relu = torch.nn.ReLU()
                self.net1 = torch.nn.Linear(10, 5)

            def forward(self, x):
                return self.net1(self.relu(self.net0(x)))

        dev0 = torch.device(2 * self.rank)
        dev1 = torch.device(2 * self.rank + 1)
        mp_model = ModelParallelModel(dev0, dev1)
        ddp_model = DDP(mp_model)
        local_model = LocalModel().to(dev0)

        # Ensure the parameters are the same across the two models
        def copy_param(p):
            return torch.nn.Parameter(p.detach().clone().to(dev0))

        local_model.net0.weight = copy_param(mp_model.net0.weight)
        local_model.net0.bias = copy_param(mp_model.net0.bias)
        local_model.net1.weight = copy_param(mp_model.net1.weight)
        local_model.net1.bias = copy_param(mp_model.net1.bias)

        # Compare parity between DDP with model parallelism using ZeRO and
        # a local model using a local optimizer
        zero_optim = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            parameters_as_bucket_view=parameters_as_bucket_view,
            lr=LR,
        )
        local_optim = torch.optim.Adam(local_model.parameters(), lr=LR)
        inputs = [torch.randn(20, 10).to(dev0) for _ in range(NUM_INPUTS)]

        for _ in range(NUM_EPOCHS):
            for input in inputs:
                def closure_local():
                    local_optim.zero_grad()
                    local_loss = local_model(input).abs().sum()
                    local_loss.backward()
                    return local_loss

                def closure_ddp():
                    zero_optim.zero_grad()
                    ddp_loss = ddp_model(input).abs().sum()
                    ddp_loss.backward()
                    return ddp_loss

                local_loss = cast(
                    torch.Tensor, local_optim.step(closure=closure_local)
                )
                ddp_loss = cast(
                    torch.Tensor, zero_optim.step(closure=closure_ddp)
                )

                # Increased tolerances are needed to pass when using TF32
                # See: https://github.com/pytorch/pytorch/issues/67764
                torch.testing.assert_close(
                    local_loss.cpu(), ddp_loss.cpu(), rtol=1e-03, atol=1e-08,
                ), "Losses differ between local optimizer and ZeRO"

                for local_p, ddp_p in zip(
                    local_model.parameters(),
                    ddp_model.parameters()
                ):
                    torch.testing.assert_close(
                        local_p.cpu(), ddp_p.cpu(), rtol=1e-03, atol=1e-04,
                    ), "Models differ after a step"

    @common_distributed.skip_if_lt_x_gpu(4)
    @parametrize(
        "parameters_as_bucket_view",
        [False, True],
    )
    def test_zero_model_parallel(
        self,
        parameters_as_bucket_view: bool,
    ):
        """Check that ZeRO works with model parallelism where the model's
        layers are assigned to different devices."""
        if self.rank >= 2:
            return
        # Disable DDP + ReplicatedTensor when `parameter_as_bucket_view=True`
        # since then ZeroRedundancyOptimizer modifies the model parameters in
        # place.
        from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor
        context = _ddp_replicated_tensor(False) if parameters_as_bucket_view \
            else suppress()
        with context:
            self.dist_init(self.rank, world_size=2)
            self._test_zero_model_parallel(parameters_as_bucket_view)

    def _test_ddp_zero_overlap(
        self,
        device,
        hook_constructor,
        gradient_as_bucket_view,
        static_graph,
        **kwargs,
    ):
        SGD_LR = 0.01
        SGD_MOMENTUM = 0.9
        SGD_WEIGHT_DECAY = 0.001
        NUM_INPUTS = 5
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        rank = self.rank
        is_gpu = device.type == "cuda"
        if is_gpu:
            torch.cuda.set_device(device)
        models_to_test = [(
            torch.nn.Sequential(
                torch.nn.Linear(1000, 2000),
                torch.nn.Linear(2000, 500),
            ),
            [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)],
        )]
        if HAS_TORCHVISION:
            models_to_test.append((
                torchvision.models.resnet50(),
                [
                    torch.randn(1, 3, 3, 1000).to(device)
                    for _ in range(NUM_INPUTS)
                ]
            ))
        for (model, inputs) in models_to_test:
            # Enable determinism in cudnn operators
            with torch.backends.cudnn.flags(
                enabled=True, deterministic=True, benchmark=False
            ):
                device_ids = [rank] if is_gpu else None
                # Set up the DDP model overlapping with ZeRO
                ddp_model_overlap = DDP(
                    copy.deepcopy(model).to(device),
                    device_ids=device_ids,
                    gradient_as_bucket_view=gradient_as_bucket_view
                )
                if static_graph:
                    ddp_model_overlap._set_static_graph()
                zero_optim = ZeroRedundancyOptimizer(
                    ddp_model_overlap.parameters(),
                    optimizer_class=torch.optim.SGD,
                    overlap_with_ddp=True,
                    lr=SGD_LR,
                    momentum=SGD_MOMENTUM,
                    weight_decay=SGD_WEIGHT_DECAY,
                )
                ddp_model_overlap.register_comm_hook(
                    None,
                    hook_constructor(
                        allreduce_hook, ddp_model_overlap, zero_optim,
                        **kwargs,
                    )
                )

                # Set up the DDP model with local optimizer
                ddp_model_local = DDP(
                    copy.deepcopy(model).to(device),
                    device_ids=device_ids,
                    gradient_as_bucket_view=gradient_as_bucket_view
                )
                if static_graph:
                    ddp_model_local._set_static_graph()
                local_optim = torch.optim.SGD(
                    ddp_model_local.parameters(),
                    lr=SGD_LR,
                    momentum=SGD_MOMENTUM,
                    weight_decay=SGD_WEIGHT_DECAY
                )

                # Check that the parameters match initially
                for p1, p2 in zip(
                    ddp_model_overlap.parameters(),
                    ddp_model_local.parameters()
                ):
                    self.assertEqual(p1, p2)

                # Save the parameters to ensure they were updated
                init_params_overlap = copy.deepcopy(
                    list(ddp_model_overlap.parameters())
                )

                # Ensure that this test runs independently
                dist.barrier()

                # Run the DDP model overlapping with ZeRO
                # NOTE: Overlapping currently requires 2 or 3 warmup iterations
                # to ensure DDP buckets have been rebuilt (depending on the
                # value of `static_graph`)
                num_warmup_inputs = 2 if not static_graph else 3
                for input in inputs[:num_warmup_inputs]:
                    output = ddp_model_overlap(input)
                    loss = output.sum()
                    loss.backward()
                for input in inputs:
                    zero_optim.zero_grad()
                    output = ddp_model_overlap(input)
                    loss = output.sum()
                    loss.backward()

                # Run the DDP model with local optimizer
                for input in inputs:
                    local_optim.zero_grad()
                    output = ddp_model_local(input)
                    loss = output.sum()
                    loss.backward()
                    local_optim.step()
                dist.barrier()

                # Check that the parameters are equal
                for p1, p2 in zip(
                    ddp_model_overlap.parameters(),
                    ddp_model_local.parameters()
                ):
                    self.assertEqual(p1, p2)

                # Check that the parameters were updated
                self.assertNotEqual(
                    init_params_overlap, list(ddp_model_overlap.parameters()),
                )

                # Ensure that this test runs independently
                dist.barrier()

    # NOTE: The test is skipped if using Windows since functional optimizers
    # are not currently supported.
    @common_distributed.skip_if_win32()
    @common_distributed.requires_nccl()
    @common_distributed.skip_if_no_gpu
    @common_distributed.skip_if_rocm
    @parametrize(
        "use_gpu",
        [True],
        # Add `False` once the Gloo sync issue causing hangs is fixed
        # See: https://github.com/pytorch/pytorch/issues/62300
    )
    @parametrize(
        "use_interleaved_hook",
        [False, True],
    )
    @parametrize(
        "gradient_as_bucket_view",
        [False, True],
    )
    @parametrize(
        "static_graph",
        [False, True],
    )
    @parametrize(
        "shard_buckets",
        [False, True],
    )
    def test_ddp_zero_overlap(
        self,
        use_gpu: bool,
        use_interleaved_hook: bool,
        gradient_as_bucket_view: bool,
        static_graph: bool,
        shard_buckets: bool,
    ):
        """
        Check that overlapping DDP with ZeRO using the given method determined
        by ``hook_constructor`` and ``shard_buckets`` and using the given ZeRO
        and DDP arguments achieves parity with DDP using a local optimizer.
        """
        device = torch.device(self.rank) if use_gpu else torch.device("cpu")
        backend = _get_backend_for_tests()
        self.dist_init(self.rank, self.world_size, backend)
        hook_constructor = hook_with_zero_step if not use_interleaved_hook \
            else hook_with_zero_step_interleaved

        # Disable DDP + ReplicatedTensor since ZeroRedundancyOptimizer
        # modifies the model parameters in place.
        from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor
        with _ddp_replicated_tensor(False):
            self._test_ddp_zero_overlap(
                device, hook_constructor, gradient_as_bucket_view, static_graph,
                shard_buckets=shard_buckets,
            )


instantiate_parametrized_tests(TestZeroRedundancyOptimizerSingleRank)
instantiate_parametrized_tests(TestZeroRedundancyOptimizerDistributed)

if __name__ == "__main__":
    # ! unittest should not be used here, else the tests are not properly registered
    run_tests()