1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
|
import torch
import torch.fx as fx
from torch.utils._pytree import tree_flatten
aten = torch.ops.aten
def get_aten_target(node):
if hasattr(node.target, 'overloadpacket'):
return node.target.overloadpacket
return node.target
rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
aten.bernoulli, aten.multinomial, aten.native_dropout,
aten.normal, aten.poisson, aten.binomial, aten.rrelu,
aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
new_graph = fx.Graph()
env = {} # map from node in the old graph to node in the new graph
hash_env = {} # map from hash to a node in the new graph
token_map = {} # map from hash to token
for n in fx_g.nodes:
# The placeholder, output, and get_attr nodes are copied to the new grpah without change
# do not CSE away random operations
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops:
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
# substitute args and kwargs memebrs to their mapping in env if exists
# specs can be used to reconstruct nested list/dictionaries
def substitute(arg_list):
arg_list, spec = tree_flatten(arg_list)
for i in range(len(arg_list)):
v = arg_list[i]
if isinstance(v, torch.fx.node.Node) and v in env:
arg_list[i] = env[v]
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {"target": n.target, "args": args, "args_spec": args_spec,
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
# hash substituted args to a number, do not hash specs because specs are not hashable
hash_arg = hash((args, kwargs))
hash_val = (n.target, hash_arg)
# check if a node has a substitute and can be eliminated
hash_val_in_hash_env = hash_val in hash_env
if hash_val_in_hash_env and token_map[hash_val] == token:
env[n] = hash_env[hash_val]
continue
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
if not hash_val_in_hash_env:
hash_env[hash_val] = new_node
token_map[hash_val] = token
return new_graph
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
def get_placeholders(graph):
return list(filter(lambda x: x.op == 'placeholder', graph.nodes))
def get_outputs(graph):
for node in graph.nodes:
if node.op == 'output':
return tree_flatten(node.args[0])[0]
raise AssertionError("No output node found")
|