# Usage: python create_dummy_model.py import sys import torch from torch import nn class NeuralNetwork(nn.Module): def __init__(self) -> None: super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits if __name__ == "__main__": jit_module = torch.jit.script(NeuralNetwork()) torch.jit.save(jit_module, sys.argv[1]) orig_module = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) torch.save(orig_module, sys.argv[1] + ".orig")