File: compile_utils.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (90 lines) | stat: -rw-r--r-- 3,532 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

import torch
import torch.fx as fx
from torch.utils._pytree import tree_flatten

aten = torch.ops.aten


def get_aten_target(node):
    if hasattr(node.target, 'overloadpacket'):
        return node.target.overloadpacket
    return node.target


rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
            aten.bernoulli, aten.multinomial, aten.native_dropout,
            aten.normal, aten.poisson, aten.binomial, aten.rrelu,
            aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]


# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
    new_graph = fx.Graph()
    env = {}  # map from node in the old graph to node in the new graph
    hash_env = {}  # map from hash to a node in the new graph
    token_map = {}  # map from hash to token
    for n in fx_g.nodes:
        # The placeholder, output, and get_attr nodes are copied to the new grpah without change
        # do not CSE away random operations
        if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops:
            new_node = new_graph.node_copy(n, lambda x: env[x])
            env[n] = new_node
        else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
            # substitute args and kwargs memebrs to their mapping in env if exists
            # specs can be used to reconstruct nested list/dictionaries
            def substitute(arg_list):
                arg_list, spec = tree_flatten(arg_list)
                for i in range(len(arg_list)):
                    v = arg_list[i]
                    if isinstance(v, torch.fx.node.Node) and v in env:
                        arg_list[i] = env[v]
                return tuple(arg_list), spec
            args, args_spec = substitute(n.args)
            kwargs, kwargs_spec = substitute(n.kwargs)

            # each token corresponds to a unique node
            # nodes with the same token can be substituted
            token = {"target": n.target, "args": args, "args_spec": args_spec,
                     "kwargs": kwargs, "kwargs_spec": kwargs_spec}

            # hash substituted args to a number, do not hash specs because specs are not hashable
            hash_arg = hash((args, kwargs))
            hash_val = (n.target, hash_arg)

            # check if a node has a substitute and can be eliminated
            hash_val_in_hash_env = hash_val in hash_env
            if hash_val_in_hash_env and token_map[hash_val] == token:
                env[n] = hash_env[hash_val]
                continue

            new_node = new_graph.node_copy(n, lambda x: env[x])
            env[n] = new_node
            if not hash_val_in_hash_env:
                hash_env[hash_val] = new_node
                token_map[hash_val] = token

    return new_graph


def strip_overloads(gm):
    """
    Modifies the target of graph nodes in :attr:`gm` to strip overloads.

    Args:
        gm(fx.GraphModule): The input Fx graph module to be modified
    """
    for node in gm.graph.nodes:
        if isinstance(node.target, torch._ops.OpOverload):
            node.target = node.target.overloadpacket
    gm.recompile()


def get_placeholders(graph):
    return list(filter(lambda x: x.op == 'placeholder', graph.nodes))

def get_outputs(graph):
    for node in graph.nodes:
        if node.op == 'output':
            return tree_flatten(node.args[0])[0]
    raise AssertionError("No output node found")