1# mypy: ignore-errors 2 3import os 4 5from torchvision import datasets, transforms 6 7import torch 8import torch._lazy 9import torch._lazy.metrics 10import torch._lazy.ts_backend 11import torch.nn as nn 12import torch.nn.functional as F 13import torch.optim as optim 14from torch.optim.lr_scheduler import StepLR 15 16 17torch._lazy.ts_backend.init() 18 19 20class Net(nn.Module): 21 def __init__(self) -> None: 22 super().__init__() 23 self.conv1 = nn.Conv2d(1, 32, 3, 1) 24 self.conv2 = nn.Conv2d(32, 64, 3, 1) 25 self.dropout1 = nn.Dropout(0.25) 26 self.dropout2 = nn.Dropout(0.5) 27 self.fc1 = nn.Linear(9216, 128) 28 self.fc2 = nn.Linear(128, 10) 29 30 def forward(self, x): 31 x = self.conv1(x) 32 x = F.relu(x) 33 x = self.conv2(x) 34 x = F.relu(x) 35 x = F.max_pool2d(x, 2) 36 x = self.dropout1(x) 37 x = torch.flatten(x, 1) 38 x = self.fc1(x) 39 x = F.relu(x) 40 x = self.dropout2(x) 41 x = self.fc2(x) 42 output = F.log_softmax(x, dim=1) 43 return output 44 45 46def train(log_interval, model, device, train_loader, optimizer, epoch): 47 model.train() 48 for batch_idx, (data, target) in enumerate(train_loader): 49 data, target = data.to(device), target.to(device) 50 optimizer.zero_grad(set_to_none=True) 51 output = model(data) 52 loss = F.nll_loss(output, target) 53 loss.backward() 54 optimizer.step() 55 torch._lazy.mark_step() 56 57 if batch_idx % log_interval == 0: 58 print( 59 f"Train Epoch: {epoch} " 60 f"[{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]" 61 f"\tLoss: {loss.item():.6f}" 62 ) 63 64 65if __name__ == "__main__": 66 bsz = 64 67 device = "lazy" 68 epochs = 14 69 log_interval = 10 70 lr = 1 71 gamma = 0.7 72 train_kwargs = {"batch_size": bsz} 73 # if we want to use CUDA 74 if "LTC_TS_CUDA" in os.environ: 75 cuda_kwargs = { 76 "num_workers": 1, 77 "pin_memory": True, 78 "shuffle": True, 79 "batch_size": bsz, 80 } 81 train_kwargs.update(cuda_kwargs) 82 83 transform = transforms.Compose( 84 [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 85 ) 86 dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform) 87 train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) 88 model = Net().to(device) 89 optimizer = optim.Adadelta(model.parameters(), lr=lr) 90 scheduler = StepLR(optimizer, step_size=1, gamma=gamma) 91 for epoch in range(1, epochs + 1): 92 train(log_interval, model, device, train_loader, optimizer, epoch) 93 scheduler.step() 94