1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
|
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.utils import disable_cache_limit
# NB: do NOT include this test class in test_dynamic_shapes.py
class ConfigTests(torch._dynamo.test_case.TestCase):
@disable_cache_limit()
def test_no_automatic_dynamic(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_static = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn = torch.compile(fn, backend=cnt_static)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_static.frame_count, 10)
@disable_cache_limit()
def test_automatic_dynamic(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=True
):
opt_fn = torch.compile(fn, backend=cnt_dynamic)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
# two graphs now rather than 10
self.assertEqual(cnt_dynamic.frame_count, 2)
@disable_cache_limit()
def test_no_assume_static_by_default(self):
def fn(a, b):
return a - b * 10
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn = torch.compile(fn, backend=cnt_dynamic)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
# one graph now, as we didn't wait for recompile
self.assertEqual(cnt_dynamic.frame_count, 1)
def test_config_compile_ignored(self):
# Remove from this list if no longer relevant
dynamo_guarded_config_ignorelist = {
"log_file_name",
"verbose",
"verify_correctness", # will not affect model, will raise RuntimeError
# (no silent change to compilation behaviour)
"cache_size_limit",
"accumulated_cache_size_limit",
"replay_record_enabled",
"cprofile", # only wraps _compile, not graph
"repro_after",
"repro_level",
"repro_forward_only",
"repro_tolerance",
"same_two_models_use_fp64",
"error_on_recompile", # safe because: will throw error
"report_guard_failures",
"base_dir", # used for minifying / logging
"DEBUG_DIR_VAR_NAME",
"debug_dir_root",
}
for k in dynamo_guarded_config_ignorelist:
assert k in torch._dynamo.config._compile_ignored_keys, k
def test_config_hash(self):
config = torch._dynamo.config
starting_hash = config.get_hash()
with config.patch({"verbose": not config.verbose}):
new_hash = config.get_hash()
assert "verbose" in config._compile_ignored_keys
assert new_hash == starting_hash
new_hash = config.get_hash()
assert new_hash == starting_hash
with config.patch({"suppress_errors": not config.suppress_errors}):
changed_hash = config.get_hash()
assert "suppress_errors" not in config._compile_ignored_keys
assert changed_hash != starting_hash
# Test nested patch
with config.patch({"verbose": not config.verbose}):
inner_changed_hash = config.get_hash()
assert inner_changed_hash == changed_hash
assert inner_changed_hash != starting_hash
newest_hash = config.get_hash()
assert changed_hash != newest_hash
assert newest_hash == starting_hash
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|