1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" create train dataset. """ 16 17from functools import partial 18import mindspore.dataset as ds 19import mindspore.common.dtype as mstype 20import mindspore.dataset.vision.c_transforms as C 21import mindspore.dataset.transforms.c_transforms as C2 22 23 24def create_dataset(dataset_path, config, repeat_num=1, batch_size=32): 25 """ 26 create a train dataset 27 28 Args: 29 dataset_path(string): the path of dataset. 30 config(EasyDict):the basic config for training 31 repeat_num(int): the repeat times of dataset. Default: 1. 32 batch_size(int): the batch size of dataset. Default: 32. 33 34 Returns: 35 dataset 36 """ 37 38 load_func = partial(ds.Cifar10Dataset, dataset_path) 39 cifar_ds = load_func(num_parallel_workers=8, shuffle=False) 40 41 resize_height = config.image_height 42 resize_width = config.image_width 43 rescale = 1.0 / 255.0 44 shift = 0.0 45 46 # define map operations 47 # interpolation default BILINEAR 48 resize_op = C.Resize((resize_height, resize_width)) 49 rescale_op = C.Rescale(rescale, shift) 50 normalize_op = C.Normalize( 51 (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 52 changeswap_op = C.HWC2CHW() 53 type_cast_op = C2.TypeCast(mstype.int32) 54 55 c_trans = [resize_op, rescale_op, normalize_op, changeswap_op] 56 57 # apply map operations on images 58 cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op) 59 cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans) 60 61 # apply batch operations 62 cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True) 63 64 # apply dataset repeat operation 65 cifar_ds = cifar_ds.repeat(repeat_num) 66 67 return cifar_ds 68