File: subgraph_matcher.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (53 lines) | stat: -rw-r--r-- 1,991 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once

#include <torch/csrc/jit/ir/ir.h>

#include <unordered_map>
#include <vector>

namespace torch {
namespace jit {

/**
 * \brief A structure describing a match of a pattern in a graph.
 *
 * The structure contains an anchor node, from which the match was found, and
 * match-maps for nodes and values. A match-map specifies correspondance between
 * nodes in the pattern graph (match-map keys) with nodes in the actual graph
 * (match-map values). We keep such maps for both nodes and values.
 */
struct Match {
  Node* anchor;
  std::unordered_map<const Node*, Node*> nodes_map;
  std::unordered_map<const Value*, Value*> values_map;
};

/**
 * \brief Find all matches of a \p PATTERN in a \p GRAPH.
 *
 * The function returns a vector of match-descriptors (see description of
 * `struct Match`).
 *
 * Matching rules:
 *  - Pattern graph must contain a single block.
 *  - Matched subgraphs do not span across different blocks.
 *  - No uses outside the match are allowed, except for Param and Return nodes.
 *  Basically, we're matching hammocks, not arbitrary subgraphs.
 *  - Pattern graph must return only one value (i.e. it must have a single
 *  node leading to return).
 *  - Nodes that are not used in computation of the return value in the pattern
 * graph are ignored during matching (IOW, we're essentially performing DCE on
 * the pattern).
 *  - Pattern graph nodes cannot alias. TODO: the check not implemented yet.
 *  - Aliasing nodes in the graph can not consitute a match (i.e. in all found
 * matches no nodes in the subgraph alias with each other). TODO: the check not
 * implemented yet.
 *  - The matcher will not mutate either the pattern graph or the matched graph,
 * but the latter is taken as non-const so that Match may contain non-const
 * pointers.  This enables clients of this API to use Match to drive mutations.
 */
std::vector<Match> TORCH_API
findPatternMatches(const Graph& pattern, Graph& graph);

} // namespace jit
} // namespace torch