File: test_fsdp_exec_order.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (216 lines) | stat: -rw-r--r-- 8,149 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Owner(s): ["oncall: distributed"]

import sys
import warnings
from contextlib import suppress

import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
    TEST_WITH_DEV_DBG_ASAN,
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class Model(torch.nn.Module):
    """
    Model that supports two computation paths: `layer0` -> `layer1` and
    `layer0` -> `layer2`. Notably, both `layer1` and `layer2` have 36 elements
    when flattened, which means that their corresponding all-gathers and
    reduce-scatters may be silently matched if we do not perform any checks.
    """

    def __init__(self) -> None:
        super().__init__()
        self.layer0 = torch.nn.Linear(5, 6)
        self.layer1 = torch.nn.Linear(6, 6, bias=False)
        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(6, 3, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(3, 6, bias=False),
        )
        self.relu = torch.nn.ReLU()
        self.use_alt_path = False
        for param in self.layer2.parameters():
            param.requires_grad = False

    def forward(self, x):
        # `layer0` -> `layer1` (normal)
        # `layer0` -> `layer2` (alternate)
        z = self.relu(self.layer0(x))
        z = (
            self.relu(self.layer2(z))
            if self.use_alt_path
            else self.relu(self.layer1(z))
        )
        return z

    def get_input(self, device: torch.device):
        return (torch.randn((8, 5)).to(device),)

    def get_loss(self, input, output):
        return output.sum()

    def run_backward(self, loss):
        loss.backward()

    def flip_path(self):
        params_to_freeze = (
            self.layer2.parameters() if self.use_alt_path else self.layer1.parameters()
        )
        params_to_unfreeze = (
            self.layer1.parameters() if self.use_alt_path else self.layer2.parameters()
        )
        for param in params_to_freeze:
            param.requires_grad = False
        for param in params_to_unfreeze:
            param.requires_grad = True
        self.use_alt_path = not self.use_alt_path

    @staticmethod
    def wrap(sharding_strategy: ShardingStrategy, device: torch.device):
        model = Model()
        model.layer1 = FSDP(model.layer1, sharding_strategy=sharding_strategy)
        model.layer2 = FSDP(model.layer2, sharding_strategy=sharding_strategy)
        fsdp_model = FSDP(model, sharding_strategy=sharding_strategy)
        return fsdp_model.to(device)


class TestFSDPExecOrder(FSDPTest):
    def setUp(self):
        super().setUp()

    @property
    def device(self):
        return torch.device("cuda")

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "sharding_strategy",
        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
    )
    def test_invalid_first_iter_order(
        self,
        sharding_strategy: ShardingStrategy,
    ):
        """Tests that FSDP errors if the all-gather order differs across ranks
        in the first iteration."""
        dist.set_debug_level(dist.DebugLevel.INFO)
        # Rank 0 runs the forward pass in one order and all other ranks run in
        # different order
        dist.set_debug_level(dist.DebugLevel.INFO)
        fsdp_model = Model.wrap(sharding_strategy, self.device)
        if self.rank != 0:
            fsdp_model.flip_path()
        inp = fsdp_model.module.get_input(self.device)
        # Match the error message with the following prefix
        error_regex = "^(Forward order differs across ranks)"
        with self.assertRaisesRegex(RuntimeError, error_regex):
            fsdp_model(*inp)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "sharding_strategy",
        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
    )
    @parametrize("iters_before_path_change", [1, 3])
    def test_invalid_later_iter_order(
        self,
        sharding_strategy: ShardingStrategy,
        iters_before_path_change: int,
    ):
        """Tests that FSDP warns the user if the all-gather order changes after
        the first iteration."""
        dist.set_debug_level(dist.DebugLevel.INFO)
        # On the first iteration, all ranks run the same order, and on the next
        # iteration, all but rank 0 run in a different order
        fsdp_model = Model.wrap(sharding_strategy, self.device)
        for _ in range(iters_before_path_change):
            inp = fsdp_model.module.get_input(self.device)
            output = fsdp_model(*inp)
            loss = fsdp_model.module.get_loss(inp, output).to(self.device)
            fsdp_model.module.run_backward(loss)
        # Match the warning message with the following prefix
        regex = (
            "^(Forward order differs from that of the first iteration "
            f"on rank {self.rank}. Collectives are unchecked and may give "
            "incorrect results or hang)"
        )
        context = (
            self.assertWarnsRegex(
                expected_warning=UserWarning,
                expected_regex=regex,
            )
            if self.rank != 0
            else suppress()
        )
        if self.rank != 0:
            fsdp_model.flip_path()
        inp = fsdp_model.module.get_input(self.device)
        # Expect a warning for the forward pass all-gather
        with context:  # warning for forward pass all-gather
            output = fsdp_model(*inp)
        loss = fsdp_model.module.get_loss(inp, output).to(self.device)
        fsdp_model.module.run_backward(loss)
        # Run an additional iteration to check that there are no more warnings
        inp = fsdp_model.module.get_input(self.device)
        output = fsdp_model(*inp)
        loss = fsdp_model.module.get_loss(inp, output).to(self.device)
        fsdp_model.module.run_backward(loss)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "sharding_strategy",
        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
    )
    def test_train_eval(self, sharding_strategy: ShardingStrategy):
        dist.set_debug_level(dist.DebugLevel.INFO)
        fsdp_model = Model.wrap(sharding_strategy, self.device)
        NUM_ITERS = 3
        NUM_EPOCHS = 2
        with warnings.catch_warnings(record=True) as w:  # records warnings to `w`
            for _ in range(NUM_EPOCHS):
                fsdp_model.train()
                for _ in range(NUM_ITERS):
                    inp = fsdp_model.module.get_input(self.device)
                    output = fsdp_model(*inp)
                    loss = fsdp_model.module.get_loss(inp, output).to(self.device)
                    fsdp_model.module.run_backward(loss)
                fsdp_model.eval()
                for _ in range(NUM_ITERS):
                    inp = fsdp_model.module.get_input(self.device)
                    output = fsdp_model(*inp)
                    fsdp_model.module.get_loss(inp, output).to(self.device)
        # Check that the order validation warning was not issued (errors do not
        # need to be checked since they will be directly reported)
        warning_prefix = "Forward order differs"
        for warning in w:
            if str(warning.message).startswith(warning_prefix):
                raise AssertionError(
                    f"Warning was incorrectly issued: {warning.message}"
                )
        # If we still validate the forward execution order in eval mode, then
        # an `AssertionError` will be raised above for both sharding strategies


instantiate_parametrized_tests(TestFSDPExecOrder)

if __name__ == "__main__":
    run_tests()