1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Produce the dataset 17""" 18 19import mindspore.dataset as ds 20import mindspore.dataset.vision.c_transforms as CV 21import mindspore.dataset.transforms.c_transforms as C 22from mindspore.dataset.vision import Inter 23from mindspore import dtype as mstype 24 25 26def create_dataset(data_path, batch_size=32, repeat_size=1, 27 num_parallel_workers=1): 28 """ 29 create dataset for train or test 30 """ 31 # define dataset 32 mnist_ds = ds.MnistDataset(data_path) 33 34 resize_height, resize_width = 32, 32 35 rescale = 1.0 / 255.0 36 shift = 0.0 37 rescale_nml = 1 / 0.3081 38 shift_nml = -1 * 0.1307 / 0.3081 39 40 # define map operations 41 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 42 rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) 43 rescale_op = CV.Rescale(rescale, shift) 44 hwc2chw_op = CV.HWC2CHW() 45 type_cast_op = C.TypeCast(mstype.int32) 46 47 # apply map operations on images 48 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) 49 mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) 50 mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) 51 mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) 52 mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) 53 54 # apply DatasetOps 55 buffer_size = 10000 56 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script 57 mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) 58 mnist_ds = mnist_ds.repeat(repeat_size) 59 60 return mnist_ds 61