• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import mindspore.dataset as ds
16from mindspore import log as logger
17
18
19def test_imagefolder_shardings(print_res=False):
20    image_folder_dir = "../data/dataset/testPK/data"
21
22    def sharding_config(num_shards, shard_id, num_samples, shuffle, class_index, repeat_cnt=1):
23        data1 = ds.ImageFolderDataset(image_folder_dir, num_samples=num_samples, num_shards=num_shards,
24                                      shard_id=shard_id,
25                                      shuffle=shuffle, class_indexing=class_index, decode=True)
26        data1 = data1.repeat(repeat_cnt)
27        res = []
28        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
29            res.append(item["label"].item())
30        if print_res:
31            logger.info("labels of dataset: {}".format(res))
32        return res
33
34    # total 44 rows in dataset
35    assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1])  # 5 rows
36    assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3])  # 11 rows
37    assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])  # 11 rows
38    assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])  # 44 rows
39    assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])  # 22 rows
40    assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3])  # 22 rows
41    # total 22 in dataset rows because of class indexing which takes only 2 folders
42    assert len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6
43    assert len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3
44    # test with repeat
45    assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3)
46    assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5)
47    assert len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20
48
49
50def test_tfrecord_shardings1(print_res=False):
51    """ Test TFRecordDataset sharding with num_parallel_workers=1 """
52
53    # total 40 rows in dataset
54    tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
55                "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
56
57    def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1):
58        data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
59                                   shuffle=ds.Shuffle.FILES, num_parallel_workers=1)
60        data1 = data1.repeat(repeat_cnt)
61        res = []
62        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
63            res.append(item["scalars"][0])
64        if print_res:
65            logger.info("scalars of dataset: {}".format(res))
66        return res
67
68    assert sharding_config(2, 0, None, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]  # 20 rows
69    assert sharding_config(2, 1, None, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]  # 20 rows
70    assert sharding_config(2, 0, 3, 1) == [11, 12, 13]  # 3 rows
71    assert sharding_config(2, 1, 3, 1) == [1, 2, 3]  # 3 rows
72    assert sharding_config(2, 0, 40, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]  # 20 rows
73    assert sharding_config(2, 1, 40, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]  # 20 rows
74    assert sharding_config(2, 0, 55, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]  # 20 rows
75    assert sharding_config(2, 1, 55, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]  # 20 rows
76    assert sharding_config(3, 0, 8, 1) == [11, 12, 13, 14, 15, 16, 17, 18]  # 8 rows
77    assert sharding_config(3, 1, 8, 1) == [1, 2, 3, 4, 5, 6, 7, 8]  # 8 rows
78    assert sharding_config(3, 2, 8, 1) == [21, 22, 23, 24, 25, 26, 27, 28]  # 8 rows
79    assert sharding_config(4, 0, 2, 1) == [11, 12]  # 2 rows
80    assert sharding_config(4, 1, 2, 1) == [1, 2]  # 2 rows
81    assert sharding_config(4, 2, 2, 1) == [21, 22]  # 2 rows
82    assert sharding_config(4, 3, 2, 1) == [31, 32]  # 2 rows
83    assert sharding_config(3, 0, 4, 2) == [11, 12, 13, 14, 21, 22, 23, 24]  # 8 rows
84    assert sharding_config(3, 1, 4, 2) == [1, 2, 3, 4, 11, 12, 13, 14]  # 8 rows
85    assert sharding_config(3, 2, 4, 2) == [21, 22, 23, 24, 31, 32, 33, 34]  # 8 rows
86
87
88def test_tfrecord_shardings4(print_res=False):
89    """ Test TFRecordDataset sharding with num_parallel_workers=4 """
90
91    # total 40 rows in dataset
92    tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data",
93                "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"]
94
95    def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1):
96        data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
97                                   shuffle=ds.Shuffle.FILES, num_parallel_workers=4)
98        data1 = data1.repeat(repeat_cnt)
99        res = []
100        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
101            res.append(item["scalars"][0])
102        if print_res:
103            logger.info("scalars of dataset: {}".format(res))
104        return res
105
106    def check_result(result_list, expect_length, expect_set):
107        assert len(result_list) == expect_length
108        assert set(result_list) == expect_set
109
110    check_result(sharding_config(2, 0, None, 1), 20,
111                 {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
112    check_result(sharding_config(2, 1, None, 1), 20,
113                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
114    check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21})
115    check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31})
116    check_result(sharding_config(2, 0, 40, 1), 20,
117                 {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
118    check_result(sharding_config(2, 1, 40, 1), 20,
119                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
120    check_result(sharding_config(2, 0, 55, 1), 20,
121                 {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30})
122    check_result(sharding_config(2, 1, 55, 1), 20,
123                 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40})
124    check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31})
125    check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8})
126    check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28})
127    check_result(sharding_config(4, 0, 2, 1), 2, {11, 12})
128    check_result(sharding_config(4, 1, 2, 1), 2, {1, 2})
129    check_result(sharding_config(4, 2, 2, 1), 2, {21, 22})
130    check_result(sharding_config(4, 3, 2, 1), 2, {31, 32})
131    check_result(sharding_config(3, 0, 4, 2), 8, {32, 1, 2, 11, 12, 21, 22, 31})
132    check_result(sharding_config(3, 1, 4, 2), 8, {1, 2, 3, 4, 11, 12, 13, 14})
133    check_result(sharding_config(3, 2, 4, 2), 8, {32, 33, 34, 21, 22, 23, 24, 31})
134
135
136def test_manifest_shardings(print_res=False):
137    manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
138
139    def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
140        data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, num_shards=num_shards, shard_id=shard_id,
141                                   shuffle=shuffle, decode=True)
142        data1 = data1.repeat(repeat_cnt)
143        res = []
144        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
145            res.append(item["label"].item())
146        if print_res:
147            logger.info("labels of dataset: {}".format(res))
148        return res
149
150    # 5 train images in total
151    sharding_config(2, 0, None, False)
152    assert (sharding_config(2, 0, None, False) == [0, 1, 1])
153    assert (sharding_config(2, 1, None, False) == [0, 0, 0])
154    assert (sharding_config(2, 0, 2, False) == [0, 1])
155    assert (sharding_config(2, 1, 2, False) == [0, 0])
156    # with repeat
157    assert (sharding_config(2, 1, None, False, 3) == [0, 0, 0] * 3)
158    assert (sharding_config(2, 0, 2, False, 5) == [0, 1] * 5)
159
160
161def test_voc_shardings(print_res=False):
162    voc_dir = "../data/dataset/testVOC2012"
163
164    def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
165        sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
166        data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler)
167        data1 = data1.repeat(repeat_cnt)
168        res = []
169        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
170            res.append(item["image"].shape[0])
171        if print_res:
172            logger.info("labels of dataset: {}".format(res))
173        return res
174
175    # 10 images in total, always decode to get the shape
176    # first dim of all 10 images [2268,2268,2268,2268,642,607,561,596,612,2268]
177    # 3 shard_workers, 0th worker will get 0-th, 3nd, 6th and 9th image
178    assert (sharding_config(3, 0, None, False, 2) == [2268, 2268, 561, 2268] * 2)
179    # 3 shard_workers, 1st worker will get 1-st, 4nd, 7th and 0th image, the last one goes back bc of rounding up
180    assert (sharding_config(3, 1, 5, False, 3) == [2268, 642, 596, 2268] * 3)
181    # 3 shard_workers, 2nd worker will get 2nd, 5th, 8th and 11th (which is 1st)
182    # then takes the first 2 bc num_samples = 2
183    assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4)
184    # test that each epoch, each shard_worker returns a different sample
185    assert len(sharding_config(2, 0, None, True, 1)) == 5
186    assert len(set(sharding_config(11, 0, None, True, 10))) > 1
187
188
189def test_cifar10_shardings(print_res=False):
190    cifar10_dir = "../data/dataset/testCifar10Data"
191
192    def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
193        data1 = ds.Cifar10Dataset(cifar10_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
194                                  shuffle=shuffle)
195        data1 = data1.repeat(repeat_cnt)
196        res = []
197        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
198            res.append(item["label"].item())
199        if print_res:
200            logger.info("labels of dataset: {}".format(res))
201        return res
202
203    # 10000 rows in total. CIFAR reads everything in memory which would make each test case very slow
204    # therefore, only 2 test cases for now.
205    assert sharding_config(10000, 9999, 7, False, 1) == [9]
206    assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0]
207
208
209def test_cifar100_shardings(print_res=False):
210    cifar100_dir = "../data/dataset/testCifar100Data"
211
212    def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
213        data1 = ds.Cifar100Dataset(cifar100_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
214                                   shuffle=shuffle)
215        data1 = data1.repeat(repeat_cnt)
216        res = []
217        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
218            res.append(item["coarse_label"].item())
219        if print_res:
220            logger.info("labels of dataset: {}".format(res))
221        return res
222
223    # 10000 rows in total in test.bin CIFAR100 file
224    assert (sharding_config(1000, 999, 7, False, 2) == [1, 18, 10, 17, 5, 0, 15] * 2)
225    assert (sharding_config(1000, 0, None, False) == [10, 16, 2, 11, 10, 17, 11, 14, 13, 3])
226
227
228def test_mnist_shardings(print_res=False):
229    mnist_dir = "../data/dataset/testMnistData"
230
231    def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1):
232        data1 = ds.MnistDataset(mnist_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples,
233                                shuffle=shuffle)
234        data1 = data1.repeat(repeat_cnt)
235        res = []
236        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
237            res.append(item["label"].item())
238        if print_res:
239            logger.info("labels of dataset: {}".format(res))
240        return res
241
242    # 70K rows in total , divide across 10K hosts, each host has 7 images
243    assert sharding_config(10000, 0, num_samples=5, shuffle=False, repeat_cnt=3) == [0, 0, 0]
244    assert sharding_config(10000, 9999, num_samples=None, shuffle=False, repeat_cnt=1) == [9]
245
246
247if __name__ == '__main__':
248    test_imagefolder_shardings(True)
249    test_tfrecord_shardings1(True)
250    test_tfrecord_shardings4(True)
251    test_manifest_shardings(True)
252    test_voc_shardings(True)
253    test_cifar10_shardings(True)
254    test_cifar100_shardings(True)
255    test_mnist_shardings(True)
256