1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
# mypy: allow-untyped-defs
import contextlib
import torch
from torch.fx.graph_module import GraphModule
_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
def _node_metadata_hook(node: torch.fx.Node, stack_trace: str) -> None:
"""
Hook for adding the appropriate metadata to nodes that are created during a
pass using graph.create_node. An example of how to use it:
```
with _set_node_metadata_hook(gm,
functools.partial(_node_metadata_hook, stack_trace="file")
):
pass(gm)
```
This hook should not work for all generic cases -- specifically it assumes
that nodes being added are only call_function nodes, and copies over the
first argument node's nn_module_stack.
"""
assert node.op == "call_function" and callable(node.target)
arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)]
assert len(arg_meta) >= 1
arg_meta = arg_meta[0]
if (
isinstance(node.target, torch._ops.OpOverload)
and len(node.target._schema.returns) == 0
):
node.meta["val"] = None
else:
fake_args = [
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in node.args
]
fake_res = node.target(*fake_args)
node.meta["val"] = fake_res
node.meta["stack_trace"] = stack_trace
node.meta["nn_module_stack"] = arg_meta.get(
"nn_module_stack",
{
_EMPTY_NN_MODULE_STACK_KEY: (
_EMPTY_NN_MODULE_STACK_KEY,
_EMPTY_NN_MODULE_STACK_KEY,
)
},
)
node.meta["torch_fn"] = (
f"{node.target.__name__}_0",
f"{node.target.__class__.__name__}.{node.target.__name__}",
)
@contextlib.contextmanager
def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
"""
Takes a callable which will be called after we create a new node. The
callable takes the newly created node as input and returns None.
"""
assert callable(f), "node_metadata_hook must be a callable."
# Add the hook to all submodules
for m in gm.modules():
if isinstance(m, GraphModule):
m._register_create_node_hook(f)
try:
yield
finally:
# Restore hook for all submodules
for m in gm.modules():
if isinstance(m, GraphModule):
m._unregister_create_node_hook(f)
|