File: fuse.py

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (68 lines) | stat: -rw-r--r-- 2,219 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from torch._fx import (
    GraphModule
)

from torch._fx.graph import (
    Graph,
    map_arg,
)

from .pattern_utils import (
    is_match,
    get_fusion_patterns,
)

from .fusion_patterns import *  # noqa: F401

import copy
class Fuser:
    def fuse(self, model, inplace=False):
        if not inplace:
            model = copy.deepcopy(model)
        input_root = model
        input_graph = model.graph
        self.modules = dict(input_root.named_modules())

        fusion_patterns = get_fusion_patterns()
        # find fusion
        fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
        self.fused_graph = Graph()
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in input_graph.nodes:
            root_node, obj = fusion_pairs.get(node.name, (None, None))
            if root_node is node:
                env[node.name] = obj.fuse(self, load_arg)
            elif root_node is None:
                env[node.name] = self.fused_graph.node_copy(node, load_arg)
            # node matched in patterns and is not root is removed here

        self.fused_graph.output(load_arg(input_graph.result))
        model = GraphModule(input_root, self.fused_graph)
        return model

    def _find_matches(self, root, graph, patterns):
        modules = dict(root.named_modules())
        match_map = {}  # node name -> (root_node, match_value?)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
            else:
                # the first pattern matches will take precedence
                if node.name not in match_map:
                    match_map[node.name] = match

        for node in reversed(graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    if is_match(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map