• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import pytest
17
18import mindspore.dataset as ds
19from mindspore import log as logger
20from util import dataset_equal
21
22
23# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
24# the label of each image is [0,0,0,1,1] each image can be uniquely identified
25# via the following lookup table (dict){(83554, 0): 0, (54214, 0): 1, (54214, 1): 2, (65512, 0): 3, (64631, 1): 4}
26
27def test_sequential_sampler(print_res=False):
28    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
29    map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
30
31    def test_config(num_samples, num_repeats=None):
32        sampler = ds.SequentialSampler(num_samples=num_samples)
33        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
34        if num_repeats is not None:
35            data1 = data1.repeat(num_repeats)
36        res = []
37        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
38            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
39                        .format(item["image"].shape[0], item["label"].item()))
40            res.append(map_[(item["image"].shape[0], item["label"].item())])
41        if print_res:
42            logger.info("image.shapes and labels: {}".format(res))
43        return res
44
45    assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2]
46    assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2
47    assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2
48
49
50def test_random_sampler(print_res=False):
51    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
52    map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
53
54    def test_config(replacement, num_samples, num_repeats):
55        sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
56        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
57        data1 = data1.repeat(num_repeats)
58        res = []
59        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
60            res.append(map_[(item["image"].shape[0], item["label"].item())])
61        if print_res:
62            logger.info("image.shapes and labels: {}".format(res))
63        return res
64
65    # this tests that each epoch COULD return different samples than the previous epoch
66    assert len(set(test_config(replacement=False, num_samples=2, num_repeats=6))) > 2
67    # the following two tests test replacement works
68    ordered_res = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
69    assert sorted(test_config(replacement=False, num_samples=None, num_repeats=4)) == ordered_res
70    assert sorted(test_config(replacement=True, num_samples=None, num_repeats=4)) != ordered_res
71
72
73def test_random_sampler_multi_iter(print_res=False):
74    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
75    map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
76
77    def test_config(replacement, num_samples, num_repeats, validate):
78        sampler = ds.RandomSampler(replacement=replacement, num_samples=num_samples)
79        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
80        while num_repeats > 0:
81            res = []
82            for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
83                res.append(map_[(item["image"].shape[0], item["label"].item())])
84            if print_res:
85                logger.info("image.shapes and labels: {}".format(res))
86            if validate != sorted(res):
87                break
88            num_repeats -= 1
89        assert num_repeats > 0
90
91    test_config(replacement=True, num_samples=5, num_repeats=5, validate=[0, 1, 2, 3, 4, 5])
92
93
94def test_sampler_py_api():
95    sampler = ds.SequentialSampler().parse()
96    sampler1 = ds.RandomSampler().parse()
97    sampler1.add_child(sampler)
98
99
100def test_python_sampler():
101    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
102    map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
103
104    class Sp1(ds.Sampler):
105        def __iter__(self):
106            return iter([i for i in range(self.dataset_size)])
107
108    class Sp2(ds.Sampler):
109        def __init__(self, num_samples=None):
110            super(Sp2, self).__init__(num_samples)
111            # at this stage, self.dataset_size and self.num_samples are not yet known
112            self.cnt = 0
113
114        def __iter__(self):  # first epoch, all 0, second epoch all 1, third all 2 etc.. ...
115            return iter([self.cnt for i in range(self.num_samples)])
116
117        def reset(self):
118            self.cnt = (self.cnt + 1) % self.dataset_size
119
120    def test_config(num_repeats, sampler):
121        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
122        if num_repeats is not None:
123            data1 = data1.repeat(num_repeats)
124        res = []
125        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
126            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
127                        .format(item["image"].shape[0], item["label"].item()))
128            res.append(map_[(item["image"].shape[0], item["label"].item())])
129        # print(res)
130        return res
131
132    def test_generator():
133        class MySampler(ds.Sampler):
134            def __iter__(self):
135                for i in range(99, -1, -1):
136                    yield i
137
138        data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
139        i = 99
140        for data in data1:
141            assert data[0].asnumpy() == (np.array(i),)
142            i = i - 1
143
144    # This 2nd case is the one that exhibits the same behavior as the case above without inheritance
145    def test_generator_iter_sampler():
146        class MySampler():
147            def __iter__(self):
148                for i in range(99, -1, -1):
149                    yield i
150
151        data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler=MySampler())
152        i = 99
153        for data in data1:
154            assert data[0].asnumpy() == (np.array(i),)
155            i = i - 1
156
157    assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
158    assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0]
159    test_generator()
160    test_generator_iter_sampler()
161
162
163def test_sequential_sampler2():
164    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
165    map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
166
167    def test_config(start_index, num_samples):
168        sampler = ds.SequentialSampler(start_index, num_samples)
169        d = ds.ManifestDataset(manifest_file, sampler=sampler)
170
171        res = []
172        for item in d.create_dict_iterator(num_epochs=1, output_numpy=True):
173            res.append(map_[(item["image"].shape[0], item["label"].item())])
174
175        return res
176
177    assert test_config(0, 1) == [0]
178    assert test_config(0, 2) == [0, 1]
179    assert test_config(0, 3) == [0, 1, 2]
180    assert test_config(0, 4) == [0, 1, 2, 3]
181    assert test_config(0, 5) == [0, 1, 2, 3, 4]
182    assert test_config(1, 1) == [1]
183    assert test_config(2, 3) == [2, 3, 4]
184    assert test_config(3, 2) == [3, 4]
185    assert test_config(4, 1) == [4]
186    assert test_config(4, None) == [4]
187
188
189def test_subset_sampler():
190    def test_config(indices, num_samples=None, exception_msg=None):
191        def pipeline():
192            sampler = ds.SubsetSampler(indices, num_samples)
193            data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler)
194            data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples)
195            dataset_size = data.get_dataset_size()
196            dataset_size2 = data.get_dataset_size()
197            res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size
198            res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2
199            return res1, res2
200
201        if exception_msg is None:
202            res, res2 = pipeline()
203            res, size = res
204            res2, size2 = res2
205            if not isinstance(indices, list):
206                indices = list(indices)
207            assert indices[:num_samples] == res
208            assert len(indices[:num_samples]) == size
209            assert indices[:num_samples] == res2
210            assert len(indices[:num_samples]) == size2
211        else:
212            with pytest.raises(Exception) as error_info:
213                pipeline()
214            print(str(error_info.value))
215            assert exception_msg in str(error_info.value)
216
217    test_config([1, 2, 3])
218    test_config(list(range(10)))
219    test_config([0])
220    test_config([9])
221    test_config(list(range(0, 10, 2)))
222    test_config(list(range(1, 10, 2)))
223    test_config(list(range(9, 0, -1)))
224    test_config(list(range(9, 0, -2)))
225    test_config(list(range(8, 0, -2)))
226    test_config([0, 9, 3, 2])
227    test_config([0, 0, 0, 0])
228    test_config([0])
229    test_config([0, 9, 3, 2], num_samples=2)
230    test_config([0, 9, 3, 2], num_samples=5)
231
232    test_config(np.array([1, 2, 3]))
233
234    test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]")
235    test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]")
236    test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]")
237    test_config([0, 9, -6, 2], exception_msg="Sample ID (-6) is out of bound, expected range [0, 9]")
238    # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset
239    test_config([0, 9, 3, 2], num_samples=-1,
240                exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)")
241    test_config(np.array([[1], [5]]), num_samples=10,
242                exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1],"
243                              " type: <class 'numpy.ndarray'>.")
244
245
246def test_sampler_chain():
247    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
248    map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
249
250    def test_config(num_shards, shard_id):
251        sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=False, num_samples=5)
252        child_sampler = ds.SequentialSampler()
253        sampler.add_child(child_sampler)
254
255        data1 = ds.ManifestDataset(manifest_file, sampler=sampler)
256
257        res = []
258        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
259            logger.info("item[image].shape[0]: {}, item[label].item(): {}"
260                        .format(item["image"].shape[0], item["label"].item()))
261            res.append(map_[(item["image"].shape[0], item["label"].item())])
262        return res
263
264    assert test_config(2, 0) == [0, 2, 4]
265    assert test_config(2, 1) == [1, 3, 0]
266    assert test_config(5, 0) == [0]
267    assert test_config(5, 1) == [1]
268    assert test_config(5, 2) == [2]
269    assert test_config(5, 3) == [3]
270    assert test_config(5, 4) == [4]
271
272
273def test_add_sampler_invalid_input():
274    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
275    _ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
276    data1 = ds.ManifestDataset(manifest_file)
277
278    with pytest.raises(TypeError) as info:
279        data1.use_sampler(1)
280    assert "not an instance of a sampler" in str(info.value)
281
282    with pytest.raises(TypeError) as info:
283        data1.use_sampler("sampler")
284    assert "not an instance of a sampler" in str(info.value)
285
286    sampler = ds.SequentialSampler()
287    with pytest.raises(RuntimeError) as info:
288        data2 = ds.ManifestDataset(manifest_file, sampler=sampler, num_samples=20)
289    assert "sampler and num_samples cannot be specified at the same time" in str(info.value)
290
291
292def test_distributed_sampler_invalid_offset():
293    with pytest.raises(RuntimeError) as info:
294        sampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=5).parse()
295    assert "DistributedSampler: offset must be no more than num_shards(4)" in str(info.value)
296
297
298def test_sampler_list():
299    data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=[1, 3, 5])
300    data21 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(2).skip(1)
301    data22 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(4).skip(3)
302    data23 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(6).skip(5)
303
304    dataset_equal(data1, data21 + data22 + data23, 0)
305
306    data3 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=1)
307    dataset_equal(data3, data21, 0)
308
309    def bad_pipeline(sampler, msg):
310        with pytest.raises(Exception) as info:
311            data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=sampler)
312            for _ in data1:
313                pass
314        assert msg in str(info.value)
315
316    bad_pipeline(sampler=[1.5, 7],
317                 msg="Type of indices element must be int, but got list[0]: 1.5, type: <class 'float'>")
318
319    bad_pipeline(sampler=["a", "b"],
320                 msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.")
321    bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)")
322    bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)")
323    bad_pipeline(sampler=np.array([[1, 2]]),
324                 msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
325
326
327if __name__ == '__main__':
328    test_sequential_sampler(True)
329    test_random_sampler(True)
330    test_random_sampler_multi_iter(True)
331    test_sampler_py_api()
332    test_python_sampler()
333    test_sequential_sampler2()
334    test_subset_sampler()
335    test_sampler_chain()
336    test_add_sampler_invalid_input()
337    test_distributed_sampler_invalid_offset()
338    test_sampler_list()
339