1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
|
# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch._dynamo.utils import dynamo_timed
from torch.testing._internal.common_utils import TemporaryFileName
class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
def test_dynamo_timed_profiling_isolated(self):
# dynamo_timed functions should appear in profile traces.
def inner_fn(x):
with dynamo_timed("inner_fn"):
return x.sin()
def outer_fn(x, y):
return inner_fn(x) * y
x, y = (torch.rand((2, 2)) for _ in range(2))
with torch.profiler.profile(with_stack=False) as prof:
outer_fn(x, y)
self.assertTrue(
any("inner_fn (dynamo_timed)" in evt.name for evt in prof.events())
)
def test_dynamo_timed_profiling_backend_compile(self):
# dynamo_timed functions should appear in profile traces.
# this checks whether these actually appear in actual dynamo execution.
# "backend_compile" is just chosen as an example; if it gets renamed
# this test can be replaced or deleted
fn_name = "call_user_compiler"
def fn(x, y):
return x.sin() * y.cos()
x, y = (torch.rand((2, 2)) for _ in range(2))
with torch.profiler.profile(with_stack=False) as prof:
torch.compile(fn, backend="aot_eager")(x, y)
self.assertTrue(
any(f"{fn_name} (dynamo_timed)" in evt.name for evt in prof.events())
)
@patch.object(torch._dynamo.config, "assume_static_by_default", False)
def test_profile_dynamic_shapes_runtime(self):
def fn(x, y, z):
return x @ y + z
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True, fullgraph=True)
inputs = [
(torch.rand(a, b), torch.rand(b, c), torch.rand(a, c))
for (a, b, c) in [(15, 16, 17), (15, 15, 16), (16, 16, 16)]
]
opt_fn(*inputs[0])
opt_fn(*inputs[1])
with torch.profiler.profile(record_shapes=True):
opt_fn(*inputs[2])
@patch.object(torch._dynamo.config, "assume_static_by_default", False)
def test_profile_dynamic_shapes_compilation(self):
def fn(x, y, z):
return x @ y + z
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True, fullgraph=True)
inputs = (torch.rand(15, 16), torch.rand(16, 17), torch.rand(15, 17))
with torch.profiler.profile(record_shapes=True):
opt_fn(*inputs)
@patch.object(torch._dynamo.config, "assume_static_by_default", False)
def test_profile_dynamic_shapes_list_compilation(self):
def fn(x, y, z):
return torch.cat([x, y], dim=0) + z
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True, fullgraph=True)
inputs = (torch.rand(4, 16), torch.rand(12, 16), torch.rand(16, 16))
with torch.profiler.profile(record_shapes=True):
opt_fn(*inputs)
def test_execution_trace_dynamic_shapes(self):
def fn(x, y, z):
return x @ y + z
et = torch.profiler.ExecutionTraceObserver()
opt_fn = torch.compile(fn, dynamic=True, backend="aot_eager")
inputs = [torch.rand((4, 4)) for _ in range(3)]
with TemporaryFileName() as fname:
et.register_callback(fname)
et.start()
out = opt_fn(*inputs)
et.stop()
et.unregister_callback()
def test_profiler_cache_lookup(self):
def fn(x):
y = x**2
y = y + 2
z = y**3
return z
for profiler, get_events in (
(torch.autograd.profiler.profile, lambda prof: prof.function_events),
(torch.profiler.profiler.profile, lambda prof: prof.events()),
):
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager")
# warmup
opt_fn(x)
with profiler() as prof:
res = opt_fn(x)
events = list(
filter(
lambda event: "TorchDynamo Cache Lookup" in event.name,
get_events(prof),
)
)
self.assertEqual(ref, res)
self.assertTrue(
len(events) == 1,
"Expected one lookup profiler event for one opt_fn run",
)
def test_profiler_cache_lookup_profiler_step(self):
def fn(x, y, z):
return torch.add(torch.sub(x, y), z)
opt_fn = torch.compile(fn, backend="aot_eager")
(
x,
y,
z,
) = (torch.rand(4, 4) for _ in range(3))
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=2, warmup=2, active=2, repeat=1)
)
for _ in range(10):
opt_fn(x, y, z)
prof.step()
self.assertTrue(
any(e.name == "TorchDynamo Cache Lookup" for e in prof.events())
)
def test_profiler_dynamo_compiled_region(self):
def fn(x, y):
r = y.sum(dim=1)
print(r.shape)
return x * r
with torch.profiler.profile() as prof:
fn_c = torch.compile(fn)
fn_c(
torch.randn(10),
torch.randn(10, 10),
)
fn_c(
torch.randn(10),
torch.randn(10, 15),
)
annotations = [e.name for e in prof.events() if "Compiled" in e.name]
self.assertEqual(
annotations,
[
"Torch-Compiled Region: 0/0",
"Torch-Compiled Region: 1/0",
"Torch-Compiled Region: 0/1",
"Torch-Compiled Region: 1/0",
],
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|