• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import mindspore.dataset as ds
17from util import config_get_set_num_parallel_workers
18
19
20# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
21# the label of each image is [0,0,0,1,1] each image can be uniquely identified
22# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
23manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
24manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
25
26text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
27text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
28                  "End of file.", "Good luck to everyone."]
29
30def split_with_invalid_inputs(d):
31    with pytest.raises(ValueError) as info:
32        _, _ = d.split([])
33    assert "sizes cannot be empty" in str(info.value)
34
35    with pytest.raises(ValueError) as info:
36        _, _ = d.split([5, 0.6])
37    assert "sizes should be list of int or list of float" in str(info.value)
38
39    with pytest.raises(ValueError) as info:
40        _, _ = d.split([-1, 6])
41    assert "there should be no negative or zero numbers" in str(info.value)
42
43    with pytest.raises(RuntimeError) as info:
44        _, _ = d.split([3, 1])
45    assert "Sum of split sizes 4 is not equal to dataset size 5" in str(info.value)
46
47    with pytest.raises(RuntimeError) as info:
48        _, _ = d.split([5, 1])
49    assert "Sum of split sizes 6 is not equal to dataset size 5" in str(info.value)
50
51    with pytest.raises(RuntimeError) as info:
52        _, _ = d.split([0.15, 0.15, 0.15, 0.15, 0.15, 0.25])
53    assert "Sum of calculated split sizes 6 is not equal to dataset size 5" in str(info.value)
54
55    with pytest.raises(ValueError) as info:
56        _, _ = d.split([-0.5, 0.5])
57    assert "there should be no numbers outside the range (0, 1]" in str(info.value)
58
59    with pytest.raises(ValueError) as info:
60        _, _ = d.split([1.5, 0.5])
61    assert "there should be no numbers outside the range (0, 1]" in str(info.value)
62
63    with pytest.raises(ValueError) as info:
64        _, _ = d.split([0.5, 0.6])
65    assert "percentages do not sum up to 1" in str(info.value)
66
67    with pytest.raises(ValueError) as info:
68        _, _ = d.split([0.3, 0.6])
69    assert "percentages do not sum up to 1" in str(info.value)
70
71    with pytest.raises(RuntimeError) as info:
72        _, _ = d.split([0.05, 0.95])
73    assert "percentage 0.05 is too small" in str(info.value)
74
75
76def test_unmappable_invalid_input():
77    d = ds.TextFileDataset(text_file_dataset_path)
78    split_with_invalid_inputs(d)
79
80    d = ds.TextFileDataset(text_file_dataset_path, num_shards=2, shard_id=0)
81    with pytest.raises(RuntimeError) as info:
82        _, _ = d.split([4, 1])
83    assert "Dataset should not be sharded before split" in str(info.value)
84
85
86def test_unmappable_split():
87    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
88
89    d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
90    s1, s2 = d.split([4, 1], randomize=False)
91
92    s1_output = []
93    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
94        s1_output.append(item["text"].item().decode("utf8"))
95
96    s2_output = []
97    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
98        s2_output.append(item["text"].item().decode("utf8"))
99
100    assert s1_output == text_file_data[0:4]
101    assert s2_output == text_file_data[4:]
102
103    # exact percentages
104    s1, s2 = d.split([0.8, 0.2], randomize=False)
105
106    s1_output = []
107    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
108        s1_output.append(item["text"].item().decode("utf8"))
109
110    s2_output = []
111    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
112        s2_output.append(item["text"].item().decode("utf8"))
113
114    assert s1_output == text_file_data[0:4]
115    assert s2_output == text_file_data[4:]
116
117    # fuzzy percentages
118    s1, s2 = d.split([0.33, 0.67], randomize=False)
119
120    s1_output = []
121    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
122        s1_output.append(item["text"].item().decode("utf8"))
123
124    s2_output = []
125    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
126        s2_output.append(item["text"].item().decode("utf8"))
127
128    assert s1_output == text_file_data[0:2]
129    assert s2_output == text_file_data[2:]
130
131    # Restore configuration num_parallel_workers
132    ds.config.set_num_parallel_workers(original_num_parallel_workers)
133
134
135def test_unmappable_randomize_deterministic():
136    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
137
138    # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
139    ds.config.set_seed(53)
140
141    d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
142    s1, s2 = d.split([0.8, 0.2])
143
144    for _ in range(10):
145        s1_output = []
146        for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
147            s1_output.append(item["text"].item().decode("utf8"))
148
149        s2_output = []
150        for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
151            s2_output.append(item["text"].item().decode("utf8"))
152
153        # note no overlap
154        assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
155        assert s2_output == [text_file_data[3]]
156
157    # Restore configuration num_parallel_workers
158    ds.config.set_num_parallel_workers(original_num_parallel_workers)
159
160
161def test_unmappable_randomize_repeatable():
162    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
163
164    # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
165    ds.config.set_seed(53)
166
167    d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
168    s1, s2 = d.split([0.8, 0.2])
169
170    num_epochs = 5
171    s1 = s1.repeat(num_epochs)
172    s2 = s2.repeat(num_epochs)
173
174    s1_output = []
175    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
176        s1_output.append(item["text"].item().decode("utf8"))
177
178    s2_output = []
179    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
180        s2_output.append(item["text"].item().decode("utf8"))
181
182    # note no overlap
183    assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs
184    assert s2_output == [text_file_data[3]] * num_epochs
185
186    # Restore configuration num_parallel_workers
187    ds.config.set_num_parallel_workers(original_num_parallel_workers)
188
189
190def test_unmappable_get_dataset_size():
191    d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
192    s1, s2 = d.split([0.8, 0.2])
193
194    assert d.get_dataset_size() == 5
195    assert s1.get_dataset_size() == 4
196    assert s2.get_dataset_size() == 1
197
198
199def test_unmappable_multi_split():
200    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
201
202    # the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
203    ds.config.set_seed(53)
204
205    d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
206    s1, s2 = d.split([4, 1])
207
208    s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
209
210    s1_output = []
211    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
212        s1_output.append(item["text"].item().decode("utf8"))
213    assert s1_output == s1_correct_output
214
215    # no randomize in second split
216    s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
217
218    s1s1_output = []
219    for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True):
220        s1s1_output.append(item["text"].item().decode("utf8"))
221
222    s1s2_output = []
223    for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True):
224        s1s2_output.append(item["text"].item().decode("utf8"))
225
226    s1s3_output = []
227    for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True):
228        s1s3_output.append(item["text"].item().decode("utf8"))
229
230    assert s1s1_output == [s1_correct_output[0]]
231    assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
232    assert s1s3_output == [s1_correct_output[3]]
233
234    s2_output = []
235    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
236        s2_output.append(item["text"].item().decode("utf8"))
237    assert s2_output == [text_file_data[3]]
238
239    # randomize in second split
240    # the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
241    shuffled_ids = [2, 3, 1, 0]
242
243    s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
244
245    s1s1_output = []
246    for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True):
247        s1s1_output.append(item["text"].item().decode("utf8"))
248
249    s1s2_output = []
250    for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True):
251        s1s2_output.append(item["text"].item().decode("utf8"))
252
253    s1s3_output = []
254    for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True):
255        s1s3_output.append(item["text"].item().decode("utf8"))
256
257    assert s1s1_output == [s1_correct_output[shuffled_ids[0]]]
258    assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]]
259    assert s1s3_output == [s1_correct_output[shuffled_ids[3]]]
260
261    s2_output = []
262    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
263        s2_output.append(item["text"].item().decode("utf8"))
264    assert s2_output == [text_file_data[3]]
265
266    # Restore configuration num_parallel_workers
267    ds.config.set_num_parallel_workers(original_num_parallel_workers)
268
269
270def test_mappable_invalid_input():
271    d = ds.ManifestDataset(manifest_file)
272    split_with_invalid_inputs(d)
273
274    d = ds.ManifestDataset(manifest_file, num_shards=2, shard_id=0)
275    with pytest.raises(RuntimeError) as info:
276        _, _ = d.split([4, 1])
277    assert "Dataset should not be sharded before split" in str(info.value)
278
279
280def test_mappable_split_general():
281    d = ds.ManifestDataset(manifest_file, shuffle=False)
282    d = d.take(5)
283
284    # absolute rows
285    s1, s2 = d.split([4, 1], randomize=False)
286
287    s1_output = []
288    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
289        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
290
291    s2_output = []
292    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
293        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
294
295    assert s1_output == [0, 1, 2, 3]
296    assert s2_output == [4]
297
298    # exact percentages
299    s1, s2 = d.split([0.8, 0.2], randomize=False)
300
301    s1_output = []
302    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
303        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
304
305    s2_output = []
306    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
307        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
308
309    assert s1_output == [0, 1, 2, 3]
310    assert s2_output == [4]
311
312    # fuzzy percentages
313    s1, s2 = d.split([0.33, 0.67], randomize=False)
314
315    s1_output = []
316    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
317        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
318
319    s2_output = []
320    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
321        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
322
323    assert s1_output == [0, 1]
324    assert s2_output == [2, 3, 4]
325
326
327def test_mappable_split_optimized():
328    d = ds.ManifestDataset(manifest_file, shuffle=False)
329
330    # absolute rows
331    s1, s2 = d.split([4, 1], randomize=False)
332
333    s1_output = []
334    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
335        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
336
337    s2_output = []
338    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
339        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
340
341    assert s1_output == [0, 1, 2, 3]
342    assert s2_output == [4]
343
344    # exact percentages
345    s1, s2 = d.split([0.8, 0.2], randomize=False)
346
347    s1_output = []
348    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
349        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
350
351    s2_output = []
352    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
353        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
354
355    assert s1_output == [0, 1, 2, 3]
356    assert s2_output == [4]
357
358    # fuzzy percentages
359    s1, s2 = d.split([0.33, 0.67], randomize=False)
360
361    s1_output = []
362    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
363        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
364
365    s2_output = []
366    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
367        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
368
369    assert s1_output == [0, 1]
370    assert s2_output == [2, 3, 4]
371
372
373def test_mappable_randomize_deterministic():
374    # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
375    ds.config.set_seed(53)
376
377    d = ds.ManifestDataset(manifest_file, shuffle=False)
378    s1, s2 = d.split([0.8, 0.2])
379
380    for _ in range(10):
381        s1_output = []
382        for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
383            s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
384
385        s2_output = []
386        for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
387            s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
388
389        # note no overlap
390        assert s1_output == [0, 1, 3, 4]
391        assert s2_output == [2]
392
393
394def test_mappable_randomize_repeatable():
395    # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
396    ds.config.set_seed(53)
397
398    d = ds.ManifestDataset(manifest_file, shuffle=False)
399    s1, s2 = d.split([0.8, 0.2])
400
401    num_epochs = 5
402    s1 = s1.repeat(num_epochs)
403    s2 = s2.repeat(num_epochs)
404
405    s1_output = []
406    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
407        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
408
409    s2_output = []
410    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
411        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
412
413    # note no overlap
414    assert s1_output == [0, 1, 3, 4] * num_epochs
415    assert s2_output == [2] * num_epochs
416
417
418def test_mappable_sharding():
419    # set arbitrary seed for repeatability for shard after split
420    # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
421    ds.config.set_seed(53)
422
423    num_epochs = 5
424    first_split_num_rows = 4
425
426    d = ds.ManifestDataset(manifest_file, shuffle=False)
427    s1, s2 = d.split([first_split_num_rows, 1])
428
429    distributed_sampler = ds.DistributedSampler(2, 0)
430    s1.use_sampler(distributed_sampler)
431
432    s1 = s1.repeat(num_epochs)
433
434    # testing sharding, second dataset to simulate another instance
435    d2 = ds.ManifestDataset(manifest_file, shuffle=False)
436    d2s1, d2s2 = d2.split([first_split_num_rows, 1])
437
438    distributed_sampler = ds.DistributedSampler(2, 1)
439    d2s1.use_sampler(distributed_sampler)
440
441    d2s1 = d2s1.repeat(num_epochs)
442
443    # shard 0
444    s1_output = []
445    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
446        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
447
448    # shard 1
449    d2s1_output = []
450    for item in d2s1.create_dict_iterator(num_epochs=1, output_numpy=True):
451        d2s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
452
453    rows_per_shard_per_epoch = 2
454    assert len(s1_output) == rows_per_shard_per_epoch * num_epochs
455    assert len(d2s1_output) == rows_per_shard_per_epoch * num_epochs
456
457    # verify each epoch that
458    #   1. shards contain no common elements
459    #   2. the data was split the same way, and that the union of shards equal the split
460    correct_sorted_split_result = [0, 1, 3, 4]
461    for i in range(num_epochs):
462        combined_data = []
463        for j in range(rows_per_shard_per_epoch):
464            combined_data.append(s1_output[i * rows_per_shard_per_epoch + j])
465            combined_data.append(d2s1_output[i * rows_per_shard_per_epoch + j])
466
467        assert sorted(combined_data) == correct_sorted_split_result
468
469    # test other split
470    s2_output = []
471    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
472        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
473
474    d2s2_output = []
475    for item in d2s2.create_dict_iterator(num_epochs=1, output_numpy=True):
476        d2s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
477
478    assert s2_output == [2]
479    assert d2s2_output == [2]
480
481
482def test_mappable_get_dataset_size():
483    d = ds.ManifestDataset(manifest_file, shuffle=False)
484    s1, s2 = d.split([4, 1])
485
486    assert d.get_dataset_size() == 5
487    assert s1.get_dataset_size() == 4
488    assert s2.get_dataset_size() == 1
489
490
491def test_mappable_multi_split():
492    # the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
493    ds.config.set_seed(53)
494
495    d = ds.ManifestDataset(manifest_file, shuffle=False)
496    s1, s2 = d.split([4, 1])
497
498    s1_correct_output = [0, 1, 3, 4]
499
500    s1_output = []
501    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
502        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
503    assert s1_output == s1_correct_output
504
505    # no randomize in second split
506    s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
507
508    s1s1_output = []
509    for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True):
510        s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
511
512    s1s2_output = []
513    for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True):
514        s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
515
516    s1s3_output = []
517    for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True):
518        s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
519
520    assert s1s1_output == [s1_correct_output[0]]
521    assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
522    assert s1s3_output == [s1_correct_output[3]]
523
524    s2_output = []
525    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
526        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
527    assert s2_output == [2]
528
529    # randomize in second split
530    # the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
531    random_sampler_ids = [3, 1, 2, 0]
532
533    s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
534
535    s1s1_output = []
536    for item in s1s1.create_dict_iterator(num_epochs=1, output_numpy=True):
537        s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
538
539    s1s2_output = []
540    for item in s1s2.create_dict_iterator(num_epochs=1, output_numpy=True):
541        s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
542
543    s1s3_output = []
544    for item in s1s3.create_dict_iterator(num_epochs=1, output_numpy=True):
545        s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
546
547    assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]]
548    assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]]
549    assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]]
550
551    s2_output = []
552    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
553        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
554    assert s2_output == [2]
555
556
557def test_rounding():
558    d = ds.ManifestDataset(manifest_file, shuffle=False)
559
560    # under rounding
561    s1, s2 = d.split([0.5, 0.5], randomize=False)
562
563    s1_output = []
564    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
565        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
566
567    s2_output = []
568    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
569        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
570
571    assert s1_output == [0, 1, 2]
572    assert s2_output == [3, 4]
573
574    # over rounding
575    s1, s2, s3 = d.split([0.15, 0.55, 0.3], randomize=False)
576
577    s1_output = []
578    for item in s1.create_dict_iterator(num_epochs=1, output_numpy=True):
579        s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
580
581    s2_output = []
582    for item in s2.create_dict_iterator(num_epochs=1, output_numpy=True):
583        s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
584
585    s3_output = []
586    for item in s3.create_dict_iterator(num_epochs=1, output_numpy=True):
587        s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
588
589    assert s1_output == [0]
590    assert s2_output == [1, 2]
591    assert s3_output == [3, 4]
592
593
594if __name__ == '__main__':
595    test_unmappable_invalid_input()
596    test_unmappable_split()
597    test_unmappable_randomize_deterministic()
598    test_unmappable_randomize_repeatable()
599    test_unmappable_get_dataset_size()
600    test_unmappable_multi_split()
601    test_mappable_invalid_input()
602    test_mappable_split_general()
603    test_mappable_split_optimized()
604    test_mappable_randomize_deterministic()
605    test_mappable_randomize_repeatable()
606    test_mappable_sharding()
607    test_mappable_get_dataset_size()
608    test_mappable_multi_split()
609    test_rounding()
610