1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
# Owner(s): ["module: inductor"]
from typing import Any, Callable
import torch
from torch._inductor.fx_passes.pre_grad import (
linear_permute_fusion,
linear_transpose,
permute_linear_fusion,
permute_matmul_fusion,
sink_cat_after_pointwise,
transpose_linear,
transpose_matmul,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.fx.passes.shape_prop import ShapeProp
PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule]
def chain_passes(*passes: PassFunc) -> PassFunc:
def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule:
for pass_ in passes:
if isinstance(module, torch.fx.GraphModule):
ShapeProp(module).propagate(*input)
module = pass_(module)
return module
return parent_pass
def count_call(module: torch.fx.GraphModule, op: str, target_op: Any) -> int:
return sum(
1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes
)
def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int:
return count_call(module, "call_function", target_op)
def count_call_method(module: torch.fx.GraphModule, target_op: Any) -> int:
return count_call(module, "call_method", target_op)
class TestFxFusion(TestCase):
def test_sink_cat_after_pointwise(self):
def test_kwarg(x, y):
return torch.cat([x, y], dim=-1).view(-1).view(128).tanh()
def test_arg(x, y):
return torch.cat([x, y], -1).view(-1).view(128).tanh()
def test_arg2(x, y):
return torch.cat([x, y]).view(-1).view(128).tanh()
def test_kwarg2(x, y):
return torch.cat(tensors=[x, y], dim=0).tanh()
def test_kwarg3(x, y):
return torch.cat(tensors=[x, y], dim=0).view(128).tanh()
trace_func = chain_passes(torch.fx.symbolic_trace, sink_cat_after_pointwise)
inputs = [
torch.randn(8, 8),
torch.randn(8, 8),
]
for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]:
traced = trace_func(f, inputs)
torch.testing.assert_close(f(*inputs), traced(*inputs))
self.assertEqual(count_call_method(traced, "tanh"), 2)
def test_linear_permute_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, k: int, n: int, has_bias: bool):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(n, k))
self.has_bias = has_bias
if has_bias:
self.bias = torch.nn.Parameter(torch.randn(n))
def forward(self, input: torch.Tensor):
if self.has_bias:
a0 = torch.nn.functional.linear(input, self.weight, self.bias)
else:
a0 = torch.nn.functional.linear(input, self.weight)
b0 = a0.permute(0, 2, 1)
return b0
m, k, n = 16, 8, 4
trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion)
for has_bias in [True, False]:
module = TestModule(k, n, has_bias).eval()
input = torch.randn(6, m, k)
traced = trace_func(module, [input])
num_linear = count_call_function(traced, torch.nn.functional.linear)
num_linear_transpose = count_call_function(traced, linear_transpose)
self.assertEqual(num_linear, 0)
self.assertEqual(num_linear_transpose, 1)
torch.testing.assert_close(module(input), traced(input))
def test_permute_linear_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, k: int, n: int, has_bias: bool):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(n, k))
self.has_bias = has_bias
if has_bias:
self.bias = torch.nn.Parameter(torch.randn(n))
def forward(self, input: torch.Tensor):
input1 = input.permute(0, 2, 1)
if self.has_bias:
return torch.nn.functional.linear(input1, self.weight, self.bias)
return torch.nn.functional.linear(input1, self.weight)
m, k, n = 16, 8, 4
trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion)
for has_bias in [True, False]:
module = TestModule(k, n, has_bias).eval()
input = torch.randn(6, k, m)
traced = trace_func(module, [input])
num_linear = count_call_function(traced, torch.nn.functional.linear)
num_transpose_linear = count_call_function(traced, transpose_linear)
self.assertEqual(num_linear, 0)
self.assertEqual(num_transpose_linear, 1)
torch.testing.assert_close(module(input), traced(input))
def test_permute_bmm_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, batch: int, k: int, n: int):
super().__init__()
self.other = torch.randn(batch, k, n)
def forward(self, input: torch.Tensor):
input1 = input.permute(0, 2, 1)
output = torch.bmm(input1, self.other)
return output
batch, m, k, n = 6, 16, 8, 4
trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion)
module = TestModule(batch, k, n).eval()
input = torch.randn(batch, k, m)
traced = trace_func(module, [input])
num_bmm = count_call_function(traced, torch.bmm)
num_transpose_matmul = count_call_function(traced, transpose_matmul)
self.assertEqual(num_bmm, 0)
self.assertEqual(num_transpose_matmul, 1)
torch.testing.assert_close(module(input), traced(input))
if __name__ == "__main__":
run_tests()
|