• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import mindspore.dataset as ds
17import mindspore.dataset.vision.c_transforms as vision
18from mindspore import log as logger
19
20DATA_DIR = "../data/dataset/testPK/data"
21
22
23def test_imagefolder_basic():
24    logger.info("Test Case basic")
25    # define parameters
26    repeat_count = 1
27
28    # apply dataset operations
29    data1 = ds.ImageFolderDataset(DATA_DIR)
30    data1 = data1.repeat(repeat_count)
31
32    num_iter = 0
33    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
34        # in this example, each dictionary has keys "image" and "label"
35        logger.info("image is {}".format(item["image"]))
36        logger.info("label is {}".format(item["label"]))
37        num_iter += 1
38
39    logger.info("Number of data in data1: {}".format(num_iter))
40    assert num_iter == 44
41
42
43def test_imagefolder_numsamples():
44    logger.info("Test Case numSamples")
45    # define parameters
46    repeat_count = 1
47
48    # apply dataset operations
49    data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
50    data1 = data1.repeat(repeat_count)
51
52    num_iter = 0
53    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
54        # in this example, each dictionary has keys "image" and "label"
55        logger.info("image is {}".format(item["image"]))
56        logger.info("label is {}".format(item["label"]))
57        num_iter += 1
58
59    logger.info("Number of data in data1: {}".format(num_iter))
60    assert num_iter == 10
61
62    random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
63    data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
64
65    num_iter = 0
66    for item in data1.create_dict_iterator(num_epochs=1):
67        num_iter += 1
68
69    assert num_iter == 3
70
71    random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
72    data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
73
74    num_iter = 0
75    for item in data1.create_dict_iterator(num_epochs=1):
76        num_iter += 1
77
78    assert num_iter == 3
79
80
81def test_imagefolder_numshards():
82    logger.info("Test Case numShards")
83    # define parameters
84    repeat_count = 1
85
86    # apply dataset operations
87    data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3)
88    data1 = data1.repeat(repeat_count)
89
90    num_iter = 0
91    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
92        # in this example, each dictionary has keys "image" and "label"
93        logger.info("image is {}".format(item["image"]))
94        logger.info("label is {}".format(item["label"]))
95        num_iter += 1
96
97    logger.info("Number of data in data1: {}".format(num_iter))
98    assert num_iter == 11
99
100
101def test_imagefolder_shardid():
102    logger.info("Test Case withShardID")
103    # define parameters
104    repeat_count = 1
105
106    # apply dataset operations
107    data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=1)
108    data1 = data1.repeat(repeat_count)
109
110    num_iter = 0
111    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
112        # in this example, each dictionary has keys "image" and "label"
113        logger.info("image is {}".format(item["image"]))
114        logger.info("label is {}".format(item["label"]))
115        num_iter += 1
116
117    logger.info("Number of data in data1: {}".format(num_iter))
118    assert num_iter == 11
119
120
121def test_imagefolder_noshuffle():
122    logger.info("Test Case noShuffle")
123    # define parameters
124    repeat_count = 1
125
126    # apply dataset operations
127    data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=False)
128    data1 = data1.repeat(repeat_count)
129
130    num_iter = 0
131    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
132        # in this example, each dictionary has keys "image" and "label"
133        logger.info("image is {}".format(item["image"]))
134        logger.info("label is {}".format(item["label"]))
135        num_iter += 1
136
137    logger.info("Number of data in data1: {}".format(num_iter))
138    assert num_iter == 44
139
140
141def test_imagefolder_extrashuffle():
142    logger.info("Test Case extraShuffle")
143    # define parameters
144    repeat_count = 2
145
146    # apply dataset operations
147    data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=True)
148    data1 = data1.shuffle(buffer_size=5)
149    data1 = data1.repeat(repeat_count)
150
151    num_iter = 0
152    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
153        # in this example, each dictionary has keys "image" and "label"
154        logger.info("image is {}".format(item["image"]))
155        logger.info("label is {}".format(item["label"]))
156        num_iter += 1
157
158    logger.info("Number of data in data1: {}".format(num_iter))
159    assert num_iter == 88
160
161
162def test_imagefolder_classindex():
163    logger.info("Test Case classIndex")
164    # define parameters
165    repeat_count = 1
166
167    # apply dataset operations
168    class_index = {"class3": 333, "class1": 111}
169    data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
170    data1 = data1.repeat(repeat_count)
171
172    golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
173              333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333]
174
175    num_iter = 0
176    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
177        # in this example, each dictionary has keys "image" and "label"
178        logger.info("image is {}".format(item["image"]))
179        logger.info("label is {}".format(item["label"]))
180        assert item["label"] == golden[num_iter]
181        num_iter += 1
182
183    logger.info("Number of data in data1: {}".format(num_iter))
184    assert num_iter == 22
185
186
187def test_imagefolder_negative_classindex():
188    logger.info("Test Case negative classIndex")
189    # define parameters
190    repeat_count = 1
191
192    # apply dataset operations
193    class_index = {"class3": -333, "class1": 111}
194    data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False)
195    data1 = data1.repeat(repeat_count)
196
197    golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111,
198              -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333]
199
200    num_iter = 0
201    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
202        # in this example, each dictionary has keys "image" and "label"
203        logger.info("image is {}".format(item["image"]))
204        logger.info("label is {}".format(item["label"]))
205        assert item["label"] == golden[num_iter]
206        num_iter += 1
207
208    logger.info("Number of data in data1: {}".format(num_iter))
209    assert num_iter == 22
210
211
212def test_imagefolder_extensions():
213    logger.info("Test Case extensions")
214    # define parameters
215    repeat_count = 1
216
217    # apply dataset operations
218    ext = [".jpg", ".JPEG"]
219    data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext)
220    data1 = data1.repeat(repeat_count)
221
222    num_iter = 0
223    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
224        # in this example, each dictionary has keys "image" and "label"
225        logger.info("image is {}".format(item["image"]))
226        logger.info("label is {}".format(item["label"]))
227        num_iter += 1
228
229    logger.info("Number of data in data1: {}".format(num_iter))
230    assert num_iter == 44
231
232
233def test_imagefolder_decode():
234    logger.info("Test Case decode")
235    # define parameters
236    repeat_count = 1
237
238    # apply dataset operations
239    ext = [".jpg", ".JPEG"]
240    data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext, decode=True)
241    data1 = data1.repeat(repeat_count)
242
243    num_iter = 0
244    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
245        # in this example, each dictionary has keys "image" and "label"
246        logger.info("image is {}".format(item["image"]))
247        logger.info("label is {}".format(item["label"]))
248        num_iter += 1
249
250    logger.info("Number of data in data1: {}".format(num_iter))
251    assert num_iter == 44
252
253
254def test_sequential_sampler():
255    logger.info("Test Case SequentialSampler")
256
257    golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
258              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
259              2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
260              3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
261
262    # define parameters
263    repeat_count = 1
264
265    # apply dataset operations
266    sampler = ds.SequentialSampler()
267    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
268    data1 = data1.repeat(repeat_count)
269
270    result = []
271    num_iter = 0
272    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
273        # in this example, each dictionary has keys "image" and "label"
274        result.append(item["label"])
275        num_iter += 1
276
277    assert num_iter == 44
278    logger.info("Result: {}".format(result))
279    assert result == golden
280
281
282def test_random_sampler():
283    logger.info("Test Case RandomSampler")
284    # define parameters
285    repeat_count = 1
286
287    # apply dataset operations
288    sampler = ds.RandomSampler()
289    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
290    data1 = data1.repeat(repeat_count)
291
292    num_iter = 0
293    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
294        # in this example, each dictionary has keys "image" and "label"
295        logger.info("image is {}".format(item["image"]))
296        logger.info("label is {}".format(item["label"]))
297        num_iter += 1
298
299    logger.info("Number of data in data1: {}".format(num_iter))
300    assert num_iter == 44
301
302
303def test_distributed_sampler():
304    logger.info("Test Case DistributedSampler")
305    # define parameters
306    repeat_count = 1
307
308    # apply dataset operations
309    sampler = ds.DistributedSampler(10, 1)
310    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
311    data1 = data1.repeat(repeat_count)
312
313    num_iter = 0
314    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
315        # in this example, each dictionary has keys "image" and "label"
316        logger.info("image is {}".format(item["image"]))
317        logger.info("label is {}".format(item["label"]))
318        num_iter += 1
319
320    logger.info("Number of data in data1: {}".format(num_iter))
321    assert num_iter == 5
322
323
324def test_pk_sampler():
325    logger.info("Test Case PKSampler")
326    # define parameters
327    repeat_count = 1
328
329    # apply dataset operations
330    sampler = ds.PKSampler(3)
331    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
332    data1 = data1.repeat(repeat_count)
333
334    num_iter = 0
335    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
336        # in this example, each dictionary has keys "image" and "label"
337        logger.info("image is {}".format(item["image"]))
338        logger.info("label is {}".format(item["label"]))
339        num_iter += 1
340
341    logger.info("Number of data in data1: {}".format(num_iter))
342    assert num_iter == 12
343
344
345def test_subset_random_sampler():
346    logger.info("Test Case SubsetRandomSampler")
347    # define parameters
348    repeat_count = 1
349
350    # apply dataset operations
351    indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
352    sampler = ds.SubsetRandomSampler(indices)
353    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
354    data1 = data1.repeat(repeat_count)
355
356    num_iter = 0
357    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
358        # in this example, each dictionary has keys "image" and "label"
359        logger.info("image is {}".format(item["image"]))
360        logger.info("label is {}".format(item["label"]))
361        num_iter += 1
362
363    logger.info("Number of data in data1: {}".format(num_iter))
364    assert num_iter == 12
365
366
367def test_weighted_random_sampler():
368    logger.info("Test Case WeightedRandomSampler")
369    # define parameters
370    repeat_count = 1
371
372    # apply dataset operations
373    weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
374    sampler = ds.WeightedRandomSampler(weights, 11)
375    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
376    data1 = data1.repeat(repeat_count)
377
378    num_iter = 0
379    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
380        # in this example, each dictionary has keys "image" and "label"
381        logger.info("image is {}".format(item["image"]))
382        logger.info("label is {}".format(item["label"]))
383        num_iter += 1
384
385    logger.info("Number of data in data1: {}".format(num_iter))
386    assert num_iter == 11
387
388
389def test_weighted_random_sampler_exception():
390    """
391    Test error cases for WeightedRandomSampler
392    """
393    logger.info("Test error cases for WeightedRandomSampler")
394    error_msg_1 = "type of weights element must be number"
395    with pytest.raises(TypeError, match=error_msg_1):
396        weights = ""
397        ds.WeightedRandomSampler(weights)
398
399    error_msg_2 = "type of weights element must be number"
400    with pytest.raises(TypeError, match=error_msg_2):
401        weights = (0.9, 0.8, 1.1)
402        ds.WeightedRandomSampler(weights)
403
404    error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
405    with pytest.raises(RuntimeError, match=error_msg_3):
406        weights = []
407        sampler = ds.WeightedRandomSampler(weights)
408        sampler.parse()
409
410    error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
411    with pytest.raises(RuntimeError, match=error_msg_4):
412        weights = [1.0, 0.1, 0.02, 0.3, -0.4]
413        sampler = ds.WeightedRandomSampler(weights)
414        sampler.parse()
415
416    error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
417    with pytest.raises(RuntimeError, match=error_msg_5):
418        weights = [0, 0, 0, 0, 0]
419        sampler = ds.WeightedRandomSampler(weights)
420        sampler.parse()
421
422
423def test_chained_sampler_01():
424    logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
425
426    # Create chained sampler, random and sequential
427    sampler = ds.RandomSampler()
428    child_sampler = ds.SequentialSampler()
429    sampler.add_child(child_sampler)
430    # Create ImageFolderDataset with sampler
431    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
432
433    data1 = data1.repeat(count=3)
434
435    # Verify dataset size
436    data1_size = data1.get_dataset_size()
437    logger.info("dataset size is: {}".format(data1_size))
438    assert data1_size == 132
439
440    # Verify number of iterations
441    num_iter = 0
442    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
443        # in this example, each dictionary has keys "image" and "label"
444        logger.info("image is {}".format(item["image"]))
445        logger.info("label is {}".format(item["label"]))
446        num_iter += 1
447
448    logger.info("Number of data in data1: {}".format(num_iter))
449    assert num_iter == 132
450
451
452def test_chained_sampler_02():
453    logger.info("Test Case Chained Sampler - Random and Sequential, with batch then repeat")
454
455    # Create chained sampler, random and sequential
456    sampler = ds.RandomSampler()
457    child_sampler = ds.SequentialSampler()
458    sampler.add_child(child_sampler)
459    # Create ImageFolderDataset with sampler
460    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
461
462    data1 = data1.batch(batch_size=5, drop_remainder=True)
463    data1 = data1.repeat(count=2)
464
465    # Verify dataset size
466    data1_size = data1.get_dataset_size()
467    logger.info("dataset size is: {}".format(data1_size))
468    assert data1_size == 16
469
470    # Verify number of iterations
471    num_iter = 0
472    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
473        # in this example, each dictionary has keys "image" and "label"
474        logger.info("image is {}".format(item["image"]))
475        logger.info("label is {}".format(item["label"]))
476        num_iter += 1
477
478    logger.info("Number of data in data1: {}".format(num_iter))
479    assert num_iter == 16
480
481
482def test_chained_sampler_03():
483    logger.info("Test Case Chained Sampler - Random and Sequential, with repeat then batch")
484
485    # Create chained sampler, random and sequential
486    sampler = ds.RandomSampler()
487    child_sampler = ds.SequentialSampler()
488    sampler.add_child(child_sampler)
489    # Create ImageFolderDataset with sampler
490    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
491
492    data1 = data1.repeat(count=2)
493    data1 = data1.batch(batch_size=5, drop_remainder=False)
494
495    # Verify dataset size
496    data1_size = data1.get_dataset_size()
497    logger.info("dataset size is: {}".format(data1_size))
498    assert data1_size == 18
499
500    # Verify number of iterations
501    num_iter = 0
502    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
503        # in this example, each dictionary has keys "image" and "label"
504        logger.info("image is {}".format(item["image"]))
505        logger.info("label is {}".format(item["label"]))
506        num_iter += 1
507
508    logger.info("Number of data in data1: {}".format(num_iter))
509    assert num_iter == 18
510
511
512def test_chained_sampler_04():
513    logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat")
514
515    # Create chained sampler, distributed and random
516    sampler = ds.DistributedSampler(num_shards=4, shard_id=3)
517    child_sampler = ds.RandomSampler()
518    sampler.add_child(child_sampler)
519    # Create ImageFolderDataset with sampler
520    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
521
522    data1 = data1.batch(batch_size=5, drop_remainder=True)
523    data1 = data1.repeat(count=3)
524
525    # Verify dataset size
526    data1_size = data1.get_dataset_size()
527    logger.info("dataset size is: {}".format(data1_size))
528    assert data1_size == 6
529
530    # Verify number of iterations
531    num_iter = 0
532    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
533        # in this example, each dictionary has keys "image" and "label"
534        logger.info("image is {}".format(item["image"]))
535        logger.info("label is {}".format(item["label"]))
536        num_iter += 1
537
538    logger.info("Number of data in data1: {}".format(num_iter))
539    # Note: Each of the 4 shards has 44/4=11 samples
540    # Note: Number of iterations is (11/5 = 2) * 3 = 6
541    assert num_iter == 6
542
543
544def skip_test_chained_sampler_05():
545    logger.info("Test Case Chained Sampler - PKSampler and WeightedRandom")
546
547    # Create chained sampler, PKSampler and WeightedRandom
548    sampler = ds.PKSampler(num_val=3)  # Number of elements per class is 3 (and there are 4 classes)
549    weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
550    child_sampler = ds.WeightedRandomSampler(weights, num_samples=12)
551    sampler.add_child(child_sampler)
552    # Create ImageFolderDataset with sampler
553    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
554
555    # Verify dataset size
556    data1_size = data1.get_dataset_size()
557    logger.info("dataset size is: {}".format(data1_size))
558    assert data1_size == 12
559
560    # Verify number of iterations
561    num_iter = 0
562    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
563        # in this example, each dictionary has keys "image" and "label"
564        logger.info("image is {}".format(item["image"]))
565        logger.info("label is {}".format(item["label"]))
566        num_iter += 1
567
568    logger.info("Number of data in data1: {}".format(num_iter))
569    # Note: PKSampler produces 4x3=12 samples
570    # Note: Child WeightedRandomSampler produces 12 samples
571    assert num_iter == 12
572
573
574def test_chained_sampler_06():
575    logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")
576
577    # Create chained sampler, WeightedRandom and PKSampler
578    weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5]
579    sampler = ds.WeightedRandomSampler(weights=weights, num_samples=12)
580    child_sampler = ds.PKSampler(num_val=3)  # Number of elements per class is 3 (and there are 4 classes)
581    sampler.add_child(child_sampler)
582    # Create ImageFolderDataset with sampler
583    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
584
585    # Verify dataset size
586    data1_size = data1.get_dataset_size()
587    logger.info("dataset size is: {}".format(data1_size))
588    assert data1_size == 12
589
590    # Verify number of iterations
591    num_iter = 0
592    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
593        # in this example, each dictionary has keys "image" and "label"
594        logger.info("image is {}".format(item["image"]))
595        logger.info("label is {}".format(item["label"]))
596        num_iter += 1
597
598    logger.info("Number of data in data1: {}".format(num_iter))
599    # Note: WeightedRandomSampler produces 12 samples
600    # Note: Child PKSampler produces 12 samples
601    assert num_iter == 12
602
603
604def test_chained_sampler_07():
605    logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 2 shards")
606
607    # Create chained sampler, subset random and distributed
608    indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
609    sampler = ds.SubsetRandomSampler(indices, num_samples=12)
610    child_sampler = ds.DistributedSampler(num_shards=2, shard_id=1)
611    sampler.add_child(child_sampler)
612    # Create ImageFolderDataset with sampler
613    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
614
615    # Verify dataset size
616    data1_size = data1.get_dataset_size()
617    logger.info("dataset size is: {}".format(data1_size))
618    assert data1_size == 12
619
620    # Verify number of iterations
621
622    num_iter = 0
623    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
624        # in this example, each dictionary has keys "image" and "label"
625        logger.info("image is {}".format(item["image"]))
626        logger.info("label is {}".format(item["label"]))
627        num_iter += 1
628
629    logger.info("Number of data in data1: {}".format(num_iter))
630    # Note: SubsetRandomSampler produces 12 samples
631    # Note: Each of 2 shards has 6 samples
632    # FIXME: Uncomment the following assert when code issue is resolved; at runtime, number of samples is 12 not 6
633    # assert num_iter == 6
634
635
636def skip_test_chained_sampler_08():
637    logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 4 shards")
638
639    # Create chained sampler, subset random and distributed
640    indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11]
641    sampler = ds.SubsetRandomSampler(indices, num_samples=12)
642    child_sampler = ds.DistributedSampler(num_shards=4, shard_id=1)
643    sampler.add_child(child_sampler)
644    # Create ImageFolderDataset with sampler
645    data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler)
646
647    # Verify dataset size
648    data1_size = data1.get_dataset_size()
649    logger.info("dataset size is: {}".format(data1_size))
650    assert data1_size == 3
651
652    # Verify number of iterations
653    num_iter = 0
654    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
655        # in this example, each dictionary has keys "image" and "label"
656        logger.info("image is {}".format(item["image"]))
657        logger.info("label is {}".format(item["label"]))
658        num_iter += 1
659
660    logger.info("Number of data in data1: {}".format(num_iter))
661    # Note: SubsetRandomSampler returns 12 samples
662    # Note: Each of 4 shards has 3 samples
663    assert num_iter == 3
664
665
666def test_imagefolder_rename():
667    logger.info("Test Case rename")
668    # define parameters
669    repeat_count = 1
670
671    # apply dataset operations
672    data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
673    data1 = data1.repeat(repeat_count)
674
675    num_iter = 0
676    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
677        # in this example, each dictionary has keys "image" and "label"
678        logger.info("image is {}".format(item["image"]))
679        logger.info("label is {}".format(item["label"]))
680        num_iter += 1
681
682    logger.info("Number of data in data1: {}".format(num_iter))
683    assert num_iter == 10
684
685    data1 = data1.rename(input_columns=["image"], output_columns="image2")
686
687    num_iter = 0
688    for item in data1.create_dict_iterator(num_epochs=1):  # each data is a dictionary
689        # in this example, each dictionary has keys "image" and "label"
690        logger.info("image is {}".format(item["image2"]))
691        logger.info("label is {}".format(item["label"]))
692        num_iter += 1
693
694    logger.info("Number of data in data1: {}".format(num_iter))
695    assert num_iter == 10
696
697
698def test_imagefolder_zip():
699    logger.info("Test Case zip")
700    # define parameters
701    repeat_count = 2
702
703    # apply dataset operations
704    data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
705    data2 = ds.ImageFolderDataset(DATA_DIR, num_samples=10)
706
707    data1 = data1.repeat(repeat_count)
708    # rename dataset2 for no conflict
709    data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
710    data3 = ds.zip((data1, data2))
711
712    num_iter = 0
713    for item in data3.create_dict_iterator(num_epochs=1):  # each data is a dictionary
714        # in this example, each dictionary has keys "image" and "label"
715        logger.info("image is {}".format(item["image"]))
716        logger.info("label is {}".format(item["label"]))
717        num_iter += 1
718
719    logger.info("Number of data in data1: {}".format(num_iter))
720    assert num_iter == 10
721
722
723def test_imagefolder_exception():
724    logger.info("Test imagefolder exception")
725
726    def exception_func(item):
727        raise Exception("Error occur!")
728
729    def exception_func2(image, label):
730        raise Exception("Error occur!")
731
732    try:
733        data = ds.ImageFolderDataset(DATA_DIR)
734        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
735        for _ in data.__iter__():
736            pass
737        assert False
738    except RuntimeError as e:
739        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
740
741    try:
742        data = ds.ImageFolderDataset(DATA_DIR)
743        data = data.map(operations=exception_func2, input_columns=["image", "label"],
744                        output_columns=["image", "label", "label1"],
745                        column_order=["image", "label", "label1"], num_parallel_workers=1)
746        for _ in data.__iter__():
747            pass
748        assert False
749    except RuntimeError as e:
750        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
751
752    try:
753        data = ds.ImageFolderDataset(DATA_DIR)
754        data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
755        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
756        for _ in data.__iter__():
757            pass
758        assert False
759    except RuntimeError as e:
760        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
761
762    data_dir_invalid = "../data/dataset/testPK"
763    try:
764        data = ds.ImageFolderDataset(data_dir_invalid)
765        for _ in data.__iter__():
766            pass
767        assert False
768    except RuntimeError as e:
769        assert "should be file, but got directory" in str(e)
770
771
772if __name__ == '__main__':
773    test_imagefolder_basic()
774    logger.info('test_imagefolder_basic Ended.\n')
775
776    test_imagefolder_numsamples()
777    logger.info('test_imagefolder_numsamples Ended.\n')
778
779    test_sequential_sampler()
780    logger.info('test_sequential_sampler Ended.\n')
781
782    test_random_sampler()
783    logger.info('test_random_sampler Ended.\n')
784
785    test_distributed_sampler()
786    logger.info('test_distributed_sampler Ended.\n')
787
788    test_pk_sampler()
789    logger.info('test_pk_sampler Ended.\n')
790
791    test_subset_random_sampler()
792    logger.info('test_subset_random_sampler Ended.\n')
793
794    test_weighted_random_sampler()
795    logger.info('test_weighted_random_sampler Ended.\n')
796
797    test_weighted_random_sampler_exception()
798    logger.info('test_weighted_random_sampler_exception Ended.\n')
799
800    test_chained_sampler_01()
801    logger.info('test_chained_sampler_01 Ended.\n')
802
803    test_chained_sampler_02()
804    logger.info('test_chained_sampler_02 Ended.\n')
805
806    test_chained_sampler_03()
807    logger.info('test_chained_sampler_03 Ended.\n')
808
809    test_chained_sampler_04()
810    logger.info('test_chained_sampler_04 Ended.\n')
811
812    # test_chained_sampler_05()
813    # logger.info('test_chained_sampler_05 Ended.\n')
814
815    test_chained_sampler_06()
816    logger.info('test_chained_sampler_06 Ended.\n')
817
818    test_chained_sampler_07()
819    logger.info('test_chained_sampler_07 Ended.\n')
820
821    # test_chained_sampler_08()
822    # logger.info('test_chained_sampler_07 Ended.\n')
823
824    test_imagefolder_numshards()
825    logger.info('test_imagefolder_numshards Ended.\n')
826
827    test_imagefolder_shardid()
828    logger.info('test_imagefolder_shardid Ended.\n')
829
830    test_imagefolder_noshuffle()
831    logger.info('test_imagefolder_noshuffle Ended.\n')
832
833    test_imagefolder_extrashuffle()
834    logger.info('test_imagefolder_extrashuffle Ended.\n')
835
836    test_imagefolder_classindex()
837    logger.info('test_imagefolder_classindex Ended.\n')
838
839    test_imagefolder_negative_classindex()
840    logger.info('test_imagefolder_negative_classindex Ended.\n')
841
842    test_imagefolder_extensions()
843    logger.info('test_imagefolder_extensions Ended.\n')
844
845    test_imagefolder_decode()
846    logger.info('test_imagefolder_decode Ended.\n')
847
848    test_imagefolder_rename()
849    logger.info('test_imagefolder_rename Ended.\n')
850
851    test_imagefolder_zip()
852    logger.info('test_imagefolder_zip Ended.\n')
853
854    test_imagefolder_exception()
855    logger.info('test_imagefolder_exception Ended.\n')
856