File: test_fsdp_fx.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (122 lines) | stat: -rw-r--r-- 4,502 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Owner(s): ["oncall: distributed"]

from typing import Any

import torch
from torch.distributed.fsdp._symbolic_trace import _init_execution_info, _patch_tracer
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    run_tests,
)


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight1 = torch.nn.Parameter(torch.randn(6, 6))
        self.weight2 = torch.nn.Parameter(torch.randn(6, 6))
        self.weight_unused = torch.nn.Parameter(torch.randn(2, 2))
        self.layer0 = torch.nn.Linear(6, 6)
        self.layer1 = torch.nn.Linear(6, 6, bias=False)
        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(6, 3, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(3, 6, bias=False),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x: Any, run_all_layers: bool):
        z = self.relu(self.layer0(x))
        z = self.relu(self.layer2(z))
        z = z @ self.weight1
        if run_all_layers:
            z = self.relu(self.layer1(z))
            z = z @ self.weight2
            # used to test the case where a module is called more than once
            z = self.relu(self.layer0(x))
        return z


class TestSymbolicTracing(FSDPTest):
    def test_symbolic_tracing_outputs(self):
        """
        test ``execution_info.module_forward_order`` and ``execution_info.module_to_execution_infos``
        after running ``tracer.trace()`` inside ``_patch_tracer``.
        """
        model = Model()
        tracer = torch.fx.Tracer()
        execution_info = _init_execution_info(model)
        original_call_module = tracer.call_module
        original_create_proxy = tracer.create_proxy
        with _patch_tracer(
            tracer=tracer, root_module=model, execution_info=execution_info
        ):
            concrete_args = {"run_all_layers": True}
            tracer.trace(model, concrete_args)
        # the member functions of tracer should not be changed
        self.assertEqual(original_call_module, tracer.call_module)
        self.assertEqual(original_create_proxy, tracer.create_proxy)
        # test tracer.module_forward_order
        correct_module_forward_order = [
            model,
            model.layer0,
            model.relu,
            model.layer2,
            model.layer2[0],
            model.layer2[1],
            model.layer2[2],
            model.relu,
            model.layer1,
            model.relu,
            model.layer0,
            model.relu,
        ]
        self.assertEqual(
            execution_info.module_forward_order, correct_module_forward_order
        )
        # test execution_info.module_to_execution_infos
        self.assertEqual(
            execution_info.module_to_execution_infos[model],
            [
                (model.layer0, list(model.layer0.named_parameters())),
                (model.layer2, list(model.layer2.named_parameters())),
                (model, [("weight1", model.weight1)]),
                (model.layer1, list(model.layer1.named_parameters())),
                (model, [("weight2", model.weight2)]),
                (model.layer0, list(model.layer0.named_parameters())),
            ],
        )
        self.assertEqual(
            execution_info.module_to_execution_infos[model.layer0],
            [(model.layer0, list(model.layer0.named_parameters()))],
        )
        self.assertEqual(
            execution_info.module_to_execution_infos[model.layer1],
            [(model.layer1, list(model.layer1.named_parameters()))],
        )
        self.assertEqual(
            execution_info.module_to_execution_infos[model.layer2],
            [
                (model.layer2[0], list(model.layer2[0].named_parameters())),
                (model.layer2[2], list(model.layer2[2].named_parameters())),
            ],
        )
        self.assertEqual(execution_info.module_to_execution_infos[model.relu], [])
        # test tracer.param_exec_order
        correct_param_order = [
            model.layer0.weight,
            model.layer0.bias,
            model.layer2[0].weight,
            model.layer2[2].weight,
            model.weight1,
            model.layer1.weight,
            model.weight2,
        ]
        self.assertEqual(execution_info.param_exec_order, correct_param_order)


instantiate_parametrized_tests(TestSymbolicTracing)

if __name__ == "__main__":
    run_tests()