1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
#include "caffe2/transforms/common_subexpression_elimination.h"
#include "caffe2/core/common.h"
#include "caffe2/core/net.h"
#include "caffe2/proto/caffe2_pb.h"
#include <c10/util/irange.h>
namespace caffe2 {
using transform::Graph;
using transform::Node;
// Checks if the node at model_idx and the node at candidate_idx are
// "common subexpressions". That is, do they have the same function, and
// take in the exact same input. If so, then their function is duplicated.
bool are_nodes_common(const Graph& g, int model_idx, int candidate_idx) {
// We need the candidate operator to match this model_op.
const Node& model_node = g.node(model_idx);
const Node& candidate_node = g.node(candidate_idx);
// Types need to match.
if (model_node.op.type() != candidate_node.op.type()) {
return false;
}
// Arguments need to match.
if (!MatchArguments(model_node.op, candidate_node.op)) {
return false;
}
// Inputs need to match.
if (model_node.op.input_size() != candidate_node.op.input_size()) {
return false;
}
// If any input_blob name is different, this is not okay.
for (int i = 0; i < model_node.op.input_size(); i++) {
if (candidate_node.op.input(i) != model_node.op.input(i)) {
return false;
}
}
// Now, we also need to check that each blob comes from the same parent, or
// if they are external (isn't in parents). This is equivalent to a
// map equality (since parent edges can only contain up to one blob).
if (model_node.parents.size() != candidate_node.parents.size() ||
!std::equal(
model_node.parents.begin(),
model_node.parents.end(),
candidate_node.parents.begin())) {
return false;
}
// Output size have to match too.
if (model_node.op.output_size() != candidate_node.op.output_size()) {
return false;
}
return true;
}
bool CommonSubexpressionEliminationTransform::PatternRule(
const Graph& g,
const std::vector<int>& subgraph,
int idx) {
if (subgraph.size() == 0) {
if (IsAllowed(g.node(idx).op.type()))
return true;
return false;
}
return are_nodes_common(g, subgraph.at(0), idx);
}
// As long as we have matched more than 2 ops, it is worth eliminating.
bool CommonSubexpressionEliminationTransform::ValidatorRule(
const Graph& /*g*/,
const std::vector<int>& subgraph) {
if (subgraph.size() >= 2) {
return true;
}
return false;
}
bool CommonSubexpressionEliminationTransform::ReplaceRule(
const std::vector<int>& subgraph,
Graph* g_ptr) {
CHECK(g_ptr);
auto& g = *g_ptr;
// We're gonna make a new node, with the same input as all of the ones in
// subgraph, but with their combined children.
int new_idx = g.size();
OperatorDef new_op = g.node(subgraph[0]).op;
// We will need to rename the output blobs.
new_op.clear_output();
for (const auto& blob : g.node(subgraph[0]).op.output()) {
new_op.add_output("transform/" + blob);
}
// Need to set up the parents.
const auto& new_op_parents = g.node(subgraph[0]).parents;
for (auto& parent : new_op_parents) {
int parent_idx = parent.first;
// Make the parents acknowledge us as its new child.
g.node(parent_idx).children[new_idx] = new_op_parents.at(parent_idx);
// Make the parents disown all our outdated siblings.
for (const auto i : c10::irange(subgraph.size())) {
g.node(parent_idx).children.erase(subgraph[i]);
}
}
// Add the node now.
g.push_node(
Node(new_op, true, new_op_parents, std::map<int, std::vector<string>>()));
// Now, we need to populate the child edges.
for (const int x : subgraph) {
// Figure out what the subgraph's node's blobs correspond to in new_op
// This is easy, since their indices match.
std::map<string, string> output_renamings;
for (int i = 0; i < new_op.output_size(); i++) {
output_renamings[g.node(x).op.output(i)] = g.node(new_idx).op.output(i);
}
// Now, time to add the old node's children to new_op
for (auto& child : g.node(x).children) {
int child_idx = child.first;
std::vector<string> blobs = child.second;
// rename the old blobs, and use them for our new edge.
for (string& blob : blobs) {
blob = output_renamings.at(blob);
}
// create this new edge
g.node(new_idx).children[child_idx] = blobs;
g.node(child_idx).parents[new_idx] = blobs;
// delete the old edge
g.node(child_idx).parents.erase(x);
// need to rename the inputs of the children too.
for (int i = 0; i < g.node(child_idx).op.input_size(); i++) {
string blob = g.node(child_idx).op.input(i);
if (output_renamings.count(blob) > 0) {
g.node(child_idx).op.set_input(i, output_renamings.at(blob));
}
}
}
}
g.DeactivateSubgraph(subgraph);
return true;
}
REGISTER_TRANSFORM(
CommonSubexpressionElimination,
CommonSubexpressionEliminationTransform);
} // namespace caffe2
|