• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import argparse
2
3import numpy as np
4import mindspore
5from mindspore import nn, context, ops
6from mindspore.common import dtype as mstype
7from mindspore.dataset import MnistDataset
8from hcmodel import HCModel
9
10context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
11
12
13class ImageToDualImage:
14    @staticmethod
15    def __call__(img):
16        return np.concatenate((img, img), axis=0)
17
18
19def create_dataset(dataset_dir, batch_size, usage=None):
20    dataset = MnistDataset(dataset_dir=dataset_dir, usage=usage)
21    type_cast_op = mindspore.dataset.transforms.TypeCast(mstype.int32)
22
23    # define map operations
24    trans = [mindspore.dataset.vision.Rescale(1.0 / 255.0, 0),
25             mindspore.dataset.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
26             mindspore.dataset.vision.HWC2CHW(),
27             ImageToDualImage()]
28
29    dataset = dataset.map(operations=type_cast_op, input_columns="label")
30    dataset = dataset.map(operations=trans, input_columns="image")
31    dataset = dataset.batch(batch_size)
32    return dataset
33
34
35def train(model, dataset, loss_fn, optimizer):
36    # Define forward function
37    def forward_fn(data, label):
38        logits = model(data)
39        loss = loss_fn(logits, label)
40        return loss, logits
41
42    # Get gradient function
43    grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
44
45    # Define function of one-step training
46    def train_step(data, label):
47        (loss, _), grads = grad_fn(data, label)
48        loss = ops.depend(loss, optimizer(grads))
49        return loss
50
51    size = dataset.get_dataset_size()
52    model.set_train()
53    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
54        loss = train_step(data, label)
55
56        if batch % 100 == 0:
57            loss, current = loss.asnumpy(), batch
58            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
59
60
61def test(model, dataset, loss_fn):
62    num_batches = dataset.get_dataset_size()
63    model.set_train(False)
64    total, test_loss, correct = 0, 0, 0
65    for data, label in dataset.create_tuple_iterator():
66        pred = model(data)
67        total += len(data)
68        test_loss += loss_fn(pred, label).asnumpy()
69        correct += (pred.argmax(1) == label).asnumpy().sum()
70    test_loss /= num_batches
71    correct /= total
72    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
73
74
75def main():
76    parser = argparse.ArgumentParser(description='MindSpore MNIST Testing')
77    parser.add_argument(
78        '--dataset', default=None, type=str, metavar='DS', required=True,
79        help='Path to the dataset folder'
80    )
81    parser.add_argument(
82        '--bs', default=64, type=int, metavar='N', required=False,
83        help='Mini-batch size'
84    )
85    args = parser.parse_args()
86
87    # Process the MNIST dataset.
88    train_dataset = create_dataset(args.dataset, args.bs, "train")
89    test_dataset = create_dataset(args.dataset, args.bs, "test")
90
91    for img, lbl in test_dataset.create_tuple_iterator():
92        print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}")
93        print(f"Shape of label: {lbl.shape} {lbl.dtype}")
94        break
95
96    # Initialize hypercomplex model
97    net = HCModel()
98
99    # Initialize loss function and optimizer
100    criterion = nn.CrossEntropyLoss()
101    optim = nn.SGD(net.trainable_params(), 1e-2)
102
103    epochs = 10
104    for t in range(epochs):
105        print(f"Epoch {t+1}\n-------------------------------")
106        train(net, train_dataset, criterion, optim)
107        test(net, test_dataset, criterion)
108    print("Done!")
109
110
111if __name__ == "__main__":
112    main()
113