1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
|
from typing import Dict, Tuple, Any
import torch
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.utils._pytree import tree_flatten
from torch.fx import GraphModule, Graph
from torch.fx import Node
aten = torch.ops.aten
# stateful ops are banned from CSE
rand_ops = set([aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]) # noqa: E501
inplace_ops = set([aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_]) # noqa: E501
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def get_CSE_banned_ops():
return rand_ops.union(inplace_ops)
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
class CSEPass(PassBase):
def __init__(self, banned_ops=None):
"""
This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
For functional dialects, user would only need to specify the random ops in ban list.
Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
If your dialect contains stateful operators, please customized the banned_ops.
"""
if banned_ops is None:
banned_ops = set()
self.banned_ops = banned_ops
super().__init__()
def call(self, graph_module: GraphModule) -> PassResult:
"""
Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
Example usage:
from torch.fx.experimental.proxy_tensor import make_fx
def f(a):
b = a * a
c = a * a
return b+c
p = CSEPass()
traced_graph = make_fx(f)(torch.tensor(1))
print(traced_graph)
result = p(traced_graph)
print(result.graph_module)
"""
def get_aten_target(node):
if hasattr(node.target, 'overloadpacket'):
return node.target.overloadpacket
return node.target
modified = False
new_graph = Graph()
env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
for n in graph_module.graph.nodes:
# The placeholder, output, and get_attr nodes are copied to the new grpah without change
# do not CSE away random operations
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
# substitute args and kwargs memebrs to their mapping in env if exists
# specs can be used to reconstruct nested list/dictionaries
def substitute(arg_list):
arg_list, spec = tree_flatten(arg_list)
for i in range(len(arg_list)):
v = arg_list[i]
if isinstance(v, Node) and v in env:
arg_list[i] = env[v]
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {"target": n.target, "args": args, "args_spec": args_spec,
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
# hash substituted args to a number, do not hash specs because specs are not hashable
hash_arg = hash((args, kwargs))
hash_val = (n.target, hash_arg)
# check if a node has a substitute and can be eliminated
hash_val_in_hash_env = hash_val in hash_env
if hash_val_in_hash_env and token_map[hash_val] == token:
modified = True # substition happens and the graph is modified
env[n] = hash_env[hash_val]
continue
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
if not hash_val_in_hash_env:
hash_env[hash_val] = new_node
token_map[hash_val] = token
csed_gm = GraphModule(graph_module, new_graph)
return PassResult(csed_gm, modified)
|