• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import torch
3
4from torch.utils import _pytree as pytree
5
6class PytreeFlatten(torch.nn.Module):
7    """
8    Pytree from PyTorch can be captured by TorchDynamo.
9    """
10
11    def forward(self, x):
12        y, spec = pytree.tree_flatten(x)
13        return y[0] + 1
14
15example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
16model = PytreeFlatten()
17