1import argparse 2import numpy as np 3from mindspore import nn, context, ops 4from mindspore.common import dtype as mstype 5from mindspore.dataset import Cifar10Dataset 6from mindspore.dataset import vision, transforms 7from resnet import resnet18 8 9context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 10 11 12class ImageToDualImage: 13 @staticmethod 14 def __call__(img): 15 return np.concatenate((img, img), axis=0) 16 17 18def create_dataset(dataset_dir, batch_size, usage=None): 19 dataset = Cifar10Dataset(dataset_dir=dataset_dir, usage=usage) 20 type_cast_op = transforms.TypeCast(mstype.int32) 21 22 # define map operations 23 trans = [vision.ToPIL(), 24 vision.RandomCrop((32, 32), (4, 4, 4, 4)), 25 vision.RandomHorizontalFlip(prob=0.5), 26 vision.Resize((224, 224)), 27 vision.ToTensor(), 28 vision.Rescale(1.0 / 255.0, 0.0), 29 vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010], is_hwc=False), 30 ImageToDualImage()] 31 32 dataset = dataset.map(operations=type_cast_op, input_columns="label") 33 dataset = dataset.map(operations=trans, input_columns="image") 34 dataset = dataset.batch(batch_size) 35 return dataset 36 37 38def train(model, dataset, loss_fn, optimizer): 39 # Define forward function 40 def forward_fn(data, label): 41 logits = model(data) 42 loss = loss_fn(logits, label) 43 return loss, logits 44 45 # Get gradient function 46 grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) 47 48 # Define function of one-step training 49 def train_step(data, label): 50 (loss, _), grads = grad_fn(data, label) 51 loss = ops.depend(loss, optimizer(grads)) 52 return loss 53 54 size = dataset.get_dataset_size() 55 model.set_train() 56 for batch, (data, label) in enumerate(dataset.create_tuple_iterator()): 57 loss = train_step(data, label) 58 59 if batch % 100 == 0: 60 loss, current = loss.asnumpy(), batch 61 print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]") 62 63 64def test(model, dataset, loss_fn): 65 num_batches = dataset.get_dataset_size() 66 model.set_train(False) 67 total, test_loss, correct = 0, 0, 0 68 for data, label in dataset.create_tuple_iterator(): 69 pred = model(data) 70 total += len(data) 71 test_loss += loss_fn(pred, label).asnumpy() 72 correct += (pred.argmax(1) == label).asnumpy().sum() 73 test_loss /= num_batches 74 correct /= total 75 print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") 76 77 78def main(): 79 parser = argparse.ArgumentParser(description='MindSpore ResNet Testing') 80 parser.add_argument( 81 '--dataset', default=None, type=str, metavar='DS', required=True, 82 help='Path to the dataset folder' 83 ) 84 parser.add_argument( 85 '--bs', default=64, type=int, metavar='N', required=False, 86 help='Mini-batch size' 87 ) 88 args = parser.parse_args() 89 90 # Process the cifar dataset. 91 train_dataset = create_dataset(args.dataset, args.bs, "train") 92 test_dataset = create_dataset(args.dataset, args.bs, "test") 93 94 for img, lbl in test_dataset.create_tuple_iterator(): 95 print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}") 96 print(f"Shape of label: {lbl.shape} {lbl.dtype}") 97 break 98 99 # Initialize hypercomplex model 100 net = resnet18() 101 102 # Initialize loss function and optimizer 103 criterion = nn.CrossEntropyLoss() 104 optim = nn.SGD(net.trainable_params(), 1e-2) 105 106 epochs = 10 107 for t in range(epochs): 108 print(f"Epoch {t+1}\n-------------------------------") 109 train(net, train_dataset, criterion, optim) 110 test(net, test_dataset, criterion) 111 print("Done!") 112 113 114if __name__ == "__main__": 115 main() 116