1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import Sequence
import torch
from torch.onnx._internal.fx import _pass, diagnostics
class RestoreParameterAndBufferNames(_pass.Transform):
"""Restore parameter and buffer names from original nn.module.
This pass is useful for readability of the exported ONNX graph. It restores the
parameter and buffer names from the original nn.module. For example, if the original
nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to
`_param_constant9` by FX, this pass will rename it back.
This pass must be run after `Decompose` pass. Because this pass is expected to be called on
`fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers
are registered at root level.
"""
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
fx_module: torch.fx.GraphModule,
original_nn_module: torch.nn.Module,
):
super().__init__(diagnostic_context, fx_module)
self.original_nn_module = original_nn_module
def _rename_param_and_buffer(
self,
diagnostic: diagnostics.Diagnostic,
nodes: Sequence[torch.fx.Node],
new_name: str,
) -> None:
"""Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target."""
assert len(nodes) > 0, "`nodes` cannot be empty"
assert (
len({node.target for node in nodes}) == 1
), "`nodes` must all have same `target`"
old_name = nodes[0].target
assert isinstance(old_name, str), f"Expected str, got type({old_name})"
# Parameter/buffer name cannot contain "."
normalized_name = new_name.replace(".", "/")
attr_value = getattr(self.module, old_name)
setattr(self.module, normalized_name, attr_value)
delattr(self.module, old_name)
for node in nodes:
with self.module.graph.inserting_before(node):
new_node = self.module.graph.get_attr(normalized_name)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
self.module.graph.erase_node(node)
diagnostic.info(
"Renamed 'self.%s' to 'self.%s', "
"normalized from original parameter name '%s'.",
old_name,
normalized_name,
new_name,
)
def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
"""Restore parameter and buffer names from original module.
For each `get_attr` node, if the target is a str representing a parameter or buffer
under `self.module`, we rename the parameter or buffer to its original name.
The parameters and buffers between `self.module` and `self.original_nn_module` refer
to the same objects, allowing us to use it as key to retrieve the original name.
"""
assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args"
assert (
len(kwargs) == 0
), "RestoreParameterAndBufferNames does not take any kwargs"
# state_to_readable_name[parameter/buffer] returns the original readable name of
# the parameter/buffer. E.g., "self.linear.weight".
state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {}
state_to_readable_name.update(
{v: k for k, v in self.original_nn_module.named_parameters()}
)
state_to_readable_name.update(
{v: k for k, v in self.original_nn_module.named_buffers()}
)
diagnostic = self.diagnostic_context.inflight_diagnostic()
# old_name_to_nodes[old_name] returns a tuple of (nodes, new_name)
# where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and
# `new_name` is the new readable name.
old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {}
for node in self.module.graph.nodes:
if node.op == "get_attr":
assert isinstance(
node.target, str
), f"Expected str, got type({node.target})"
if node.target.find(".") != -1:
raise RuntimeError(
f"Unexpected target {node.target} in get_attr, found '.' in target. "
f"All parameters and buffers are expected to be registered at root level, "
f"i.e., self.module. "
)
if node.target in old_name_to_nodes:
# We have already processed this parameter/buffer.
old_name_to_nodes[node.target][0].append(node)
continue
attr_value = getattr(self.module, node.target)
if (
isinstance(attr_value, (torch.nn.Parameter, torch.Tensor))
and attr_value in state_to_readable_name
):
readable_name = state_to_readable_name[attr_value]
old_name_to_nodes[node.target] = ([node], readable_name)
continue
diagnostic.info(
"Cannot find readable name for self.%s: %s. The name is unchanged.",
node.target,
type(attr_value),
)
if isinstance(attr_value, torch.nn.Parameter):
# If it is a parameter we treat it more seriously.
diagnostic.level = diagnostics.levels.WARNING
else:
diagnostic.level = diagnostics.levels.NONE
for nodes, new_name in old_name_to_nodes.values():
self._rename_param_and_buffer(diagnostic, nodes, new_name)
return self.module
|