File: subgraph_rewrite.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (222 lines) | stat: -rw-r--r-- 7,412 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>

#include <c10/util/irange.h>

namespace torch {
namespace jit {

namespace {
void update_source_range_and_cs_ptr(
    const std::set<const Node*>& input_nodes,
    const Match& m,
    std::unordered_map<Node*, Node*>& pattern_node_map) {
  // pattern_node_map, maps nodes of the replacement graph
  // to the nodes of the pattern graph.
  // Now we iterate over each node of the replacement graph
  // and find the corresponding pattern node in the match.
  // The matched's node's source range and callstack is then
  // used to update replacement node's source range and callstack
  for (auto& it : pattern_node_map) {
    Node* replacement_node = it.first;
    Node* pattern_node = it.second;
    if (!input_nodes.count(pattern_node)) {
      Node* orig_node = m.nodes_map.at(pattern_node);
      replacement_node->setSourceRange(orig_node->sourceRange());
      if (orig_node->callstack()) {
        replacement_node->setCallStack(orig_node->callstack().value());
      }
    }
  }
}
} // namespace

void SubgraphRewriter::RegisterDefaultPatterns() {
  // TODO: Add actual patterns (like Conv-Relu).
  RegisterRewritePattern(
      R"IR(
graph(%x, %w, %b):
  %c = aten::conv(%x, %w, %b)
  %r = aten::relu(%c)
  return (%r))IR",
      R"IR(
graph(%x, %w, %b):
  %r = aten::convrelu(%x, %w, %b)
  return (%r))IR",
      {{"r", "c"}});
}

void SubgraphRewriter::RegisterRewritePattern(
    const std::string& pattern,
    const std::string& replacement,
    const std::vector<std::pair<std::string, std::string>>& value_name_pairs) {
  std::unordered_map<std::string, std::string> value_name_map(
      value_name_pairs.begin(), value_name_pairs.end());
  RewritePatternDescr d = {pattern, replacement, value_name_map};
  patterns_.push_back(d);
}

Module SubgraphRewriter::runOnModule(const Module& module) {
  nodes_to_delete_.clear();
  for (const auto& m : module.get_methods()) {
    auto g = toGraphFunction(m.function()).graph();
    runOnGraph(g);
  }
  return module;
}

void SubgraphRewriter::runOnGraph(
    std::shared_ptr<Graph>& graph,
    const std::vector<MatchFilter>& filters) {
  for (const RewritePatternDescr& pattern : patterns_) {
    rewriteSinglePatternOnGraph(graph, pattern, filters);
  }
}

