1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
|
import copy
import typing
import torch
from torch.export.exported_program import _decompose_exported_program
def _copy_graph_module_and_signature(
ep: torch.fx.GraphModule,
) -> typing.Tuple[
torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature
]:
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
# and this can break placeholder names in some particular cases.
# For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.
# So we manually overwrite placeholder names by reading the old graph.
gm = copy.deepcopy(ep.graph_module)
new_graph_signature = copy.deepcopy(ep.graph_signature)
# iterate over old/new graph modules
for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr]
old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"]
new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"]
# iterate over placeholders
assert len(old_phs) == len(new_phs)
for old_node, new_node in zip(old_phs, new_phs):
new_node.name = old_node.name
return gm, new_graph_signature # type: ignore[return-value]
def _remove_detach_pass(
gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature
) -> None:
with gm._set_replace_hook(sig.get_replace_hook()):
for node in list(reversed(gm.graph.nodes)):
if node.op != "call_function":
continue
if (
node.target == torch.ops.aten.detach.default
and len(node.users) == 1
and next(iter(node.users)).target == torch.ops.aten.detach.default
):
next(iter(node.users)).replace_all_uses_with(node)
gm.graph.eliminate_dead_code()
gm.recompile()
def _export_forward_backward(
ep: torch.export.ExportedProgram, joint_loss_index: int = 0
) -> torch.export.ExportedProgram:
"""
WARNING: This API is highly unstable and will be subject to change in the future.
"""
from torch._decomp import core_aten_decompositions
ep = _decompose_exported_program(
ep,
cia_to_decomp={},
python_decomp_table=core_aten_decompositions(),
joint_loss_index=joint_loss_index,
)
gm, new_graph_signature = _copy_graph_module_and_signature(ep)
_remove_detach_pass(gm, new_graph_signature)
return ep._update(gm, new_graph_signature)
|