File: test_fx_param_shape_control_flow.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (155 lines) | stat: -rw-r--r-- 5,002 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Owner(s): ["module: fx"]

import unittest
import torch
import torch.fx

from torch.testing._internal.common_utils import TestCase


class MyModuleBase(torch.nn.Module):
    def forward(self, x):
        matrx = self.get_mul_matrix()
        if self.no_relu():
            return torch.mm(x, matrx)
        else:
            return torch.relu(torch.mm(x, matrx))

    def get_mul_matrix(self):
        return self.param

    def no_relu(self):
        raise Exception("not implemented")

class MyModuleParamShape(MyModuleBase):
    def __init__(self, in_channels):
        super().__init__()
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))

    def no_relu(self):
        return self.param.shape[0] < 10


class MyModuleParamSize(MyModuleBase):
    def __init__(self, in_channels):
        super().__init__()
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))

    def no_relu(self):
        return self.param.size()[0] < 10


class MyModuleParamDim(MyModuleBase):
    def __init__(self, param):
        super().__init__()
        self.param = param

    def get_mul_matrix(self):
        return self.param[0] if (self.param.dim() == 3) else self.param

    def no_relu(self):
        return self.param.dim() == 3


class MyModuleParamNDim(MyModuleBase):
    def __init__(self, param):
        super().__init__()
        self.param = param

    def get_mul_matrix(self):
        return self.param[0] if (self.param.ndim == 3) else self.param

    def no_relu(self):
        return self.param.ndim == 3


class MyModuleParamNumEl(MyModuleBase):
    def __init__(self, in_channels):
        super().__init__()
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))

    def no_relu(self):
        return self.param.numel() < 10 * 3



class MyModuleParamNElement(MyModuleBase):
    def __init__(self, in_channels):
        super().__init__()
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))

    def no_relu(self):
        return self.param.nelement() < 10 * 3



class TestConstParamShapeInControlFlow(TestCase):

    def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
        """
        Verify one module only does a mm op while the other
        performs both mm and relu ops in cascade
        """
        x = torch.randn(10, 5)
        torch.testing.assert_allclose(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
        tracer = torch.fx.Tracer(param_shapes_constant=True)
        traced_graph = tracer.trace(mm_only_mod)

        # verify the graph module calculates the same result
        graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
        torch.testing.assert_allclose(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()))


        # Make a new module with different parameter shape to go down the different
        # code path
        x = torch.randn(10, 15)
        torch.testing.assert_allclose(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))

        tracer2 = torch.fx.Tracer(param_shapes_constant=True)
        traced_graph2 = tracer2.trace(relu_mod)

        # verify the graph module calculates the same result
        graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
        torch.testing.assert_allclose(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))


        graph1_node_targets = [n.target for n in traced_graph.nodes]
        graph2_node_targets = [n.target for n in traced_graph2.nodes]

        # the second graph has an exta relu function call node
        assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
        assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets

    def test_param_shape_const(self):
        mymod = MyModuleParamShape(in_channels=5)
        mymod2 = MyModuleParamShape(in_channels=15)
        self.verify_mm_relu_mods(mymod, mymod2)

    def test_param_size_const(self):
        mymod = MyModuleParamSize(in_channels=5)
        mymod2 = MyModuleParamSize(in_channels=15)
        self.verify_mm_relu_mods(mymod, mymod2)

    def test_param_dim_const(self):
        mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
        mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
        self.verify_mm_relu_mods(mymod, mymod2)

    def test_param_ndim_const(self):
        mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
        mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
        self.verify_mm_relu_mods(mymod, mymod2)

    def test_param_numel_const(self):
        mymod = MyModuleParamNumEl(in_channels=5)
        mymod2 = MyModuleParamNumEl(in_channels=15)
        self.verify_mm_relu_mods(mymod, mymod2)

    def test_param_nelement_const(self):
        mymod = MyModuleParamNElement(in_channels=5)
        mymod2 = MyModuleParamNElement(in_channels=15)
        self.verify_mm_relu_mods(mymod, mymod2)


if __name__ == '__main__':
    unittest.main()