1# Usage: python create_dummy_model.py <name_of_the_file> 2import sys 3 4import torch 5from torch import nn 6 7 8class NeuralNetwork(nn.Module): 9 def __init__(self) -> None: 10 super().__init__() 11 self.flatten = nn.Flatten() 12 self.linear_relu_stack = nn.Sequential( 13 nn.Linear(28 * 28, 512), 14 nn.ReLU(), 15 nn.Linear(512, 512), 16 nn.ReLU(), 17 nn.Linear(512, 10), 18 ) 19 20 def forward(self, x): 21 x = self.flatten(x) 22 logits = self.linear_relu_stack(x) 23 return logits 24 25 26if __name__ == "__main__": 27 jit_module = torch.jit.script(NeuralNetwork()) 28 torch.jit.save(jit_module, sys.argv[1]) 29 orig_module = nn.Sequential( 30 nn.Linear(28 * 28, 512), 31 nn.ReLU(), 32 nn.Linear(512, 512), 33 nn.ReLU(), 34 nn.Linear(512, 10), 35 ) 36 torch.save(orig_module, sys.argv[1] + ".orig") 37