1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" Utils """ 16 17from PIL import Image 18import numpy as np 19 20from mindspore.common import dtype as mstype 21import mindspore.dataset as ds 22import mindspore.dataset.transforms.c_transforms as C 23import mindspore.dataset.transforms.vision.c_transforms as CV 24from mindspore.dataset.transforms.vision import Inter 25 26 27def create_dataset(data_path, batch_size=32, repeat_size=1, 28 num_parallel_workers=1): 29 """ create dataset for train or test 30 Args: 31 data_path: Data path 32 batch_size: The number of data records in each group 33 repeat_size: The number of replicated data records 34 num_parallel_workers: The number of parallel workers 35 """ 36 # define dataset 37 mnist_ds = ds.MnistDataset(data_path) 38 #mnist_ds = ds.MnistDataset(data_path,num_samples=32) 39 40 # define operation parameters 41 resize_height, resize_width = 32, 32 42 rescale = 1.0 / 255.0 43 shift = 0.0 44 45 # define map operations 46 # resize images to (32, 32) 47 resize_op = CV.Resize((resize_height, resize_width), 48 interpolation=Inter.LINEAR) 49 rescale_op = CV.Rescale(rescale, shift) # rescale images 50 # change shape from (height, width, channel) to (channel, height, width) to fit network. 51 hwc2chw_op = CV.HWC2CHW() 52 # change data type of label to int32 to fit network 53 type_cast_op = C.TypeCast(mstype.int32) 54 55 # apply map operations on images 56 mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, 57 num_parallel_workers=num_parallel_workers) 58 mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, 59 num_parallel_workers=num_parallel_workers) 60 mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, 61 num_parallel_workers=num_parallel_workers) 62 mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, 63 num_parallel_workers=num_parallel_workers) 64 65 # apply DatasetOps 66 buffer_size = 10000 67 # 10000 as in LeNet train script 68 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) 69 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 70 mnist_ds = mnist_ds.repeat(repeat_size) 71 72 return mnist_ds 73 74 75def save_img(data, name, size=32, num=32): 76 """ 77 Visualize data and save to target files 78 Args: 79 data: nparray of size (num, size, size) 80 name: output file name 81 size: image size 82 num: number of images 83 """ 84 col = int(num / 8) 85 row = 8 86 87 imgs = Image.new('L', (size*col, size*row)) 88 for i in range(num): 89 j = i/8 90 img_data = data[i] 91 img_data = np.resize(img_data, (size, size)) 92 img_data = img_data * 255 93 img_data = img_data.astype(np.uint8) 94 im = Image.fromarray(img_data, 'L') 95 imgs.paste(im, (int(j) * size, (i % 8) * size)) 96 imgs.save(name) 97