File: test_fsdp_param_exec_order_wrap.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (134 lines) | stat: -rw-r--r-- 5,008 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Owner(s): ["oncall: distributed"]

from typing import Any, Callable

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._symbolic_trace import TracingConfig
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.wrap import always_wrap_policy, ParamExecOrderWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layer0 = torch.nn.Linear(6, 6)
        self.layer1 = torch.nn.Linear(6, 6, bias=False)
        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(6, 3, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(3, 6, bias=False),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x: Any, use_all_params: bool = True):
        # `layer0` -> `layer2` -> `layer1`
        # the forward execution order is NOT consistent with the model definition order.
        z = self.relu(self.layer0(x))
        z = self.relu(self.layer2(z))
        if use_all_params:
            z = self.relu(self.layer1(z))
        return z

    def get_input(self, device: torch.device):
        return (torch.randn((8, 6)).to(device),)

    def get_loss(self, input, output):
        return (output - input[0]).sum()

    @staticmethod
    def wrap(
        sharding_strategy: ShardingStrategy,
        device: torch.device,
        wrap_policy: Callable,
    ) -> torch.nn.Module:
        model = Model()
        fsdp_model = FSDP(
            model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy
        )
        return fsdp_model.to(device)


class TestFSDPExecOrder(FSDPTest):
    @property
    def device(self):
        return torch.device("cuda")

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "sharding_strategy",
        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
    )
    def test_fsdp_flatten_params_exec_order(
        self,
        sharding_strategy: ShardingStrategy,
    ):
        """
        Test ``_fsdp_params_exec_order`` with ``ParamExecOrderWrapPolicy``,
        after running one iteration of forward and backward pass.
        Here ``torch.fx`` is not enabled inside ``ParamExecOrderWrapPolicy``.
        """
        wrap_policy = ParamExecOrderWrapPolicy(init_policy=always_wrap_policy)
        fsdp_model = Model.wrap(sharding_strategy, self.device, wrap_policy=wrap_policy)
        self.assertTrue(fsdp_model._is_param_exec_order_prep_stage())
        # run one iteration to record the execution ordering
        input = fsdp_model.module.get_input(self.device)
        output = fsdp_model(*input)
        loss = fsdp_model.module.get_loss(input, output).to(self.device)
        loss.backward()
        params_list = list(fsdp_model.parameters())
        # Since the forward execution order is NOT consistent with
        # the model definition order, the ordering in flatten_named_params_exec_order
        # should be different from named_parameters.
        self.assertEqual(
            fsdp_model._fsdp_params_exec_order,
            [params_list[0], params_list[2], params_list[3], params_list[1]],
        )
        self.assertTrue(fsdp_model._use_param_exec_order_policy())
        self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage())

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "sharding_strategy",
        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
    )
    def test_fsdp_flatten_params_exec_order_symbolic_trace(
        self,
        sharding_strategy: ShardingStrategy,
    ):
        """
        Tests ``ParamExecOrderWrapPolicy`` with symbolic tracing.
        With symbolic tracing enabled, ``_is_param_exec_order_prep_stage``
        should always set as False.
        """
        wrap_policy = ParamExecOrderWrapPolicy(
            init_policy=always_wrap_policy,
            tracing_config=TracingConfig(concrete_args={"use_all_params": False}),
        )
        fsdp_model = Model.wrap(
            sharding_strategy,
            self.device,
            wrap_policy=wrap_policy,
        )
        params_list = list(fsdp_model.parameters())
        # Since the forward execution order is NOT consistent with the model definition order,
        # the ordering in flatten_named_params_exec_order should be different from named_parameters
        self.assertEqual(
            fsdp_model._fsdp_params_exec_order,
            [params_list[0], params_list[2], params_list[3]],
        )
        self.assertTrue(fsdp_model._use_param_exec_order_policy())
        self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage())


instantiate_parametrized_tests(TestFSDPExecOrder)

if __name__ == "__main__":
    run_tests()