1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""create train or eval dataset.""" 17 18import os 19import mindspore.common.dtype as mstype 20import mindspore.dataset as ds 21import mindspore.dataset.vision.c_transforms as C 22import mindspore.dataset.transforms.c_transforms as C2 23 24 25def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): 26 """ 27 create a train or eval dataset. 28 29 Args: 30 dataset_path(string): the path of dataset. 31 do_train(bool): whether dataset is used for train or eval. 32 repeat_num(int): the repeat times of dataset. Default: 1 33 batch_size(int): the batch size of dataset. Default: 32 34 35 Returns: 36 dataset 37 """ 38 39 device_num = int(os.getenv("RANK_SIZE")) 40 rank_id = int(os.getenv("RANK_ID")) 41 if do_train: 42 if device_num == 1: 43 data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=16, shuffle=True) 44 else: 45 data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True, 46 num_shards=device_num, shard_id=rank_id) 47 else: 48 data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=False, 49 num_shards=device_num, shard_id=rank_id) 50 51 image_size = 224 52 mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] 53 std = [0.229 * 255, 0.224 * 255, 0.225 * 255] 54 55 # define map operations 56 if do_train: 57 trans = [ 58 C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), 59 C.RandomHorizontalFlip(prob=0.5), 60 C.Normalize(mean=mean, std=std), 61 C.HWC2CHW() 62 ] 63 else: 64 trans = [ 65 C.Decode(), 66 C.Resize((256, 256)), 67 C.CenterCrop(image_size), 68 C.Normalize(mean=mean, std=std), 69 C.HWC2CHW() 70 ] 71 72 type_cast_op = C2.TypeCast(mstype.int32) 73 74 data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=12) 75 data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=4) 76 77 # apply batch operations 78 data_set = data_set.batch(batch_size, drop_remainder=True) 79 80 # apply dataset repeat operation 81 data_set = data_set.repeat(repeat_num) 82 return data_set 83