• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3from typing import Dict
4
5import torch
6from torch.export.unflatten import _ModuleFrame
7
8
9def _outline_submodules(orig_graph: torch.fx.Graph):
10    # Create an empty GraphModule to hold the outlined modules
11    new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
12    seen_nodes: Dict[str, torch.fx.Node] = {}
13    seen_modules: Dict[int, torch.nn.Module] = {}
14    _ModuleFrame(
15        orig_graph,
16        tuple(orig_graph.nodes),
17        seen_nodes,
18        seen_modules,
19        None,
20        [""],
21        "",
22        {},
23        module=new_module,
24    ).run_outer()
25    new_module.graph.lint()
26    new_module.recompile()
27    return new_module
28