1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
|
# Owner(s): ["oncall: fx"]
import contextlib
import pickle
from io import BytesIO
from unittest.mock import patch
import torch
import torch._export
from torch import fx
from torch.fx._lazy_graph_module import (
_LazyGraphModule,
_make_graph_module,
_use_lazy_graph_module,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_utils import run_tests, TestCase
class TestLazyGraphModule(TestCase):
exit_stack = None
@classmethod
def setUpClass(cls):
cls.exit_stack = contextlib.ExitStack()
cls.exit_stack.enter_context(_use_lazy_graph_module(True))
@classmethod
def tearDownClass(cls):
cls.exit_stack.close()
@staticmethod
def replace_sin_with_cos(gm):
for n in gm.graph.nodes:
if n.target == "sin":
n.target = "cos"
def test_replace_sin_with_cos(self):
def f(x):
return x.sin()
x = torch.randn(2, 3)
gm = fx.symbolic_trace(f)
self.replace_sin_with_cos(gm)
gm.recompile()
expected = x.cos()
actual = gm(x)
self.assertTrue(torch.allclose(expected, actual))
code = gm.print_readable(False)
self.assertTrue("cos()" in code)
self.assertTrue(isinstance(gm, _LazyGraphModule))
def test_call_forward_directly(self):
def f(x):
return x.sin()
x = torch.randn(2, 3)
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.replace_sin_with_cos(gm)
gm.recompile()
expected = x.cos()
actual = gm.forward(x)
self.assertTrue(torch.allclose(expected, actual))
def test_needs_recompile(self):
"""
Make sure needs_recompile() return the corrent state.
"""
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue(gm._needs_recompile())
gm(torch.randn(2, 3))
self.assertFalse(gm._needs_recompile())
def test_multi_recompile(self):
"""
Cover the case that multiple recompilation happens.
"""
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue(gm._needs_recompile())
x = torch.randn(2, 3)
# trigger the first recompilation
self.assertTrue(torch.allclose(x.sin(), gm(x)))
self.assertFalse(gm._needs_recompile())
self.replace_sin_with_cos(gm)
self.assertFalse(gm._needs_recompile())
gm.recompile()
self.assertTrue(gm._needs_recompile())
# trigger the second recompilation
self.assertTrue(torch.allclose(x.cos(), gm(x)))
self.assertFalse(gm._needs_recompile())
def test_accessing_code_cause_recompiling(self):
"""
Make sure we recompile if we have not done that yet when we access the code
property of a GraphModule.
"""
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue(gm._needs_recompile())
# should trigger a recompilation
code = gm.code
self.assertTrue("sin" in code)
self.assertFalse(gm._needs_recompile())
def test_graph_module_str(self):
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue("sin" in str(gm))
def test_recapture_with_make_fx(self):
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue(gm._needs_recompile())
gm2 = make_fx(gm)(torch.randn(2, 3))
self.assertTrue(isinstance(gm2, _LazyGraphModule))
self.assertTrue(gm2._needs_recompile())
# make_fx will cal foward method of gm. That clears the _needs_recompile()
# flag.
self.assertFalse(gm._needs_recompile())
def test_recapture_with_symbolic_trace(self):
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue(gm._needs_recompile())
gm2 = fx.symbolic_trace(gm)
# the lazy recompilcation is already realized. We realize the
# recompilation in the beginning of symbolic_trace since symbolic_trace can not
# handle the tracing of lazy recompilation.
self.assertFalse(gm._needs_recompile())
self.assertTrue(gm2._needs_recompile())
def test_recapture_with_dynamo(self):
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
self.assertTrue(gm._needs_recompile())
torch.compile(gm)(torch.rand(2, 3))
# dynamo calls gm.forward with eval hook installed. That will trigger
# the real recompilation.
self.assertFalse(gm._needs_recompile())
def test_save_lazy_foward(self):
"""
Save the lazy forward method and call it repeatly. Make sure we
don't recompile for each such call.
"""
def f(x):
return x.sin()
orig_gm_recompile = fx.GraphModule.recompile
recompile_count = 0
def mock_gm_recompile(self):
nonlocal recompile_count
recompile_count += 1
return orig_gm_recompile(self)
with patch.object(fx.GraphModule, "recompile", mock_gm_recompile):
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
saved_fwd = gm.forward
x = torch.rand(2, 3)
for _ in range(10):
saved_fwd(x)
self.assertEqual(recompile_count, 1)
def test_pickle(self):
"""
Fx graph cache need the ability to pickle GraphModule/_LazyGraphModule.
"""
def f(x):
return x.sin()
gm = fx.symbolic_trace(f)
self.assertTrue(isinstance(gm, _LazyGraphModule))
serialized = pickle.dumps(gm)
gm2 = pickle.loads(serialized)
self.assertTrue(isinstance(gm2, _LazyGraphModule))
self.assertTrue("sin" in gm2.code)
def test_make_graph_module(self):
gm = fx.symbolic_trace(lambda x: x.sin())
self.assertTrue(isinstance(gm, _LazyGraphModule))
gm1 = _make_graph_module(
gm, gm.graph, class_name="MyGraphModule", graph_module_cls=fx.GraphModule
)
self.assertFalse(isinstance(gm1, _LazyGraphModule))
self.assertTrue(gm1.__class__.__name__ == "MyGraphModule")
gm2 = _make_graph_module(gm, gm.graph)
self.assertTrue(isinstance(gm2, _LazyGraphModule))
self.assertTrue(gm2.__class__.__name__ == "GraphModule")
def test_package_fx_simple(self):
"""
Copied from test/package/test_package_fx.py to make sure LazyGraphModule
works with torch.package.
"""
class SimpleTest(torch.nn.Module):
def forward(self, x):
return torch.relu(x + 3.0)
st = SimpleTest()
traced = fx.symbolic_trace(st)
f = BytesIO()
with PackageExporter(f) as pe:
pe.save_pickle("model", "model.pkl", traced)
f.seek(0)
pi = PackageImporter(f)
loaded_traced = pi.load_pickle("model", "model.pkl")
input = torch.rand(2, 3)
self.assertEqual(loaded_traced(input), traced(input))
def test_dynamo_innermost_fn(self):
"""
Repro for https://github.com/pytorch/pytorch/issues/121198 .
"""
def f(x):
return x * 2
gm = torch.fx.symbolic_trace(f)
lazy_gm = torch.fx._lazy_graph_module._LazyGraphModule.from_graphmodule(gm)
wrapped_forward = torch._dynamo.disable(gm.forward)
got_inner_forward = torch._dynamo.eval_frame.innermost_fn(wrapped_forward)
assert hasattr(got_inner_forward, "__self__")
wrapped_lazy_forward = torch._dynamo.disable(lazy_gm.forward)
got_lazy_inner_forward = torch._dynamo.eval_frame.innermost_fn(
wrapped_lazy_forward
)
assert hasattr(got_lazy_inner_forward, "__self__")
if __name__ == "__main__":
run_tests()
|