• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import mindspore.dataset as ds
17import mindspore.dataset.transforms.c_transforms as c_transforms
18from mindspore import log as logger
19from util import save_and_check_md5
20
21GENERATE_GOLDEN = False
22
23IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train"
24IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
25                       "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
26                       "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
27                       "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
28MNIST_DATA_DIR = "../data/dataset/testMnistData"
29MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
30CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data"
31COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
32ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
33VOC_DATA_DIR = "../data/dataset/testVOC2012"
34
35
36def test_numpyslices_sampler_no_chain():
37    """
38    Test NumpySlicesDataset with sampler, no chain
39    """
40    logger.info("test_numpyslices_sampler_no_chain")
41
42    # Create NumpySlicesDataset with sampler, no chain
43    np_data = [1, 2, 3, 4]
44    sampler = ds.SequentialSampler(start_index=1, num_samples=2)
45    data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
46
47    # Verify dataset size
48    data1_size = data1.get_dataset_size()
49    logger.info("dataset size is: {}".format(data1_size))
50    assert data1_size == 2
51
52    # Verify number of rows
53    assert sum([1 for _ in data1]) == 2
54
55    # Verify dataset contents
56    res = []
57    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
58        logger.info("item: {}".format(item))
59        res.append(item)
60    logger.info("dataset: {}".format(res))
61
62
63def test_numpyslices_sampler_chain():
64    """
65    Test NumpySlicesDataset sampler chain
66    """
67    logger.info("test_numpyslices_sampler_chain")
68
69    # Create NumpySlicesDataset with sampler chain
70    # Use 1 statement to add child sampler
71    np_data = [1, 2, 3, 4]
72    sampler = ds.SequentialSampler(start_index=1, num_samples=2)
73    sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
74    data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
75
76    # Verify dataset size
77    data1_size = data1.get_dataset_size()
78    logger.info("dataset size is: {}".format(data1_size))
79    assert data1_size == 1
80
81    # Verify number of rows
82    assert sum([1 for _ in data1]) == 1
83
84    # Verify dataset contents
85    res = []
86    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
87        logger.info("item: {}".format(item))
88        res.append(item)
89    logger.info("dataset: {}".format(res))
90
91
92def test_numpyslices_sampler_chain2():
93    """
94    Test NumpySlicesDataset sampler chain
95    """
96    logger.info("test_numpyslices_sampler_chain2")
97
98    # Create NumpySlicesDataset with sampler chain
99    # Use 2 statements to add child sampler
100    np_data = [1, 2, 3, 4]
101    sampler = ds.SequentialSampler(start_index=1, num_samples=1)
102    child_sampler = ds.SequentialSampler(start_index=1, num_samples=2)
103    sampler.add_child(child_sampler)
104    data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
105
106    # Verify dataset size
107    data1_size = data1.get_dataset_size()
108    logger.info("dataset size is: {}".format(data1_size))
109    assert data1_size == 1
110
111    # Verify number of rows
112    assert sum([1 for _ in data1]) == 1
113
114    # Verify dataset contents
115    res = []
116    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
117        logger.info("item: {}".format(item))
118        res.append(item)
119    logger.info("dataset: {}".format(res))
120
121
122def test_imagefolder_sampler_chain():
123    """
124    Test ImageFolderDataset sampler chain
125    """
126    logger.info("test_imagefolder_sampler_chain")
127
128    sampler = ds.SequentialSampler(start_index=1, num_samples=3)
129    child_sampler = ds.PKSampler(2)
130    sampler.add_child(child_sampler)
131    data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler)
132    # Verify dataset size
133    data1_size = data1.get_dataset_size()
134    logger.info("dataset size is: {}".format(data1_size))
135    assert data1_size == 3
136    # Verify number of rows
137    assert sum([1 for _ in data1]) == 3
138
139    # Verify dataset contents
140    res = []
141    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
142        logger.info("item: {}".format(item))
143        res.append(item)
144    logger.info("dataset: {}".format(res))
145
146
147def test_mnist_sampler_chain():
148    """
149    Test Mnist sampler chain
150    """
151    logger.info("test_mnist_sampler_chain")
152
153    sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
154    child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
155    sampler.add_child(child_sampler)
156    data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler)
157
158    # Verify dataset size
159    data1_size = data1.get_dataset_size()
160    logger.info("dataset size is: {}".format(data1_size))
161    assert data1_size == 3
162    # Verify number of rows
163    assert sum([1 for _ in data1]) == 3
164
165    # Verify dataset contents
166    res = []
167    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
168        logger.info("item: {}".format(item))
169        res.append(item)
170    logger.info("dataset: {}".format(res))
171
172
173def test_manifest_sampler_chain():
174    """
175    Test Manifest sampler chain
176    """
177    logger.info("test_manifest_sampler_chain")
178
179    sampler = ds.RandomSampler(replacement=True, num_samples=2)
180    child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
181    sampler.add_child(child_sampler)
182    data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler)
183
184    # Verify dataset size
185    data1_size = data1.get_dataset_size()
186    logger.info("dataset size is: {}".format(data1_size))
187    assert data1_size == 2
188    # Verify number of rows
189    assert sum([1 for _ in data1]) == 2
190
191    # Verify dataset contents
192    res = []
193    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
194        logger.info("item: {}".format(item))
195        res.append(item)
196    logger.info("dataset: {}".format(res))
197
198
199def test_coco_sampler_chain():
200    """
201    Test Coco sampler chain
202    """
203    logger.info("test_coco_sampler_chain")
204
205    sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
206    child_sampler = ds.RandomSampler(replacement=True, num_samples=2)
207    sampler.add_child(child_sampler)
208    data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True,
209                           sampler=sampler)
210
211    # Verify dataset size
212    data1_size = data1.get_dataset_size()
213    logger.info("dataset size is: {}".format(data1_size))
214    assert data1_size == 1
215
216    # Verify number of rows
217    assert sum([1 for _ in data1]) == 1
218
219    # Verify dataset contents
220    res = []
221    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
222        logger.info("item: {}".format(item))
223        res.append(item)
224    logger.info("dataset: {}".format(res))
225
226
227def test_cifar_sampler_chain():
228    """
229    Test Cifar sampler chain
230    """
231    logger.info("test_cifar_sampler_chain")
232
233    sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
234    child_sampler = ds.RandomSampler(replacement=True, num_samples=4)
235    child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2)
236    child_sampler.add_child(child_sampler2)
237    sampler.add_child(child_sampler)
238    data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler)
239    # Verify dataset size
240    data1_size = data1.get_dataset_size()
241    logger.info("dataset size is: {}".format(data1_size))
242    assert data1_size == 1
243
244    # Verify number of rows
245    assert sum([1 for _ in data1]) == 1
246
247    # Verify dataset contents
248    res = []
249    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
250        logger.info("item: {}".format(item))
251        res.append(item)
252    logger.info("dataset: {}".format(res))
253
254
255def test_voc_sampler_chain():
256    """
257    Test VOC sampler chain
258    """
259    logger.info("test_voc_sampler_chain")
260
261    sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
262    child_sampler = ds.SequentialSampler(start_index=0)
263    sampler.add_child(child_sampler)
264    data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler)
265
266    # Verify dataset size
267    data1_size = data1.get_dataset_size()
268    logger.info("dataset size is: {}".format(data1_size))
269    assert data1_size == 5
270
271    # Verify number of rows
272    assert sum([1 for _ in data1.create_dict_iterator(output_numpy=True)]) == 5
273
274    # Verify dataset contents
275    res = []
276    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
277        logger.info("item: {}".format(item))
278        res.append(item)
279    logger.info("dataset: {}".format(res))
280
281
282def test_numpyslices_sampler_chain_batch():
283    """
284    Test NumpySlicesDataset sampler chaining, with batch
285    """
286    logger.info("test_numpyslices_sampler_chain_batch")
287
288    # Create NumpySlicesDataset with sampler chain
289    np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
290    sampler = ds.SequentialSampler(start_index=1, num_samples=3)
291    sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
292    data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
293    data1 = data1.batch(batch_size=3, drop_remainder=False)
294
295    # Verify dataset size
296    data1_size = data1.get_dataset_size()
297    logger.info("dataset size is: {}".format(data1_size))
298    assert data1_size == 4
299
300    # Verify number of rows
301    assert sum([1 for _ in data1]) == 4
302
303    # Verify dataset contents
304    res = []
305    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
306        logger.info("item: {}".format(item))
307        res.append(item)
308    logger.info("dataset: {}".format(res))
309
310
311def test_sampler_chain_errors():
312    """
313    Test error cases for sampler chains
314    """
315    logger.info("test_sampler_chain_errors")
316
317    error_msg_1 = "'NoneType' object has no attribute 'add_child'"
318    # Test add child sampler within child sampler
319    sampler = ds.SequentialSampler(start_index=1, num_samples=2)
320    sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
321    with pytest.raises(AttributeError, match=error_msg_1):
322        sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
323
324    # error_msg_2 = "'NoneType' object has no attribute 'add_child'"
325    # Test add second and nested child sampler
326    sampler = ds.SequentialSampler(start_index=1, num_samples=2)
327    child_sampler = ds.SequentialSampler(start_index=1, num_samples=2)
328    sampler.add_child(child_sampler)
329    child_sampler2 = ds.SequentialSampler(start_index=1, num_samples=2)
330    sampler.add_child(child_sampler2)
331    # FIXME - no error is raised; uncomment after code issue is resolved
332    # with pytest.raises(AttributeError, match=error_msg_2):
333    #     sampler.add_child(child_sampler2)
334    #     np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
335    #     data1 = ds.NumpySlicesDataset(np_data, sampler=sampler)
336
337    error_msg_3 = "Conflicting arguments during sampler assignments."
338    # Test conflicting arguments (sampler and shuffle=False) for sampler (no chain)
339    np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
340    sampler = ds.SequentialSampler(start_index=1, num_samples=3)
341    with pytest.raises(ValueError, match=error_msg_3):
342        ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler)
343
344    # error_msg_4 = "Conflicting arguments during sampler assignments."
345    # Test conflicting arguments (sampler and shuffle=False) for sampler chaining
346    np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
347    sampler = ds.SequentialSampler(start_index=1, num_samples=3)
348    sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2))
349    # FIXME - no error is raised; uncomment after code issue is resolved
350    # with pytest.raises(ValueError, match=error_msg_4):
351    #     ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler)
352
353
354def test_manifest_sampler_chain_repeat():
355    """
356    Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with repeat
357    """
358    logger.info("test_manifest_sampler_chain_batch")
359    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
360
361    # Create sampler chain DistributedSampler->SequentialSampler
362    sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5)
363    child_sampler = ds.SequentialSampler()
364    sampler.add_child(child_sampler)
365
366    # Create ManifestDataset with sampler chain
367    data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
368    data1 = data1.repeat(count=2)
369
370    # Verify dataset size
371    data1_size = data1.get_dataset_size()
372    logger.info("dataset size is: {}".format(data1_size))
373    assert data1_size == 10
374
375    # Verify number of rows
376    assert sum([1 for _ in data1]) == 10
377
378    # Verify dataset contents
379    filename = "sampler_chain_manifest_repeat_result.npz"
380    save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
381
382
383def test_manifest_sampler_chain_batch_repeat():
384    """
385    Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with batch then repeat
386    """
387    logger.info("test_manifest_sampler_chain_batch_repeat")
388    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
389
390    # Create sampler chain DistributedSampler->SequentialSampler
391    sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5)
392    child_sampler = ds.SequentialSampler()
393    sampler.add_child(child_sampler)
394
395    # Create ManifestDataset with sampler chain
396    data1 = ds.ManifestDataset(manifest_file, decode=True, sampler=sampler)
397    one_hot_encode = c_transforms.OneHot(3)
398    data1 = data1.map(operations=one_hot_encode, input_columns=["label"])
399    data1 = data1.batch(batch_size=5, drop_remainder=False)
400    data1 = data1.repeat(count=2)
401
402    # Verify dataset size
403    data1_size = data1.get_dataset_size()
404    logger.info("dataset size is: {}".format(data1_size))
405    assert data1_size == 2
406
407    # Verify number of rows
408    # FIXME: Uncomment the following assert when code issue is resolved
409    # assert sum([1 for _ in data1]) == 2
410
411
412if __name__ == '__main__':
413    test_numpyslices_sampler_no_chain()
414    test_numpyslices_sampler_chain()
415    test_numpyslices_sampler_chain2()
416    test_imagefolder_sampler_chain()
417    test_mnist_sampler_chain()
418    test_manifest_sampler_chain()
419    test_coco_sampler_chain()
420    test_cifar_sampler_chain()
421    test_voc_sampler_chain()
422    test_numpyslices_sampler_chain_batch()
423    test_sampler_chain_errors()
424    test_manifest_sampler_chain_repeat()
425    test_manifest_sampler_chain_batch_repeat()
426