1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
|
import copy
from queue import SimpleQueue
from typing import List, Dict, Tuple
import torch.fx
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
from torch.fx.passes.utils import lift_subgraph_as_module
def topo_sort(nodes: NodeList) -> NodeList:
# sort nodes according to the topological order
indegree_map = {node : 0 for node in nodes}
candidates: SimpleQueue = SimpleQueue()
for node in nodes:
for n in node.all_input_nodes:
if n in indegree_map:
indegree_map[node] += 1
if indegree_map[node] == 0:
candidates.put(node)
sorted_nodes: NodeList = list()
while not candidates.empty():
node = candidates.get()
sorted_nodes.append(node)
for n in node.users:
if n in indegree_map:
indegree_map[n] -= 1
if indegree_map[n] == 0:
candidates.put(n)
assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
return sorted_nodes
def validate_partition(partition: NodeList) -> bool:
# verify the partition does't form a dependency cycle in the original graph
# returns True for valid partition, False for invalid
partition_set = set(partition)
outputs: NodeList = list()
for node in partition_set:
for user_node in node.users:
if user_node not in partition_set:
# external user node, need to expose as an output
outputs.append(user_node)
# perform DFS on the parition outputs
# if it reaches a node within the partition, then it found a cycle
visited: NodeSet = set()
def dfs_find_cycle(node):
if node in partition_set:
return True # found cycle, return
visited.add(node)
for user_node in node.users:
if user_node not in visited:
if dfs_find_cycle(user_node):
return True
return False
for output_node in outputs:
if dfs_find_cycle(output_node):
return False
return True
def fuse_as_graphmodule(gm: GraphModule,
nodes: NodeList,
module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
"""
Fuse nodes in graph_module into a GraphModule.
Args:
gm (GraphModule): target graph_module
nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
module_name: class name for the fused GraphModule
Returns:
fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
"""
# assumption: nodes are already sorted in topo order
for node in nodes:
assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
assert not node._erased, f"{node} has been removed from owning graph"
assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}"
# validates partition doesn't introduce dependency circles in the graph
assert validate_partition(nodes), "Invalid partition, found dependency cycles"
subgraph = Graph()
node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph
node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
# handles inputs throught graph.node_copy's arg_transform functions
def remap_inputs(x):
if x.op == "get_attr":
# TODO: do we really need copy the get_attr node into the graph?
# do something here
pass
if x in nodes:
# x is inside subgraph, return the copied node
# the node should have been copied aleady, as we are copying graph in the topological order
return node_map[x]
if x not in node_to_placeholder:
# x is not in subgraph, create a new placeholder for subgraph
placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
# copy all meta fields, even if some fields might be irrelvant for the placeholder node
placeholder_node.meta = copy.copy(x.meta)
node_to_placeholder[x] = placeholder_node
return node_to_placeholder[x]
# copy nodes in topological order
for node in nodes:
new_node = subgraph.node_copy(node, remap_inputs)
node_map[node] = new_node
# handles outputs
output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs
for node in nodes:
for user_node in node.users:
if user_node not in nodes:
# external user node, need to expose as an output
output_mapping[node] = node_map[node]
# outs contain nodes in the new subgraph
outs = tuple(output_mapping.values())
# Take care of the args of FX output node. If there's a single
# output then the output node args is like (output_single), else
# if there're multiple outputs then the output node args is like
# ((output_0, output_1, ...)).
subgraph.output(outs[0] if len(outs) == 1 else outs)
# lint to ensure correctness
subgraph.lint()
fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name)
# sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
# sub_gm's outputs node in the original module
original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
return fused_gm, original_inputs, original_outputs
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
# add sub_gm into gm
submodule_name = sub_gm.__class__.__name__
gm.add_submodule(submodule_name, sub_gm)
# Create a call_module node in main graph.
module_node = gm.graph.call_module(
submodule_name,
args=orig_inputs,
kwargs=None)
if len(orig_outputs) == 1:
# main_remapping[comp.orig_outputs[0]] = module_node
orig_outputs[0].replace_all_uses_with(module_node)
else:
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out)
return gm
def erase_nodes(gm: GraphModule, nodes: NodeList):
# erase original nodes in inversed topological order
for node in reversed(nodes):
gm.graph.erase_node(node)
def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
for partition_id, nodes in enumerate(partitions):
sorted_nodes = topo_sort(nodes)
submodule_name = "fused_" + str(partition_id)
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
erase_nodes(gm, sorted_nodes)
# topological sort original gm with newly created sub_gm
legalize_graph(gm)
return gm
|