1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import pytest 16import mindspore.dataset as ds 17from util import config_get_set_num_parallel_workers 18 19 20# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] 21# the label of each image is [0,0,0,1,1] each image can be uniquely identified 22# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4} 23manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" 24manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} 25 26text_file_dataset_path = "../data/dataset/testTextFileDataset/*" 27text_file_data = ["This is a text file.", "Another file.", "Be happy every day.", 28 "End of file.", "Good luck to everyone."] 29 30def split_with_invalid_inputs(d): 31 with pytest.raises(ValueError) as info: 32 _, _ = d.split([]) 33 assert "sizes cannot be empty" in str(info.value) 34 35 with pytest.raises(ValueError) as info: 36 _, _ = d.split([5, 0.6]) 37 assert "sizes should be list of int or list of float" in str(info.value) 38 39 with pytest.raises(ValueError) as info: 40 _, _ = d.split([-1, 6]) 41 assert "there should be no negative or zero numbers" in str(info.value) 42 43 with pytest.raises(RuntimeError) as info: 44 _, _ = d.split([3, 1]) 45 assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value) 46 47 with pytest.raises(RuntimeError) as info: 48 _, _ = d.split([5, 1]) 49 assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value) 50 51 with pytest.raises(RuntimeError) as info: 52 _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25]) 53 assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value) 54 55 with pytest.raises(ValueError) as info: 56 _, _ = d.split([-0.5, 0.5]) 57 assert "there should be no numbers outside the range (0, 1]" in str(info.value) 58 59 with pytest.raises(ValueError) as info: 60 _, _ = d.split([1.5, 0.5]) 61 assert "there should be no numbers outside the range (0, 1]" in str(info.value) 62 63 with pytest.raises(ValueError) as info: 64 _, _ = d.split([0.5, 0.6]) 65 assert "percentages do not sum up to 1" in str(info.value) 66 67 with pytest.raises(ValueError) as info: 68 _, _ = d.split([0.3, 0.6]) 69 assert "percentages do not sum up to 1" in str(info.value) 70 71 with pytest.raises(RuntimeError) as info: 72 _, _ = d.split([0.05, 0.95]) 73 assert "percentage 0.05 is too small" in str(info.value) 74 75 76def test_unmappable_invalid_input(): 77 d = ds.TextFileDataset(text_file_dataset_path) 78 split_with_invalid_inputs(d) 79 80 d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0) 81 with pytest.raises(RuntimeError) as info: 82 _, _ = d.split([4, 1]) 83 assert "Dataset should not be sharded before split" in str(info.value) 84 85 86def test_unmappable_split(): 87 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 88 89 d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) 90 s1, s2 = d.split([4, 1], randomize=False) 91 92 s1_output = [] 93 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 94 s1_output.append(item["text"].item().decode("utf8")) 95 96 s2_output = [] 97 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 98 s2_output.append(item["text"].item().decode("utf8")) 99 100 assert s1_output == text_file_data[0:4] 101 assert s2_output == text_file_data[4:] 102 103 # exact percentages 104 s1, s2 = d.split([0.8, 0.2], randomize=False) 105 106 s1_output = [] 107 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 108 s1_output.append(item["text"].item().decode("utf8")) 109 110 s2_output = [] 111 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 112 s2_output.append(item["text"].item().decode("utf8")) 113 114 assert s1_output == text_file_data[0:4] 115 assert s2_output == text_file_data[4:] 116 117 # fuzzy percentages 118 s1, s2 = d.split([0.33, 0.67], randomize=False) 119 120 s1_output = [] 121 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 122 s1_output.append(item["text"].item().decode("utf8")) 123 124 s2_output = [] 125 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 126 s2_output.append(item["text"].item().decode("utf8")) 127 128 assert s1_output == text_file_data[0:2] 129 assert s2_output == text_file_data[2:] 130 131 # Restore configuration num_parallel_workers 132 ds.config.set_num_parallel_workers(original_num_parallel_workers) 133 134 135def test_unmappable_randomize_deterministic(): 136 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 137 138 # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] 139 ds.config.set_seed(53) 140 141 d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) 142 s1, s2 = d.split([0.8, 0.2]) 143 144 for _ in range(10): 145 s1_output = [] 146 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 147 s1_output.append(item["text"].item().decode("utf8")) 148 149 s2_output = [] 150 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 151 s2_output.append(item["text"].item().decode("utf8")) 152 153 # note no overlap 154 assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] 155 assert s2_output == [text_file_data[3]] 156 157 # Restore configuration num_parallel_workers 158 ds.config.set_num_parallel_workers(original_num_parallel_workers) 159 160 161def test_unmappable_randomize_repeatable(): 162 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 163 164 # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] 165 ds.config.set_seed(53) 166 167 d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) 168 s1, s2 = d.split([0.8, 0.2]) 169 170 num_epochs = 5 171 s1 = s1.repeat(num_epochs) 172 s2 = s2.repeat(num_epochs) 173 174 s1_output = [] 175 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 176 s1_output.append(item["text"].item().decode("utf8")) 177 178 s2_output = [] 179 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 180 s2_output.append(item["text"].item().decode("utf8")) 181 182 # note no overlap 183 assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs 184 assert s2_output == [text_file_data[3]] * num_epochs 185 186 # Restore configuration num_parallel_workers 187 ds.config.set_num_parallel_workers(original_num_parallel_workers) 188 189 190def test_unmappable_get_dataset_size(): 191 d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) 192 s1, s2 = d.split([0.8, 0.2]) 193 194 assert d.get_dataset_size() == 5 195 assert s1.get_dataset_size() == 4 196 assert s2.get_dataset_size() == 1 197 198 199def test_unmappable_multi_split(): 200 original_num_parallel_workers = config_get_set_num_parallel_workers(4) 201 202 # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3] 203 ds.config.set_seed(53) 204 205 d = ds.TextFileDataset(text_file_dataset_path, shuffle=False) 206 s1, s2 = d.split([4, 1]) 207 208 s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] 209 210 s1_output = [] 211 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 212 s1_output.append(item["text"].item().decode("utf8")) 213 assert s1_output == s1_correct_output 214 215 # no randomize in second split 216 s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) 217 218 s1s1_output = [] 219 for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True): 220 s1s1_output.append(item["text"].item().decode("utf8")) 221 222 s1s2_output = [] 223 for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True): 224 s1s2_output.append(item["text"].item().decode("utf8")) 225 226 s1s3_output = [] 227 for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True): 228 s1s3_output.append(item["text"].item().decode("utf8")) 229 230 assert s1s1_output == [s1_correct_output[0]] 231 assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] 232 assert s1s3_output == [s1_correct_output[3]] 233 234 s2_output = [] 235 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 236 s2_output.append(item["text"].item().decode("utf8")) 237 assert s2_output == [text_file_data[3]] 238 239 # randomize in second split 240 # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0] 241 shuffled_ids = [2, 3, 1, 0] 242 243 s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) 244 245 s1s1_output = [] 246 for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True): 247 s1s1_output.append(item["text"].item().decode("utf8")) 248 249 s1s2_output = [] 250 for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True): 251 s1s2_output.append(item["text"].item().decode("utf8")) 252 253 s1s3_output = [] 254 for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True): 255 s1s3_output.append(item["text"].item().decode("utf8")) 256 257 assert s1s1_output == [s1_correct_output[shuffled_ids[0]]] 258 assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]] 259 assert s1s3_output == [s1_correct_output[shuffled_ids[3]]] 260 261 s2_output = [] 262 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 263 s2_output.append(item["text"].item().decode("utf8")) 264 assert s2_output == [text_file_data[3]] 265 266 # Restore configuration num_parallel_workers 267 ds.config.set_num_parallel_workers(original_num_parallel_workers) 268 269 270def test_mappable_invalid_input(): 271 d = ds.ManifestDataset(manifest_file) 272 split_with_invalid_inputs(d) 273 274 d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0) 275 with pytest.raises(RuntimeError) as info: 276 _, _ = d.split([4, 1]) 277 assert "Dataset should not be sharded before split" in str(info.value) 278 279 280def test_mappable_split_general(): 281 d = ds.ManifestDataset(manifest_file, shuffle=False) 282 d = d.take(5) 283 284 # absolute rows 285 s1, s2 = d.split([4, 1], randomize=False) 286 287 s1_output = [] 288 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 289 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 290 291 s2_output = [] 292 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 293 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 294 295 assert s1_output == [0, 1, 2, 3] 296 assert s2_output == [4] 297 298 # exact percentages 299 s1, s2 = d.split([0.8, 0.2], randomize=False) 300 301 s1_output = [] 302 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 303 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 304 305 s2_output = [] 306 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 307 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 308 309 assert s1_output == [0, 1, 2, 3] 310 assert s2_output == [4] 311 312 # fuzzy percentages 313 s1, s2 = d.split([0.33, 0.67], randomize=False) 314 315 s1_output = [] 316 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 317 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 318 319 s2_output = [] 320 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 321 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 322 323 assert s1_output == [0, 1] 324 assert s2_output == [2, 3, 4] 325 326 327def test_mappable_split_optimized(): 328 d = ds.ManifestDataset(manifest_file, shuffle=False) 329 330 # absolute rows 331 s1, s2 = d.split([4, 1], randomize=False) 332 333 s1_output = [] 334 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 335 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 336 337 s2_output = [] 338 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 339 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 340 341 assert s1_output == [0, 1, 2, 3] 342 assert s2_output == [4] 343 344 # exact percentages 345 s1, s2 = d.split([0.8, 0.2], randomize=False) 346 347 s1_output = [] 348 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 349 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 350 351 s2_output = [] 352 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 353 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 354 355 assert s1_output == [0, 1, 2, 3] 356 assert s2_output == [4] 357 358 # fuzzy percentages 359 s1, s2 = d.split([0.33, 0.67], randomize=False) 360 361 s1_output = [] 362 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 363 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 364 365 s2_output = [] 366 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 367 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 368 369 assert s1_output == [0, 1] 370 assert s2_output == [2, 3, 4] 371 372 373def test_mappable_randomize_deterministic(): 374 # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] 375 ds.config.set_seed(53) 376 377 d = ds.ManifestDataset(manifest_file, shuffle=False) 378 s1, s2 = d.split([0.8, 0.2]) 379 380 for _ in range(10): 381 s1_output = [] 382 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 383 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 384 385 s2_output = [] 386 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 387 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 388 389 # note no overlap 390 assert s1_output == [0, 1, 3, 4] 391 assert s2_output == [2] 392 393 394def test_mappable_randomize_repeatable(): 395 # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] 396 ds.config.set_seed(53) 397 398 d = ds.ManifestDataset(manifest_file, shuffle=False) 399 s1, s2 = d.split([0.8, 0.2]) 400 401 num_epochs = 5 402 s1 = s1.repeat(num_epochs) 403 s2 = s2.repeat(num_epochs) 404 405 s1_output = [] 406 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 407 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 408 409 s2_output = [] 410 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 411 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 412 413 # note no overlap 414 assert s1_output == [0, 1, 3, 4] * num_epochs 415 assert s2_output == [2] * num_epochs 416 417 418def test_mappable_sharding(): 419 # set arbitrary seed for repeatability for shard after split 420 # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] 421 ds.config.set_seed(53) 422 423 num_epochs = 5 424 first_split_num_rows = 4 425 426 d = ds.ManifestDataset(manifest_file, shuffle=False) 427 s1, s2 = d.split([first_split_num_rows, 1]) 428 429 distributed_sampler = ds.DistributedSampler(2, 0) 430 s1.use_sampler(distributed_sampler) 431 432 s1 = s1.repeat(num_epochs) 433 434 # testing sharding, second dataset to simulate another instance 435 d2 = ds.ManifestDataset(manifest_file, shuffle=False) 436 d2s1, d2s2 = d2.split([first_split_num_rows, 1]) 437 438 distributed_sampler = ds.DistributedSampler(2, 1) 439 d2s1.use_sampler(distributed_sampler) 440 441 d2s1 = d2s1.repeat(num_epochs) 442 443 # shard 0 444 s1_output = [] 445 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 446 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 447 448 # shard 1 449 d2s1_output = [] 450 for item in d2s1.create_dict_iterator(num_epochs=1, output_numpy=True): 451 d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 452 453 rows_per_shard_per_epoch = 2 454 assert len(s1_output) == rows_per_shard_per_epoch * num_epochs 455 assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs 456 457 # verify each epoch that 458 # 1. shards contain no common elements 459 # 2. the data was split the same way, and that the union of shards equal the split 460 correct_sorted_split_result = [0, 1, 3, 4] 461 for i in range(num_epochs): 462 combined_data = [] 463 for j in range(rows_per_shard_per_epoch): 464 combined_data.append(s1_output[i * rows_per_shard_per_epoch + j]) 465 combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j]) 466 467 assert sorted(combined_data) == correct_sorted_split_result 468 469 # test other split 470 s2_output = [] 471 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 472 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 473 474 d2s2_output = [] 475 for item in d2s2.create_dict_iterator(num_epochs=1, output_numpy=True): 476 d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 477 478 assert s2_output == [2] 479 assert d2s2_output == [2] 480 481 482def test_mappable_get_dataset_size(): 483 d = ds.ManifestDataset(manifest_file, shuffle=False) 484 s1, s2 = d.split([4, 1]) 485 486 assert d.get_dataset_size() == 5 487 assert s1.get_dataset_size() == 4 488 assert s2.get_dataset_size() == 1 489 490 491def test_mappable_multi_split(): 492 # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2] 493 ds.config.set_seed(53) 494 495 d = ds.ManifestDataset(manifest_file, shuffle=False) 496 s1, s2 = d.split([4, 1]) 497 498 s1_correct_output = [0, 1, 3, 4] 499 500 s1_output = [] 501 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 502 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 503 assert s1_output == s1_correct_output 504 505 # no randomize in second split 506 s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False) 507 508 s1s1_output = [] 509 for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True): 510 s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 511 512 s1s2_output = [] 513 for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True): 514 s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 515 516 s1s3_output = [] 517 for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True): 518 s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 519 520 assert s1s1_output == [s1_correct_output[0]] 521 assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]] 522 assert s1s3_output == [s1_correct_output[3]] 523 524 s2_output = [] 525 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 526 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 527 assert s2_output == [2] 528 529 # randomize in second split 530 # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0] 531 random_sampler_ids = [3, 1, 2, 0] 532 533 s1s1, s1s2, s1s3 = s1.split([1, 2, 1]) 534 535 s1s1_output = [] 536 for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True): 537 s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 538 539 s1s2_output = [] 540 for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True): 541 s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 542 543 s1s3_output = [] 544 for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True): 545 s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 546 547 assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]] 548 assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]] 549 assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]] 550 551 s2_output = [] 552 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 553 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 554 assert s2_output == [2] 555 556 557def test_rounding(): 558 d = ds.ManifestDataset(manifest_file, shuffle=False) 559 560 # under rounding 561 s1, s2 = d.split([0.5, 0.5], randomize=False) 562 563 s1_output = [] 564 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 565 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 566 567 s2_output = [] 568 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 569 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 570 571 assert s1_output == [0, 1, 2] 572 assert s2_output == [3, 4] 573 574 # over rounding 575 s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False) 576 577 s1_output = [] 578 for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True): 579 s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 580 581 s2_output = [] 582 for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True): 583 s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 584 585 s3_output = [] 586 for item in s3.create_dict_iterator(num_epochs=1, output_numpy=True): 587 s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())]) 588 589 assert s1_output == [0] 590 assert s2_output == [1, 2] 591 assert s3_output == [3, 4] 592 593 594if __name__ == '__main__': 595 test_unmappable_invalid_input() 596 test_unmappable_split() 597 test_unmappable_randomize_deterministic() 598 test_unmappable_randomize_repeatable() 599 test_unmappable_get_dataset_size() 600 test_unmappable_multi_split() 601 test_mappable_invalid_input() 602 test_mappable_split_general() 603 test_mappable_split_optimized() 604 test_mappable_randomize_deterministic() 605 test_mappable_randomize_repeatable() 606 test_mappable_sharding() 607 test_mappable_get_dataset_size() 608 test_mappable_multi_split() 609 test_rounding() 610