1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
void FuseLinear(std::shared_ptr<Graph>& graph) {
std::string addmm_pattern = R"IR(
graph(%input, %weight_t, %bias, %beta, %alpha):
%res = aten::addmm(%bias, %input, %weight_t, %beta, %alpha)
return (%res))IR";
std::string fused_linear_addmm = R"IR(
graph(%input, %weight_t, %bias, %beta, %alpha):
%weight = aten::t(%weight_t)
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
auto beta_is_one = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
return is_int_constant(match, vmap, "beta", 1);
};
// check %weight_t is produced by `aten::t` to make sure
// we can transform the pattern to `aten::linear`
auto weight_transposed =
[](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto v = match_vmap.at(vmap.at("weight_t"));
return v->node()->kind() == Symbol::aten("t");
};
// replace addmm pattern to linear
SubgraphRewriter addmm_to_linear;
std::vector<std::pair<std::string, std::string>> value_mappings(
{{"weight", "res"}, {"res", "res"}});
addmm_to_linear.RegisterRewritePattern(
addmm_pattern, fused_linear_addmm, value_mappings);
addmm_to_linear.runOnGraph(
graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed});
std::string matmul_add_pattern = R"IR(
graph(%input, %weight_t, %bias, %alpha):
%output = aten::matmul(%input, %weight_t)
%res = aten::add_(%output, %bias, %alpha)
return (%res))IR";
std::string fused_linear_matmul = R"IR(
graph(%input, %weight_t, %bias, %alpha):
%weight = aten::t(%weight_t)
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
value_mappings = {{"weight", "output"}, {"res", "output"}};
// replace matmul + add pattern to linear
SubgraphRewriter matmuladd_to_linear;
matmuladd_to_linear.RegisterRewritePattern(
matmul_add_pattern, fused_linear_matmul, value_mappings);
matmuladd_to_linear.runOnGraph(
graph, {aten_add_alpha_is_one, weight_transposed});
std::string matmul_pattern = R"IR(
graph(%input, %weight_t):
%output = aten::matmul(%input, %weight_t)
return (%output))IR";
std::string fused_linear_bias_none = R"IR(
graph(%input, %weight_t):
%weight = aten::t(%weight_t)
%bias: Tensor? = prim::Constant()
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
// replace matmul with bias=None pattern to linear
SubgraphRewriter matmul_to_linear;
matmul_to_linear.RegisterRewritePattern(
matmul_pattern, fused_linear_bias_none, value_mappings);
matmul_to_linear.runOnGraph(graph, weight_transposed);
// clean up extra transpose for the weight of aten::linear
std::string linear_weight_extra_transpose = R"IR(
graph(%input, %weight, %bias):
%weight_t1 = aten::t(%weight)
%weight_t2 = aten::t(%weight_t1)
%res = aten::linear(%input, %weight_t2, %bias)
return (%res))IR";
std::string linear_weight_no_transpose = R"IR(
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
value_mappings = {{"res", "res"}};
SubgraphRewriter cleanup;
cleanup.RegisterRewritePattern(
linear_weight_extra_transpose,
linear_weight_no_transpose,
value_mappings);
cleanup.runOnGraph(graph);
SwapFunctionalLinear(graph);
}
void SwapFunctionalLinear(Module& module) {
for (auto& method : module.get_methods()) {
std::shared_ptr<Graph> g = method.graph();
SwapFunctionalLinear(g);
}
for (Module m : module.children()) {
SwapFunctionalLinear(m);
}
}
void SwapFunctionalLinear(std::shared_ptr<Graph>& graph) {
std::string functional_linear = R"(
graph(%linear, %input, %weight, %bias):
%r = prim::CallFunction(%linear, %input, %weight, %bias)
return (%r) )";
std::string aten_linear = R"(
graph(%linear, %input, %weight, %bias):
%r = aten::linear(%input, %weight, %bias)
return (%r) )";
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto linear = graph_rewrite_helper::getValue("linear", match_vmap, vmap);
auto func_name = graph_rewrite_helper::getFuncName(linear);
return func_name == "linear";
};
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(functional_linear, aten_linear);
rewriter.runOnGraph(graph, filter);
}
} // namespace jit
} // namespace torch
|