• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Usage: python create_dummy_model.py <name_of_the_file>
2import sys
3
4import torch
5from torch import nn
6
7
8class NeuralNetwork(nn.Module):
9    def __init__(self) -> None:
10        super().__init__()
11        self.flatten = nn.Flatten()
12        self.linear_relu_stack = nn.Sequential(
13            nn.Linear(28 * 28, 512),
14            nn.ReLU(),
15            nn.Linear(512, 512),
16            nn.ReLU(),
17            nn.Linear(512, 10),
18        )
19
20    def forward(self, x):
21        x = self.flatten(x)
22        logits = self.linear_relu_stack(x)
23        return logits
24
25
26if __name__ == "__main__":
27    jit_module = torch.jit.script(NeuralNetwork())
28    torch.jit.save(jit_module, sys.argv[1])
29    orig_module = nn.Sequential(
30        nn.Linear(28 * 28, 512),
31        nn.ReLU(),
32        nn.Linear(512, 512),
33        nn.ReLU(),
34        nn.Linear(512, 10),
35    )
36    torch.save(orig_module, sys.argv[1] + ".orig")
37