File: test_lazy_graph_module.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (279 lines) | stat: -rw-r--r-- 8,557 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# Owner(s): ["oncall: fx"]

import contextlib
import pickle
from io import BytesIO
from unittest.mock import patch

import torch
import torch._export
from torch import fx
from torch.fx._lazy_graph_module import (
    _LazyGraphModule,
    _make_graph_module,
    _use_lazy_graph_module,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_utils import run_tests, TestCase


class TestLazyGraphModule(TestCase):
    exit_stack = None

    @classmethod
    def setUpClass(cls):
        cls.exit_stack = contextlib.ExitStack()
        cls.exit_stack.enter_context(_use_lazy_graph_module(True))

    @classmethod
    def tearDownClass(cls):
        cls.exit_stack.close()

    @staticmethod
    def replace_sin_with_cos(gm):
        for n in gm.graph.nodes:
            if n.target == "sin":
                n.target = "cos"

    def test_replace_sin_with_cos(self):
        def f(x):
            return x.sin()

        x = torch.randn(2, 3)
        gm = fx.symbolic_trace(f)
        self.replace_sin_with_cos(gm)

        gm.recompile()
        expected = x.cos()
        actual = gm(x)

        self.assertTrue(torch.allclose(expected, actual))
        code = gm.print_readable(False)
        self.assertTrue("cos()" in code)
        self.assertTrue(isinstance(gm, _LazyGraphModule))

    def test_call_forward_directly(self):
        def f(x):
            return x.sin()

        x = torch.randn(2, 3)
        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.replace_sin_with_cos(gm)
        gm.recompile()
        expected = x.cos()
        actual = gm.forward(x)

        self.assertTrue(torch.allclose(expected, actual))

    def test_needs_recompile(self):
        """
        Make sure needs_recompile() return the corrent state.
        """

        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue(gm._needs_recompile())
        gm(torch.randn(2, 3))
        self.assertFalse(gm._needs_recompile())

    def test_multi_recompile(self):
        """
        Cover the case that multiple recompilation happens.
        """

        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue(gm._needs_recompile())
        x = torch.randn(2, 3)
        # trigger the first recompilation
        self.assertTrue(torch.allclose(x.sin(), gm(x)))
        self.assertFalse(gm._needs_recompile())

        self.replace_sin_with_cos(gm)
        self.assertFalse(gm._needs_recompile())
        gm.recompile()
        self.assertTrue(gm._needs_recompile())
        # trigger the second recompilation
        self.assertTrue(torch.allclose(x.cos(), gm(x)))
        self.assertFalse(gm._needs_recompile())

    def test_accessing_code_cause_recompiling(self):
        """
        Make sure we recompile if we have not done that yet when we access the code
        property of a GraphModule.
        """

        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue(gm._needs_recompile())
        # should trigger a recompilation
        code = gm.code
        self.assertTrue("sin" in code)
        self.assertFalse(gm._needs_recompile())

    def test_graph_module_str(self):
        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue("sin" in str(gm))

    def test_recapture_with_make_fx(self):
        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue(gm._needs_recompile())
        gm2 = make_fx(gm)(torch.randn(2, 3))
        self.assertTrue(isinstance(gm2, _LazyGraphModule))
        self.assertTrue(gm2._needs_recompile())

        # make_fx will cal foward method of gm. That clears the _needs_recompile()
        # flag.
        self.assertFalse(gm._needs_recompile())

    def test_recapture_with_symbolic_trace(self):
        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue(gm._needs_recompile())
        gm2 = fx.symbolic_trace(gm)

        # the lazy recompilcation is already realized. We realize the
        # recompilation in the beginning of symbolic_trace since symbolic_trace can not
        # handle the tracing of lazy recompilation.
        self.assertFalse(gm._needs_recompile())
        self.assertTrue(gm2._needs_recompile())

    def test_recapture_with_dynamo(self):
        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        self.assertTrue(gm._needs_recompile())
        torch.compile(gm)(torch.rand(2, 3))

        # dynamo calls gm.forward with eval hook installed. That will trigger
        # the real recompilation.
        self.assertFalse(gm._needs_recompile())

    def test_save_lazy_foward(self):
        """
        Save the lazy forward method and call it repeatly. Make sure we
        don't recompile for each such call.
        """

        def f(x):
            return x.sin()

        orig_gm_recompile = fx.GraphModule.recompile
        recompile_count = 0

        def mock_gm_recompile(self):
            nonlocal recompile_count
            recompile_count += 1
            return orig_gm_recompile(self)

        with patch.object(fx.GraphModule, "recompile", mock_gm_recompile):
            gm = fx.symbolic_trace(f)
            self.assertTrue(isinstance(gm, _LazyGraphModule))
            saved_fwd = gm.forward

            x = torch.rand(2, 3)
            for _ in range(10):
                saved_fwd(x)

        self.assertEqual(recompile_count, 1)

    def test_pickle(self):
        """
        Fx graph cache need the ability to pickle GraphModule/_LazyGraphModule.
        """

        def f(x):
            return x.sin()

        gm = fx.symbolic_trace(f)
        self.assertTrue(isinstance(gm, _LazyGraphModule))
        serialized = pickle.dumps(gm)
        gm2 = pickle.loads(serialized)
        self.assertTrue(isinstance(gm2, _LazyGraphModule))
        self.assertTrue("sin" in gm2.code)

    def test_make_graph_module(self):
        gm = fx.symbolic_trace(lambda x: x.sin())
        self.assertTrue(isinstance(gm, _LazyGraphModule))

        gm1 = _make_graph_module(
            gm, gm.graph, class_name="MyGraphModule", graph_module_cls=fx.GraphModule
        )
        self.assertFalse(isinstance(gm1, _LazyGraphModule))
        self.assertTrue(gm1.__class__.__name__ == "MyGraphModule")

        gm2 = _make_graph_module(gm, gm.graph)
        self.assertTrue(isinstance(gm2, _LazyGraphModule))
        self.assertTrue(gm2.__class__.__name__ == "GraphModule")

    def test_package_fx_simple(self):
        """
        Copied from test/package/test_package_fx.py to make sure LazyGraphModule
        works with torch.package.
        """

        class SimpleTest(torch.nn.Module):
            def forward(self, x):
                return torch.relu(x + 3.0)

        st = SimpleTest()
        traced = fx.symbolic_trace(st)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.save_pickle("model", "model.pkl", traced)

        f.seek(0)
        pi = PackageImporter(f)
        loaded_traced = pi.load_pickle("model", "model.pkl")
        input = torch.rand(2, 3)
        self.assertEqual(loaded_traced(input), traced(input))

    def test_dynamo_innermost_fn(self):
        """
        Repro for https://github.com/pytorch/pytorch/issues/121198 .
        """

        def f(x):
            return x * 2

        gm = torch.fx.symbolic_trace(f)
        lazy_gm = torch.fx._lazy_graph_module._LazyGraphModule.from_graphmodule(gm)

        wrapped_forward = torch._dynamo.disable(gm.forward)
        got_inner_forward = torch._dynamo.eval_frame.innermost_fn(wrapped_forward)
        assert hasattr(got_inner_forward, "__self__")

        wrapped_lazy_forward = torch._dynamo.disable(lazy_gm.forward)
        got_lazy_inner_forward = torch._dynamo.eval_frame.innermost_fn(
            wrapped_lazy_forward
        )
        assert hasattr(got_lazy_inner_forward, "__self__")


if __name__ == "__main__":
    run_tests()