1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
#include <torch/csrc/jit/passes/fuse_relu.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
namespace torch {
namespace jit {
namespace {
void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
std::string add_relu_0 = R"(
graph(%a, %b, %alpha):
%add_res = aten::add(%a, %b, %alpha)
%res = aten::relu(%add_res)
return (%res))";
std::string add_relu_fused = R"(
graph(%a, %b, %alpha):
%res = aten::_add_relu(%a, %b, %alpha)
return (%res))";
rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
std::string add_relu_1 = R"(
graph(%a, %b, %alpha):
%add_res = aten::add(%a, %b, %alpha)
%res = aten::relu_(%add_res)
return (%res))";
rewriter.RegisterRewritePattern(add_relu_1, add_relu_fused);
std::string add_inplace_relu_1 = R"(
graph(%a, %b, %alpha):
%add_res = aten::add_(%a, %b, %alpha)
%res = aten::relu_(%add_res)
return (%res))";
std::string add_inplace_relu_fused = R"(
graph(%a, %b, %alpha):
%res = aten::_add_relu_(%a, %b, %alpha)
return (%res))";
rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
std::string add_out_relu = R"(
graph(%a, %b, %alpha, %out):
%add_res = aten::add(%a, %b, %alpha, %out)
%res = aten::relu_(%add_res)
return (%res))";
std::string add_out_relu_fused = R"(
graph(%a, %b, %alpha, %out):
%res = aten::_add_relu(%a, %b, %alpha, %out)
return (%res))";
rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);
rewriter.runOnGraph(graph);
// NB: Patterns that are left out are add_ + relu and add_out + relu
// This is because inplace mutation of the testor done by add_ will be lost if
// inplace mutatation of the same tensor actually does add+relu
}
} // namespace
void FuseAddRelu(script::Module& module) {
auto graph = module.get_method("forward").graph();
fuseAddReluImpl(graph);
}
void FuseAddRelu(std::shared_ptr<Graph>& graph) {
fuseAddReluImpl(graph);
}
} // namespace jit
} // namespace torch
|