File: pipe_with_ddp_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (147 lines) | stat: -rw-r--r-- 5,166 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.distributed as dist

from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
    RpcAgentTestFixture,
)
from torch.testing._internal.common_distributed import (
    requires_gloo,
    requires_nccl,
    skip_if_lt_x_gpu,
    skip_if_rocm,
)
from torch.distributed.pipeline.sync import Pipe

class PipeWithDDPTest(RpcAgentTestFixture):
    @property
    def world_size(self) -> int:
        return 2

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_never(self):
        self._run_basic_test("nccl", "never")

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_never_find_unused(self):
        self._run_basic_test("nccl", "never", find_unused_parameters=True)

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_always(self):
        self._run_basic_test("nccl", "always", static_graph=True)

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_except_last(self):
        self._run_basic_test("nccl", "except_last", static_graph=True)

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_never(self):
        self._run_basic_test("gloo", "never")

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_never_find_unused(self):
        self._run_basic_test("gloo", "never", find_unused_parameters=True)

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_always(self):
        self._run_basic_test("gloo", "always", static_graph=True)

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_except_last(self):
        self._run_basic_test("gloo", "except_last", static_graph=True)

    def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False):
        dist.init_process_group(
            backend=backend,
            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
            world_size=self.world_size,
            rank=self.rank,
        )

        # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another
        # pipe between GPU 2 and 3. Both replicas are replicated via DDP.
        fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank)

        class MyModule(nn.Module):
            def __init__(self, device):
                super(MyModule, self).__init__()
                self.fc2 = nn.Linear(8, 4, bias=False).cuda(device)
                self.fc3 = nn.Linear(4, 2, bias=False).cuda(device)

            def forward(self, inp):
                if find_unused_parameters:
                    return self.fc2(inp)
                else:
                    return self.fc3(self.fc2(inp))

        layer2 = MyModule(2 * self.rank + 1)
        model = nn.Sequential(
            fc1,
            layer2
        )
        model = Pipe(model, chunks=2, checkpoint=checkpoint)
        model = DistributedDataParallel(
            model,
            find_unused_parameters=find_unused_parameters,
            static_graph=static_graph,
        )

        # Ensure inputs are different across ranks to verify that gradient
        # sync indeed occurs.
        model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1)
        out = model(model_input).local_value()
        out.sum().backward()

        # Run forward again for find_unused_parameters to trigger any potential errors.
        if find_unused_parameters:
            # Ensure inputs are different across ranks to verify that gradient
            # sync indeed occurs.
            unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1)
            model(unused_param_input).local_value().sum().backward()

        # Run a few more iterations of fwd + bwd to ensure gradient synchronization
        # occurs properly across iterations via delay_all_reduce/bucketized allreduce.
        for _ in range(3):
            model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1)
            out = model(model_input).local_value()
            out.sum().backward()

        # Check grads
        output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)]
        dist.all_gather(output, fc1.weight.grad)
        self.assertEqual(output[0], output[1])

        output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)]
        dist.all_gather(output, layer2.fc2.weight.grad)
        self.assertEqual(output[0], output[1])

        if not find_unused_parameters:
            output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)]
            dist.all_gather(output, layer2.fc3.weight.grad)
            self.assertEqual(output[0], output[1])