• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import torch
2from torch import Tensor
3
4
5@torch.jit.interface
6class ModuleInterface(torch.nn.Module):
7    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
8        pass
9
10
11class OrigModule(torch.nn.Module):
12    """A module that implements ModuleInterface."""
13
14    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
15        return inp1 + inp2 + 1
16
17    def two(self, input: Tensor) -> Tensor:
18        return input + 2
19
20    def forward(self, input: Tensor) -> Tensor:
21        return input + self.one(input, input) + 1
22
23
24class NewModule(torch.nn.Module):
25    """A *different* module that implements ModuleInterface."""
26
27    def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
28        return inp1 * inp2 + 1
29
30    def forward(self, input: Tensor) -> Tensor:
31        return self.one(input, input + 1)
32
33
34class UsesInterface(torch.nn.Module):
35    proxy_mod: ModuleInterface
36
37    def __init__(self) -> None:
38        super().__init__()
39        self.proxy_mod = OrigModule()
40
41    def forward(self, input: Tensor) -> Tensor:
42        return self.proxy_mod.one(input, input)
43