1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
|
# Owner(s): ["module: inductor"]
import unittest
from unittest import mock
import torch
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import config, memory
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class Foo(torch.nn.Module):
"""
The default compiled graph is
graph():
...
%op0 : [num_users=2] = call_function[...](args = (%primals_2, %primals_1), ...)
%op1 : [num_users=2] = call_function[...](args = (%primals_2, %primals_3), ...)
%op2 : [num_users=1] = call_function[...](args = (%op0, %primals_4), ...)
%op3 : [num_users=1] = call_function[...](args = (%op1, %primals_5), ...)
%op4 : [num_users=1] = call_function[...](args = (%op2,), ...)
%op5 : [num_users=1] = call_function[...](args = (%op3,), ...)
%op6_op7 : [num_users=1] = call_function[...](args = (%op5, %op4), ...)
"""
def __init__(self):
super().__init__()
self.w1 = torch.nn.Parameter(torch.ones(1, 10))
self.w2 = torch.nn.Parameter(torch.ones(1, 1))
self.w3 = torch.nn.Parameter(torch.ones(10, 1))
self.w4 = torch.nn.Parameter(torch.ones(1, 10))
def forward(self, x):
t1 = torch.matmul(x, self.w1)
t2 = torch.matmul(x, self.w2)
t3 = torch.matmul(t1, self.w3)
t4 = torch.matmul(t2, self.w4)
return t3.sum() + t4.sum()
# The tests in this class uses very small tensors. The default
# score_fusion_memory threshold will cause different fusion decisions and
# generate a different wrapper. Override the threshold to make these tests
# happy.
@config.patch("score_fusion_memory_threshold", 1)
class TestOperatorReorderForPeakMemory(TestCase):
def setUp(self):
super().setUp()
self.model = Foo().to(GPU_TYPE)
self.inputs = torch.ones((2048, 1), device=GPU_TYPE)
self.orig_reorder_method = memory.reorder_for_peak_memory
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory(self):
outp_corr = self.model(self.inputs)
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf1 = ")
.check("buf0 = ")
.check("buf2 = ")
.check("buf4 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory_lpmf(self):
outp_corr = self.model(self.inputs)
def reorder_with_only_lpmf(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=None,
):
return self.orig_reorder_method(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=[memory.topological_sort_lpmf],
)
with mock.patch.object(
memory, "reorder_for_peak_memory", reorder_with_only_lpmf
):
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf1 = ")
.check("buf0 = ")
.check("buf2 = ")
.check("buf4 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory_bfs(self):
outp_corr = self.model(self.inputs)
def reorder_with_only_bfs(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=None,
):
return self.orig_reorder_method(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=[memory.topological_sort_bfs],
)
with mock.patch.object(
memory, "reorder_for_peak_memory", reorder_with_only_bfs
):
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf0 = ")
.check("buf1 = ")
.check("buf2 = ")
.check("buf3 = ")
.check("buf4 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@mock.patch.object(config, "reorder_for_peak_memory", True)
def test_reorder_peak_memory_dfs(self):
outp_corr = self.model(self.inputs)
def reorder_with_only_dfs(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=None,
):
return self.orig_reorder_method(
nodes,
name_to_buf,
name_to_fused_node,
graph_inputs,
graph_outputs,
methods=[memory.topological_sort_dfs],
)
with mock.patch.object(
memory, "reorder_for_peak_memory", reorder_with_only_dfs
):
compiled_model = torch.compile(self.model)
code = run_and_get_triton_code(compiled_model, self.inputs)
(
FileCheck()
.check("def call(args):")
.check("buf0 = ")
.check("buf2 = ")
.check("buf4 = ")
.check("buf1 = ")
.check("buf3 = ")
.check("buf5 = ")
.check("buf7 = ")
.run(code)
)
# check for correctness
outp = compiled_model(self.inputs)
self.assertTrue(same(outp, outp_corr))
@unittest.skipIf(
not torch.cuda.is_available()
or torch.cuda.get_device_properties().total_memory < int(1e10),
"Need 10GB memory to be safe to run the test",
)
def test_fusing_reductions_increase_peak_memory(self):
@torch.compile
def f(a, b, c):
return (a @ c).sum(dim=-1) + (b @ c).sum(dim=-1)
a = torch.randn(1024 * 32, 16, device=GPU_TYPE)
b = torch.randn(1024 * 32, 16, device=GPU_TYPE)
c = torch.randn(16, 1024 * 32, device=GPU_TYPE)
torch.cuda.reset_peak_memory_stats()
f(a, b, c)
peak_mem = torch.cuda.max_memory_allocated()
expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2
self.assertLess(peak_mem, expected_bound)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests()
|