1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
|
# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
class RecompileTests(torch._dynamo.test_case.TestCase):
def test_automatic_dynamic_reduce_recompiles(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
return x * y
def run_foo_6_times_and_count_recompiles(dynamic=None):
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn([2])
y = torch.randn([2])
opt = torch.compile(foo, backend=cnt, dynamic=dynamic)
opt(x, y)
x = torch.randn([3])
y = torch.randn([3])
opt(x, y)
x = torch.randn([4])
y = torch.randn([4])
opt(x, y)
opt(x, y)
x = torch.randn([5])
y = torch.randn([5])
opt(x, y)
opt(x, y)
x = torch.randn([6])
y = torch.randn([6])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
without = run_foo_6_times_and_count_recompiles(dynamic=False)
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_automatic = run_foo_6_times_and_count_recompiles(dynamic=None)
self.assertEqual(with_automatic.frame_count, 2)
self.assertEqual(with_automatic.op_count, 2)
torch._dynamo.reset()
with_dynamic = run_foo_6_times_and_count_recompiles(dynamic=True)
self.assertEqual(with_dynamic.frame_count, 1)
self.assertEqual(with_dynamic.op_count, 1)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def test_recompiles_true_false_flop(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
if x:
return y * 2
else:
return y * y
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch.compile(foo, backend=cnt, fullgraph=True)
x = True
y = torch.randn([2])
opt(x, y)
x = False
y = torch.randn([2])
opt(x, y)
x = True
y = torch.randn([3])
opt(x, y)
x = True
y = torch.randn([4])
opt(x, y)
x = True
y = torch.randn([5])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 3)
self.assertEqual(with_automatic.op_count, 3)
def test_automatic_dynamic_tensor_scalar_change(self):
# Test the counterfactual, lots of recompiles without this config
def foo(x, y):
return x * y
def run_foo_6_times_and_count_recompiles_swap_types():
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn([2])
y = torch.randn([2])
opt = torch.compile(foo, backend=cnt)
opt(x, y)
x = torch.randn([3])
y = 3
opt(x, y)
x = torch.randn([4])
y = torch.randn([4])
opt(x, y)
opt(x, y)
x = torch.randn([5])
y = 4
opt(x, y)
opt(x, y)
x = torch.randn([6])
y = torch.randn([6])
opt(x, y)
return cnt
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_without_automatic():
return run_foo_6_times_and_count_recompiles_swap_types()
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_with_automatic():
return run_foo_6_times_and_count_recompiles_swap_types()
without = run_without_automatic()
self.assertEqual(without.frame_count, 5)
self.assertEqual(without.op_count, 5)
torch._dynamo.reset()
with_automatic = run_with_automatic()
self.assertEqual(with_automatic.frame_count, 3)
self.assertEqual(with_automatic.op_count, 3)
def test_aliasing_guard_failures(self):
def foo(a, b, c):
a.add_(b)
return c + 1
cnt = torch._dynamo.testing.CompileCounter()
compiled_foo = torch.compile(foo, backend=cnt, fullgraph=True)
x = torch.randn([3])
y = torch.randn([3])
z = torch.randn([3])
cmp_result = compiled_foo(
x.detach().clone(), y.detach().clone(), z.detach().clone()
)
eager_result = foo(x.detach().clone(), y.detach().clone(), z.detach().clone())
self.assertEqual(cmp_result, eager_result)
self.assertEqual(cnt.frame_count, 1)
cmp_result = compiled_foo(
z.detach().clone(), y.detach().clone(), x.detach().clone()
)
eager_result = foo(z.detach().clone(), y.detach().clone(), x.detach().clone())
self.assertEqual(cmp_result, eager_result)
# No recompile, alias preserved
self.assertEqual(cnt.frame_count, 1)
x_clone = x.detach().clone()
cmp_result = compiled_foo(x_clone, y.detach().clone(), x_clone)
x_clone = x.detach().clone()
eager_result = compiled_foo(x_clone, y.detach().clone(), x_clone)
self.assertEqual(cmp_result, eager_result)
# Recompile, alias changed
self.assertEqual(cnt.frame_count, 2)
def test_aliasing_guard_failures_with_globals(self):
g1 = torch.randn([3])
g2 = torch.randn([3])
def foo(a):
a.add_(g1)
return g2 + 1
cnt = torch._dynamo.testing.CompileCounter()
compiled_foo = torch.compile(foo, backend=cnt, fullgraph=True)
z = torch.randn([3])
cmp_result = compiled_foo(z.detach().clone())
eager_result = foo(z.detach().clone())
self.assertEqual(cmp_result, eager_result)
self.assertEqual(cnt.frame_count, 1)
g1 = g1.detach().clone()
cmp_result = compiled_foo(g1)
g1 = g1.detach().clone()
eager_result = compiled_foo(g1)
self.assertEqual(cmp_result, eager_result)
# Recompile, alias changed
self.assertEqual(cnt.frame_count, 2)
def test_dynamic_shape_parameter_recompile(self):
# Test the matrix multiplication with Parameters.
# Without the config assume_parameters_shapes_static_by_default,
# the torch.nn.Parameter shapes are assumed to be static which leads to recompilation
w = torch.nn.Parameter(torch.randn(3, 2))
def foo(x):
return x @ w
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch.compile(foo, backend=cnt, fullgraph=True)
x = torch.nn.Parameter(torch.randn(1, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(10, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(11, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(15, 3))
opt(x)
x = torch.nn.Parameter(torch.randn(15, 3))
opt(x)
return cnt
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_static_comp_default_param():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", True)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_dynamic_comp_default_param():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", False)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_static_comp_dynamic_param():
return run_foo_6_times_and_count_recompiles()
@patch.object(torch._dynamo.config, "force_parameter_static_shapes", False)
@patch.object(torch._dynamo.config, "automatic_dynamic_shapes", True)
@patch.object(torch._dynamo.config, "assume_static_by_default", True)
def run_dynamic_comp_dynamic_param():
return run_foo_6_times_and_count_recompiles()
torch._dynamo.reset()
static_comp_default_param = run_static_comp_default_param()
self.assertEqual(static_comp_default_param.frame_count, 4)
self.assertEqual(static_comp_default_param.op_count, 4)
torch._dynamo.reset()
dynamic_comp_default_param = run_dynamic_comp_default_param()
self.assertEqual(dynamic_comp_default_param.frame_count, 4)
self.assertEqual(dynamic_comp_default_param.op_count, 4)
torch._dynamo.reset()
static_comp_dynamic_param = run_static_comp_dynamic_param()
self.assertEqual(static_comp_dynamic_param.frame_count, 4)
self.assertEqual(static_comp_dynamic_param.op_count, 4)
torch._dynamo.reset()
dynamic_comp_dynamic_param = run_dynamic_comp_dynamic_param()
self.assertEqual(dynamic_comp_dynamic_param.frame_count, 2)
self.assertEqual(dynamic_comp_dynamic_param.op_count, 2)
def test_simple_module_recompile(self):
class SimpleDropout(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.dropout(self.linear(x))
model = SimpleDropout()
x = torch.randn(10)
counter = torch._dynamo.testing.CompileCounter()
model = torch.compile(model, backend=counter, fullgraph=True)
for _ in range(20):
model.eval()
model(x)
model.train()
model(x)
self.assertEqual(counter.frame_count, 2)
@patch.object(torch._dynamo.config, "cache_size_limit", 2)
def test_no_recursive_compile_after_cache_limit_hit(self):
def f(x, n):
x = x + n
return g(x, n)
def g(x, n):
x = x + n
return h(x, n)
def h(x, n):
return x + n
counter = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=counter, dynamic=False)
for i in range(10):
opt_f(torch.ones(3), i)
self.assertEqual(counter.frame_count, 2)
def test_automatic_dynamic_on_closed_ints(self):
def f(x):
def g(y):
return y + x
return g
counter = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=counter)
def h(x, g):
return g(x)
for i in range(10):
h(torch.randn(5), f(i))
self.assertEqual(counter.frame_count, 2)
@patch.object(torch._dynamo.config, "cache_size_limit", 2)
def test_run_mode_after_cache_limit_hit(self):
def f(x, n):
x = x + n
if torch._dynamo.is_compiling():
x = x + 1
return g(x, n)
def g(x, n):
x = x + n
if torch._dynamo.is_compiling():
x = x + 2
return x
counter = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=counter, dynamic=False)
# compiles
self.assertEqual(opt_f(torch.ones(3), 0), torch.ones(3) + 3)
self.assertEqual(opt_f(torch.ones(3), 1), torch.ones(3) + 5)
# cache limit hit
self.assertEqual(opt_f(torch.ones(3), 2), torch.ones(3) + 4)
self.assertEqual(opt_f(torch.ones(3), 3), torch.ones(3) + 6)
# run mode
self.assertEqual(opt_f(torch.ones(3), 0), torch.ones(3) + 3)
self.assertEqual(opt_f(torch.ones(3), 1), torch.ones(3) + 5)
self.assertEqual(counter.frame_count, 2)
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="unbacked")
def test_automatic_dynamic_shapes_mark_as_unbacked(self):
counter = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=counter)
def f(x):
return x * x
f(torch.randn(3))
f(torch.randn(2))
f(torch.randn(1))
f(torch.randn(0))
self.assertEqual(counter.frame_count, 2) # not three or four!
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
def test_automatic_dynamic_shapes_mark_as_oblivious(self):
counter = torch._dynamo.testing.CompileCounter()
def f(x):
if x.size(0) < 10:
return x * 1
else:
return x + 10
opt_f = torch.compile(backend=counter, fullgraph=True)(f)
for i in [3, 2, 1, 0]:
self.assertEqual(f(torch.zeros(i)), opt_f(torch.zeros(i)))
self.assertEqual(counter.frame_count, 2) # not three or four!
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
def test_automatic_dynamic_shapes_mark_as_oblivious_fail_counterfactual(self):
counter = torch._dynamo.testing.CompileCounter()
def f(x):
if x.size(0) < 2:
return x * 1
else:
return x + 10
opt_f = torch.compile(backend=counter, fullgraph=True)(f)
opt_f(torch.randn(1))
with self.assertRaises(torch._dynamo.exc.UserError):
opt_f(torch.randn(0))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|