void SubgraphRewriter::rewriteSinglePatternOnGraph(
    std::shared_ptr<Graph>& graph,
    const RewritePatternDescr& pattern,
    const std::vector<MatchFilter>& filters) {
  std::unordered_map<Value*, Value*> rewrite_map;
  std::vector<Value*> values_to_rewrite;

  Graph pattern_graph;
  std::unordered_map<std::string, Value*> vmap;
  parseIR(pattern.pattern, &pattern_graph, vmap);

  Graph replacement_graph;
  std::unordered_map<std::string, Value*> vmap_replacement;
  parseIR(pattern.replacement, &replacement_graph, vmap_replacement);

  // First construct map of Node*-to-Node*
  // This maps Nodes in replacement graph to nodes in pattern graph
  // given the value_name_map, which maps value names from repalcement
  // pattern to value name in pattern
  std::unordered_map<Node*, Node*> pattern_node_map;
  std::set<const Node*> pattern_input_nodes;
  for (auto& it : vmap_replacement) {
    const auto& replacement_value_name = it.first;
    Node* replacement_value_node = it.second->node();
    if (pattern.value_name_map.count(replacement_value_name)) {
      const auto& pattern_value_name =
          pattern.value_name_map.at(replacement_value_name);
      TORCH_CHECK(
          vmap.count(pattern_value_name),
          "Value must be found in the replacement graph.");
      Node* pattern_value_node = vmap.at(pattern_value_name)->node();
      pattern_node_map.emplace(replacement_value_node, pattern_value_node);
    }
  }

  const auto& matches = findPatternMatches(pattern_graph, *graph);
  for (const Match& match : matches) {
    if (!std::all_of(filters.begin(), filters.end(), [&](const MatchFilter& f) {
          return f(match, vmap);
        })) {
      continue;
    }
    // Matches might overlap with each other, in that case some of the nodes in
    // the current match might have already been used in another folded pattern.
    // We need to skip such matches.
    if (overlapsWithPreviousMatches(&match)) {
      continue;
    }

    // Figure out what values we need to use as inputs and outputs for the
    // replacement subgraph and where the replacement subgraph needs to be
    // inserted.
    Node* ins_point = nullptr;
    std::vector<Value*> inputs, outputs;
    for (Value* v : pattern_graph.inputs()) {
      Value* input = match.values_map.at(v);
      if (!ins_point || ins_point->isBefore(input->node())) {
        ins_point = input->node();
      }
      inputs.push_back(input);
    }
    AT_ASSERT(ins_point);

    // Check that the insertion point we've chosen precedes all the uses of the
    // outputs - otherwise the replacement is incorrect and we have to skip it.
    bool ins_point_before_uses = true;
    for (Value* v : pattern_graph.outputs()) {
      Value* output = match.values_map.at(v);
      outputs.push_back(match.values_map.at(v));

      for (const Use& u : output->uses()) {
        if (u.user->isBefore(ins_point)) {
          ins_point_before_uses = false;
          break;
        }
      }
    }

    if (!ins_point_before_uses) {
      continue;
    }

    // Before rewriting the graph, update source range and callstack
    // info of the replacement pattern graph so that the rewritten graph
    // has the updated info
    update_source_range_and_cs_ptr(
        pattern_input_nodes, match, pattern_node_map);
    // Insert a clone of replacement subgraph.
    // `inputs` vector holds values that we would use as incoming values to the
    // new subgraph, and we will get `new_outputs` vector containing values
    // produced by this new subgraph - we will then rewrite old outputs with the
    // new ones.
    WithInsertPoint insert_point(ins_point->next());
    std::vector<Value*> new_outputs =
        insertGraph(*graph, replacement_graph, inputs);

    // Record all planned rewritings
    AT_ASSERT(outputs.size() == new_outputs.size());
    for (const auto idx : c10::irange(outputs.size())) {
      values_to_rewrite.push_back(outputs[idx]);
      rewrite_map[outputs[idx]] =
          new_outputs[idx]->setType(outputs[idx]->type());
    }
    // Record all planned deletions
    for (Node* pattern_n : pattern_graph.nodes()) {
      if (match.nodes_map.count(pattern_n)) {
        Node* n = match.nodes_map.at(pattern_n);
        nodes_to_delete_.insert(n);
      }
    }
  }

  // Perform planned rewritings
  for (auto v : values_to_rewrite) {
    v->replaceAllUsesWith(rewrite_map.at(v));
  }

  // Perform planned deletions
  for (auto n : nodes_to_delete_) {
    n->removeAllInputs();
  }
  for (auto n : nodes_to_delete_) {
    n->destroy();
  }
  nodes_to_delete_.clear();
}

bool SubgraphRewriter::overlapsWithPreviousMatches(const Match* match) {
  for (auto n : match->nodes_map) {
    if (nodes_to_delete_.count(n.second)) {
      return true;
    }
  }
  return false;
}

Module PatternBasedRewrite(const Module& module) {
  // TODO: Deep-copy the module
  SubgraphRewriter subgraph_rewriter;
  subgraph_rewriter.RegisterDefaultPatterns();
  return subgraph_rewriter.runOnModule(module);
}

} // namespace jit
} // namespace torch