File: graphmatcher.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (132 lines) | stat: -rw-r--r-- 4,659 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include "ast.h"
#include "caffe2/opt/converter.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"

namespace nom {
namespace nql {

using Criteria = std::string;

using TestMatchGraph = nom::matcher::MatchGraph<nom::repr::NNGraph>;
using TestMatchPredicate = nom::matcher::MatchPredicate<nom::repr::NNGraph>;

// Each match is a struct of
// subgraph and map from the string used in the query to a NodeRef in the
// subgraph note: the maps are injective but not necessarily bijective -- if
// you use the same name in the query twice only one will be mapped.
//
// See `getMatches` to generate these structs.
struct MatchedSubgraph {
  // A subgraph that contains at least all the nodes in matchMap
  // This is the canonical match -- the matchMap is only a useful utility
  nom::repr::NNGraph::SubgraphType subgraph;

  // Provides safer access to matchMap with nicer semantics
  nom::repr::NNGraph::NodeRef operator[](const std::string& key) const;

  // Maps a variable name to a Node in a dataflow graph
  std::map<std::string, nom::repr::NNGraph::NodeRef> matchMap;
};

/// \brief Main graph matcher interface.
///
/// This class solves a problem of finding a matching subgraph, which is
/// specified in a text form.
class GraphMatcher {
 public:
  /// \brief Initialize subgraph pattern from \p STR.
  void initFromString(const char* str) {
    genMatcherFromIRStr(str);
  }
  /// \brief Initialize subgraph patter from IR stored in file \p fname.
  void initFromFile(const char* fname) {
    genMatcherFromIRFile(fname);
  }
  /// \brief Try to find the pattern in the given graph \p DF and return true
  /// if it was found.
  bool findSubgraph(nom::repr::NNGraph& df) {
    return doesMatch(df);
  }
  /// \brief Replace the found subgraph with another one.
  void replaceSubgraphWith() {
    CAFFE_THROW("Subgraph replacement is not implemented yet.");
  }
  /// \brief Return the matcher graph.
  TestMatchGraph* getMatcherGraph() {
    return &matchGraph_;
  }
  // TODO: Do we need this, or can we get it from getMatcherGraph?
  TestMatchGraph::NodeRef getMatcher() {
    return matchGraphRootNode_;
  }
  // \brief Return a mapping from IR variable name (std::string) to Node in the
  // matched graph.
  std::unordered_map<std::string, nom::repr::NNGraph::NodeRef> getMatchMap()
      const {
    return matchMap_;
  }

  // \brief Returns a vector of matches.
  std::vector<MatchedSubgraph> getMatches(nom::repr::NNGraph& df) const;

 private:
  std::unordered_map<std::string, nom::repr::NNGraph::NodeRef> matchMap_;
  std::unordered_map<std::string, TestMatchGraph::NodeRef> varMap_;
  std::unordered_map<std::string, TestMatchGraph::NodeRef> callMap_;
  TestMatchGraph matchGraph_;
  TestMatchGraph::NodeRef matchGraphRootNode_;
  bool syntaxIsValid_ = true;

  bool doesMatch(nom::repr::NNGraph& df) {
    if (!syntaxIsValid_) {
      return false;
    }
    matchMap_.clear();
    std::vector<nom::repr::NNGraph::NodeRef> Nodes = df.getMutableNodes();
    for (auto& Node : Nodes) {
      auto match =
          matchGraph_.isSubgraphMatch(Node, matchGraphRootNode_, true, true);
      if (match.isMatch()) {
        // Fill the match map
        auto subgraphMatcherMap = match.getMatchNodeMap();
        for (auto p : varMap_) {
          auto iter = subgraphMatcherMap->find(p.second);
          if (iter != subgraphMatcherMap->end()) {
            matchMap_[p.first] = iter->second;
          }
        }
        for (auto p : callMap_) {
          auto iter = subgraphMatcherMap->find(p.second);
          if (iter != subgraphMatcherMap->end()) {
            matchMap_[p.first] = iter->second;
          }
        }

        return true;
      }
    }
    return false;
  }
  TestMatchGraph::NodeRef genMatcherFromIRFile(const char* fname);
  TestMatchGraph::NodeRef genMatcherFromIRStr(const char* str);
  TestMatchGraph::NodeRef genMatcherFromASTGraph(ASTGraph* ast);
  TestMatchGraph::NodeRef genMatcherFromASTStmt(ASTStmt* stmt);
  TestMatchGraph::NodeRef genMatcherFromASTExpr(ASTExpr* expr, bool insertTemp);
};

// Node matches a criteria (string) if the data string is the same as the
// criteria. Special case: "*" will match any thing.
TestMatchPredicate testMatchPredicate(const Criteria& criteria);

// \brief Return a short string name for the given \param node.
// The function works with both tensors and operators.
std::string getNodeName(const nom::repr::NNGraph::NodeRef);

// \brief Return a string representing the given graph \param g.
// The returned string is a valid NQL query.
std::string convertToNQLString(nom::repr::NNGraph&);

void deallocTokenStrings();

} // namespace nql
} // namespace nom