1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
|
# Owner(s): ["module: dynamo"]
import unittest
import weakref
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._logging
from torch._dynamo.exc import FailOnRecompileLimitHit
from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
class RecompileUxTests(torch._dynamo.test_case.TestCase):
# TODO(whc) dynamo actually recompiles one more time than the cache limit
cache_limit = 1
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch("cache_size_limit", cls.cache_limit)
)
def test_drop_cache_on_skip(self):
def model(x, i):
return x + i
attached = False
triggered = False
def trigger():
nonlocal triggered
triggered = True
def compiler(gm, input):
nonlocal attached
f = gm.forward
assert not attached
# NB: making this a weakref.ref causes the cycle to no
# longer be promptly GC'ed
weakref.finalize(f, trigger)
attached = True
return f
x = torch.randn(2)
for i in range(2):
opt_model = torch.compile(model, backend=compiler)
opt_model(x, i)
self.assertTrue(triggered)
def test_loop_torture(self):
def loop_torture(input, iters):
out = input
# randint itself causes one graph break
for _ in range(iters):
out += input
return out
compile_counter = torch._dynamo.testing.CompileCounter()
for _ in range(10):
x = torch.randn(3)
iters = torch.randint(low=0, high=1000, size=())
opt_loop_torture = torch.compile(loop_torture, backend=compile_counter)
opt_loop_torture(x, iters)
# Currently, we recompile each time,
# We'd probably like to bail out quickly and warn
# TODO(whc) these checks fail on py37. Why?
# self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit)
# self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit)
# compile_counter only sees frames that were fed to the backend compiler,
# which is a subset of counters["frames"]["ok"] -- probably because
# counters["frames"]["ok"] includes frames not containing torch ops?
self.assertEqual(compile_counter.frame_count, self.cache_limit)
@torch._dynamo.config.patch("automatic_dynamic_shapes", False)
def test_dynamic_input(self):
def model(input):
return input + input
expected_recompiles = 2
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("cache_size_limit", expected_recompiles):
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
for _ in range(10):
bsz = torch.randint(low=0, high=1000, size=())
x = torch.randn((bsz, 3, 4))
opt_model = torch.compile(model, backend=compile_counter)
opt_model(x)
self.assertEqual(compile_counter.frame_count, expected_recompiles)
self.assertEqual(len(logs.records), 1)
print(logs.records[0])
self.assertTrue(
logs.records[0]
.getMessage()
.startswith("torch._dynamo hit config.cache_size_limit")
)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_nvfuser_guards(self):
# we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
# such that we ensure dynamo is in charge of all the recompilations at the top level,
# and we could thus simplify the underlying torchscript executor
def func(a, b, c):
return a + b * c
a = torch.rand(3, 4, 5, device="cuda")
b = torch.rand(3, 4, 5, device="cuda")
b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5)
b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1)
c = torch.rand(3, 4, 5, device="cuda")
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("cache_size_limit", 2):
opt_func = torch.compile(func, backend=compile_counter)
opt_func(a, b, c) # warmup
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b, c) # no guard fail or recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_v, c) # a view should not cause nvfuser recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_p, c) # a permutation should cause recompile
self.assertEqual(compile_counter.frame_count, 2)
def assert_single_log_contains(self, logs, contains_str):
self.assertEqual(len(logs.records), 1)
self.assertTrue(
logs.records[0].getMessage().find(contains_str) > 0,
msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"',
)
def test_verbose_tensor_check(self):
def func(a):
# Warning: choose a function here whose meta implementation lives
# entirely in C++. If you do a Python one, Dynamo will dive into
# torch._refs which is OK but it will muddy up the warnings
return torch.add(a, 4)
def cache_fail_test(cached_input, missed_input, expected_failure):
# TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
opt_func = torch.compile(func, backend="eager")
# warmup
opt_func(cached_input)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch.compile(func, backend="eager")
opt_func(missed_input)
self.assert_single_log_contains(logs, expected_failure)
a = torch.rand(3, 4, 5)
cache_fail_test(
a,
a[0:2, :, :],
"tensor 'L['a']' size mismatch at index 0. expected 3, actual 2",
)
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
"tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
)
cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.")
cache_fail_test(
a,
a.to(torch.float16),
"tensor 'L['a']' dtype mismatch. expected Float, actual Half",
)
a_grad = a.clone()
a_grad.requires_grad = True
cache_fail_test(
a,
a_grad,
"tensor 'L['a']' requires_grad mismatch. expected requires_grad=0",
)
def test_mismatched_type(self):
a = torch.rand(3, 4, 5)
b = torch.rand(3, 4, 5)
def func(a, b):
return a + b
opt_func = torch.compile(func, backend="eager")
# warmup
opt_func(a, b)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch.compile(func, backend="eager")
opt_func(a, 1)
self.assert_single_log_contains(
logs,
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
)
@torch._dynamo.config.patch(cache_size_limit=1, fail_on_cache_limit_hit=True)
def test_fail_on_cache_limit_hit(self):
@torch.compile(backend="eager")
def func(b, a):
if a:
return b * 2
else:
return b + 1
func(torch.randn(5), True)
with self.assertRaises(FailOnRecompileLimitHit):
func(torch.randn(5), False)
@torch._dynamo.config.patch("cache_size_limit", 32)
def test_multiple_guard_fails(self):
failure_reasons = []
def guard_fail_fn(failure):
failure_reasons.append(failure[0])
def f(x):
return torch.relu(x)
opt_f = torch._dynamo.optimize(
backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
)(f)
for i in range(5):
failure_reasons.clear()
opt_f(torch.randn(8 + i))
failure_str = "\n".join(failure_reasons)
for line in """\
tensor 'L['x']' size mismatch at index 0. expected 11, actual 12
tensor 'L['x']' size mismatch at index 0. expected 10, actual 12
tensor 'L['x']' size mismatch at index 0. expected 9, actual 12
tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
"\n"
):
self.assertIn(
line,
failure_str,
)
@torch._dynamo.config.patch("cache_size_limit", 32)
def test_multiple_guard_fails_report_all(self):
with log_settings(kwargs_to_settings(recompiles_verbose=True)):
failure_reasons = []
def guard_fail_fn(failure):
failure_reasons.append(failure[0])
def f(x):
return torch.ones(len(x), x[-1])
opt_f = torch._dynamo.optimize(
backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
)(f)
opt_f([4, 5, 6])
def filter_reasons():
return "\n".join(
[
line
for line in "\n".join(failure_reasons).splitlines()
if not line.startswith("___check_type_id")
]
)
failure_reasons.clear()
opt_f([7, 8])
for line in """\
len(L['x']) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())
failure_reasons.clear()
opt_f([9])
for line in """\
len(L['x']) == 2
len(L['x']) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|