import copy import typing import torch from torch.export.exported_program import _decompose_exported_program def _copy_graph_module_and_signature( ep: torch.fx.GraphModule, ) -> typing.Tuple[ torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature ]: # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), # and this can break placeholder names in some particular cases. # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. # So we manually overwrite placeholder names by reading the old graph. gm = copy.deepcopy(ep.graph_module) new_graph_signature = copy.deepcopy(ep.graph_signature) # iterate over old/new graph modules for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] # iterate over placeholders assert len(old_phs) == len(new_phs) for old_node, new_node in zip(old_phs, new_phs): new_node.name = old_node.name return gm, new_graph_signature def _remove_detach_pass( gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature ) -> None: with gm._set_replace_hook(sig.get_replace_hook()): for node in list(reversed(gm.graph.nodes)): if node.op != "call_function": continue if ( node.target == torch.ops.aten.detach.default and len(node.users) == 1 and next(iter(node.users)).target == torch.ops.aten.detach.default ): next(iter(node.users)).replace_all_uses_with(node) gm.graph.eliminate_dead_code() gm.recompile() def _export_forward_backward( ep: torch.export.ExportedProgram, joint_loss_index: int = 0 ) -> torch.export.ExportedProgram: """ WARNING: This API is highly unstable and will be subject to change in the future. """ from torch._decomp import core_aten_decompositions ep = _decompose_exported_program( ep, decomp_table=core_aten_decompositions(), _preserve_ops=(), # type: ignore[arg-type] joint_loss_index=joint_loss_index, ) gm, new_graph_signature = _copy_graph_module_and_signature(ep) _remove_detach_pass(gm, new_graph_signature) return ep._update(gm, new_graph_signature)