1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16This is the test module for mindrecord 17""" 18import os 19import pytest 20import numpy as np 21 22import mindspore.dataset as ds 23from mindspore import log as logger 24from mindspore.dataset.text import to_str 25from mindspore.mindrecord import FileWriter 26 27FILES_NUM = 4 28CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" 29CV_DIR_NAME = "../data/mindrecord/testImageNetData" 30 31 32@pytest.fixture 33def add_and_remove_cv_file(): 34 """add/remove cv file""" 35 paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) 36 for x in range(FILES_NUM)] 37 try: 38 for x in paths: 39 if os.path.exists("{}".format(x)): 40 os.remove("{}".format(x)) 41 if os.path.exists("{}.db".format(x)): 42 os.remove("{}.db".format(x)) 43 writer = FileWriter(CV_FILE_NAME, FILES_NUM) 44 data = get_data(CV_DIR_NAME, True) 45 cv_schema_json = {"id": {"type": "int32"}, 46 "file_name": {"type": "string"}, 47 "label": {"type": "int32"}, 48 "data": {"type": "bytes"}} 49 writer.add_schema(cv_schema_json, "img_schema") 50 writer.add_index(["file_name", "label"]) 51 writer.write_raw_data(data) 52 writer.commit() 53 yield "yield_cv_data" 54 except Exception as error: 55 for x in paths: 56 os.remove("{}".format(x)) 57 os.remove("{}.db".format(x)) 58 raise error 59 else: 60 for x in paths: 61 os.remove("{}".format(x)) 62 os.remove("{}.db".format(x)) 63 64 65def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): 66 """tutorial for cv minderdataset.""" 67 num_readers = 4 68 sampler = ds.PKSampler(2) 69 data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, 70 sampler=sampler) 71 72 assert data_set.get_dataset_size() == 6 73 num_iter = 0 74 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 75 logger.info( 76 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 77 logger.info("-------------- item[file_name]: \ 78 {}------------------------".format(to_str(item["file_name"]))) 79 logger.info( 80 "-------------- item[label]: {} ----------------------------".format(item["label"])) 81 num_iter += 1 82 83 84def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): 85 """tutorial for cv minderdataset.""" 86 columns_list = ["data", "file_name", "label"] 87 num_readers = 4 88 sampler = ds.PKSampler(2) 89 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 90 sampler=sampler) 91 92 assert data_set.get_dataset_size() == 6 93 num_iter = 0 94 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 95 logger.info( 96 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 97 logger.info("-------------- item[data]: \ 98 {}------------------------".format(item["data"][:10])) 99 logger.info("-------------- item[file_name]: \ 100 {}------------------------".format(to_str(item["file_name"]))) 101 logger.info( 102 "-------------- item[label]: {} ----------------------------".format(item["label"])) 103 num_iter += 1 104 105 106def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): 107 """tutorial for cv minderdataset.""" 108 columns_list = ["data", "file_name", "label"] 109 num_readers = 4 110 sampler = ds.PKSampler(3, None, True) 111 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 112 sampler=sampler) 113 114 assert data_set.get_dataset_size() == 9 115 num_iter = 0 116 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 117 logger.info( 118 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 119 logger.info("-------------- item[file_name]: \ 120 {}------------------------".format(to_str(item["file_name"]))) 121 logger.info( 122 "-------------- item[label]: {} ----------------------------".format(item["label"])) 123 num_iter += 1 124 assert num_iter == 9 125 126 127def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file): 128 """tutorial for cv minderdataset.""" 129 columns_list = ["data", "file_name", "label"] 130 num_readers = 4 131 sampler = ds.PKSampler(3, None, True, 'label', 5) 132 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 133 sampler=sampler) 134 135 assert data_set.get_dataset_size() == 5 136 num_iter = 0 137 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 138 logger.info( 139 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 140 logger.info("-------------- item[file_name]: \ 141 {}------------------------".format(to_str(item["file_name"]))) 142 logger.info( 143 "-------------- item[label]: {} ----------------------------".format(item["label"])) 144 num_iter += 1 145 assert num_iter == 5 146 147 148def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file): 149 """tutorial for cv minderdataset.""" 150 columns_list = ["data", "file_name", "label"] 151 num_readers = 4 152 sampler = ds.PKSampler(3, None, True, 'label', 10) 153 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 154 sampler=sampler) 155 156 assert data_set.get_dataset_size() == 9 157 num_iter = 0 158 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 159 logger.info( 160 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 161 logger.info("-------------- item[file_name]: \ 162 {}------------------------".format(to_str(item["file_name"]))) 163 logger.info( 164 "-------------- item[label]: {} ----------------------------".format(item["label"])) 165 num_iter += 1 166 assert num_iter == 9 167 168 169def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file): 170 """tutorial for cv minderdataset.""" 171 columns_list = ["data", "file_name", "label"] 172 num_readers = 4 173 sampler = ds.PKSampler(5, None, True) 174 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 175 sampler=sampler) 176 assert data_set.get_dataset_size() == 15 177 num_iter = 0 178 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 179 logger.info( 180 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 181 logger.info("-------------- item[file_name]: \ 182 {}------------------------".format(to_str(item["file_name"]))) 183 logger.info( 184 "-------------- item[label]: {} ----------------------------".format(item["label"])) 185 num_iter += 1 186 assert num_iter == 15 187 188 189def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file): 190 """tutorial for cv minderdataset.""" 191 columns_list = ["data", "file_name", "label"] 192 num_readers = 4 193 sampler = ds.PKSampler(5, None, True, 'label', 20) 194 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 195 sampler=sampler) 196 assert data_set.get_dataset_size() == 15 197 num_iter = 0 198 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 199 logger.info( 200 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 201 logger.info("-------------- item[file_name]: \ 202 {}------------------------".format(to_str(item["file_name"]))) 203 logger.info( 204 "-------------- item[label]: {} ----------------------------".format(item["label"])) 205 num_iter += 1 206 assert num_iter == 15 207 208 209def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file): 210 """tutorial for cv minderdataset.""" 211 columns_list = ["data", "file_name", "label"] 212 num_readers = 4 213 sampler = ds.PKSampler(5, None, True, 'label', 10) 214 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 215 sampler=sampler) 216 assert data_set.get_dataset_size() == 10 217 num_iter = 0 218 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 219 logger.info( 220 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 221 logger.info("-------------- item[file_name]: \ 222 {}------------------------".format(to_str(item["file_name"]))) 223 logger.info( 224 "-------------- item[label]: {} ----------------------------".format(item["label"])) 225 num_iter += 1 226 assert num_iter == 10 227 228 229def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): 230 """tutorial for cv minderdataset.""" 231 columns_list = ["data", "file_name", "label"] 232 num_readers = 4 233 indices = [1, 2, 3, 5, 7] 234 samplers = (ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)) 235 for sampler in samplers: 236 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 237 sampler=sampler) 238 assert data_set.get_dataset_size() == 5 239 num_iter = 0 240 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 241 logger.info( 242 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 243 logger.info( 244 "-------------- item[data]: {} -----------------------------".format(item["data"])) 245 logger.info( 246 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 247 logger.info( 248 "-------------- item[label]: {} ----------------------------".format(item["label"])) 249 num_iter += 1 250 assert num_iter == 5 251 252 253def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): 254 """tutorial for cv minderdataset.""" 255 columns_list = ["data", "file_name", "label"] 256 num_readers = 4 257 indices = [1, 2, 2, 5, 7, 9] 258 samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices) 259 for sampler in samplers: 260 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 261 sampler=sampler) 262 assert data_set.get_dataset_size() == 6 263 num_iter = 0 264 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 265 logger.info( 266 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 267 logger.info( 268 "-------------- item[data]: {} -----------------------------".format(item["data"])) 269 logger.info( 270 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 271 logger.info( 272 "-------------- item[label]: {} ----------------------------".format(item["label"])) 273 num_iter += 1 274 assert num_iter == 6 275 276 277def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): 278 """tutorial for cv minderdataset.""" 279 columns_list = ["data", "file_name", "label"] 280 num_readers = 4 281 indices = [] 282 samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices) 283 for sampler in samplers: 284 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 285 sampler=sampler) 286 assert data_set.get_dataset_size() == 0 287 num_iter = 0 288 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 289 logger.info( 290 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 291 logger.info( 292 "-------------- item[data]: {} -----------------------------".format(item["data"])) 293 logger.info( 294 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 295 logger.info( 296 "-------------- item[label]: {} ----------------------------".format(item["label"])) 297 num_iter += 1 298 assert num_iter == 0 299 300 301def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file): 302 """tutorial for cv minderdataset.""" 303 columns_list = ["data", "file_name", "label"] 304 num_readers = 4 305 indices = [1, 2, 4, 11, 13] 306 samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices) 307 for sampler in samplers: 308 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 309 sampler=sampler) 310 assert data_set.get_dataset_size() == 5 311 num_iter = 0 312 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 313 logger.info( 314 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 315 logger.info( 316 "-------------- item[data]: {} -----------------------------".format(item["data"])) 317 logger.info( 318 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 319 logger.info( 320 "-------------- item[label]: {} ----------------------------".format(item["label"])) 321 num_iter += 1 322 assert num_iter == 5 323 324 325def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): 326 columns_list = ["data", "file_name", "label"] 327 num_readers = 4 328 indices = [1, 2, 4, -1, -2] 329 samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices) 330 for sampler in samplers: 331 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 332 sampler=sampler) 333 assert data_set.get_dataset_size() == 5 334 num_iter = 0 335 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 336 logger.info( 337 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 338 logger.info( 339 "-------------- item[data]: {} -----------------------------".format(item["data"])) 340 logger.info( 341 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 342 logger.info( 343 "-------------- item[label]: {} ----------------------------".format(item["label"])) 344 num_iter += 1 345 assert num_iter == 5 346 347 348def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file): 349 data = get_data(CV_DIR_NAME, True) 350 columns_list = ["data", "file_name", "label"] 351 num_readers = 4 352 sampler = ds.RandomSampler() 353 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 354 sampler=sampler) 355 assert data_set.get_dataset_size() == 10 356 num_iter = 0 357 new_dataset = [] 358 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 359 logger.info( 360 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 361 logger.info( 362 "-------------- item[data]: {} -----------------------------".format(item["data"])) 363 logger.info( 364 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 365 logger.info( 366 "-------------- item[label]: {} ----------------------------".format(item["label"])) 367 num_iter += 1 368 new_dataset.append(item['file_name']) 369 assert num_iter == 10 370 assert new_dataset != [x['file_name'] for x in data] 371 372 373def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file): 374 columns_list = ["data", "file_name", "label"] 375 num_readers = 4 376 sampler = ds.RandomSampler() 377 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 378 sampler=sampler) 379 assert data_set.get_dataset_size() == 10 380 ds1 = data_set.repeat(3) 381 num_iter = 0 382 epoch1_dataset = [] 383 epoch2_dataset = [] 384 epoch3_dataset = [] 385 for item in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): 386 logger.info( 387 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 388 logger.info( 389 "-------------- item[data]: {} -----------------------------".format(item["data"])) 390 logger.info( 391 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 392 logger.info( 393 "-------------- item[label]: {} ----------------------------".format(item["label"])) 394 num_iter += 1 395 if num_iter <= 10: 396 epoch1_dataset.append(item['file_name']) 397 elif num_iter <= 20: 398 epoch2_dataset.append(item['file_name']) 399 else: 400 epoch3_dataset.append(item['file_name']) 401 assert num_iter == 30 402 assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) 403 assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) 404 assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) 405 406 407def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file): 408 columns_list = ["data", "file_name", "label"] 409 num_readers = 4 410 sampler = ds.RandomSampler(replacement=True, num_samples=5) 411 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 412 sampler=sampler) 413 assert data_set.get_dataset_size() == 5 414 num_iter = 0 415 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 416 logger.info( 417 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 418 logger.info( 419 "-------------- item[data]: {} -----------------------------".format(item["data"])) 420 logger.info( 421 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 422 logger.info( 423 "-------------- item[label]: {} ----------------------------".format(item["label"])) 424 num_iter += 1 425 assert num_iter == 5 426 427 428def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file): 429 columns_list = ["data", "file_name", "label"] 430 num_readers = 4 431 sampler = ds.RandomSampler(replacement=False, num_samples=2) 432 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 433 sampler=sampler) 434 assert data_set.get_dataset_size() == 2 435 num_iter = 0 436 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 437 logger.info( 438 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 439 logger.info( 440 "-------------- item[data]: {} -----------------------------".format(item["data"])) 441 logger.info( 442 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 443 logger.info( 444 "-------------- item[label]: {} ----------------------------".format(item["label"])) 445 num_iter += 1 446 assert num_iter == 2 447 448 449def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file): 450 columns_list = ["data", "file_name", "label"] 451 num_readers = 4 452 sampler = ds.RandomSampler(replacement=False, num_samples=20) 453 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 454 sampler=sampler) 455 assert data_set.get_dataset_size() == 10 456 num_iter = 0 457 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 458 logger.info( 459 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 460 logger.info( 461 "-------------- item[data]: {} -----------------------------".format(item["data"])) 462 logger.info( 463 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 464 logger.info( 465 "-------------- item[label]: {} ----------------------------".format(item["label"])) 466 num_iter += 1 467 assert num_iter == 10 468 469 470def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file): 471 data = get_data(CV_DIR_NAME, True) 472 columns_list = ["data", "file_name", "label"] 473 num_readers = 4 474 sampler = ds.SequentialSampler(1, 4) 475 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 476 sampler=sampler) 477 assert data_set.get_dataset_size() == 4 478 num_iter = 0 479 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 480 logger.info( 481 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 482 logger.info( 483 "-------------- item[data]: {} -----------------------------".format(item["data"])) 484 logger.info( 485 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 486 logger.info( 487 "-------------- item[label]: {} ----------------------------".format(item["label"])) 488 assert item['file_name'] == np.array( 489 data[num_iter + 1]['file_name'], dtype='S') 490 num_iter += 1 491 assert num_iter == 4 492 493 494def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file): 495 data = get_data(CV_DIR_NAME, True) 496 columns_list = ["data", "file_name", "label"] 497 num_readers = 4 498 sampler = ds.SequentialSampler(2, 10) 499 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 500 sampler=sampler) 501 dataset_size = data_set.get_dataset_size() 502 assert dataset_size == 10 503 num_iter = 0 504 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 505 logger.info( 506 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 507 logger.info( 508 "-------------- item[data]: {} -----------------------------".format(item["data"])) 509 logger.info( 510 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 511 logger.info( 512 "-------------- item[label]: {} ----------------------------".format(item["label"])) 513 assert item['file_name'] == np.array( 514 data[(num_iter + 2) % dataset_size]['file_name'], dtype='S') 515 num_iter += 1 516 assert num_iter == 10 517 518 519def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file): 520 data = get_data(CV_DIR_NAME, True) 521 columns_list = ["data", "file_name", "label"] 522 num_readers = 4 523 sampler = ds.SequentialSampler(2, 20) 524 data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, 525 sampler=sampler) 526 dataset_size = data_set.get_dataset_size() 527 assert dataset_size == 10 528 num_iter = 0 529 for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): 530 logger.info( 531 "-------------- cv reader basic: {} ------------------------".format(num_iter)) 532 logger.info( 533 "-------------- item[data]: {} -----------------------------".format(item["data"])) 534 logger.info( 535 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 536 logger.info( 537 "-------------- item[label]: {} ----------------------------".format(item["label"])) 538 assert item['file_name'] == np.array( 539 data[(num_iter + 2) % dataset_size]['file_name'], dtype='S') 540 num_iter += 1 541 assert num_iter == 10 542 543 544def test_cv_minddataset_split_basic(add_and_remove_cv_file): 545 data = get_data(CV_DIR_NAME, True) 546 columns_list = ["data", "file_name", "label"] 547 num_readers = 4 548 d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, 549 num_readers, shuffle=False) 550 d1, d2 = d.split([8, 2], randomize=False) 551 assert d.get_dataset_size() == 10 552 assert d1.get_dataset_size() == 8 553 assert d2.get_dataset_size() == 2 554 num_iter = 0 555 for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True): 556 logger.info( 557 "-------------- item[data]: {} -----------------------------".format(item["data"])) 558 logger.info( 559 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 560 logger.info( 561 "-------------- item[label]: {} ----------------------------".format(item["label"])) 562 assert item['file_name'] == np.array(data[num_iter]['file_name'], 563 dtype='S') 564 num_iter += 1 565 assert num_iter == 8 566 num_iter = 0 567 for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True): 568 logger.info( 569 "-------------- item[data]: {} -----------------------------".format(item["data"])) 570 logger.info( 571 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 572 logger.info( 573 "-------------- item[label]: {} ----------------------------".format(item["label"])) 574 assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], 575 dtype='S') 576 num_iter += 1 577 assert num_iter == 2 578 579 580def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file): 581 data = get_data(CV_DIR_NAME, True) 582 columns_list = ["data", "file_name", "label"] 583 num_readers = 4 584 d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, 585 num_readers, shuffle=False) 586 d1, d2 = d.split([0.8, 0.2], randomize=False) 587 assert d.get_dataset_size() == 10 588 assert d1.get_dataset_size() == 8 589 assert d2.get_dataset_size() == 2 590 num_iter = 0 591 for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True): 592 logger.info( 593 "-------------- item[data]: {} -----------------------------".format(item["data"])) 594 logger.info( 595 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 596 logger.info( 597 "-------------- item[label]: {} ----------------------------".format(item["label"])) 598 assert item['file_name'] == np.array( 599 data[num_iter]['file_name'], dtype='S') 600 num_iter += 1 601 assert num_iter == 8 602 num_iter = 0 603 for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True): 604 logger.info( 605 "-------------- item[data]: {} -----------------------------".format(item["data"])) 606 logger.info( 607 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 608 logger.info( 609 "-------------- item[label]: {} ----------------------------".format(item["label"])) 610 assert item['file_name'] == np.array(data[num_iter + 8]['file_name'], 611 dtype='S') 612 num_iter += 1 613 assert num_iter == 2 614 615 616def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file): 617 data = get_data(CV_DIR_NAME, True) 618 columns_list = ["data", "file_name", "label"] 619 num_readers = 4 620 d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, 621 num_readers, shuffle=False) 622 d1, d2 = d.split([0.41, 0.59], randomize=False) 623 assert d.get_dataset_size() == 10 624 assert d1.get_dataset_size() == 4 625 assert d2.get_dataset_size() == 6 626 num_iter = 0 627 for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True): 628 logger.info( 629 "-------------- item[data]: {} -----------------------------".format(item["data"])) 630 logger.info( 631 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 632 logger.info( 633 "-------------- item[label]: {} ----------------------------".format(item["label"])) 634 assert item['file_name'] == np.array( 635 data[num_iter]['file_name'], dtype='S') 636 num_iter += 1 637 assert num_iter == 4 638 num_iter = 0 639 for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True): 640 logger.info( 641 "-------------- item[data]: {} -----------------------------".format(item["data"])) 642 logger.info( 643 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 644 logger.info( 645 "-------------- item[label]: {} ----------------------------".format(item["label"])) 646 assert item['file_name'] == np.array(data[num_iter + 4]['file_name'], 647 dtype='S') 648 num_iter += 1 649 assert num_iter == 6 650 651 652def test_cv_minddataset_split_deterministic(add_and_remove_cv_file): 653 columns_list = ["data", "file_name", "label"] 654 num_readers = 4 655 d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, 656 num_readers, shuffle=False) 657 # should set seed to avoid data overlap 658 ds.config.set_seed(111) 659 d1, d2 = d.split([0.8, 0.2]) 660 assert d.get_dataset_size() == 10 661 assert d1.get_dataset_size() == 8 662 assert d2.get_dataset_size() == 2 663 664 d1_dataset = [] 665 d2_dataset = [] 666 num_iter = 0 667 for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True): 668 logger.info( 669 "-------------- item[data]: {} -----------------------------".format(item["data"])) 670 logger.info( 671 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 672 logger.info( 673 "-------------- item[label]: {} ----------------------------".format(item["label"])) 674 d1_dataset.append(item['file_name']) 675 num_iter += 1 676 assert num_iter == 8 677 num_iter = 0 678 for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True): 679 logger.info( 680 "-------------- item[data]: {} -----------------------------".format(item["data"])) 681 logger.info( 682 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 683 logger.info( 684 "-------------- item[label]: {} ----------------------------".format(item["label"])) 685 d2_dataset.append(item['file_name']) 686 num_iter += 1 687 assert num_iter == 2 688 inter_dataset = [x for x in d1_dataset if x in d2_dataset] 689 assert inter_dataset == [] # intersection of d1 and d2 690 691 692def test_cv_minddataset_split_sharding(add_and_remove_cv_file): 693 data = get_data(CV_DIR_NAME, True) 694 columns_list = ["data", "file_name", "label"] 695 num_readers = 4 696 d = ds.MindDataset(CV_FILE_NAME + "0", columns_list, 697 num_readers, shuffle=False) 698 # should set seed to avoid data overlap 699 ds.config.set_seed(111) 700 d1, d2 = d.split([0.8, 0.2]) 701 assert d.get_dataset_size() == 10 702 assert d1.get_dataset_size() == 8 703 assert d2.get_dataset_size() == 2 704 distributed_sampler = ds.DistributedSampler(2, 0) 705 d1.use_sampler(distributed_sampler) 706 assert d1.get_dataset_size() == 4 707 708 num_iter = 0 709 d1_shard1 = [] 710 for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True): 711 logger.info( 712 "-------------- item[data]: {} -----------------------------".format(item["data"])) 713 logger.info( 714 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 715 logger.info( 716 "-------------- item[label]: {} ----------------------------".format(item["label"])) 717 num_iter += 1 718 d1_shard1.append(item['file_name']) 719 assert num_iter == 4 720 assert d1_shard1 != [x['file_name'] for x in data[0:4]] 721 722 distributed_sampler = ds.DistributedSampler(2, 1) 723 d1.use_sampler(distributed_sampler) 724 assert d1.get_dataset_size() == 4 725 726 d1s = d1.repeat(3) 727 epoch1_dataset = [] 728 epoch2_dataset = [] 729 epoch3_dataset = [] 730 num_iter = 0 731 for item in d1s.create_dict_iterator(num_epochs=1, output_numpy=True): 732 logger.info( 733 "-------------- item[data]: {} -----------------------------".format(item["data"])) 734 logger.info( 735 "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) 736 logger.info( 737 "-------------- item[label]: {} ----------------------------".format(item["label"])) 738 num_iter += 1 739 if num_iter <= 4: 740 epoch1_dataset.append(item['file_name']) 741 elif num_iter <= 8: 742 epoch2_dataset.append(item['file_name']) 743 else: 744 epoch3_dataset.append(item['file_name']) 745 assert len(epoch1_dataset) == 4 746 assert len(epoch2_dataset) == 4 747 assert len(epoch3_dataset) == 4 748 inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset] 749 assert inter_dataset == [] # intersection of d1's shard1 and d1's shard2 750 assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset) 751 assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset) 752 assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset) 753 754 epoch1_dataset.sort() 755 epoch2_dataset.sort() 756 epoch3_dataset.sort() 757 assert epoch1_dataset != epoch2_dataset 758 assert epoch2_dataset != epoch3_dataset 759 assert epoch3_dataset != epoch1_dataset 760 761 762def get_data(dir_name, sampler=False): 763 """ 764 usage: get data from imagenet dataset 765 params: 766 dir_name: directory containing folder images and annotation information 767 768 """ 769 if not os.path.isdir(dir_name): 770 raise IOError("Directory {} not exists".format(dir_name)) 771 img_dir = os.path.join(dir_name, "images") 772 if sampler: 773 ann_file = os.path.join(dir_name, "annotation_sampler.txt") 774 else: 775 ann_file = os.path.join(dir_name, "annotation.txt") 776 with open(ann_file, "r") as file_reader: 777 lines = file_reader.readlines() 778 779 data_list = [] 780 for i, line in enumerate(lines): 781 try: 782 filename, label = line.split(",") 783 label = label.strip("\n") 784 with open(os.path.join(img_dir, filename), "rb") as file_reader: 785 img = file_reader.read() 786 data_json = {"id": i, 787 "file_name": filename, 788 "data": img, 789 "label": int(label)} 790 data_list.append(data_json) 791 except FileNotFoundError: 792 continue 793 return data_list 794 795 796if __name__ == '__main__': 797 test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file) 798 test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file) 799 test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file) 800 test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file) 801 test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file) 802 test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file) 803 test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file) 804 test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file) 805 test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file) 806 test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file) 807 test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file) 808 test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file) 809 test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file) 810 test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file) 811 test_cv_minddataset_split_basic(add_and_remove_cv_file) 812 test_cv_minddataset_split_exact_percent(add_and_remove_cv_file) 813 test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file) 814 test_cv_minddataset_split_deterministic(add_and_remove_cv_file) 815 test_cv_minddataset_split_sharding(add_and_remove_cv_file) 816