1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.testing import CompileCounter
_variable = 0
_variable_2 = 0
def user_function():
return torch.compiler.is_compiling()
def user_generator():
for _ in range(1):
yield torch.compiler.is_compiling()
return
class MyModule(torch.nn.Module):
def __init__(self, mode: int):
super().__init__()
self.mode = mode
self.register_forward_pre_hook(self.pre_forward, with_kwargs=True)
def pre_forward(self, module, args, kwargs):
if self.mode == 5:
if user_function():
global _variable
_variable += 1
return args, kwargs
def forward(self, x):
global _variable, _variable_2
if self.mode == 1:
if torch.compiler.is_compiling():
_variable += 1
else:
_variable_2 += 1
elif self.mode == 2:
if user_function():
_variable += 1
elif self.mode == 3:
lambda_f = lambda: torch.compiler.is_compiling() # noqa: E731
if lambda_f():
_variable += 1
elif self.mode == 4:
for cond in user_generator():
if cond:
_variable += 1
elif self.mode == 5:
x += 1
elif self.mode == 6:
if user_function():
torch._dynamo.graph_break()
_variable += 1
return x
class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
def test_add_tensor1(self):
def fn(a, b):
return a + b
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn(x, y)
assert counter.op_count == 1
def test_add_tensor2(self):
def fn(a, b):
return torch.add(a, b)
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn(x, y)
assert counter.op_count == 1
def test_add_tensor_list(self):
def fn(lst):
return lst[0] + lst[1]
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn([x, y])
assert counter.op_count == 1
def test_add_tensor_dict(self):
def fn(dt):
return dt["a"] + dt["b"]
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn({"a": x, "b": y})
assert counter.op_count == 1
def test_add_skip(self):
def fn(a, b):
return a + b
counter = CompileCounter()
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
x = 4
y = 5
opt_fn(x, y)
assert counter.op_count == 0
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_recursive_list(self):
def fn(x):
return x
counter = CompileCounter()
x = []
x.append(x)
with torch._dynamo.optimize_assert(counter):
fn(x)
assert counter.op_count == 0
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_custom_list(self):
def fn(x):
return x[0] + x[1]
counter = CompileCounter()
class Foo(list):
def __iter__(self):
raise Exception # noqa: TRY002
def __len__(self):
raise Exception # noqa: TRY002
x = Foo()
x.append(torch.randn(4))
x.append(torch.randn(4))
with torch._dynamo.optimize_assert(counter):
fn(x)
assert counter.op_count == 0
def test_do_not_skip_side_effects(self):
# https://github.com/pytorch/pytorch/issues/110765
# By invoking torch.compiler.is_compiling(),
# there may be side-effects inconsistent with eager when
# compiling. Thus we force dynamo to commit the graph,
# even if it does not perform any tensor operation
global _variable, _variable_2
for mode in range(1, 7):
torch._dynamo.reset()
_variable = 0
_variable_2 = 0
mod = MyModule(mode=mode)
model = torch.compile(mod, backend="eager", fullgraph=mode != 6)
assert _variable == 0
assert _variable_2 == 0
model(torch.tensor([1]))
assert _variable == 1
assert _variable_2 == 0
model(torch.tensor([1]))
assert _variable == 2
assert _variable_2 == 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|