1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
|
# Owner(s): ["module: dynamo"]
import contextlib
import dis
import unittest
from typing import List
import torch
import torch._dynamo.test_case
from torch.testing._internal.common_utils import IS_FBCODE
def _filter_instructions(instructions, opname):
return list(filter(lambda x: x.opname == opname, instructions))
class ReconstructTest(torch._dynamo.test_case.TestCase):
@contextlib.contextmanager
def register_bytecode_hook(self, fn):
def hook(code, out_code):
fn(list(dis.get_instructions(out_code)))
return code
torch._dynamo.reset()
handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
try:
yield
finally:
handle.remove()
def test_ConstDict_optimize_reconstruct(self):
"""
Emit code to reconstruct only the key that changed
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct only d[40]
self.assertEqual(build_map[0].argval, 1)
def f(d, t):
d[40] = t + 1
t = torch.randn(3, 4)
d = {1: t}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_pop_reconstruct(self):
"""
If something is pop'ed from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 2)
def f(d, t):
d.pop(2)
d[40] = t + 1
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
@unittest.expectedFailure
def test_ConstDict_popitem_reconstruct(self):
"""
If something is pop'ed from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 1)
def f(d, t):
d.popitem()
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_popitem_reconstruct_graph_break(self):
"""
If something is pop'ed from the dict, we reconstruct everything.
Calling dict.popitem will graph break.
"""
def f(d, t):
d.popitem()
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
opt_f = torch.compile(backend="eager")(f)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_del_reconstruct(self):
"""
If something is deleted from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 2)
def f(d, t):
del d[2]
d[40] = t + 1
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_get_reconstruct(self):
"""
dict.get shouldn't affect anything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
self.assertEqual(build_map[0].argval, 1)
load_const = _filter_instructions(instructions, "LOAD_CONST")
self.assertNotIn(123, load_const)
def f(d, t):
d[456] = d.get(456) + t
t = torch.randn(3, 4)
d = {123: t, 456: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_ConstDict_clear_reconstruct(self):
"""
If dict.clear() is used, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 1)
def f(d, t):
d.clear()
d[3] = t + 3
t = torch.randn(3, 4)
d = {1: t, 2: t + 1}
d_opt = d.copy()
f(d, t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
opt_f(d_opt, t)
self.assertEqual(d, d_opt)
def test_create_dict_reconstruct(self):
"""
If dict is created inside a function, everything needs to be reconstructed
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
self.assertEqual(build_map[0].argval, 2)
def f(t):
return {1: t, 2: t + 1}
t = torch.randn(3, 4)
d = f(t)
with self.register_bytecode_hook(hook):
opt_f = torch.compile(f, backend="eager", fullgraph=True)
d_opt = opt_f(t)
self.assertEqual(d, d_opt)
@unittest.skipIf(
IS_FBCODE, "capturing functional_call is not enabled by default in FB_CODE"
)
def test_functional_call_reconstruct(self):
"""
PyTorch shouldn't codegen any key/value when functional_call is used
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything
self.assertEqual(len(build_map), 0)
m = torch.nn.Linear(3, 3)
new_bias = torch.randn(3)
new_weight = torch.randn(3, 3)
def fn(new_weight, new_bias, x):
return torch.func.functional_call(
m, {"weight": new_weight, "bias": new_bias}, x
)
x = torch.randn(2, 3)
expected = torch.nn.functional.linear(x, new_weight, new_bias)
with self.register_bytecode_hook(hook):
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
got = opt_fn(new_weight, new_bias, x)
self.assertEqual(expected, got)
@unittest.skipIf(
IS_FBCODE, "capturing functional_call is not enabled by default in FB_CODE"
)
def test_functional_call_reconstruct_2(self):
"""
PyTorch shouldn't codegen any key/value when functional_call is used
"""
def hook(instructions: List[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything
self.assertEqual(len(build_map), 0)
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.ModuleDict(
{
"b": torch.nn.ModuleDict(
{
"c": torch.nn.ModuleDict(
{
"d": torch.nn.ModuleDict(
{"e": torch.nn.Linear(10, 10, bias=False)}
)
}
)
}
)
}
)
def forward(self, x):
return self.a.b.c.d.e(x)
model = DummyModule()
def fn(model, states, x):
return torch.func.functional_call(model, states, x)
x = torch.randn(2, 3)
states = model.state_dict()
x = torch.randn(10, 10)
expected = fn(model, states, x)
with self.register_bytecode_hook(hook):
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
got = opt_fn(model, states, x)
self.assertEqual(expected, got)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|