1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
# mypy: allow-untyped-defs
import itertools
import operator
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
import torch
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
get_source_partitions,
SourcePartition,
)
__all__ = [
"find_sequential_partitions",
"get_control_flow_submodules",
"get_equivalent_types",
"update_equivalent_types_dict",
]
_EQUIVALENT_TYPES: List[Set] = [
{torch.nn.Conv1d, torch.nn.functional.conv1d},
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
{torch.add, operator.add, operator.iadd, "add", "add_"},
{torch.mul, operator.mul, operator.imul, "mul", "mul_"},
]
def _create_equivalent_types_dict():
_DICT = {}
for values in _EQUIVALENT_TYPES:
for v in values:
_DICT[v] = list(values)
return _DICT
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def get_equivalent_types() -> List[Set]:
return _EQUIVALENT_TYPES
def update_equivalent_types_dict(customized_equivalent_types=None):
"""Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
When customized_equivalent_types passes in,
re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
"""
if customized_equivalent_types is None:
raise ValueError("customized_equivalent_types should not be None")
global _EQUIVALENT_TYPES
global _EQUIVALENT_TYPES_DICT
_EQUIVALENT_TYPES = customized_equivalent_types
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def _partitions_sequential(partitions: Sequence[SourcePartition]):
prev_partition = None
for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected(
prev_partition, partition
):
return False
prev_partition = partition
return True
def _get_matching_types(partition_type):
matching_types = [partition_type]
if partition_type in _EQUIVALENT_TYPES_DICT:
matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
return matching_types
def _valid_type_sequence(partition_types: List[Any]):
partition_types_set = set() # type: ignore[var-annotated]
for partition_type in partition_types:
matching_types = _get_matching_types(partition_type)
matching_types_set = set(matching_types)
if len(partition_types_set & matching_types_set) > 0:
return False
partition_types_set |= matching_types_set
return True
def find_sequential_partitions(
gm: torch.fx.GraphModule,
partition_types: List[Any],
include_functional_equivalent=True,
filter_fn: Optional[Callable[[Node], bool]] = None,
):
if not _valid_type_sequence(partition_types):
raise ValueError(
f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
)
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
for partition_type in partition_types:
types_to_match = _get_matching_types(partition_type)
partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
typed_partitions[partition_type] = list(
itertools.chain.from_iterable(partitions.values())
)
typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = [
candidate
for candidate in fusion_candidates
if _partitions_sequential(candidate)
]
return fused_partitions
def _get_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
submod_node = node.args[arg_index]
assert isinstance(submod_node, torch.fx.Node)
assert submod_node.op == "get_attr"
assert isinstance(submod_node.target, str)
submodule = graph_module.get_submodule(submod_node.target)
# pyre-ignore
return submod_node.target, submodule, node
def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target is torch.ops.higher_order.cond:
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
if node.target is torch.ops.higher_order.map_impl:
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
return control_flow_submodules
|