1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import pytest 16import mindspore.dataset as ds 17import mindspore.dataset.transforms.c_transforms as c_transforms 18from mindspore import log as logger 19from util import save_and_check_md5 20 21GENERATE_GOLDEN = False 22 23IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" 24IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", 25 "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", 26 "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", 27 "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] 28MNIST_DATA_DIR = "../data/dataset/testMnistData" 29MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" 30CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" 31COCO_DATA_DIR = "../data/dataset/testCOCO/train/" 32ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" 33VOC_DATA_DIR = "../data/dataset/testVOC2012" 34 35 36def test_numpyslices_sampler_no_chain(): 37 """ 38 Test NumpySlicesDataset with sampler, no chain 39 """ 40 logger.info("test_numpyslices_sampler_no_chain") 41 42 # Create NumpySlicesDataset with sampler, no chain 43 np_data = [1, 2, 3, 4] 44 sampler = ds.SequentialSampler(start_index=1, num_samples=2) 45 data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) 46 47 # Verify dataset size 48 data1_size = data1.get_dataset_size() 49 logger.info("dataset size is: {}".format(data1_size)) 50 assert data1_size == 2 51 52 # Verify number of rows 53 assert sum([1 for _ in data1]) == 2 54 55 # Verify dataset contents 56 res = [] 57 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 58 logger.info("item: {}".format(item)) 59 res.append(item) 60 logger.info("dataset: {}".format(res)) 61 62 63def test_numpyslices_sampler_chain(): 64 """ 65 Test NumpySlicesDataset sampler chain 66 """ 67 logger.info("test_numpyslices_sampler_chain") 68 69 # Create NumpySlicesDataset with sampler chain 70 # Use 1 statement to add child sampler 71 np_data = [1, 2, 3, 4] 72 sampler = ds.SequentialSampler(start_index=1, num_samples=2) 73 sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) 74 data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) 75 76 # Verify dataset size 77 data1_size = data1.get_dataset_size() 78 logger.info("dataset size is: {}".format(data1_size)) 79 assert data1_size == 1 80 81 # Verify number of rows 82 assert sum([1 for _ in data1]) == 1 83 84 # Verify dataset contents 85 res = [] 86 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 87 logger.info("item: {}".format(item)) 88 res.append(item) 89 logger.info("dataset: {}".format(res)) 90 91 92def test_numpyslices_sampler_chain2(): 93 """ 94 Test NumpySlicesDataset sampler chain 95 """ 96 logger.info("test_numpyslices_sampler_chain2") 97 98 # Create NumpySlicesDataset with sampler chain 99 # Use 2 statements to add child sampler 100 np_data = [1, 2, 3, 4] 101 sampler = ds.SequentialSampler(start_index=1, num_samples=1) 102 child_sampler = ds.SequentialSampler(start_index=1, num_samples=2) 103 sampler.add_child(child_sampler) 104 data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) 105 106 # Verify dataset size 107 data1_size = data1.get_dataset_size() 108 logger.info("dataset size is: {}".format(data1_size)) 109 assert data1_size == 1 110 111 # Verify number of rows 112 assert sum([1 for _ in data1]) == 1 113 114 # Verify dataset contents 115 res = [] 116 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 117 logger.info("item: {}".format(item)) 118 res.append(item) 119 logger.info("dataset: {}".format(res)) 120 121 122def test_imagefolder_sampler_chain(): 123 """ 124 Test ImageFolderDataset sampler chain 125 """ 126 logger.info("test_imagefolder_sampler_chain") 127 128 sampler = ds.SequentialSampler(start_index=1, num_samples=3) 129 child_sampler = ds.PKSampler(2) 130 sampler.add_child(child_sampler) 131 data1 = ds.ImageFolderDataset(IMAGENET_RAWDATA_DIR, sampler=sampler) 132 # Verify dataset size 133 data1_size = data1.get_dataset_size() 134 logger.info("dataset size is: {}".format(data1_size)) 135 assert data1_size == 3 136 # Verify number of rows 137 assert sum([1 for _ in data1]) == 3 138 139 # Verify dataset contents 140 res = [] 141 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 142 logger.info("item: {}".format(item)) 143 res.append(item) 144 logger.info("dataset: {}".format(res)) 145 146 147def test_mnist_sampler_chain(): 148 """ 149 Test Mnist sampler chain 150 """ 151 logger.info("test_mnist_sampler_chain") 152 153 sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) 154 child_sampler = ds.RandomSampler(replacement=True, num_samples=4) 155 sampler.add_child(child_sampler) 156 data1 = ds.MnistDataset(MNIST_DATA_DIR, sampler=sampler) 157 158 # Verify dataset size 159 data1_size = data1.get_dataset_size() 160 logger.info("dataset size is: {}".format(data1_size)) 161 assert data1_size == 3 162 # Verify number of rows 163 assert sum([1 for _ in data1]) == 3 164 165 # Verify dataset contents 166 res = [] 167 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 168 logger.info("item: {}".format(item)) 169 res.append(item) 170 logger.info("dataset: {}".format(res)) 171 172 173def test_manifest_sampler_chain(): 174 """ 175 Test Manifest sampler chain 176 """ 177 logger.info("test_manifest_sampler_chain") 178 179 sampler = ds.RandomSampler(replacement=True, num_samples=2) 180 child_sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1) 181 sampler.add_child(child_sampler) 182 data1 = ds.ManifestDataset(MANIFEST_DATA_FILE, sampler=sampler) 183 184 # Verify dataset size 185 data1_size = data1.get_dataset_size() 186 logger.info("dataset size is: {}".format(data1_size)) 187 assert data1_size == 2 188 # Verify number of rows 189 assert sum([1 for _ in data1]) == 2 190 191 # Verify dataset contents 192 res = [] 193 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 194 logger.info("item: {}".format(item)) 195 res.append(item) 196 logger.info("dataset: {}".format(res)) 197 198 199def test_coco_sampler_chain(): 200 """ 201 Test Coco sampler chain 202 """ 203 logger.info("test_coco_sampler_chain") 204 205 sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) 206 child_sampler = ds.RandomSampler(replacement=True, num_samples=2) 207 sampler.add_child(child_sampler) 208 data1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, 209 sampler=sampler) 210 211 # Verify dataset size 212 data1_size = data1.get_dataset_size() 213 logger.info("dataset size is: {}".format(data1_size)) 214 assert data1_size == 1 215 216 # Verify number of rows 217 assert sum([1 for _ in data1]) == 1 218 219 # Verify dataset contents 220 res = [] 221 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 222 logger.info("item: {}".format(item)) 223 res.append(item) 224 logger.info("dataset: {}".format(res)) 225 226 227def test_cifar_sampler_chain(): 228 """ 229 Test Cifar sampler chain 230 """ 231 logger.info("test_cifar_sampler_chain") 232 233 sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) 234 child_sampler = ds.RandomSampler(replacement=True, num_samples=4) 235 child_sampler2 = ds.SequentialSampler(start_index=0, num_samples=2) 236 child_sampler.add_child(child_sampler2) 237 sampler.add_child(child_sampler) 238 data1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, sampler=sampler) 239 # Verify dataset size 240 data1_size = data1.get_dataset_size() 241 logger.info("dataset size is: {}".format(data1_size)) 242 assert data1_size == 1 243 244 # Verify number of rows 245 assert sum([1 for _ in data1]) == 1 246 247 # Verify dataset contents 248 res = [] 249 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 250 logger.info("item: {}".format(item)) 251 res.append(item) 252 logger.info("dataset: {}".format(res)) 253 254 255def test_voc_sampler_chain(): 256 """ 257 Test VOC sampler chain 258 """ 259 logger.info("test_voc_sampler_chain") 260 261 sampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5) 262 child_sampler = ds.SequentialSampler(start_index=0) 263 sampler.add_child(child_sampler) 264 data1 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", sampler=sampler) 265 266 # Verify dataset size 267 data1_size = data1.get_dataset_size() 268 logger.info("dataset size is: {}".format(data1_size)) 269 assert data1_size == 5 270 271 # Verify number of rows 272 assert sum([1 for _ in data1.create_dict_iterator(output_numpy=True)]) == 5 273 274 # Verify dataset contents 275 res = [] 276 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 277 logger.info("item: {}".format(item)) 278 res.append(item) 279 logger.info("dataset: {}".format(res)) 280 281 282def test_numpyslices_sampler_chain_batch(): 283 """ 284 Test NumpySlicesDataset sampler chaining, with batch 285 """ 286 logger.info("test_numpyslices_sampler_chain_batch") 287 288 # Create NumpySlicesDataset with sampler chain 289 np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 290 sampler = ds.SequentialSampler(start_index=1, num_samples=3) 291 sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) 292 data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) 293 data1 = data1.batch(batch_size=3, drop_remainder=False) 294 295 # Verify dataset size 296 data1_size = data1.get_dataset_size() 297 logger.info("dataset size is: {}".format(data1_size)) 298 assert data1_size == 4 299 300 # Verify number of rows 301 assert sum([1 for _ in data1]) == 4 302 303 # Verify dataset contents 304 res = [] 305 for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): 306 logger.info("item: {}".format(item)) 307 res.append(item) 308 logger.info("dataset: {}".format(res)) 309 310 311def test_sampler_chain_errors(): 312 """ 313 Test error cases for sampler chains 314 """ 315 logger.info("test_sampler_chain_errors") 316 317 error_msg_1 = "'NoneType' object has no attribute 'add_child'" 318 # Test add child sampler within child sampler 319 sampler = ds.SequentialSampler(start_index=1, num_samples=2) 320 sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) 321 with pytest.raises(AttributeError, match=error_msg_1): 322 sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) 323 324 # error_msg_2 = "'NoneType' object has no attribute 'add_child'" 325 # Test add second and nested child sampler 326 sampler = ds.SequentialSampler(start_index=1, num_samples=2) 327 child_sampler = ds.SequentialSampler(start_index=1, num_samples=2) 328 sampler.add_child(child_sampler) 329 child_sampler2 = ds.SequentialSampler(start_index=1, num_samples=2) 330 sampler.add_child(child_sampler2) 331 # FIXME - no error is raised; uncomment after code issue is resolved 332 # with pytest.raises(AttributeError, match=error_msg_2): 333 # sampler.add_child(child_sampler2) 334 # np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 335 # data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) 336 337 error_msg_3 = "Conflicting arguments during sampler assignments." 338 # Test conflicting arguments (sampler and shuffle=False) for sampler (no chain) 339 np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 340 sampler = ds.SequentialSampler(start_index=1, num_samples=3) 341 with pytest.raises(ValueError, match=error_msg_3): 342 ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler) 343 344 # error_msg_4 = "Conflicting arguments during sampler assignments." 345 # Test conflicting arguments (sampler and shuffle=False) for sampler chaining 346 np_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 347 sampler = ds.SequentialSampler(start_index=1, num_samples=3) 348 sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) 349 # FIXME - no error is raised; uncomment after code issue is resolved 350 # with pytest.raises(ValueError, match=error_msg_4): 351 # ds.NumpySlicesDataset(np_data, shuffle=False, sampler=sampler) 352 353 354def test_manifest_sampler_chain_repeat(): 355 """ 356 Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with repeat 357 """ 358 logger.info("test_manifest_sampler_chain_batch") 359 manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" 360 361 # Create sampler chain DistributedSampler->SequentialSampler 362 sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5) 363 child_sampler = ds.SequentialSampler() 364 sampler.add_child(child_sampler) 365 366 # Create ManifestDataset with sampler chain 367 data1 = ds.ManifestDataset(manifest_file, sampler=sampler) 368 data1 = data1.repeat(count=2) 369 370 # Verify dataset size 371 data1_size = data1.get_dataset_size() 372 logger.info("dataset size is: {}".format(data1_size)) 373 assert data1_size == 10 374 375 # Verify number of rows 376 assert sum([1 for _ in data1]) == 10 377 378 # Verify dataset contents 379 filename = "sampler_chain_manifest_repeat_result.npz" 380 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 381 382 383def test_manifest_sampler_chain_batch_repeat(): 384 """ 385 Test ManifestDataset sampler chain DistributedSampler->SequentialSampler, with batch then repeat 386 """ 387 logger.info("test_manifest_sampler_chain_batch_repeat") 388 manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" 389 390 # Create sampler chain DistributedSampler->SequentialSampler 391 sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=5) 392 child_sampler = ds.SequentialSampler() 393 sampler.add_child(child_sampler) 394 395 # Create ManifestDataset with sampler chain 396 data1 = ds.ManifestDataset(manifest_file, decode=True, sampler=sampler) 397 one_hot_encode = c_transforms.OneHot(3) 398 data1 = data1.map(operations=one_hot_encode, input_columns=["label"]) 399 data1 = data1.batch(batch_size=5, drop_remainder=False) 400 data1 = data1.repeat(count=2) 401 402 # Verify dataset size 403 data1_size = data1.get_dataset_size() 404 logger.info("dataset size is: {}".format(data1_size)) 405 assert data1_size == 2 406 407 # Verify number of rows 408 # FIXME: Uncomment the following assert when code issue is resolved 409 # assert sum([1 for _ in data1]) == 2 410 411 412if __name__ == '__main__': 413 test_numpyslices_sampler_no_chain() 414 test_numpyslices_sampler_chain() 415 test_numpyslices_sampler_chain2() 416 test_imagefolder_sampler_chain() 417 test_mnist_sampler_chain() 418 test_manifest_sampler_chain() 419 test_coco_sampler_chain() 420 test_cifar_sampler_chain() 421 test_voc_sampler_chain() 422 test_numpyslices_sampler_chain_batch() 423 test_sampler_chain_errors() 424 test_manifest_sampler_chain_repeat() 425 test_manifest_sampler_chain_batch_repeat() 426