1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
def my_custom_function(x):
return x + 1
class RunDiffGuardTests(torch._dynamo.test_case.TestCase):
def test_bool_recompile(self):
def fn(x, y, c):
if c:
return x * y
else:
return x + y
opt_fn = torch.compile(fn, backend="inductor")
x = 2 * torch.ones(4)
y = 3 * torch.ones(4)
ref1 = opt_fn(x, y, True)
ref2 = opt_fn(x, y, False)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
res2 = opt_fn(x, y, False)
res1 = opt_fn(x, y, True)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
def test_tensor_recompile(self):
def fn(x, y):
return x * y
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, dtype=torch.float32)
y = torch.randn(4, dtype=torch.float32)
ref1 = opt_fn(x, y)
x64 = torch.randn(4, dtype=torch.float64)
y64 = torch.randn(4, dtype=torch.float64)
ref2 = opt_fn(x64, y64)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
res1 = opt_fn(x, y)
res2 = opt_fn(x64, y64)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
def test_post_recompile(self):
class Foo:
a = 4
b = 5
foo = Foo()
def fn(x):
return x + foo.a + foo.b
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
foo.a = 11
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
# Set it back to original value
foo.a = 4
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
foo.a = 11
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
# Check that we are back to original behavior
foo.b = 8
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
def test_fail_on_tensor_shape_change(self):
def fn(dt):
return dt["x"] + 1
x = torch.randn(4)
dt = {}
dt["x"] = x
opt_fn = torch.compile(fn, backend="eager")
opt_fn(dt)
with self.assertRaisesRegex(
RuntimeError, "Recompilation triggered with skip_guard_eval_unsafe stance"
):
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
x = torch.randn(4, 4)
dt["x"] = x
opt_fn(dt)
def test_cache_line_pickup(self):
def fn(x, a=None, b=None):
x = x * 3
if a:
x = x * 5
if b:
x = x * 7
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.ones(4)
ref1 = opt_fn(x, a=None, b=None)
ref2 = opt_fn(x, a=1, b=None)
ref3 = opt_fn(x, a=1, b=1)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
res1 = opt_fn(x, a=None, b=None)
res2 = opt_fn(x, a=1, b=None)
res3 = opt_fn(x, a=1, b=1)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
self.assertEqual(ref3, res3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|