File: subgraph_rewrite.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (141 lines) | stat: -rw-r--r-- 4,278 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>

namespace torch {
namespace jit {

void SubgraphRewriter::RegisterDefaultPatterns() {
  // TODO: Add actual patterns (like Conv-Relu).
  RegisterRewritePattern(
      R"IR(
graph(%x, %w, %b):
  %c = aten::conv(%x, %w, %b)
  %r = aten::relu(%c)
  return (%r))IR",
      R"IR(
graph(%x, %w, %b):
  %r = aten::convrelu(%x, %w, %b)
  return (%r))IR");
}

void SubgraphRewriter::RegisterRewritePattern(
    const std::string& pattern,
    const std::string& replacement) {
  RewritePatternDescr d = {pattern, replacement};
  patterns_.push_back(d);
}

Module SubgraphRewriter::runOnModule(const Module& module) {
  nodes_to_delete_.clear();
  for (const auto& m : module.get_methods()) {
    auto g = m.function().graph();
    runOnGraph(g);
  }
  return module;
}

void SubgraphRewriter::runOnGraph(
    std::shared_ptr<Graph>& graph,
    const std::vector<MatchFilter>& filters) {
  for (const RewritePatternDescr& pattern : patterns_) {
    rewriteSinglePatternOnGraph(graph, pattern, filters);
  }
}

void SubgraphRewriter::rewriteSinglePatternOnGraph(
    std::shared_ptr<Graph>& graph,
    const RewritePatternDescr& pattern,
    const std::vector<MatchFilter>& filters) {
  std::unordered_map<Value*, Value*> rewrite_map;
  std::vector<Value*> values_to_rewrite;

  Graph pattern_graph;
  std::unordered_map<std::string, Value*> vmap;
  parseIR(pattern.pattern, &pattern_graph, vmap);

  Graph replacement_graph;
  parseIR(pattern.replacement, &replacement_graph);

  const auto& matches = findPatternMatches(pattern_graph, *graph);
  for (const Match& match : matches) {
    if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
          return f(match, vmap);
        })) {
      continue;
    }
    // Matches might overlap with each other, in that case some of the nodes in
    // the current match might have already been used in another folded pattern.
    // We need to skip such matches.
    if (overlapsWithPreviousMatches(&match)) {
      continue;
    }

    // Figure out what values we need to use as inputs and outputs for the
    // replacement subgraph. These would be inputs and outputs of the subgraph
    // we matched.
    std::vector<Value*> inputs, outputs;
    for (Value* v : pattern_graph.inputs()) {
      inputs.push_back(match.values_map.at(v));
    }
    for (Value* v : pattern_graph.outputs()) {
      outputs.push_back(match.values_map.at(v));
    }

    // Insert a clone of replacement subgraph after the matched subgraph.
    // `inputs` vector holds values that we would use as incoming values to the
    // new subgraph, and we will get `new_outputs` vector containing values
    // produced by this new subgraph - we will then rewrite old outputs with the
    // new ones.
    WithInsertPoint insert_point(match.anchor);
    std::vector<Value*> new_outputs =
        insertGraph(*graph, replacement_graph, inputs);

    // Record all planned rewritings
    AT_ASSERT(outputs.size() == new_outputs.size());
    for (size_t idx = 0; idx < outputs.size(); idx++) {
      values_to_rewrite.push_back(outputs[idx]);
      rewrite_map[outputs[idx]] = new_outputs[idx];
    }
    // Record all planned deletions
    for (Node* pattern_n : pattern_graph.nodes()) {
      if (match.nodes_map.count(pattern_n)) {
        Node* n = match.nodes_map.at(pattern_n);
        nodes_to_delete_.insert(n);
      }
    }
  }

  // Perform planned rewritings
  for (auto v : values_to_rewrite) {
    v->replaceAllUsesWith(rewrite_map.at(v));
  }

  // Perform planned deletions
  for (auto n : nodes_to_delete_) {
    n->removeAllInputs();
  }
  for (auto n : nodes_to_delete_) {
    n->destroy();
  }
  nodes_to_delete_.clear();
}

bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
  for (auto n : match->nodes_map) {
    if (nodes_to_delete_.count(n.second)) {
      return true;
    }
  }
  return false;
}

Module PatternBasedRewrite(const Module& module) {
  // TODO: Deep-copy the module
  SubgraphRewriter subgraph_rewriter;
  subgraph_rewriter.RegisterDefaultPatterns();
  return subgraph_rewriter.runOnModule(module);
}

} // namespace jit
} // namespace torch