1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
# mypy: allow-untyped-defs
from __future__ import annotations
import operator
from typing import Dict, Optional, TYPE_CHECKING, Union
import torch
from torch.export.exported_program import ConstantArgument, TensorArgument
from torch.fx.passes.infra.pass_base import PassBase, PassResult
if TYPE_CHECKING:
from torch.export.exported_program import ModuleCallSignature
from torch.export.graph_signature import ExportGraphSignature
__all__ = ["CollectTracepointsPass"]
class CollectTracepointsPass(PassBase):
"""
Performs constant folding and constant propagation.
"""
def __init__(
self, specs: Dict[str, ModuleCallSignature], sig: ExportGraphSignature
) -> None:
super().__init__()
self.specs = specs
self.sig = sig
def call(self, gm: torch.fx.GraphModule) -> Optional[PassResult]:
def get_arg_spec(arg) -> Union[TensorArgument, ConstantArgument]:
if isinstance(arg, torch.fx.Node):
if isinstance(arg.meta.get("val"), torch.Tensor):
return TensorArgument(name=arg.name)
else:
raise AssertionError(
"Symint input is not implemented yet for submodule call signature."
)
else:
return ConstantArgument(name="", value=arg)
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
nn_module_stack = None
for node in module.graph.nodes:
if node.op != "call_function":
continue
if node.target == torch.ops.higher_order._export_tracepoint:
kind = node.kwargs["kind"]
if kind == "module_call_outputs":
nn_module_stack = node.meta["nn_module_stack"]
elif kind == "module_call_inputs":
nn_module_stack = None
else:
raise AssertionError(f"Unknown tracepoint kind: {kind}")
elif node.meta["nn_module_stack"] == nn_module_stack:
node.meta["nn_module_stack"].popitem()
else:
nn_module_stack = None
nn_module_stack = None
for node in reversed(module.graph.nodes):
if node.op != "call_function":
continue
if node.target == torch.ops.higher_order._export_tracepoint:
kind = node.kwargs["kind"]
if kind == "module_call_inputs":
nn_module_stack = node.meta["nn_module_stack"]
elif kind == "module_call_outputs":
nn_module_stack = None
else:
raise AssertionError(f"Unknown tracepoint kind: {kind}")
elif node.meta["nn_module_stack"] == nn_module_stack:
node.meta["nn_module_stack"].popitem()
else:
nn_module_stack = None
def copy_sig(sig) -> ModuleCallSignature:
from torch.export.exported_program import ModuleCallSignature
return ModuleCallSignature(
inputs=[],
outputs=[],
in_spec=sig.in_spec,
out_spec=sig.out_spec,
forward_arg_names=None,
)
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.op != "call_function":
continue
if node.target == torch.ops.higher_order._export_tracepoint:
# There's some subtlety worth noting. Here fqn corresponds to
# the call name, whereas path corresponds to the module name.
# They are not necessarily the same! When a submodule is shared
# through different aliases, there are as many _export_tracepoint
# markers as there are aliases, since the shared submodule is
# wrapped once for each alias.
path = node.kwargs["path"]
fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
module_key = next(reversed(node.meta["nn_module_stack"]))
if "@" in module_key:
suffix = module_key.split("@")[-1]
path = f"{path}@{suffix}"
call_fqn = f"{fqn}@{suffix}"
if call_fqn not in self.specs:
self.specs[call_fqn] = copy_sig(self.specs[fqn])
fqn = call_fqn
kind = node.kwargs["kind"]
for i, arg in enumerate(node.args):
# We only update the signature of the alias used to call
# the submodule. Otherwise the signatures of all aliases
# would get conflated; the inputs/outputs of every call
# would be recorded in every other call as well.
if fqn == path:
if kind == "module_call_inputs":
self.specs[path].inputs.append(get_arg_spec(arg))
elif kind == "module_call_outputs":
self.specs[path].outputs.append(get_arg_spec(arg))
else:
raise AssertionError(f"Unknown tracepoint kind: {kind}")
if isinstance(arg, torch.fx.Node):
for user in node.users:
assert user.op == "call_function"
assert user.target == operator.getitem
assert isinstance(user.args[1], int)
if user.args[1] == i:
user.replace_all_uses_with(arg)
self.sig.replace_all_uses(user.name, arg.name)
break
users = list(node.users)
for user in users:
assert len(user.users) == 0
gm.graph.erase_node(user)
gm.graph.erase_node(node)
return PassResult(gm, True)
return None
|