1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3from typing import Dict 4 5import torch 6from torch.export.unflatten import _ModuleFrame 7 8 9def _outline_submodules(orig_graph: torch.fx.Graph): 10 # Create an empty GraphModule to hold the outlined modules 11 new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 12 seen_nodes: Dict[str, torch.fx.Node] = {} 13 seen_modules: Dict[int, torch.nn.Module] = {} 14 _ModuleFrame( 15 orig_graph, 16 tuple(orig_graph.nodes), 17 seen_nodes, 18 seen_modules, 19 None, 20 [""], 21 "", 22 {}, 23 module=new_module, 24 ).run_outer() 25 new_module.graph.lint() 26 new_module.recompile() 27 return new_module 28