1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.common.dtype as mstype 18import mindspore.dataset as ds 19import mindspore.dataset.transforms.c_transforms as C 20import mindspore.dataset.transforms.py_transforms 21import mindspore.dataset.vision.py_transforms as F 22from mindspore import log as logger 23 24 25# In generator dataset: Number of rows is 3; its values are 0, 1, 2 26def generator(): 27 for i in range(3): 28 yield (np.array([i]),) 29 30 31# In generator_10 dataset: Number of rows is 7; its values are 3, 4, 5 ... 9 32def generator_10(): 33 for i in range(3, 10): 34 yield (np.array([i]),) 35 36 37# In generator_20 dataset: Number of rows is 10; its values are 10, 11, 12 ... 19 38def generator_20(): 39 for i in range(10, 20): 40 yield (np.array([i]),) 41 42 43# In generator_29 dataset: Number of rows is 9; its values are 20, 21, 22 ... 28 44def generator_29(): 45 for i in range(20, 29): 46 yield (np.array([i]),) 47 48 49def test_concat_01(): 50 """ 51 Test concat: test concat 2 datasets that have the same column name and data type 52 """ 53 logger.info("test_concat_01") 54 data1 = ds.GeneratorDataset(generator, ["col1"]) 55 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 56 57 data3 = data1 + data2 58 59 # Here i refers to index, d refers to data element 60 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 61 t = d 62 logger.info("data: %i", t[0][0]) 63 assert i == t[0][0] 64 65 assert sum([1 for _ in data3]) == 10 66 67 68def test_concat_02(): 69 """ 70 Test concat: test concat 2 datasets using concat operation not "+" operation 71 """ 72 logger.info("test_concat_02") 73 data1 = ds.GeneratorDataset(generator, ["col1"]) 74 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 75 76 data3 = data1.concat(data2) 77 78 # Here i refers to index, d refers to data element 79 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 80 t = d 81 logger.info("data: %i", t[0][0]) 82 assert i == t[0][0] 83 84 assert sum([1 for _ in data3]) == 10 85 86 87def test_concat_03(): 88 """ 89 Test concat: test concat dataset that has different column 90 """ 91 logger.info("test_concat_03") 92 data1 = ds.GeneratorDataset(generator, ["col1"]) 93 data2 = ds.GeneratorDataset(generator_10, ["col2"]) 94 95 data3 = data1 + data2 96 97 try: 98 for _, _ in enumerate(data3): 99 pass 100 assert False 101 except RuntimeError: 102 pass 103 104 105def test_concat_04(): 106 """ 107 Test concat: test concat dataset that has different rank 108 """ 109 logger.info("test_concat_04") 110 data1 = ds.GeneratorDataset(generator, ["col1"]) 111 data2 = ds.GeneratorDataset(generator_10, ["col2"]) 112 data2 = data2.batch(3) 113 114 data3 = data1 + data2 115 116 try: 117 for _, _ in enumerate(data3): 118 pass 119 assert False 120 except RuntimeError: 121 pass 122 123 124def test_concat_05(): 125 """ 126 Test concat: test concat dataset that has different data type 127 """ 128 logger.info("test_concat_05") 129 data1 = ds.GeneratorDataset(generator, ["col1"]) 130 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 131 132 type_cast_op = C.TypeCast(mstype.float32) 133 data1 = data1.map(operations=type_cast_op, input_columns=["col1"]) 134 135 data3 = data1 + data2 136 137 try: 138 for _, _ in enumerate(data3): 139 pass 140 assert False 141 except RuntimeError: 142 pass 143 144 145def test_concat_06(): 146 """ 147 Test concat: test concat multi datasets in one time 148 """ 149 logger.info("test_concat_06") 150 data1 = ds.GeneratorDataset(generator, ["col1"]) 151 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 152 data3 = ds.GeneratorDataset(generator_20, ["col1"]) 153 154 dataset = data1 + data2 + data3 155 156 # Here i refers to index, d refers to data element 157 for i, d in enumerate(dataset.create_tuple_iterator(output_numpy=True)): 158 t = d 159 logger.info("data: %i", t[0][0]) 160 assert i == t[0][0] 161 162 assert sum([1 for _ in dataset]) == 20 163 164 165def test_concat_07(): 166 """ 167 Test concat: test concat one dataset with multi datasets (datasets list) 168 """ 169 logger.info("test_concat_07") 170 data1 = ds.GeneratorDataset(generator, ["col1"]) 171 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 172 data3 = ds.GeneratorDataset(generator_20, ["col1"]) 173 174 dataset = [data2] + [data3] 175 data4 = data1 + dataset 176 177 # Here i refers to index, d refers to data element 178 for i, d in enumerate(data4.create_tuple_iterator(output_numpy=True)): 179 t = d 180 logger.info("data: %i", t[0][0]) 181 assert i == t[0][0] 182 183 assert sum([1 for _ in data4]) == 20 184 185 186def test_concat_08(): 187 """ 188 Test concat: test concat 2 datasets, and then repeat 189 """ 190 logger.info("test_concat_08") 191 data1 = ds.GeneratorDataset(generator, ["col1"]) 192 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 193 194 data3 = data1 + data2 195 data3 = data3.repeat(2) 196 197 # Here i refers to index, d refers to data element 198 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 199 t = d 200 logger.info("data: %i", t[0][0]) 201 assert i % 10 == t[0][0] 202 203 assert sum([1 for _ in data3]) == 20 204 205 206def test_concat_09(): 207 """ 208 Test concat: test concat 2 datasets, both of them have been repeat before 209 """ 210 logger.info("test_concat_09") 211 data1 = ds.GeneratorDataset(generator, ["col1"]) 212 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 213 214 data1 = data1.repeat(2) 215 data2 = data2.repeat(2) 216 data3 = data1 + data2 217 218 res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9] 219 # Here i refers to index, d refers to data element 220 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 221 t = d 222 logger.info("data: %i", t[0][0]) 223 assert res[i] == t[0][0] 224 225 assert sum([1 for _ in data3]) == 20 226 227 228def test_concat_10(): 229 """ 230 Test concat: test concat 2 datasets, one of them have repeat before 231 """ 232 logger.info("test_concat_10") 233 data1 = ds.GeneratorDataset(generator, ["col1"]) 234 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 235 236 data1 = data1.repeat(2) 237 data3 = data1 + data2 238 239 res = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 240 # Here i refers to index, d refers to data element 241 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 242 t = d 243 logger.info("data: %i", t[0][0]) 244 assert res[i] == t[0][0] 245 246 assert sum([1 for _ in data3]) == 13 247 248 249def test_concat_11(): 250 """ 251 Test concat: test dataset batch then concat 252 """ 253 logger.info("test_concat_11") 254 data1 = ds.GeneratorDataset(generator, ["col1"]) 255 data2 = ds.GeneratorDataset(generator_20, ["col1"]) 256 257 data1 = data1.batch(3) 258 data2 = data2.batch(5) 259 260 data3 = data1 + data2 261 res = [0, 10, 15, 20] 262 263 # Here i refers to index, d refers to data element 264 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 265 t = d 266 logger.info("data: %i", t[0][0]) 267 assert res[i] == t[0][0] 268 269 assert sum([1 for _ in data3]) == 3 270 271 272def test_concat_12(): 273 """ 274 Test concat: test dataset concat then shuffle 275 """ 276 logger.info("test_concat_12") 277 data1 = ds.GeneratorDataset(generator, ["col1"]) 278 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 279 280 data3 = data1 + data2 281 res = [8, 6, 2, 5, 0, 4, 9, 3, 7, 1] 282 283 ds.config.set_seed(1) 284 assert data3.get_dataset_size() == 10 285 data3 = data3.shuffle(buffer_size=10) 286 287 # Here i refers to index, d refers to data element 288 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 289 t = d 290 logger.info("data: %i", t[0][0]) 291 assert res[i] == t[0][0] 292 293 assert sum([1 for _ in data3]) == 10 294 295 296def test_concat_13(): 297 """ 298 Test concat: test dataset batch then shuffle and concat 299 """ 300 logger.info("test_concat_13") 301 data1 = ds.GeneratorDataset(generator, ["col1"]) 302 data2 = ds.GeneratorDataset(generator_20, ["col1"]) 303 304 data1 = data1.batch(3) 305 data2 = data2.batch(5) 306 307 data3 = data1 + data2 308 res = [15, 0, 10] 309 310 ds.config.set_seed(1) 311 assert data3.get_dataset_size() == 3 312 313 data3 = data3.shuffle(buffer_size=int(data3.get_dataset_size())) 314 315 # Here i refers to index, d refers to data element 316 for i, d in enumerate(data3.create_tuple_iterator(output_numpy=True)): 317 t = d 318 logger.info("data: %i", t[0][0]) 319 assert res[i] == t[0][0] 320 321 assert sum([1 for _ in data3]) == 3 322 323 324def test_concat_14(): 325 """ 326 Test concat: Testing concat on two different source datasets with different dataset operations. 327 """ 328 logger.info("test_concat_14") 329 DATA_DIR = "../data/dataset/testPK/data" 330 DATA_DIR2 = "../data/dataset/testImageNetData/train/" 331 332 data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=3) 333 data2 = ds.ImageFolderDataset(DATA_DIR2, num_samples=2) 334 335 transforms1 = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), 336 F.Resize((224, 224)), 337 F.ToTensor()]) 338 339 data1 = data1.map(operations=transforms1, input_columns=["image"]) 340 data2 = data2.map(operations=transforms1, input_columns=["image"]) 341 data3 = data1 + data2 342 343 expected, output = [], [] 344 for d in data1.create_tuple_iterator(output_numpy=True): 345 expected.append(d[0]) 346 for d in data2.create_tuple_iterator(output_numpy=True): 347 expected.append(d[0]) 348 for d in data3.create_tuple_iterator(output_numpy=True): 349 output.append(d[0]) 350 351 assert len(expected) == len(output) 352 np.array_equal(np.array(output), np.array(expected)) 353 354 assert sum([1 for _ in data3]) == 5 355 assert data3.get_dataset_size() == 5 356 357 358def test_concat_15(): 359 """ 360 Test concat: create dataset with different format of dataset file, and then concat 361 """ 362 logger.info("test_concat_15") 363 DATA_DIR = "../data/dataset/testPK/data" 364 DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 365 366 data1 = ds.ImageFolderDataset(DATA_DIR) 367 data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) 368 369 data1 = data1.project(["image"]) 370 data3 = data1 + data2 371 372 assert sum([1 for _ in data3]) == 47 373 374 375def test_concat_16(): 376 """ 377 Test concat: test get_dataset_size on nested concats 378 """ 379 logger.info("test_concat_16") 380 DATA_DIR = "../data/dataset/testPK/data" 381 DATA_DIR2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 382 383 data1 = ds.ImageFolderDataset(DATA_DIR) 384 data2 = ds.TFRecordDataset(DATA_DIR2, columns_list=["image"]) 385 386 data3 = ds.GeneratorDataset(generator, ["col1"]) 387 data4 = ds.GeneratorDataset(generator_10, ["col1"]) 388 389 data5 = data1 + data2 390 data6 = data3 + data4 391 data7 = data5 + data6 392 393 ds.config.set_seed(1) 394 395 # 57 is the total size of all 4 leaf datasets 396 assert data7.get_dataset_size() == 57 397 398 399def test_concat_17(): 400 """ 401 Test concat: test get_dataset_size on nested concats (with sampler) 402 """ 403 logger.info("test_concat_17") 404 405 data1 = ds.GeneratorDataset(generator, ["col1"]) 406 data2 = ds.GeneratorDataset(generator_10, ["col1"]) 407 408 data3 = ds.GeneratorDataset(generator_20, ["col1"]) 409 data4 = ds.GeneratorDataset(generator_29, ["col1"]) 410 411 data5 = data1 + data2 412 data6 = data3 + data4 413 data7 = data5 + data6 414 415 ds.config.set_seed(1) 416 shard_num = 10 417 counter = 0 418 419 for i in range(shard_num): 420 distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None) 421 data7.use_sampler(distributed_sampler) 422 iter_counter = 0 423 for _ in data7.create_dict_iterator(num_epochs=1, output_numpy=True): 424 counter += 1 425 iter_counter += 1 426 assert data7.get_dataset_size() == iter_counter 427 428 # 29 is the total size of all 4 leaf datasets 429 assert counter == 29 430 431 432if __name__ == "__main__": 433 test_concat_01() 434 test_concat_02() 435 test_concat_03() 436 test_concat_04() 437 test_concat_05() 438 test_concat_06() 439 test_concat_07() 440 test_concat_08() 441 test_concat_09() 442 test_concat_10() 443 test_concat_11() 444 test_concat_12() 445 test_concat_13() 446 test_concat_14() 447 test_concat_15() 448 test_concat_16() 449 test_concat_17() 450