1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.node import Node
__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"]
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
def _init_logger() -> logging.Logger:
logger = logging.getLogger(__name__)
level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper()
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter("%(filename)s > %(message)s")
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger
logger = _init_logger()
@compatibility(is_backward_compatible=False)
@dataclass
class SourcePartition:
# Nodes in a particular partition
nodes: List[Node]
# The source these nodes decomposed from
source: Any
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the
# partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[Node] = field(default_factory=list)
@compatibility(is_backward_compatible=False) # type: ignore[misc]
def get_source_partitions(
graph: Graph,
wanted_sources: List[Any],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Dict[Any, List[SourcePartition]]:
"""
Args:
graph: The graph we want to partition
wanted_sources: List of sources of nodes that were decomposed from this
source. This can be a function (ex. torch.nn.functional.linear) or a
leaf module type (ex. torch.nn.Linear).
Returns:
Dictionary mapping sources that were given to a list of SourcePartitions
that correspond to the list of nodes that were decomposed from the given
source.
"""
modules: Dict[Type, Dict[str, List[Node]]] = {}
for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the
# source, and the source function if the node is decomposed from a
# function, or the type of module if the node is decomposed from a leaf
# module
# TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can
# be different from "source_fn_stack", for example for the add_ node
# decomposed from batch norm. We should remove the check on "source_fn_stack"
# after we fix "torch_fn". T199561090
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
torch_fn := node.meta.get("torch_fn", None)
) is not None:
node_fqn, source_fn = torch_fn
source_fn_name = source_fn.split(".")[1]
if source_fn_name in wanted_sources:
diff_modules = modules.setdefault(source_fn_name, {})
partition = diff_modules.setdefault(node_fqn, [])
partition.append(node)
if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None:
source_fn = source_fn_st[-1]
if source_fn[1] in wanted_sources:
diff_modules = modules.setdefault(source_fn[1], {})
partition = diff_modules.setdefault(source_fn[0], [])
partition.append(node)
def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition:
input_nodes = set()
output_nodes = set()
params = set()
for node in nodes:
for arg in node.args:
if isinstance(arg, Node) and arg not in nodes:
input_nodes.add(arg)
if node.op == "get_attr":
params.add(node)
for user in node.users.keys():
if user not in nodes:
output_nodes.add(node)
return SourcePartition(
nodes,
module_type,
list(input_nodes),
list(output_nodes),
list(params), # type: ignore[arg-type]
)
ret: Dict[Type[Any], List[SourcePartition]] = {}
if filter_fn:
# for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the
# filter condition
filtered_modules = {}
for tp, name_to_partition in modules.items():
filtered_name_to_partition = {
name: partition
for name, partition in name_to_partition.items()
if all(map(filter_fn, partition))
}
filtered_modules[tp] = filtered_name_to_partition
modules = filtered_modules
for k, v in modules.items():
ret[k] = [make_partition(partition, k) for partition in v.values()]
return ret
@compatibility(is_backward_compatible=False) # type: ignore[misc]
def check_subgraphs_connected(
subgraph1: SourcePartition, subgraph2: SourcePartition
) -> bool:
"""
Given two subgraphs A and B (in the form of a list of nodes), checks if
A has nodes connecting to at least one node in B -- aka there exists a node
in B that uses a node in A (not the other way around).
"""
for node in reversed(subgraph1.nodes):
for user in node.users.keys():
if user in subgraph2.nodes:
return True
return False
|