1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
# Owner(s): ["module: inductor"]
import contextlib
import torch
from torch._inductor.dependencies import MemoryDep
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout, Pointwise
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import sympy_index_symbol
from torch._inductor.virtualized import ops, V
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
class TestDependencies(InductorTestCase):
def _create_buffer(self, name, shape, dtype=torch.float32):
return Buffer(
name=name,
layout=FixedLayout(torch.device(GPU_TYPE), dtype=dtype, size=shape),
)
def setUp(self):
super().setUp()
class DummyModule(torch.nn.Module):
def forward(self, x):
return x * 2
self._gm = torch.fx.symbolic_trace(DummyModule())
self._graph = GraphLowering(self._gm)
self._stack = contextlib.ExitStack()
self._stack.enter_context(V.set_graph_handler(self._graph))
def tearDown(self):
self._stack.close()
super().tearDown()
def test_bucketize_dependencies_no_sorter(self):
offsets = self._create_buffer("offsets", (1025,), torch.int32)
def inner_fn(index):
idx = index[0]
return ops.bucketize(
values=idx,
boundaries=(
offsets.get_name(),
offsets.get_size()[-1],
offsets.get_size()[0] * offsets.get_stride()[0],
offsets.get_stride()[-1],
),
boundary_indices=0,
indexing_dtype=torch.int32,
right=True,
)
pointwise = Pointwise.create(
device=torch.device(GPU_TYPE),
dtype=torch.int32,
inner_fn=inner_fn,
ranges=[1024 * 4],
)
self.assertEqual(len(pointwise.get_reads()), 1)
def test_bucketize_dependencies_sorter(self):
offsets = self._create_buffer("offsets", (1025,), torch.int32)
sorter = self._create_buffer("sorter", (1025,), torch.int32)
def inner_fn(index):
idx = index[0]
return ops.bucketize(
values=idx,
boundaries=(
offsets.get_name(),
offsets.get_size()[-1],
offsets.get_size()[0] * offsets.get_stride()[0],
offsets.get_stride()[-1],
),
boundary_indices=0,
indexing_dtype=torch.int32,
right=True,
sorter=(
sorter.get_name(),
sorter.get_stride()[-1],
),
sorter_indices=0,
)
pointwise = Pointwise.create(
device=torch.device(GPU_TYPE),
dtype=torch.int32,
inner_fn=inner_fn,
ranges=[1024 * 4],
)
self.assertEqual(len(pointwise.get_reads()), 2)
def test_get_offset(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")
var_ranges = {
x: 1024,
y: 2048,
}
dep1 = MemoryDep(
"dep1",
x * 2048 + y,
list(var_ranges.keys()),
list(var_ranges.values()),
)
dep2 = MemoryDep(
"dep2",
x * 2048 + y + 1024,
list(var_ranges.keys()),
list(var_ranges.values()),
)
self.assertEqual(dep1.get_offset(), 0)
self.assertEqual(dep2.get_offset(), 1024)
def test_normalize_with_stride_order_equal(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")
var_ranges = {
x: 1024,
y: 2048,
}
loop_order1 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y,
[x, y],
[1024, 2048],
)
loop_order2 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y,
[y, x],
[2048, 1024],
)
self.assertTrue(loop_order1 != loop_order2)
normalized_loop_order1 = loop_order1.normalize_with_stride_order()
normalized_loop_order2 = loop_order2.normalize_with_stride_order()
self.assertTrue(normalized_loop_order1 == normalized_loop_order2)
def test_normalize_with_stride_order_unequal(self):
x = sympy_index_symbol("x")
y = sympy_index_symbol("y")
var_ranges = {
x: 1024,
y: 2048,
}
loop_order1 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y,
[x, y],
[1024, 2048],
)
loop_order2 = MemoryDep(
"access_the_same_buffer",
x * 2048 + y + 5,
[y, x],
[2048, 1024],
)
self.assertTrue(loop_order1 != loop_order2)
normalized_loop_order1 = loop_order1.normalize_with_stride_order()
normalized_loop_order2 = loop_order2.normalize_with_stride_order()
# unequal due to different offset
self.assertTrue(normalized_loop_order1 != normalized_loop_order2)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CPU and HAS_GPU:
run_tests("sympy")
|