1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
namespace caffe2 {
namespace transform {
/**
* Graph representation of an operator.
*/
struct TORCH_API Node {
public:
// Empty constructor for resize
Node() {}
// Alternate constructor
Node(
const OperatorDef& op,
bool active,
std::map<int, std::vector<string>> parents,
std::map<int, std::vector<string>> children)
: op(op), active(active), parents(parents), children(children) {}
// The OperatorDef which this node represents.
OperatorDef op;
// Keeps track of if an operator has been deleted through a transformation.
bool active = true;
// Stores a pair (idx, blob_list),
// idx = index of the child
// blob_list = a list of strings, containing the blobs that connect the nodes
std::map<int, std::vector<string>> parents;
std::map<int, std::vector<string>> children;
};
/**
* Graph representation of a Netdef.
*/
struct TORCH_API Graph {
public:
/**
* Given a subgraph, gets all of the parents of the subgraph, as well as
* their associated blob names. Sorted by blob names.
*
* <string, int> := (name of blob writing into subgraph,
* index of node that writes into subgraph using that blob)
*/
const std::vector<std::pair<string, int>> GetSubgraphInput(
const std::vector<int>& subgraph);
/**
* Given a subgraph, gets all of the children of the subgraph, as well as
* their associated blob names. Sorted by blob names.
*
* <string, int> := (name of blob reading from subgraph,
* index of node that reads from subgraph using that blob)
*/
const std::vector<std::pair<string, int>> GetSubgraphOutput(
const std::vector<int>& subgraph);
/**
* Graph generation.
* Given a netdef, returns a Graph.
*
* Each node represents an operator.
* An edge exists between two nodes if the parent op writes to a blob, which
* is the input of the child blob, with no other op writing to the blob in
* between the execution order.
*
* Time Complexity: O(E), where E is the number of blobs
*/
explicit Graph(const NetDef& net_def);
/**
* Generates a NetDef Representation for the current graph.
* Nodes are visited in topological order, which is proper Opdef ordering.
* TODO(benz):
* There exists conflicts with repeated blob names, where topological sorting
* is not sufficient for correct netdef representation, unless blobs are
* renamed.
* For example, if after a transformation, We have operator ancestry:
* A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the
* same blob name, then A, B, D, E, C is a correct topological ordering,
* but D will write to the blob that C reads from, instead of B.
* Currently believe that there will always be ambiguity unless blobs are
* renamed.
* This is solved by performing SSA on all transformed blob names.
*/
NetDef GetNetDef();
/**
* Deactivate a subgraph, and get rid of all edges into this subgraph.
*/
void DeactivateSubgraph(std::vector<int> subgraph);
size_t size() const {
return nodes_.size();
}
void push_node(const Node& new_node) {
return nodes_.push_back(new_node);
}
void resize_nodes(size_t new_size) {
nodes_.resize(new_size);
}
// Index safe, less verbose way to access nodes
inline const Node& node(size_t idx) const {
return nodes_.at(idx);
}
inline Node& node(size_t idx) {
return nodes_.at(idx);
}
inline bool is_node_active(size_t idx) {
return node(idx).active;
}
inline const std::set<string>& external_input() const {
return external_input_;
}
inline const std::set<string>& external_output() const {
return external_output_;
}
private:
const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
bool from_children,
const std::vector<int>& match);
// Stores the netdef representation. Is updated upon calls to GetNetDef.
NetDef netdef_;
// Stores which blobs the graph reads from, and writes to.
std::set<string> external_input_;
std::set<string> external_output_;
// Keeps track of all the Operators currently within graph, even if inactive.
std::vector<Node> nodes_;
};
} // namespace transform
// Adds an operator def to a netdef.
// Returns the ptr, if you want to add anything extra (such as device_option)
TORCH_API OperatorDef* AddOp(
NetDef* netdef_ptr,
string op_type,
std::vector<string> inputs,
std::vector<string> outputs);
/**
* This allows for the use of * and | to match operator types,
* engines, or any other property that is represented by strings.
*
* For example, if we wanted to match an operator to Conv or FC, we can give:
* "Conv|FC" as the type() of that op.
*/
TORCH_API bool MatchStrings(string p, string s);
/**
* This ensures that each named arg that exists in the pattern exists in g_op,
* is equal in value.
*/
TORCH_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);
} // namespace caffe2
|