• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import argparse
2import numpy as np
3from mindspore import nn, context, ops
4from mindspore.common import dtype as mstype
5from mindspore.dataset import Cifar10Dataset
6from mindspore.dataset import vision, transforms
7from resnet import resnet18
8
9context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
10
11
12class ImageToDualImage:
13    @staticmethod
14    def __call__(img):
15        return np.concatenate((img, img), axis=0)
16
17
18def create_dataset(dataset_dir, batch_size, usage=None):
19    dataset = Cifar10Dataset(dataset_dir=dataset_dir, usage=usage)
20    type_cast_op = transforms.TypeCast(mstype.int32)
21
22    # define map operations
23    trans = [vision.ToPIL(),
24             vision.RandomCrop((32, 32), (4, 4, 4, 4)),
25             vision.RandomHorizontalFlip(prob=0.5),
26             vision.Resize((224, 224)),
27             vision.ToTensor(),
28             vision.Rescale(1.0 / 255.0, 0.0),
29             vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010], is_hwc=False),
30             ImageToDualImage()]
31
32    dataset = dataset.map(operations=type_cast_op, input_columns="label")
33    dataset = dataset.map(operations=trans, input_columns="image")
34    dataset = dataset.batch(batch_size)
35    return dataset
36
37
38def train(model, dataset, loss_fn, optimizer):
39    # Define forward function
40    def forward_fn(data, label):
41        logits = model(data)
42        loss = loss_fn(logits, label)
43        return loss, logits
44
45    # Get gradient function
46    grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
47
48    # Define function of one-step training
49    def train_step(data, label):
50        (loss, _), grads = grad_fn(data, label)
51        loss = ops.depend(loss, optimizer(grads))
52        return loss
53
54    size = dataset.get_dataset_size()
55    model.set_train()
56    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
57        loss = train_step(data, label)
58
59        if batch % 100 == 0:
60            loss, current = loss.asnumpy(), batch
61            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
62
63
64def test(model, dataset, loss_fn):
65    num_batches = dataset.get_dataset_size()
66    model.set_train(False)
67    total, test_loss, correct = 0, 0, 0
68    for data, label in dataset.create_tuple_iterator():
69        pred = model(data)
70        total += len(data)
71        test_loss += loss_fn(pred, label).asnumpy()
72        correct += (pred.argmax(1) == label).asnumpy().sum()
73    test_loss /= num_batches
74    correct /= total
75    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
76
77
78def main():
79    parser = argparse.ArgumentParser(description='MindSpore ResNet Testing')
80    parser.add_argument(
81        '--dataset', default=None, type=str, metavar='DS', required=True,
82        help='Path to the dataset folder'
83    )
84    parser.add_argument(
85        '--bs', default=64, type=int, metavar='N', required=False,
86        help='Mini-batch size'
87    )
88    args = parser.parse_args()
89
90    # Process the cifar dataset.
91    train_dataset = create_dataset(args.dataset, args.bs, "train")
92    test_dataset = create_dataset(args.dataset, args.bs, "test")
93
94    for img, lbl in test_dataset.create_tuple_iterator():
95        print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}")
96        print(f"Shape of label: {lbl.shape} {lbl.dtype}")
97        break
98
99    # Initialize hypercomplex model
100    net = resnet18()
101
102    # Initialize loss function and optimizer
103    criterion = nn.CrossEntropyLoss()
104    optim = nn.SGD(net.trainable_params(), 1e-2)
105
106    epochs = 10
107    for t in range(epochs):
108        print(f"Epoch {t+1}\n-------------------------------")
109        train(net, train_dataset, criterion, optim)
110        test(net, test_dataset, criterion)
111    print("Done!")
112
113
114if __name__ == "__main__":
115    main()
116