1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
|
# Owner(s): ["module: dynamo"]
import io
import logging
import warnings
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
logger = logging.getLogger(__name__)
logger_test = logging.getLogger("test")
def f_info(x):
x = x + x
logger.info("moo")
x = x * x
return x
def f_isEnabledFor(x):
x = x + x
if logger.isEnabledFor(logging.INFO):
logger.info("moo")
x = x * x
return x
@instantiate_parametrized_tests
class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
@parametrize(
"ignore_method, fn, should_ignore_logger",
[
(None, f_info, False),
(logger_test.info, f_info, False),
(None, f_isEnabledFor, False),
(logger_test.isEnabledFor, f_isEnabledFor, False),
(logger.info, f_info, True),
(logging.Logger.info, f_info, True),
(logger.isEnabledFor, f_isEnabledFor, True),
(logging.Logger.isEnabledFor, f_isEnabledFor, True),
],
)
def test_ignore_logger(self, ignore_method, fn, should_ignore_logger):
counters.clear()
x = torch.randn(3, 3)
orig_out = fn(x)
with torch._dynamo.config.patch(ignore_logger_methods={ignore_method}):
opt_f = torch.compile(backend="eager")(fn)
with self.assertLogs(logger, level="INFO") as captured:
logger.info("call logger info to avoid error")
opt_out = opt_f(x)
printed_output = [entry.split(":", 2)[2] for entry in captured.output]
self.assertTrue(same(orig_out, opt_out))
if should_ignore_logger:
self.assertNotIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 0)
else:
self.assertIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 1)
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
def test_dont_reorder_print(self):
def f(x):
x = x + x
print("moo")
x = x * x
return x
counters.clear()
x = torch.randn(3, 3)
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertTrue(same(orig_out, opt_out))
self.assertEqual(printed_output, "moo")
self.assertEqual(len(counters["graph_break"]), 1)
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_reorder_print(self):
def f(x):
print("moo")
x1 = x + x
print(x1)
x2 = x1 * x1
print(1, 2, 3)
x3 = x2 + x2
return (x1, x3)
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3")
self.assertTrue(same(orig_out, opt_out))
@torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn})
def test_reorder_warnings(self):
import warnings
def f(x):
x1 = x + x
warnings.warn("moo")
x2 = x1 * x1
warnings.warn(f"{x2}")
x3 = x2 + x2
return x3
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with warnings.catch_warnings(record=True) as w:
opt_out = opt_f(x)
warning_messages = [str(i.message) for i in w]
orig_out = f(x)
self.assertTrue(same(orig_out, opt_out))
self.assertIn("moo", warning_messages)
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_reorder_print_graph_break(self):
def f(x):
x1 = x + x
print(f"res: {x1}")
x2 = x1 * x1
torch._dynamo.graph_break()
x3 = x2 + x2
print(1, 2, 3)
return x3
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3")
self.assertTrue(same(orig_out, opt_out))
def test_reorder_custom_log_fn(self):
custom_logs = []
def custom_log(s: str):
torch._dynamo.graph_break()
custom_logs.append(s)
def f(x):
custom_log("moo")
x1 = x + x
custom_log(f"{x1}")
return x + x
x = torch.ones(3, 3)
counters.clear()
with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}):
opt_f = torch.compile(backend="eager")(f)
opt_out = opt_f(x)
self.assertEqual(sum(counters["graph_break"].values()), 1)
self.assertEqual(custom_logs[0], "moo")
self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}")
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_constant_mutation(self):
def f(x):
alist = [x]
alist.append(x + 1)
print(alist[-1])
alist[0].sum().item() # graph break
res = alist.pop()
print(alist[-1])
res.sum().item() # graph break
return res
inputs = (torch.tensor([1]),)
counters.clear()
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(*inputs)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(*inputs)
self.assertEqual(printed_output, "tensor([2])\ntensor([1])")
self.assertTrue(same(orig_out, opt_out))
graph_break_key = counters["graph_break"].keys()
self.assertEqual(len(graph_break_key), 1)
self.assertEqual(next(iter(graph_break_key)), "Tensor.item")
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|