1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import mindspore.dataset as ds 16from mindspore import log as logger 17from util import save_and_check_dict 18 19# Note: Number of rows in test.data dataset: 12 20DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] 21GENERATE_GOLDEN = False 22 23 24def test_batch_01(): 25 """ 26 Test batch: batch_size>1, drop_remainder=True, no remainder exists 27 """ 28 logger.info("test_batch_01") 29 # define parameters 30 batch_size = 2 31 drop_remainder = True 32 33 # apply dataset operations 34 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 35 data1 = data1.batch(batch_size, drop_remainder) 36 37 assert sum([1 for _ in data1]) == 6 38 filename = "batch_01_result.npz" 39 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 40 41 42def test_batch_02(): 43 """ 44 Test batch: batch_size>1, drop_remainder=True, remainder exists 45 """ 46 logger.info("test_batch_02") 47 # define parameters 48 batch_size = 5 49 drop_remainder = True 50 51 # apply dataset operations 52 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 53 data1 = data1.batch(batch_size, drop_remainder=drop_remainder) 54 55 assert sum([1 for _ in data1]) == 2 56 filename = "batch_02_result.npz" 57 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 58 59 60def test_batch_03(): 61 """ 62 Test batch: batch_size>1, drop_remainder=False, no remainder exists 63 """ 64 logger.info("test_batch_03") 65 # define parameters 66 batch_size = 3 67 drop_remainder = False 68 69 # apply dataset operations 70 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 71 data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder) 72 73 assert sum([1 for _ in data1]) == 4 74 filename = "batch_03_result.npz" 75 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 76 77 78def test_batch_04(): 79 """ 80 Test batch: batch_size>1, drop_remainder=False, remainder exists 81 """ 82 logger.info("test_batch_04") 83 # define parameters 84 batch_size = 7 85 drop_remainder = False 86 87 # apply dataset operations 88 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 89 data1 = data1.batch(batch_size, drop_remainder) 90 91 assert sum([1 for _ in data1]) == 2 92 filename = "batch_04_result.npz" 93 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 94 95 96def test_batch_05(): 97 """ 98 Test batch: batch_size=1 (minimum valid size), drop_remainder default 99 """ 100 logger.info("test_batch_05") 101 # define parameters 102 batch_size = 1 103 104 # apply dataset operations 105 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 106 data1 = data1.batch(batch_size) 107 108 assert sum([1 for _ in data1]) == 12 109 filename = "batch_05_result.npz" 110 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 111 112 113def test_batch_06(): 114 """ 115 Test batch: batch_size = number-of-rows-in-dataset, drop_remainder=True, reorder params 116 """ 117 logger.info("test_batch_06") 118 # define parameters 119 batch_size = 12 120 drop_remainder = False 121 122 # apply dataset operations 123 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 124 data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size) 125 126 assert sum([1 for _ in data1]) == 1 127 filename = "batch_06_result.npz" 128 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 129 130 131def test_batch_07(): 132 """ 133 Test batch: num_parallel_workers>1, drop_remainder=False, reorder params 134 """ 135 logger.info("test_batch_07") 136 # define parameters 137 batch_size = 4 138 drop_remainder = False 139 num_parallel_workers = 2 140 141 # apply dataset operations 142 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 143 data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder, 144 batch_size=batch_size) 145 146 assert sum([1 for _ in data1]) == 3 147 filename = "batch_07_result.npz" 148 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 149 150 151def test_batch_08(): 152 """ 153 Test batch: num_parallel_workers=1, drop_remainder default 154 """ 155 logger.info("test_batch_08") 156 # define parameters 157 batch_size = 6 158 num_parallel_workers = 1 159 160 # apply dataset operations 161 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 162 data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers) 163 164 assert sum([1 for _ in data1]) == 2 165 filename = "batch_08_result.npz" 166 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 167 168 169def test_batch_09(): 170 """ 171 Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=False 172 """ 173 logger.info("test_batch_09") 174 # define parameters 175 batch_size = 13 176 drop_remainder = False 177 178 # apply dataset operations 179 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 180 data1 = data1.batch(batch_size, drop_remainder=drop_remainder) 181 182 assert sum([1 for _ in data1]) == 1 183 filename = "batch_09_result.npz" 184 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 185 186 187def test_batch_10(): 188 """ 189 Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=True 190 """ 191 logger.info("test_batch_10") 192 # define parameters 193 batch_size = 99 194 drop_remainder = True 195 196 # apply dataset operations 197 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 198 data1 = data1.batch(batch_size, drop_remainder=drop_remainder) 199 200 assert sum([1 for _ in data1]) == 0 201 filename = "batch_10_result.npz" 202 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 203 204 205def test_batch_11(): 206 """ 207 Test batch: batch_size=1 and dataset-size=1 208 """ 209 logger.info("test_batch_11") 210 # define parameters 211 batch_size = 1 212 213 # apply dataset operations 214 # Use schema file with 1 row 215 schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema1Row.json" 216 data1 = ds.TFRecordDataset(DATA_DIR, schema_file) 217 data1 = data1.batch(batch_size) 218 219 assert sum([1 for _ in data1]) == 1 220 filename = "batch_11_result.npz" 221 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 222 223 224def test_batch_12(): 225 """ 226 Test batch: batch_size boolean value True, treated as valid value 1 227 """ 228 logger.info("test_batch_12") 229 # define parameters 230 batch_size = True 231 232 # apply dataset operations 233 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 234 data1 = data1.batch(batch_size=batch_size) 235 236 assert sum([1 for _ in data1]) == 12 237 filename = "batch_12_result.npz" 238 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 239 240 241def test_batch_13(): 242 """ 243 Test batch: python_multiprocessing is True and does not work for per_batch_map is None 244 """ 245 logger.info("test_batch_12") 246 # define parameters 247 batch_size = True 248 249 # apply dataset operations 250 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 251 data1 = data1.batch(batch_size=batch_size, python_multiprocessing=True) 252 253 assert sum([1 for _ in data1]) == 12 254 filename = "batch_12_result.npz" 255 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 256 257 258def test_batch_exception_01(): 259 """ 260 Test batch exception: num_parallel_workers=0 261 """ 262 logger.info("test_batch_exception_01") 263 264 # apply dataset operations 265 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 266 try: 267 data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0) 268 sum([1 for _ in data1]) 269 270 except Exception as e: 271 logger.info("Got an exception in DE: {}".format(str(e))) 272 assert "num_parallel_workers" in str(e) 273 274 275def test_batch_exception_02(): 276 """ 277 Test batch exception: num_parallel_workers<0 278 """ 279 logger.info("test_batch_exception_02") 280 281 # apply dataset operations 282 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 283 try: 284 data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1) 285 sum([1 for _ in data1]) 286 287 except Exception as e: 288 logger.info("Got an exception in DE: {}".format(str(e))) 289 assert "num_parallel_workers" in str(e) 290 291 292def test_batch_exception_03(): 293 """ 294 Test batch exception: batch_size=0 295 """ 296 logger.info("test_batch_exception_03") 297 298 # apply dataset operations 299 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 300 try: 301 data1 = data1.batch(batch_size=0) 302 sum([1 for _ in data1]) 303 304 except Exception as e: 305 logger.info("Got an exception in DE: {}".format(str(e))) 306 assert "batch_size" in str(e) 307 308 309def test_batch_exception_04(): 310 """ 311 Test batch exception: batch_size<0 312 """ 313 logger.info("test_batch_exception_04") 314 315 # apply dataset operations 316 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 317 try: 318 data1 = data1.batch(batch_size=-1) 319 sum([1 for _ in data1]) 320 321 except Exception as e: 322 logger.info("Got an exception in DE: {}".format(str(e))) 323 assert "batch_size" in str(e) 324 325 326def test_batch_exception_05(): 327 """ 328 Test batch exception: batch_size boolean value False, treated as invalid value 0 329 """ 330 logger.info("test_batch_exception_05") 331 332 # apply dataset operations 333 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 334 try: 335 data1 = data1.batch(batch_size=False) 336 sum([1 for _ in data1]) 337 338 except Exception as e: 339 logger.info("Got an exception in DE: {}".format(str(e))) 340 assert "batch_size" in str(e) 341 342 343def test_batch_exception_07(): 344 """ 345 Test batch exception: drop_remainder wrong type 346 """ 347 logger.info("test_batch_exception_07") 348 349 # apply dataset operations 350 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 351 try: 352 data1 = data1.batch(3, drop_remainder=0) 353 sum([1 for _ in data1]) 354 355 except Exception as e: 356 logger.info("Got an exception in DE: {}".format(str(e))) 357 assert "drop_remainder" in str(e) 358 359 360def test_batch_exception_08(): 361 """ 362 Test batch exception: num_parallel_workers wrong type 363 """ 364 logger.info("test_batch_exception_08") 365 366 # apply dataset operations 367 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 368 try: 369 data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False) 370 sum([1 for _ in data1]) 371 372 except Exception as e: 373 logger.info("Got an exception in DE: {}".format(str(e))) 374 assert "num_parallel_workers" in str(e) 375 376 377def test_batch_exception_09(): 378 """ 379 Test batch exception: Missing mandatory batch_size 380 """ 381 logger.info("test_batch_exception_09") 382 383 # apply dataset operations 384 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 385 try: 386 data1 = data1.batch(drop_remainder=True, num_parallel_workers=4) 387 sum([1 for _ in data1]) 388 389 except Exception as e: 390 logger.info("Got an exception in DE: {}".format(str(e))) 391 assert "batch_size" in str(e) 392 393 394def test_batch_exception_10(): 395 """ 396 Test batch exception: num_parallel_workers>>1 397 """ 398 logger.info("test_batch_exception_10") 399 400 # apply dataset operations 401 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 402 try: 403 data1 = data1.batch(batch_size=4, num_parallel_workers=8192) 404 sum([1 for _ in data1]) 405 406 except Exception as e: 407 logger.info("Got an exception in DE: {}".format(str(e))) 408 assert "num_parallel_workers" in str(e) 409 410 411def test_batch_exception_11(): 412 """ 413 Test batch exception: wrong input order, num_parallel_workers wrongly used as drop_remainder 414 """ 415 logger.info("test_batch_exception_11") 416 # define parameters 417 batch_size = 6 418 num_parallel_workers = 1 419 420 # apply dataset operations 421 data1 = ds.TFRecordDataset(DATA_DIR) 422 try: 423 data1 = data1.batch(batch_size, num_parallel_workers) 424 sum([1 for _ in data1]) 425 426 except Exception as e: 427 logger.info("Got an exception in DE: {}".format(str(e))) 428 assert "drop_remainder" in str(e) 429 430 431def test_batch_exception_12(): 432 """ 433 Test batch exception: wrong input order, drop_remainder wrongly used as batch_size 434 """ 435 logger.info("test_batch_exception_12") 436 # define parameters 437 batch_size = 1 438 drop_remainder = True 439 440 # apply dataset operations 441 data1 = ds.TFRecordDataset(DATA_DIR) 442 try: 443 data1 = data1.batch(drop_remainder, batch_size) 444 sum([1 for _ in data1]) 445 446 except Exception as e: 447 logger.info("Got an exception in DE: {}".format(str(e))) 448 assert "drop_remainder" in str(e) 449 450 451def test_batch_exception_13(): 452 """ 453 Test batch exception: invalid input parameter 454 """ 455 logger.info("test_batch_exception_13") 456 # define parameters 457 batch_size = 4 458 459 # apply dataset operations 460 data1 = ds.TFRecordDataset(DATA_DIR) 461 try: 462 data1 = data1.batch(batch_size, shard_id=1) 463 sum([1 for _ in data1]) 464 465 except Exception as e: 466 logger.info("Got an exception in DE: {}".format(str(e))) 467 assert "shard_id" in str(e) 468 469 470def test_batch_exception_14(): 471 """ 472 Test per_batch_map and input column name 473 """ 474 logger.info("test_batch_exception_14") 475 batch_size = 2 476 input_columns = ["num"] 477 data1 = ds.TFRecordDataset(DATA_DIR) 478 try: 479 _ = data1.batch(batch_size=batch_size, input_columns=input_columns) 480 except ValueError as e: 481 assert "per_batch_map and input_columns need to be passed in together." in str(e) 482 483 484def test_batch_exception_15(): 485 """ 486 Test batch_size = int32 max value + 1 487 """ 488 logger.info("test_batch_exception_15") 489 batch_size = 2147483647 + 1 490 input_columns = ["num"] 491 data1 = ds.TFRecordDataset(DATA_DIR) 492 err_msg = "" 493 try: 494 _ = data1.batch(batch_size=batch_size, input_columns=input_columns) 495 except ValueError as e: 496 err_msg = str(e) 497 assert "batch_size is not within the required interval of [1, 2147483647]" in err_msg 498 499 500if __name__ == '__main__': 501 test_batch_01() 502 test_batch_02() 503 test_batch_03() 504 test_batch_04() 505 test_batch_05() 506 test_batch_06() 507 test_batch_07() 508 test_batch_08() 509 test_batch_09() 510 test_batch_10() 511 test_batch_11() 512 test_batch_12() 513 test_batch_13() 514 test_batch_exception_01() 515 test_batch_exception_02() 516 test_batch_exception_03() 517 test_batch_exception_04() 518 test_batch_exception_05() 519 test_batch_exception_07() 520 test_batch_exception_08() 521 test_batch_exception_09() 522 test_batch_exception_10() 523 test_batch_exception_11() 524 test_batch_exception_12() 525 test_batch_exception_13() 526 test_batch_exception_14() 527 test_batch_exception_15() 528 logger.info('\n') 529