1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
namespace torch {
namespace jit {
void SubgraphRewriter::RegisterDefaultPatterns() {
// TODO: Add actual patterns (like Conv-Relu).
RegisterRewritePattern(
R"IR(
graph(%x, %w, %b):
%c = aten::conv(%x, %w, %b)
%r = aten::relu(%c)
return (%r))IR",
R"IR(
graph(%x, %w, %b):
%r = aten::convrelu(%x, %w, %b)
return (%r))IR");
}
void SubgraphRewriter::RegisterRewritePattern(
const std::string& pattern,
const std::string& replacement) {
RewritePatternDescr d = {pattern, replacement};
patterns_.push_back(d);
}
Module SubgraphRewriter::runOnModule(const Module& module) {
nodes_to_delete_.clear();
for (const auto& m : module.get_methods()) {
auto g = m.function().graph();
runOnGraph(g);
}
return module;
}
void SubgraphRewriter::runOnGraph(
std::shared_ptr<Graph>& graph,
const std::vector<MatchFilter>& filters) {
for (const RewritePatternDescr& pattern : patterns_) {
rewriteSinglePatternOnGraph(graph, pattern, filters);
}
}
void SubgraphRewriter::rewriteSinglePatternOnGraph(
std::shared_ptr<Graph>& graph,
const RewritePatternDescr& pattern,
const std::vector<MatchFilter>& filters) {
std::unordered_map<Value*, Value*> rewrite_map;
std::vector<Value*> values_to_rewrite;
Graph pattern_graph;
std::unordered_map<std::string, Value*> vmap;
parseIR(pattern.pattern, &pattern_graph, vmap);
Graph replacement_graph;
parseIR(pattern.replacement, &replacement_graph);
const auto& matches = findPatternMatches(pattern_graph, *graph);
for (const Match& match : matches) {
if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
return f(match, vmap);
})) {
continue;
}
// Matches might overlap with each other, in that case some of the nodes in
// the current match might have already been used in another folded pattern.
// We need to skip such matches.
if (overlapsWithPreviousMatches(&match)) {
continue;
}
// Figure out what values we need to use as inputs and outputs for the
// replacement subgraph. These would be inputs and outputs of the subgraph
// we matched.
std::vector<Value*> inputs, outputs;
for (Value* v : pattern_graph.inputs()) {
inputs.push_back(match.values_map.at(v));
}
for (Value* v : pattern_graph.outputs()) {
outputs.push_back(match.values_map.at(v));
}
// Insert a clone of replacement subgraph after the matched subgraph.
// `inputs` vector holds values that we would use as incoming values to the
// new subgraph, and we will get `new_outputs` vector containing values
// produced by this new subgraph - we will then rewrite old outputs with the
// new ones.
WithInsertPoint insert_point(match.anchor);
std::vector<Value*> new_outputs =
insertGraph(*graph, replacement_graph, inputs);
// Record all planned rewritings
AT_ASSERT(outputs.size() == new_outputs.size());
for (size_t idx = 0; idx < outputs.size(); idx++) {
values_to_rewrite.push_back(outputs[idx]);
rewrite_map[outputs[idx]] = new_outputs[idx];
}
// Record all planned deletions
for (Node* pattern_n : pattern_graph.nodes()) {
if (match.nodes_map.count(pattern_n)) {
Node* n = match.nodes_map.at(pattern_n);
nodes_to_delete_.insert(n);
}
}
}
// Perform planned rewritings
for (auto v : values_to_rewrite) {
v->replaceAllUsesWith(rewrite_map.at(v));
}
// Perform planned deletions
for (auto n : nodes_to_delete_) {
n->removeAllInputs();
}
for (auto n : nodes_to_delete_) {
n->destroy();
}
nodes_to_delete_.clear();
}
bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
for (auto n : match->nodes_map) {
if (nodes_to_delete_.count(n.second)) {
return true;
}
}
return false;
}
Module PatternBasedRewrite(const Module& module) {
// TODO: Deep-copy the module
SubgraphRewriter subgraph_rewriter;
subgraph_rewriter.RegisterDefaultPatterns();
return subgraph_rewriter.runOnModule(module);
}
} // namespace jit
} // namespace torch
|