• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" Utils """
16
17from PIL import Image
18import numpy as np
19
20from mindspore.common import dtype as mstype
21import mindspore.dataset as ds
22import mindspore.dataset.transforms.c_transforms as C
23import mindspore.dataset.transforms.vision.c_transforms as CV
24from mindspore.dataset.transforms.vision import Inter
25
26
27def create_dataset(data_path, batch_size=32, repeat_size=1,
28                   num_parallel_workers=1):
29    """ create dataset for train or test
30    Args:
31        data_path: Data path
32        batch_size: The number of data records in each group
33        repeat_size: The number of replicated data records
34        num_parallel_workers: The number of parallel workers
35    """
36    # define dataset
37    mnist_ds = ds.MnistDataset(data_path)
38    #mnist_ds = ds.MnistDataset(data_path,num_samples=32)
39
40    # define operation parameters
41    resize_height, resize_width = 32, 32
42    rescale = 1.0 / 255.0
43    shift = 0.0
44
45    # define map operations
46    # resize images to (32, 32)
47    resize_op = CV.Resize((resize_height, resize_width),
48                          interpolation=Inter.LINEAR)
49    rescale_op = CV.Rescale(rescale, shift)  # rescale images
50    # change shape from (height, width, channel) to (channel, height, width) to fit network.
51    hwc2chw_op = CV.HWC2CHW()
52    # change data type of label to int32 to fit network
53    type_cast_op = C.TypeCast(mstype.int32)
54
55    # apply map operations on images
56    mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op,
57                            num_parallel_workers=num_parallel_workers)
58    mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op,
59                            num_parallel_workers=num_parallel_workers)
60    mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op,
61                            num_parallel_workers=num_parallel_workers)
62    mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op,
63                            num_parallel_workers=num_parallel_workers)
64
65    # apply DatasetOps
66    buffer_size = 10000
67    # 10000 as in LeNet train script
68    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
69    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
70    mnist_ds = mnist_ds.repeat(repeat_size)
71
72    return mnist_ds
73
74
75def save_img(data, name, size=32, num=32):
76    """
77    Visualize data and save to target files
78    Args:
79        data: nparray of size (num, size, size)
80        name: output file name
81        size: image size
82        num: number of images
83    """
84    col = int(num / 8)
85    row = 8
86
87    imgs = Image.new('L', (size*col, size*row))
88    for i in range(num):
89        j = i/8
90        img_data = data[i]
91        img_data = np.resize(img_data, (size, size))
92        img_data = img_data * 255
93        img_data = img_data.astype(np.uint8)
94        im = Image.fromarray(img_data, 'L')
95        imgs.paste(im, (int(j) * size, (i % 8) * size))
96    imgs.save(name)
97