1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
#include "ast.h"
#include "caffe2/opt/converter.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"
namespace nom {
namespace nql {
using Criteria = std::string;
using TestMatchGraph = nom::matcher::MatchGraph<nom::repr::NNGraph>;
using TestMatchPredicate = nom::matcher::MatchPredicate<nom::repr::NNGraph>;
// Each match is a struct of
// subgraph and map from the string used in the query to a NodeRef in the
// subgraph note: the maps are injective but not necessarily bijective -- if
// you use the same name in the query twice only one will be mapped.
//
// See `getMatches` to generate these structs.
struct MatchedSubgraph {
// A subgraph that contains at least all the nodes in matchMap
// This is the canonical match -- the matchMap is only a useful utility
nom::repr::NNGraph::SubgraphType subgraph;
// Provides safer access to matchMap with nicer semantics
nom::repr::NNGraph::NodeRef operator[](const std::string& key) const;
// Maps a variable name to a Node in a dataflow graph
std::map<std::string, nom::repr::NNGraph::NodeRef> matchMap;
};
/// \brief Main graph matcher interface.
///
/// This class solves a problem of finding a matching subgraph, which is
/// specified in a text form.
class GraphMatcher {
public:
/// \brief Initialize subgraph pattern from \p STR.
void initFromString(const char* str) {
genMatcherFromIRStr(str);
}
/// \brief Initialize subgraph patter from IR stored in file \p fname.
void initFromFile(const char* fname) {
genMatcherFromIRFile(fname);
}
/// \brief Try to find the pattern in the given graph \p DF and return true
/// if it was found.
bool findSubgraph(nom::repr::NNGraph& df) {
return doesMatch(df);
}
/// \brief Replace the found subgraph with another one.
void replaceSubgraphWith() {
CAFFE_THROW("Subgraph replacement is not implemented yet.");
}
/// \brief Return the matcher graph.
TestMatchGraph* getMatcherGraph() {
return &matchGraph_;
}
// TODO: Do we need this, or can we get it from getMatcherGraph?
TestMatchGraph::NodeRef getMatcher() {
return matchGraphRootNode_;
}
// \brief Return a mapping from IR variable name (std::string) to Node in the
// matched graph.
std::unordered_map<std::string, nom::repr::NNGraph::NodeRef> getMatchMap()
const {
return matchMap_;
}
// \brief Returns a vector of matches.
std::vector<MatchedSubgraph> getMatches(nom::repr::NNGraph& df) const;
private:
std::unordered_map<std::string, nom::repr::NNGraph::NodeRef> matchMap_;
std::unordered_map<std::string, TestMatchGraph::NodeRef> varMap_;
std::unordered_map<std::string, TestMatchGraph::NodeRef> callMap_;
TestMatchGraph matchGraph_;
TestMatchGraph::NodeRef matchGraphRootNode_;
bool syntaxIsValid_ = true;
bool doesMatch(nom::repr::NNGraph& df) {
if (!syntaxIsValid_) {
return false;
}
matchMap_.clear();
std::vector<nom::repr::NNGraph::NodeRef> Nodes = df.getMutableNodes();
for (auto& Node : Nodes) {
auto match =
matchGraph_.isSubgraphMatch(Node, matchGraphRootNode_, true, true);
if (match.isMatch()) {
// Fill the match map
auto subgraphMatcherMap = match.getMatchNodeMap();
for (auto p : varMap_) {
auto iter = subgraphMatcherMap->find(p.second);
if (iter != subgraphMatcherMap->end()) {
matchMap_[p.first] = iter->second;
}
}
for (auto p : callMap_) {
auto iter = subgraphMatcherMap->find(p.second);
if (iter != subgraphMatcherMap->end()) {
matchMap_[p.first] = iter->second;
}
}
return true;
}
}
return false;
}
TestMatchGraph::NodeRef genMatcherFromIRFile(const char* fname);
TestMatchGraph::NodeRef genMatcherFromIRStr(const char* str);
TestMatchGraph::NodeRef genMatcherFromASTGraph(ASTGraph* ast);
TestMatchGraph::NodeRef genMatcherFromASTStmt(ASTStmt* stmt);
TestMatchGraph::NodeRef genMatcherFromASTExpr(ASTExpr* expr, bool insertTemp);
};
// Node matches a criteria (string) if the data string is the same as the
// criteria. Special case: "*" will match any thing.
TestMatchPredicate testMatchPredicate(const Criteria& criteria);
// \brief Return a short string name for the given \param node.
// The function works with both tensors and operators.
std::string getNodeName(const nom::repr::NNGraph::NodeRef);
// \brief Return a string representing the given graph \param g.
// The returned string is a valid NQL query.
std::string convertToNQLString(nom::repr::NNGraph&);
void deallocTokenStrings();
} // namespace nql
} // namespace nom
|