1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
"""
Contains utility functions to check if a pattern is in the graph and return the matching nodes
"""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.ao.quantization.utils import MatchAllNode
from torch.fx import Node
from torch.nn.utils import parametrize
def _match(
modules: Dict[str, nn.ModuleDict],
node: Node,
current: Union[nn.Module, Any],
) -> bool:
r"""
checks to see if a single node of a pattern matches
"""
if isinstance(current, type) and issubclass(current, MatchAllNode):
return True
if not isinstance(node, Node):
return False
if isinstance(current, type) and issubclass(current, torch.nn.Module):
return (
node.op == "call_module"
and parametrize.type_before_parametrizations(modules[node.target]) # type: ignore[index]
== current
)
elif callable(current):
return node.op == "call_function" and node.target is current
elif isinstance(current, str):
return node.target == current
return False
def apply_match(
modules: Dict[str, nn.ModuleDict],
pattern: Union[Tuple[Any], Any],
node: Node,
matched_node_pattern: List[Node],
) -> Optional[List[Node]]:
r"""
This function will return the matched nodes if the pattern matches the node given
If there is no match, it will return None
"""
if isinstance(pattern, tuple):
if len(pattern) == 1:
if _match(modules, node, pattern[0]):
return matched_node_pattern + [node]
first, *rest = pattern
if _match(modules, node, first):
if rest is None:
return matched_node_pattern + [node]
for user in node.users:
return apply_match(
modules, tuple(rest), user, matched_node_pattern + [node]
)
elif _match(modules, node, pattern):
return [node]
return None
|