File: graph_rewriter.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (144 lines) | stat: -rw-r--r-- 5,376 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>

namespace torch {
namespace jit {
namespace fuser {
namespace onednn {

void GraphRewriter::cleanupSubgraphs() {
  auto curNode = *block_->nodes().rbegin();
  while (curNode != *block_->nodes().rend()) {
    // Save the previous node, since we might delete `curNode` in next block
    auto prevNode = curNode->prev();
    if (llgaHelper_.isLlgaSubgraph(curNode)) {
      // Unmerge subgraph if we don't get every nodes of a partition
      // into the subgraph due to failed alias check
      llgaHelper_.unmergeIfAnyNodeIsMissing(curNode);
    }
    curNode = prevNode;
  }
  for (Node* n : block_->nodes()) {
    for (Block* b : n->blocks()) {
      GraphRewriter(b, graph_, aliasDb_).cleanupSubgraphs();
    }
  }
}

void GraphRewriter::buildupSubgraphs() {
  // We need to run the rewriter multiple times in order to get all merge
  // opportunities. This is because moveBeforeTopologicalValid may reorder
  // nodes to be AFTER the current iteration point. In order to properly
  // consider those nodes for merging, we need run the pass until no changes
  // have been made.
  //
  // Example:
  //   c = f(a, b)
  //   d = f(c)
  //   e = f(d)  <- iter is here, moving upward
  // After c.moveBeforeTopologicallyValid(e), we have:
  //   c = f(a, b)
  //   e = f(d)  <- iter still here
  //   d = f(c)  <- this was node moved on the other side.
  // see [workblocks]
  auto workblocks = buildWorkBlocks();
  for (auto& workblock : workblocks) {
    bool any_changed = true;
    while (any_changed) {
      any_changed = false;
      auto workblock_end = workblock.end()->reverseIterator();
      auto workblock_begin = workblock.begin()->reverseIterator();
      for (auto it = workblock_end; it != workblock_begin;) {
        bool changed = false;
        std::tie(it, changed) = scanNode(*it, workblock_begin);
        any_changed |= changed;
      }
    }
  }

  // Construct Subgraphs Recursively
  for (Node* n : block_->nodes()) {
    for (auto subBlock : n->blocks()) {
      GraphRewriter(subBlock, graph_, aliasDb_).buildupSubgraphs();
    }
  }
}

std::vector<WorkBlock> GraphRewriter::buildWorkBlocks() {
  // [workblocks]
  // the IR has many nodes which can never be reordered around, such as a
  // prim::Bailout. if a node N is surrounded by two nodes which cannot be
  // reordered, A and B, then a fusion group that is created from N
  // can only contain nodes from (A, B) The nodes from A to B represent one
  // work block for the subgraph rewriter to work on. By creating these up
  // front, we avoid retraversing the whole graph block any time scanNode
  // returns
  Node* end_bound_node = block_->return_node();
  Node* curr = end_bound_node->prev();
  std::vector<WorkBlock> worklist;
  while (curr != block_->param_node()) {
    // cannot reorder around side effectful nodes
    if (curr->hasSideEffects()) {
      worklist.emplace_back(curr, end_bound_node);
      end_bound_node = curr;
    }
    curr = curr->prev();
  }
  worklist.emplace_back(curr, end_bound_node);
  return worklist;
}

std::pair<graph_node_list::iterator, bool> GraphRewriter::scanNode(
    Node* consumer,
    graph_node_list::iterator workblock_begin) {
  GRAPH_DEBUG("Scanning ", consumer->kind().toQualString());
  if (llgaHelper_.shouldConsiderForMerge(consumer)) {
    if (!llgaHelper_.isLlgaSubgraph(consumer)) {
      consumer = llgaHelper_.createSingletonSubgraph(consumer, aliasDb_);
    }
    // Iterate through the workblock to merge nodes of the
    // same partition determined by LLGA graph helper.
    // Nodes like B and C do not share a common input but belong to a
    // same partition, and thus we cannot only scan the input nodes
    // to find merging opportunities. Instead, we have to scan through
    // the whole workblock, which might lead to O^2 accesses in worst case
    //              A
    //      + - - / - \ - - +
    //      |    B     C    |
    //      |    |     |    |
    //      |    D     E    |
    //      + - - \ - / - - +
    //              F
    auto prev = ++consumer->reverseIterator();
    for (auto it = prev; it != workblock_begin; it++) {
      if (auto group = tryMerge(consumer, *it)) {
        // we successfully merged, so the new group's `inputs` may have
        // changed. So rescan the new group for more merging opportunities.
        return std::make_pair(group.value()->reverseIterator(), true);
      }
    }
  }
  return std::make_pair(++consumer->reverseIterator(), false);
}

// Try to merge `producer` into `consumer`. If successful, this destroys
// `producer` and returns the `consumer` group.
c10::optional<Node*> GraphRewriter::tryMerge(Node* consumer, Node* producer) {
  AT_ASSERT(llgaHelper_.isLlgaSubgraph(consumer));
  bool canMerge = llgaHelper_.shouldMerge(producer, consumer) &&
      aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
  if (!canMerge) {
    return c10::nullopt;
  }
  llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_);
  return consumer;
}

} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch