1import argparse 2 3import numpy as np 4import mindspore 5from mindspore import nn, context, ops 6from mindspore.common import dtype as mstype 7from mindspore.dataset import MnistDataset 8from hcmodel import HCModel 9 10context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 11 12 13class ImageToDualImage: 14 @staticmethod 15 def __call__(img): 16 return np.concatenate((img, img), axis=0) 17 18 19def create_dataset(dataset_dir, batch_size, usage=None): 20 dataset = MnistDataset(dataset_dir=dataset_dir, usage=usage) 21 type_cast_op = mindspore.dataset.transforms.TypeCast(mstype.int32) 22 23 # define map operations 24 trans = [mindspore.dataset.vision.Rescale(1.0 / 255.0, 0), 25 mindspore.dataset.vision.Normalize(mean=(0.1307,), std=(0.3081,)), 26 mindspore.dataset.vision.HWC2CHW(), 27 ImageToDualImage()] 28 29 dataset = dataset.map(operations=type_cast_op, input_columns="label") 30 dataset = dataset.map(operations=trans, input_columns="image") 31 dataset = dataset.batch(batch_size) 32 return dataset 33 34 35def train(model, dataset, loss_fn, optimizer): 36 # Define forward function 37 def forward_fn(data, label): 38 logits = model(data) 39 loss = loss_fn(logits, label) 40 return loss, logits 41 42 # Get gradient function 43 grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) 44 45 # Define function of one-step training 46 def train_step(data, label): 47 (loss, _), grads = grad_fn(data, label) 48 loss = ops.depend(loss, optimizer(grads)) 49 return loss 50 51 size = dataset.get_dataset_size() 52 model.set_train() 53 for batch, (data, label) in enumerate(dataset.create_tuple_iterator()): 54 loss = train_step(data, label) 55 56 if batch % 100 == 0: 57 loss, current = loss.asnumpy(), batch 58 print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]") 59 60 61def test(model, dataset, loss_fn): 62 num_batches = dataset.get_dataset_size() 63 model.set_train(False) 64 total, test_loss, correct = 0, 0, 0 65 for data, label in dataset.create_tuple_iterator(): 66 pred = model(data) 67 total += len(data) 68 test_loss += loss_fn(pred, label).asnumpy() 69 correct += (pred.argmax(1) == label).asnumpy().sum() 70 test_loss /= num_batches 71 correct /= total 72 print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") 73 74 75def main(): 76 parser = argparse.ArgumentParser(description='MindSpore MNIST Testing') 77 parser.add_argument( 78 '--dataset', default=None, type=str, metavar='DS', required=True, 79 help='Path to the dataset folder' 80 ) 81 parser.add_argument( 82 '--bs', default=64, type=int, metavar='N', required=False, 83 help='Mini-batch size' 84 ) 85 args = parser.parse_args() 86 87 # Process the MNIST dataset. 88 train_dataset = create_dataset(args.dataset, args.bs, "train") 89 test_dataset = create_dataset(args.dataset, args.bs, "test") 90 91 for img, lbl in test_dataset.create_tuple_iterator(): 92 print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}") 93 print(f"Shape of label: {lbl.shape} {lbl.dtype}") 94 break 95 96 # Initialize hypercomplex model 97 net = HCModel() 98 99 # Initialize loss function and optimizer 100 criterion = nn.CrossEntropyLoss() 101 optim = nn.SGD(net.trainable_params(), 1e-2) 102 103 epochs = 10 104 for t in range(epochs): 105 print(f"Epoch {t+1}\n-------------------------------") 106 train(net, train_dataset, criterion, optim) 107 test(net, test_dataset, criterion) 108 print("Done!") 109 110 111if __name__ == "__main__": 112 main() 113