1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/transform.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
/**
* PatternNetTransform allows you to create transforms using a simple
* interface.
*
* Simply provide a Pattern NetDef and a Replace NetDef,
* and this Transform will find subgraphs which fit the pattern net,
* and replace it with the replace net.
*/
class TORCH_API PatternNetTransform : public Transform {
public:
PatternNetTransform(const NetDef& pattern_net, const NetDef& replace_net)
: p_(transform::Graph(pattern_net)), r_(transform::Graph(replace_net)) {
// external input and output must match!
CAFFE_ENFORCE(
p_.external_input() == r_.external_input(),
"External inputs do not match!");
CAFFE_ENFORCE(
p_.external_output() == r_.external_output(),
"External outputs do not match!");
ordered_ops_ = GetPatternTraversalOrder(p_);
inverse_ops_.resize(ordered_ops_.size());
for (const auto i : c10::irange(ordered_ops_.size())) {
inverse_ops_[ordered_ops_[i]] = i;
}
}
void EnableArgumentMatching() {
argument_match_ = true;
}
void DisableArgumentMatching() {
argument_match_ = false;
}
protected:
/**
* We want to the final result of subgraph to match the PatternNet in the
* order of ordered_ops, operator by operator.
*
* [[[ ie. g.node(subgraph[i]) should match p.node(ordered_ops[i]) ]]]
*
* PatternRule for PatternNetTransform does the following:
*
* When trying to insert node idx into subgraph[p_idx],
* we need to see if the edges between index and the
* subgraph match the edges between p[ordered_ops[idx]]
* and p[ordered_ops[0]...ordered_ops[p_idx-1]].
*/
bool PatternRule(
const transform::Graph& g,
const std::vector<int>& subgraph,
int idx) override;
/**
* ValidatorRule for PatternNetTransform does the following:
*
* Checks if the size of subgraph and p.size() are the same. That's it!
*/
bool ValidatorRule(
const transform::Graph& g,
const std::vector<int>& subgraph) override;
/**
* ReplaceRule for PatternNet Transform does the following:
*
* 1) Figure out edge renamings for edges going into/out of the subgraph.
* That is, for each blob in the pattern graph, what is it called in the
* matched subgraph?
*
* 2) Remove the matched subgraph.
*
* 3) Append the replace graph's operators to the graph's operators, and use
* the renamings to rename the blob names.
*
* 4) Create all the children/parent relationships within the replaced graph,
* and stitch together the inputs and outputs into the rest of the graph,
* matching the removed subgraph.
*/
bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
override;
private:
/**
* This returns a permutation of the Pattern Net's operators.
* The permutation satisfies this property:
* - For any index i, order(i) is a neighbor of some node from
* {order(1), ..., order(i-1)}.
*
* Why is this important? Consider the following case:
* PatternNet: 0 ---> 2 <--- 1
*
* When we have matched onto [0], and trying to add [1] to our subgraph,
* we cannot, since PatternMatch only considers neighbors of the current
* subgraph as a candidate next node.
*
* Therefore, we must present the subgraph in an order such that each node is
* a neighbor of its prefix subgraph. One ordering for the above example is
* [0, 2, 1].
*/
std::vector<int> GetPatternTraversalOrder(const transform::Graph& g);
// Graph of Pattern NetDef
transform::Graph p_;
// The Traversal Order of the Pattern Net's Operators
// This is a permutation of the numbers from {0, ..., p.size()-1}
std::vector<int> ordered_ops_;
// The Inverse of the Traversal Order of the Pattern Net's Operators
// That is, inverse_ops[ordered_ops[i]] == i is always true.
std::vector<int> inverse_ops_;
// Graph of Replace NetDef
transform::Graph r_;
// This flag determines if the transform will match operator arguments.
bool argument_match_ = false;
const string TransformBlobWrapper(const string& blob_name) {
return "transform/" + blob_name + "_" + c10::to_string(ssa_id_);
}
int ssa_id_ = 0;
};
} // namespace caffe2
|