1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18import mindspore.dataset.vision.c_transforms as vision 19from mindspore import log as logger 20 21DATA_DIR = "../data/dataset/testPK/data" 22 23 24# Generate 1d int numpy array from 0 - 64 25def generator_1d(): 26 for i in range(64): 27 yield (np.array([i]),) 28 29 30def test_apply_generator_case(): 31 # apply dataset operations 32 data1 = ds.GeneratorDataset(generator_1d, ["data"]) 33 data2 = ds.GeneratorDataset(generator_1d, ["data"]) 34 35 def dataset_fn(ds_): 36 ds_ = ds_.repeat(2) 37 return ds_.batch(4) 38 39 data1 = data1.apply(dataset_fn) 40 data2 = data2.repeat(2) 41 data2 = data2.batch(4) 42 43 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 44 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 45 np.testing.assert_array_equal(item1["data"], item2["data"]) 46 47 48def test_apply_imagefolder_case(): 49 # apply dataset map operations 50 data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3) 51 data2 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3) 52 53 decode_op = vision.Decode() 54 normalize_op = vision.Normalize([121.0, 115.0, 100.0], [70.0, 68.0, 71.0]) 55 56 def dataset_fn(ds_): 57 ds_ = ds_.map(operations=decode_op) 58 ds_ = ds_.map(operations=normalize_op) 59 ds_ = ds_.repeat(2) 60 return ds_ 61 62 data1 = data1.apply(dataset_fn) 63 data2 = data2.map(operations=decode_op) 64 data2 = data2.map(operations=normalize_op) 65 data2 = data2.repeat(2) 66 67 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 68 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 69 np.testing.assert_array_equal(item1["image"], item2["image"]) 70 71 72def test_apply_flow_case_0(id_=0): 73 # apply control flow operations 74 data1 = ds.GeneratorDataset(generator_1d, ["data"]) 75 76 def dataset_fn(ds_): 77 if id_ == 0: 78 ds_ = ds_.batch(4) 79 elif id_ == 1: 80 ds_ = ds_.repeat(2) 81 elif id_ == 2: 82 ds_ = ds_.batch(4) 83 ds_ = ds_.repeat(2) 84 else: 85 ds_ = ds_.shuffle(buffer_size=4) 86 return ds_ 87 88 data1 = data1.apply(dataset_fn) 89 num_iter = 0 90 for _ in data1.create_dict_iterator(num_epochs=1): 91 num_iter += 1 92 93 if id_ == 0: 94 assert num_iter == 16 95 elif id_ == 1: 96 assert num_iter == 128 97 elif id_ == 2: 98 assert num_iter == 32 99 else: 100 assert num_iter == 64 101 102 103def test_apply_flow_case_1(id_=1): 104 # apply control flow operations 105 data1 = ds.GeneratorDataset(generator_1d, ["data"]) 106 107 def dataset_fn(ds_): 108 if id_ == 0: 109 ds_ = ds_.batch(4) 110 elif id_ == 1: 111 ds_ = ds_.repeat(2) 112 elif id_ == 2: 113 ds_ = ds_.batch(4) 114 ds_ = ds_.repeat(2) 115 else: 116 ds_ = ds_.shuffle(buffer_size=4) 117 return ds_ 118 119 data1 = data1.apply(dataset_fn) 120 num_iter = 0 121 for _ in data1.create_dict_iterator(num_epochs=1): 122 num_iter += 1 123 124 if id_ == 0: 125 assert num_iter == 16 126 elif id_ == 1: 127 assert num_iter == 128 128 elif id_ == 2: 129 assert num_iter == 32 130 else: 131 assert num_iter == 64 132 133 134def test_apply_flow_case_2(id_=2): 135 # apply control flow operations 136 data1 = ds.GeneratorDataset(generator_1d, ["data"]) 137 138 def dataset_fn(ds_): 139 if id_ == 0: 140 ds_ = ds_.batch(4) 141 elif id_ == 1: 142 ds_ = ds_.repeat(2) 143 elif id_ == 2: 144 ds_ = ds_.batch(4) 145 ds_ = ds_.repeat(2) 146 else: 147 ds_ = ds_.shuffle(buffer_size=4) 148 return ds_ 149 150 data1 = data1.apply(dataset_fn) 151 num_iter = 0 152 for _ in data1.create_dict_iterator(num_epochs=1): 153 num_iter += 1 154 155 if id_ == 0: 156 assert num_iter == 16 157 elif id_ == 1: 158 assert num_iter == 128 159 elif id_ == 2: 160 assert num_iter == 32 161 else: 162 assert num_iter == 64 163 164 165def test_apply_flow_case_3(id_=3): 166 # apply control flow operations 167 data1 = ds.GeneratorDataset(generator_1d, ["data"]) 168 169 def dataset_fn(ds_): 170 if id_ == 0: 171 ds_ = ds_.batch(4) 172 elif id_ == 1: 173 ds_ = ds_.repeat(2) 174 elif id_ == 2: 175 ds_ = ds_.batch(4) 176 ds_ = ds_.repeat(2) 177 else: 178 ds_ = ds_.shuffle(buffer_size=4) 179 return ds_ 180 181 data1 = data1.apply(dataset_fn) 182 num_iter = 0 183 for _ in data1.create_dict_iterator(num_epochs=1): 184 num_iter += 1 185 186 if id_ == 0: 187 assert num_iter == 16 188 elif id_ == 1: 189 assert num_iter == 128 190 elif id_ == 2: 191 assert num_iter == 32 192 else: 193 assert num_iter == 64 194 195 196def test_apply_exception_case(): 197 # apply exception operations 198 data1 = ds.GeneratorDataset(generator_1d, ["data"]) 199 200 def dataset_fn(ds_): 201 ds_ = ds_.repeat(2) 202 return ds_.batch(4) 203 204 def exception_fn(): 205 return np.array([[0], [1], [3], [4], [5]]) 206 207 try: 208 data1 = data1.apply("123") 209 for _ in data1.create_dict_iterator(num_epochs=1): 210 pass 211 assert False 212 except TypeError: 213 pass 214 215 try: 216 data1 = data1.apply(exception_fn) 217 for _ in data1.create_dict_iterator(num_epochs=1): 218 pass 219 assert False 220 except TypeError: 221 pass 222 223 try: 224 data2 = data1.apply(dataset_fn) 225 _ = data1.apply(dataset_fn) 226 for _, _ in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)): 227 pass 228 assert False 229 except ValueError as e: 230 logger.info("Got an exception in DE: {}".format(str(e))) 231 232 233if __name__ == '__main__': 234 logger.info("Running test_apply.py test_apply_generator_case() function") 235 test_apply_generator_case() 236 237 logger.info("Running test_apply.py test_apply_imagefolder_case() function") 238 test_apply_imagefolder_case() 239 240 logger.info("Running test_apply.py test_apply_flow_case(id) function") 241 test_apply_flow_case_0() 242 test_apply_flow_case_1() 243 test_apply_flow_case_2() 244 test_apply_flow_case_3() 245 246 logger.info("Running test_apply.py test_apply_exception_case() function") 247 test_apply_exception_case() 248