File: test_fsdp_fx.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (119 lines) | stat: -rw-r--r-- 4,691 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.fsdp._trace_utils import _ExecOrderTracer
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, TestCase


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight1 = torch.nn.Parameter(torch.randn(6, 6))
        self.weight2 = torch.nn.Parameter(torch.randn(6, 6))
        self.weight_unused = torch.nn.Parameter(torch.randn(2, 2))
        self.layer0 = torch.nn.Linear(6, 6)
        self.layer1 = torch.nn.Linear(6, 6, bias=False)
        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(6, 3, bias=False),
            torch.nn.ReLU(),
            torch.nn.Linear(3, 6, bias=False),
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor:
        z = self.relu(self.layer0(x))
        z = self.relu(self.layer2(z))
        z = z @ self.weight1
        if run_all_layers:
            z = self.relu(self.layer1(z))
            z = z @ self.weight2
            # Use `layer0` twice to check the handling of multiplicity in the
            # saved data structures
            z = self.relu(self.layer0(x))
        return z


class TestSymbolicTracing(TestCase):
    def test_symbolic_tracing_outputs(self):
        """
        Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking
        the saved data structures.
        """
        model = Model()
        tracer = torch.fx.Tracer()
        orig_call_module = tracer.call_module
        orig_create_proxy = tracer.create_proxy
        exec_order_tracer = _ExecOrderTracer()
        with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model):
            concrete_args = {"run_all_layers": True}
            tracer.trace(model, concrete_args)
        # Check that the tracer methods are unchanged after exiting the context
        self.assertEqual(orig_call_module, tracer.call_module)
        self.assertEqual(orig_create_proxy, tracer.create_proxy)
        # Check `module_forward_order`
        correct_module_forward_order = [
            model,
            model.layer0,
            model.relu,
            model.layer2,
            model.layer2[0],
            model.layer2[1],
            model.layer2[2],
            model.relu,
            model.layer1,
            model.relu,
            model.layer0,
            model.relu,
        ]
        exec_info = exec_order_tracer.exec_info
        self.assertEqual(exec_info.module_forward_order, correct_module_forward_order)
        # Check `module_to_param_usage_infos`
        self.assertEqual(
            exec_info.module_to_param_usage_infos[model],
            [
                (model.layer0, list(model.layer0.named_parameters())),
                (model.layer2, list(model.layer2.named_parameters())),
                (model, [("weight1", model.weight1)]),
                (model.layer1, list(model.layer1.named_parameters())),
                (model, [("weight2", model.weight2)]),
                (model.layer0, list(model.layer0.named_parameters())),
            ],
        )
        self.assertEqual(
            exec_info.module_to_param_usage_infos[model.layer0],
            [(model.layer0, list(model.layer0.named_parameters()))],
        )
        self.assertEqual(
            exec_info.module_to_param_usage_infos[model.layer1],
            [(model.layer1, list(model.layer1.named_parameters()))],
        )
        self.assertEqual(
            exec_info.module_to_param_usage_infos[model.layer2],
            [
                (model.layer2[0], list(model.layer2[0].named_parameters())),
                (model.layer2[2], list(model.layer2[2].named_parameters())),
            ],
        )
        self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], [])
        # Check `param_forward_order`
        correct_param_order = [
            model.layer0.weight,
            model.layer0.bias,
            model.layer2[0].weight,
            model.layer2[2].weight,
            model.weight1,
            model.layer1.weight,
            model.weight2,
        ]
        self.assertEqual(exec_info.param_forward_order, correct_param_order)
        # Check `visited_params`
        self.assertEqual(
            len(exec_info.visited_params), len(exec_info.param_forward_order)
        )
        self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))


devices = ("cuda", "hpu")
instantiate_device_type_tests(TestSymbolicTracing, globals(), only_for=devices)
if __name__ == "__main__":
    run_tests()