1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import mindspore.dataset as ds 16from mindspore import log as logger 17 18 19def test_imagefolder_shardings(print_res=False): 20 image_folder_dir = "../data/dataset/testPK/data" 21 22 def sharding_config(num_shards, shard_id, num_samples, shuffle, class_index, repeat_cnt=1): 23 data1 = ds.ImageFolderDataset(image_folder_dir, num_samples=num_samples, num_shards=num_shards, 24 shard_id=shard_id, 25 shuffle=shuffle, class_indexing=class_index, decode=True) 26 data1 = data1.repeat(repeat_cnt) 27 res = [] 28 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 29 res.append(item["label"].item()) 30 if print_res: 31 logger.info("labels of dataset: {}".format(res)) 32 return res 33 34 # total 44 rows in dataset 35 assert (sharding_config(4, 0, 5, False, dict()) == [0, 0, 0, 1, 1]) # 5 rows 36 assert (sharding_config(4, 0, 12, False, dict()) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) # 11 rows 37 assert (sharding_config(4, 3, None, False, dict()) == [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) # 11 rows 38 assert (sharding_config(1, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]) # 44 rows 39 assert (sharding_config(2, 0, 55, False, dict()) == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) # 22 rows 40 assert (sharding_config(2, 1, 55, False, dict()) == [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]) # 22 rows 41 # total 22 in dataset rows because of class indexing which takes only 2 folders 42 assert len(sharding_config(4, 0, None, True, {"class1": 111, "class2": 999})) == 6 43 assert len(sharding_config(4, 2, 3, True, {"class1": 111, "class2": 999})) == 3 44 # test with repeat 45 assert (sharding_config(4, 0, 12, False, dict(), 3) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3] * 3) 46 assert (sharding_config(4, 0, 5, False, dict(), 5) == [0, 0, 0, 1, 1] * 5) 47 assert len(sharding_config(5, 1, None, True, {"class1": 111, "class2": 999}, 4)) == 20 48 49 50def test_tfrecord_shardings1(print_res=False): 51 """ Test TFRecordDataset sharding with num_parallel_workers=1 """ 52 53 # total 40 rows in dataset 54 tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", 55 "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] 56 57 def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1): 58 data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, 59 shuffle=ds.Shuffle.FILES, num_parallel_workers=1) 60 data1 = data1.repeat(repeat_cnt) 61 res = [] 62 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 63 res.append(item["scalars"][0]) 64 if print_res: 65 logger.info("scalars of dataset: {}".format(res)) 66 return res 67 68 assert sharding_config(2, 0, None, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows 69 assert sharding_config(2, 1, None, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows 70 assert sharding_config(2, 0, 3, 1) == [11, 12, 13] # 3 rows 71 assert sharding_config(2, 1, 3, 1) == [1, 2, 3] # 3 rows 72 assert sharding_config(2, 0, 40, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows 73 assert sharding_config(2, 1, 40, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows 74 assert sharding_config(2, 0, 55, 1) == [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30] # 20 rows 75 assert sharding_config(2, 1, 55, 1) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] # 20 rows 76 assert sharding_config(3, 0, 8, 1) == [11, 12, 13, 14, 15, 16, 17, 18] # 8 rows 77 assert sharding_config(3, 1, 8, 1) == [1, 2, 3, 4, 5, 6, 7, 8] # 8 rows 78 assert sharding_config(3, 2, 8, 1) == [21, 22, 23, 24, 25, 26, 27, 28] # 8 rows 79 assert sharding_config(4, 0, 2, 1) == [11, 12] # 2 rows 80 assert sharding_config(4, 1, 2, 1) == [1, 2] # 2 rows 81 assert sharding_config(4, 2, 2, 1) == [21, 22] # 2 rows 82 assert sharding_config(4, 3, 2, 1) == [31, 32] # 2 rows 83 assert sharding_config(3, 0, 4, 2) == [11, 12, 13, 14, 21, 22, 23, 24] # 8 rows 84 assert sharding_config(3, 1, 4, 2) == [1, 2, 3, 4, 11, 12, 13, 14] # 8 rows 85 assert sharding_config(3, 2, 4, 2) == [21, 22, 23, 24, 31, 32, 33, 34] # 8 rows 86 87 88def test_tfrecord_shardings4(print_res=False): 89 """ Test TFRecordDataset sharding with num_parallel_workers=4 """ 90 91 # total 40 rows in dataset 92 tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", 93 "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] 94 95 def sharding_config(num_shards, shard_id, num_samples, repeat_cnt=1): 96 data1 = ds.TFRecordDataset(tf_files, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, 97 shuffle=ds.Shuffle.FILES, num_parallel_workers=4) 98 data1 = data1.repeat(repeat_cnt) 99 res = [] 100 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 101 res.append(item["scalars"][0]) 102 if print_res: 103 logger.info("scalars of dataset: {}".format(res)) 104 return res 105 106 def check_result(result_list, expect_length, expect_set): 107 assert len(result_list) == expect_length 108 assert set(result_list) == expect_set 109 110 check_result(sharding_config(2, 0, None, 1), 20, 111 {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) 112 check_result(sharding_config(2, 1, None, 1), 20, 113 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) 114 check_result(sharding_config(2, 0, 3, 1), 3, {11, 12, 21}) 115 check_result(sharding_config(2, 1, 3, 1), 3, {1, 2, 31}) 116 check_result(sharding_config(2, 0, 40, 1), 20, 117 {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) 118 check_result(sharding_config(2, 1, 40, 1), 20, 119 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) 120 check_result(sharding_config(2, 0, 55, 1), 20, 121 {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30}) 122 check_result(sharding_config(2, 1, 55, 1), 20, 123 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40}) 124 check_result(sharding_config(3, 0, 8, 1), 8, {32, 33, 34, 11, 12, 13, 14, 31}) 125 check_result(sharding_config(3, 1, 8, 1), 8, {1, 2, 3, 4, 5, 6, 7, 8}) 126 check_result(sharding_config(3, 2, 8, 1), 8, {21, 22, 23, 24, 25, 26, 27, 28}) 127 check_result(sharding_config(4, 0, 2, 1), 2, {11, 12}) 128 check_result(sharding_config(4, 1, 2, 1), 2, {1, 2}) 129 check_result(sharding_config(4, 2, 2, 1), 2, {21, 22}) 130 check_result(sharding_config(4, 3, 2, 1), 2, {31, 32}) 131 check_result(sharding_config(3, 0, 4, 2), 8, {32, 1, 2, 11, 12, 21, 22, 31}) 132 check_result(sharding_config(3, 1, 4, 2), 8, {1, 2, 3, 4, 11, 12, 13, 14}) 133 check_result(sharding_config(3, 2, 4, 2), 8, {32, 33, 34, 21, 22, 23, 24, 31}) 134 135 136def test_manifest_shardings(print_res=False): 137 manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" 138 139 def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): 140 data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, num_shards=num_shards, shard_id=shard_id, 141 shuffle=shuffle, decode=True) 142 data1 = data1.repeat(repeat_cnt) 143 res = [] 144 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 145 res.append(item["label"].item()) 146 if print_res: 147 logger.info("labels of dataset: {}".format(res)) 148 return res 149 150 # 5 train images in total 151 sharding_config(2, 0, None, False) 152 assert (sharding_config(2, 0, None, False) == [0, 1, 1]) 153 assert (sharding_config(2, 1, None, False) == [0, 0, 0]) 154 assert (sharding_config(2, 0, 2, False) == [0, 1]) 155 assert (sharding_config(2, 1, 2, False) == [0, 0]) 156 # with repeat 157 assert (sharding_config(2, 1, None, False, 3) == [0, 0, 0] * 3) 158 assert (sharding_config(2, 0, 2, False, 5) == [0, 1] * 5) 159 160 161def test_voc_shardings(print_res=False): 162 voc_dir = "../data/dataset/testVOC2012" 163 164 def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): 165 sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) 166 data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler) 167 data1 = data1.repeat(repeat_cnt) 168 res = [] 169 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 170 res.append(item["image"].shape[0]) 171 if print_res: 172 logger.info("labels of dataset: {}".format(res)) 173 return res 174 175 # 10 images in total, always decode to get the shape 176 # first dim of all 10 images [2268,2268,2268,2268,642,607,561,596,612,2268] 177 # 3 shard_workers, 0th worker will get 0-th, 3nd, 6th and 9th image 178 assert (sharding_config(3, 0, None, False, 2) == [2268, 2268, 561, 2268] * 2) 179 # 3 shard_workers, 1st worker will get 1-st, 4nd, 7th and 0th image, the last one goes back bc of rounding up 180 assert (sharding_config(3, 1, 5, False, 3) == [2268, 642, 596, 2268] * 3) 181 # 3 shard_workers, 2nd worker will get 2nd, 5th, 8th and 11th (which is 1st) 182 # then takes the first 2 bc num_samples = 2 183 assert (sharding_config(3, 2, 2, False, 4) == [2268, 607] * 4) 184 # test that each epoch, each shard_worker returns a different sample 185 assert len(sharding_config(2, 0, None, True, 1)) == 5 186 assert len(set(sharding_config(11, 0, None, True, 10))) > 1 187 188 189def test_cifar10_shardings(print_res=False): 190 cifar10_dir = "../data/dataset/testCifar10Data" 191 192 def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): 193 data1 = ds.Cifar10Dataset(cifar10_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, 194 shuffle=shuffle) 195 data1 = data1.repeat(repeat_cnt) 196 res = [] 197 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 198 res.append(item["label"].item()) 199 if print_res: 200 logger.info("labels of dataset: {}".format(res)) 201 return res 202 203 # 10000 rows in total. CIFAR reads everything in memory which would make each test case very slow 204 # therefore, only 2 test cases for now. 205 assert sharding_config(10000, 9999, 7, False, 1) == [9] 206 assert sharding_config(10000, 0, 4, False, 3) == [0, 0, 0] 207 208 209def test_cifar100_shardings(print_res=False): 210 cifar100_dir = "../data/dataset/testCifar100Data" 211 212 def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): 213 data1 = ds.Cifar100Dataset(cifar100_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, 214 shuffle=shuffle) 215 data1 = data1.repeat(repeat_cnt) 216 res = [] 217 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 218 res.append(item["coarse_label"].item()) 219 if print_res: 220 logger.info("labels of dataset: {}".format(res)) 221 return res 222 223 # 10000 rows in total in test.bin CIFAR100 file 224 assert (sharding_config(1000, 999, 7, False, 2) == [1, 18, 10, 17, 5, 0, 15] * 2) 225 assert (sharding_config(1000, 0, None, False) == [10, 16, 2, 11, 10, 17, 11, 14, 13, 3]) 226 227 228def test_mnist_shardings(print_res=False): 229 mnist_dir = "../data/dataset/testMnistData" 230 231 def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): 232 data1 = ds.MnistDataset(mnist_dir, num_shards=num_shards, shard_id=shard_id, num_samples=num_samples, 233 shuffle=shuffle) 234 data1 = data1.repeat(repeat_cnt) 235 res = [] 236 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 237 res.append(item["label"].item()) 238 if print_res: 239 logger.info("labels of dataset: {}".format(res)) 240 return res 241 242 # 70K rows in total , divide across 10K hosts, each host has 7 images 243 assert sharding_config(10000, 0, num_samples=5, shuffle=False, repeat_cnt=3) == [0, 0, 0] 244 assert sharding_config(10000, 9999, num_samples=None, shuffle=False, repeat_cnt=1) == [9] 245 246 247if __name__ == '__main__': 248 test_imagefolder_shardings(True) 249 test_tfrecord_shardings1(True) 250 test_tfrecord_shardings4(True) 251 test_manifest_shardings(True) 252 test_voc_shardings(True) 253 test_cifar10_shardings(True) 254 test_cifar100_shardings(True) 255 test_mnist_shardings(True) 256