1# Copyright 2019-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License foNtest_resr the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16""" 17Built-in validators. 18""" 19import inspect as ins 20import os 21from functools import wraps 22import numpy as np 23 24from mindspore._c_expression import typing 25from mindspore import log as logger 26from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ 27 INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ 28 validate_dataset_param_value, check_padding_options, \ 29 check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id, \ 30 check_valid_list_tuple, check_int32 31 32from . import datasets 33from . import samplers 34from . import cache_client 35 36 37def check_cmu_arctic_dataset(method): 38 """A wrapper that wraps a parameter checker around the original CMUArcticDataset.""" 39 40 @wraps(method) 41 def new_method(self, *args, **kwargs): 42 _, param_dict = parse_user_args(method, *args, **kwargs) 43 44 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 45 nreq_param_bool = ['shuffle'] 46 47 dataset_dir = param_dict.get('dataset_dir') 48 check_dir(dataset_dir) 49 50 name = param_dict.get('name') 51 if name is not None: 52 check_valid_str(name, ['aew', 'ahw', 'aup', 'awb', 'axb', 'bdl', 'clb', 'eey', 53 'fem', 'gka', 'jmk', 'ksp', 'ljm', 'lnh', 'rms', 'rxr', 'slp', 'slt'], "name") 54 55 validate_dataset_param_value(nreq_param_int, param_dict, int) 56 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 57 58 check_sampler_shuffle_shard_options(param_dict) 59 60 cache = param_dict.get('cache') 61 check_cache_option(cache) 62 63 return method(self, *args, **kwargs) 64 65 return new_method 66 67 68def check_gtzan_dataset(method): 69 """A wrapper that wraps a parameter checker around the original GTZANDataset.""" 70 71 @wraps(method) 72 def new_method(self, *args, **kwargs): 73 _, param_dict = parse_user_args(method, *args, **kwargs) 74 75 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 76 nreq_param_bool = ['shuffle'] 77 78 dataset_dir = param_dict.get('dataset_dir') 79 check_dir(dataset_dir) 80 81 usage = param_dict.get('usage') 82 if usage is not None: 83 check_valid_str(usage, ['train', 'valid', 'test', 'all'], "usage") 84 85 validate_dataset_param_value(nreq_param_int, param_dict, int) 86 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 87 88 check_sampler_shuffle_shard_options(param_dict) 89 90 cache = param_dict.get('cache') 91 check_cache_option(cache) 92 93 return method(self, *args, **kwargs) 94 95 return new_method 96 97 98def check_imagefolderdataset(method): 99 """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset).""" 100 101 @wraps(method) 102 def new_method(self, *args, **kwargs): 103 _, param_dict = parse_user_args(method, *args, **kwargs) 104 105 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 106 nreq_param_bool = ['shuffle', 'decode'] 107 nreq_param_list = ['extensions'] 108 nreq_param_dict = ['class_indexing'] 109 110 dataset_dir = param_dict.get('dataset_dir') 111 check_dir(dataset_dir) 112 113 decrypt = param_dict.get('decrypt') 114 if decrypt is not None and not callable(decrypt): 115 raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt))) 116 117 validate_dataset_param_value(nreq_param_int, param_dict, int) 118 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 119 validate_dataset_param_value(nreq_param_list, param_dict, list) 120 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 121 check_sampler_shuffle_shard_options(param_dict) 122 123 cache = param_dict.get('cache') 124 check_cache_option(cache) 125 126 return method(self, *args, **kwargs) 127 128 return new_method 129 130 131def check_imdb_dataset(method): 132 """A wrapper that wraps a parameter checker around the original IMDBDataset.""" 133 134 @wraps(method) 135 def new_method(self, *args, **kwargs): 136 _, param_dict = parse_user_args(method, *args, **kwargs) 137 138 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 139 nreq_param_bool = ['shuffle'] 140 141 dataset_dir = param_dict.get('dataset_dir') 142 check_dir(dataset_dir) 143 144 validate_dataset_param_value(nreq_param_int, param_dict, int) 145 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 146 check_sampler_shuffle_shard_options(param_dict) 147 148 cache = param_dict.get('cache') 149 check_cache_option(cache) 150 151 usage = param_dict.get('usage') 152 if usage is not None: 153 check_valid_str(usage, ["train", "test", "all"], "usage") 154 155 return method(self, *args, **kwargs) 156 157 return new_method 158 159 160def check_iwslt2016_dataset(method): 161 """A wrapper that wraps a parameter checker around the original Dataset(IWSLT2016dataset).""" 162 163 @wraps(method) 164 def new_method(self, *args, **kwargs): 165 _, param_dict = parse_user_args(method, *args, **kwargs) 166 167 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 168 169 dataset_dir = param_dict.get('dataset_dir') 170 check_dir(dataset_dir) 171 172 # check usage 173 usage = param_dict.get('usage') 174 if usage is not None: 175 check_valid_str(usage, ["train", "test", "valid", "all"], "usage") 176 177 support_language_pair = [ 178 ['en', 'ar'], ['en', 'ar'], ['en', 'de'], ['en', 'fr'], ['en', 'cs'], ['ar', 'en'], ['fr', 'en'], 179 ['de', 'en'], ['cs', 'en'] 180 ] 181 support_language_pair_tuple = ( 182 ('en', 'ar'), ('en', 'ar'), ('en', 'de'), ('en', 'fr'), ('en', 'cs'), ('ar', 'en'), ('fr', 'en'), 183 ('de', 'en'), ('cs', 'en') 184 ) 185 support_set_type = ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"] 186 # check language_pair 187 language_pair = param_dict.get('language_pair') 188 if language_pair is not None: 189 if isinstance(language_pair, (list,)): 190 check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair") 191 elif isinstance(language_pair, (tuple,)): 192 check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair") 193 else: 194 raise TypeError("language_pair should be a type list or tuple of length 2.") 195 196 # check valid_set 197 valid_set = param_dict.get('valid_set') 198 if valid_set is not None: 199 check_valid_str(valid_set, support_set_type, "valid_set") 200 201 # check test_set 202 test_set = param_dict.get('test_set') 203 if test_set is not None: 204 check_valid_str(test_set, support_set_type, "test_set") 205 206 validate_dataset_param_value(nreq_param_int, param_dict, int) 207 check_sampler_shuffle_shard_options(param_dict) 208 209 cache = param_dict.get('cache') 210 check_cache_option(cache) 211 212 return method(self, *args, **kwargs) 213 214 return new_method 215 216 217def check_iwslt2017_dataset(method): 218 """A wrapper that wraps a parameter checker around the original Dataset(IWSLT2017dataset).""" 219 220 @wraps(method) 221 def new_method(self, *args, **kwargs): 222 _, param_dict = parse_user_args(method, *args, **kwargs) 223 224 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 225 226 dataset_dir = param_dict.get('dataset_dir') 227 check_dir(dataset_dir) 228 229 # check usage 230 usage = param_dict.get('usage') 231 if usage is not None: 232 check_valid_str(usage, ["train", "test", "valid", "all"], "usage") 233 234 support_language_pair = [ 235 ['en', 'nl'], ['en', 'de'], ['en', 'it'], ['en', 'ro'], ['ro', 'de'], ['ro', 'en'], ['ro', 'nl'], 236 ['ro', 'it'], ['de', 'ro'], ['de', 'en'], ['de', 'nl'], ['de', 'it'], ['it', 'en'], ['it', 'nl'], 237 ['it', 'de'], ['it', 'ro'], ['nl', 'de'], ['nl', 'en'], ['nl', 'it'], ['nl', 'ro'] 238 ] 239 support_language_pair_tuple = ( 240 ('en', 'nl'), ('en', 'de'), ('en', 'it'), ('en', 'ro'), ('ro', 'de'), ('ro', 'en'), ('ro', 'nl'), 241 ('ro', 'it'), ('de', 'ro'), ('de', 'en'), ('de', 'nl'), ('de', 'it'), ('it', 'en'), ('it', 'nl'), 242 ('it', 'de'), ('it', 'ro'), ('nl', 'de'), ('nl', 'en'), ('nl', 'it'), ('nl', 'ro') 243 ) 244 # check language_pair 245 language_pair = param_dict.get('language_pair') 246 if language_pair is not None: 247 if isinstance(language_pair, (list,)): 248 check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair") 249 elif isinstance(language_pair, (tuple,)): 250 check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair") 251 else: 252 raise TypeError("language_pair should be a type list or tuple of length 2.") 253 254 validate_dataset_param_value(nreq_param_int, param_dict, int) 255 check_sampler_shuffle_shard_options(param_dict) 256 257 cache = param_dict.get('cache') 258 check_cache_option(cache) 259 260 return method(self, *args, **kwargs) 261 262 return new_method 263 264 265def check_kittidataset(method): 266 """A wrapper that wraps a parameter checker around the original Dataset(KITTIDataset).""" 267 268 @wraps(method) 269 def new_method(self, *args, **kwargs): 270 _, param_dict = parse_user_args(method, *args, **kwargs) 271 272 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 273 nreq_param_bool = ['shuffle', 'decode'] 274 275 dataset_dir = param_dict.get('dataset_dir') 276 check_dir(dataset_dir) 277 278 usage = param_dict.get('usage') 279 if usage is not None: 280 check_valid_str(usage, ["train", "test"], "usage") 281 282 validate_dataset_param_value(nreq_param_int, param_dict, int) 283 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 284 check_sampler_shuffle_shard_options(param_dict) 285 286 cache = param_dict.get('cache') 287 check_cache_option(cache) 288 289 return method(self, *args, **kwargs) 290 291 return new_method 292 293 294def check_lsun_dataset(method): 295 """A wrapper that wraps a parameter checker around the original Dataset(LSUNDataset).""" 296 297 @wraps(method) 298 def new_method(self, *args, **kwargs): 299 _, param_dict = parse_user_args(method, *args, **kwargs) 300 301 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 302 nreq_param_bool = ['shuffle', 'decode'] 303 nreq_param_list = ['classes'] 304 305 dataset_dir = param_dict.get('dataset_dir') 306 check_dir(dataset_dir) 307 308 usage = param_dict.get('usage') 309 if usage is not None: 310 check_valid_str(usage, ["train", "test", "valid", "all"], "usage") 311 312 validate_dataset_param_value(nreq_param_int, param_dict, int) 313 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 314 validate_dataset_param_value(nreq_param_list, param_dict, list) 315 316 categories = [ 317 'bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen', 318 'living_room', 'restaurant', 'tower' 319 ] 320 classes = param_dict.get('classes') 321 if classes is not None: 322 for class_name in classes: 323 check_valid_str(class_name, categories, "classes") 324 325 check_sampler_shuffle_shard_options(param_dict) 326 327 cache = param_dict.get('cache') 328 check_cache_option(cache) 329 330 return method(self, *args, **kwargs) 331 332 return new_method 333 334 335def check_mnist_cifar_dataset(method): 336 """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset).""" 337 338 @wraps(method) 339 def new_method(self, *args, **kwargs): 340 _, param_dict = parse_user_args(method, *args, **kwargs) 341 342 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 343 nreq_param_bool = ['shuffle'] 344 345 dataset_dir = param_dict.get('dataset_dir') 346 check_dir(dataset_dir) 347 348 usage = param_dict.get('usage') 349 if usage is not None: 350 check_valid_str(usage, ["train", "test", "all"], "usage") 351 352 validate_dataset_param_value(nreq_param_int, param_dict, int) 353 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 354 355 check_sampler_shuffle_shard_options(param_dict) 356 357 cache = param_dict.get('cache') 358 check_cache_option(cache) 359 360 return method(self, *args, **kwargs) 361 362 return new_method 363 364 365def check_omniglotdataset(method): 366 """A wrapper that wraps a parameter checker around the original Dataset(OmniglotDataset).""" 367 368 @wraps(method) 369 def new_method(self, *args, **kwargs): 370 _, param_dict = parse_user_args(method, *args, **kwargs) 371 372 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 373 nreq_param_bool = ['shuffle', 'background', 'decode'] 374 dataset_dir = param_dict.get('dataset_dir') 375 check_dir(dataset_dir) 376 377 validate_dataset_param_value(nreq_param_int, param_dict, int) 378 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 379 check_sampler_shuffle_shard_options(param_dict) 380 381 cache = param_dict.get('cache') 382 check_cache_option(cache) 383 384 return method(self, *args, **kwargs) 385 386 return new_method 387 388 389def check_photo_tour_dataset(method): 390 """A wrapper that wraps a parameter checker around the original Dataset(PhotoTourDataset).""" 391 392 @wraps(method) 393 def new_method(self, *args, **kwargs): 394 _, param_dict = parse_user_args(method, *args, **kwargs) 395 396 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 397 nreq_param_bool = ['shuffle'] 398 399 dataset_dir = param_dict.get('dataset_dir') 400 check_dir(dataset_dir) 401 402 usage = param_dict.get('usage') 403 if usage is not None: 404 check_valid_str(usage, ["train", "test"], "usage") 405 name = param_dict.get('name') 406 check_valid_str(name, ["notredame", "yosemite", "liberty", "notredame_harris", 407 "yosemite_harris", "liberty_harris"], "name") 408 validate_dataset_param_value(nreq_param_int, param_dict, int) 409 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 410 411 check_sampler_shuffle_shard_options(param_dict) 412 cache = param_dict.get('cache') 413 check_cache_option(cache) 414 return method(self, *args, **kwargs) 415 416 return new_method 417 418 419def check_places365_dataset(method): 420 """A wrapper that wraps a parameter checker around the original Dataset(Places365Dataset).""" 421 422 @wraps(method) 423 def new_method(self, *args, **kwargs): 424 _, param_dict = parse_user_args(method, *args, **kwargs) 425 426 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 427 nreq_param_bool = ['shuffle', 'small', 'decode'] 428 429 dataset_dir = param_dict.get('dataset_dir') 430 check_dir(dataset_dir) 431 432 usage = param_dict.get('usage') 433 if usage is not None: 434 check_valid_str(usage, ["train-standard", "train-challenge", "val"], "usage") 435 436 validate_dataset_param_value(nreq_param_int, param_dict, int) 437 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 438 439 check_sampler_shuffle_shard_options(param_dict) 440 441 cache = param_dict.get('cache') 442 check_cache_option(cache) 443 444 return method(self, *args, **kwargs) 445 446 return new_method 447 448 449def check_qmnist_dataset(method): 450 """A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset).""" 451 452 @wraps(method) 453 def new_method(self, *args, **kwargs): 454 _, param_dict = parse_user_args(method, *args, **kwargs) 455 456 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 457 nreq_param_bool = ['shuffle', 'compat'] 458 459 dataset_dir = param_dict.get('dataset_dir') 460 check_dir(dataset_dir) 461 462 usage = param_dict.get('usage') 463 if usage is not None: 464 check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage") 465 466 validate_dataset_param_value(nreq_param_int, param_dict, int) 467 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 468 469 check_sampler_shuffle_shard_options(param_dict) 470 471 cache = param_dict.get('cache') 472 check_cache_option(cache) 473 474 return method(self, *args, **kwargs) 475 476 return new_method 477 478 479def check_manifestdataset(method): 480 """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset).""" 481 482 @wraps(method) 483 def new_method(self, *args, **kwargs): 484 _, param_dict = parse_user_args(method, *args, **kwargs) 485 486 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 487 nreq_param_bool = ['shuffle', 'decode'] 488 nreq_param_str = ['usage'] 489 nreq_param_dict = ['class_indexing'] 490 491 dataset_file = param_dict.get('dataset_file') 492 check_file(dataset_file) 493 494 validate_dataset_param_value(nreq_param_int, param_dict, int) 495 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 496 validate_dataset_param_value(nreq_param_str, param_dict, str) 497 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 498 499 check_sampler_shuffle_shard_options(param_dict) 500 501 cache = param_dict.get('cache') 502 check_cache_option(cache) 503 504 return method(self, *args, **kwargs) 505 506 return new_method 507 508 509def check_sbu_dataset(method): 510 """A wrapper that wraps a parameter checker around the original Dataset(SBUDataset).""" 511 512 @wraps(method) 513 def new_method(self, *args, **kwargs): 514 _, param_dict = parse_user_args(method, *args, **kwargs) 515 516 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 517 nreq_param_bool = ['shuffle', 'decode'] 518 519 dataset_dir = param_dict.get('dataset_dir') 520 check_dir(dataset_dir) 521 522 check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt")) 523 check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt")) 524 check_dir(os.path.join(dataset_dir, "sbu_images")) 525 526 validate_dataset_param_value(nreq_param_int, param_dict, int) 527 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 528 529 check_sampler_shuffle_shard_options(param_dict) 530 531 cache = param_dict.get('cache') 532 check_cache_option(cache) 533 534 return method(self, *args, **kwargs) 535 536 return new_method 537 538 539def check_sogou_news_dataset(method): 540 """A wrapper that wraps a parameter checker around the original Dataset(SogouNewsDataset).""" 541 542 @wraps(method) 543 def new_method(self, *args, **kwargs): 544 _, param_dict = parse_user_args(method, *args, **kwargs) 545 546 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 547 548 dataset_dir = param_dict.get('dataset_dir') 549 check_dir(dataset_dir) 550 551 usage = param_dict.get('usage') 552 if usage is not None: 553 check_valid_str(usage, ["train", "test", "all"], "usage") 554 555 validate_dataset_param_value(nreq_param_int, param_dict, int) 556 check_sampler_shuffle_shard_options(param_dict) 557 558 cache = param_dict.get('cache') 559 check_cache_option(cache) 560 561 return method(self, *args, **kwargs) 562 563 return new_method 564 565 566def check_tfrecorddataset(method): 567 """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset).""" 568 569 @wraps(method) 570 def new_method(self, *args, **kwargs): 571 _, param_dict = parse_user_args(method, *args, **kwargs) 572 573 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 574 nreq_param_list = ['columns_list'] 575 nreq_param_bool = ['shard_equal_rows'] 576 577 dataset_files = param_dict.get('dataset_files') 578 if not isinstance(dataset_files, (str, list)): 579 raise TypeError("dataset_files should be type str or a list of strings.") 580 if not dataset_files: 581 raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.") 582 583 validate_dataset_param_value(nreq_param_int, param_dict, int) 584 validate_dataset_param_value(nreq_param_list, param_dict, list) 585 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 586 587 compression_type = param_dict.get('compression_type') 588 if compression_type is not None and compression_type not in ['', 'ZLIB', 'GZIP']: 589 raise ValueError("Input compression_type can only be either '' (no compression), 'ZLIB', or 'GZIP', " + 590 "but got '" + str(compression_type) + "'.") 591 if compression_type is not None and compression_type in ['ZLIB', 'GZIP'] and \ 592 param_dict.get('num_samples') is not None: 593 if param_dict.get('num_shards') is not None and ((isinstance(dataset_files, str) and \ 594 param_dict.get('num_shards') > 1) or (isinstance(dataset_files, list) and \ 595 len(dataset_files) < param_dict.get('num_shards'))): 596 num_files = len(dataset_files) if isinstance(dataset_files, list) else 1 597 act_num_shard = param_dict.get('num_shards') if param_dict.get('num_shards') is not None else 1 598 raise ValueError("When compression_type is provided, the number of dataset files cannot be less " + 599 "than num_shards, but the actual number of files is " + str(num_files) + 600 " and actual num_shards is " + str(act_num_shard) + ".") 601 if param_dict.get('shard_equal_rows') is None or not param_dict.get('shard_equal_rows'): 602 logger.warning("If compression_type is set, shard_equal_rows will be ignored.") 603 604 check_sampler_shuffle_shard_options(param_dict) 605 606 cache = param_dict.get('cache') 607 check_cache_option(cache) 608 609 return method(self, *args, **kwargs) 610 611 return new_method 612 613 614def check_udpos_dataset(method): 615 """A wrapper that wraps a parameter checker around the original Dataset(UDPOSDataset).""" 616 617 @wraps(method) 618 def new_method(self, *args, **kwargs): 619 _, param_dict = parse_user_args(method, *args, **kwargs) 620 621 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 622 623 # check dataset_dir; required argument 624 dataset_dir = param_dict.get('dataset_dir') 625 check_dir(dataset_dir) 626 627 # check usage 628 usage = param_dict.get('usage') 629 if usage is not None: 630 check_valid_str(usage, ["train", "valid", "test", "all"], "usage") 631 632 validate_dataset_param_value(nreq_param_int, param_dict, int) 633 check_sampler_shuffle_shard_options(param_dict) 634 635 cache = param_dict.get('cache') 636 check_cache_option(cache) 637 638 return method(self, *args, **kwargs) 639 640 return new_method 641 642 643def check_usps_dataset(method): 644 """A wrapper that wraps a parameter checker around the original Dataset(USPSDataset).""" 645 646 @wraps(method) 647 def new_method(self, *args, **kwargs): 648 _, param_dict = parse_user_args(method, *args, **kwargs) 649 650 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 651 652 dataset_dir = param_dict.get('dataset_dir') 653 check_dir(dataset_dir) 654 655 usage = param_dict.get('usage') 656 if usage is not None: 657 check_valid_str(usage, ["train", "test", "all"], "usage") 658 659 validate_dataset_param_value(nreq_param_int, param_dict, int) 660 check_sampler_shuffle_shard_options(param_dict) 661 662 cache = param_dict.get('cache') 663 check_cache_option(cache) 664 665 return method(self, *args, **kwargs) 666 667 return new_method 668 669 670def check_caltech101_dataset(method): 671 """A wrapper that wraps a parameter checker around the original Dataset(Caltech101Dataset).""" 672 673 @wraps(method) 674 def new_method(self, *args, **kwargs): 675 _, param_dict = parse_user_args(method, *args, **kwargs) 676 677 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 678 nreq_param_bool = ['shuffle', 'decode'] 679 nreq_param_str = ['target_type'] 680 681 dataset_dir = param_dict.get('dataset_dir') 682 check_dir(dataset_dir) 683 684 target_type = param_dict.get('target_type') 685 if target_type is not None: 686 check_valid_str(target_type, ["category", "annotation", "all"], "target_type") 687 688 validate_dataset_param_value(nreq_param_int, param_dict, int) 689 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 690 validate_dataset_param_value(nreq_param_str, param_dict, str) 691 check_sampler_shuffle_shard_options(param_dict) 692 693 cache = param_dict.get('cache') 694 check_cache_option(cache) 695 696 return method(self, *args, **kwargs) 697 698 return new_method 699 700 701def check_caltech256_dataset(method): 702 """A wrapper that wraps a parameter checker around the original Dataset(Caltech256Dataset).""" 703 704 @wraps(method) 705 def new_method(self, *args, **kwargs): 706 _, param_dict = parse_user_args(method, *args, **kwargs) 707 708 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 709 nreq_param_bool = ['shuffle', 'decode'] 710 711 dataset_dir = param_dict.get('dataset_dir') 712 check_dir(dataset_dir) 713 714 validate_dataset_param_value(nreq_param_int, param_dict, int) 715 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 716 check_sampler_shuffle_shard_options(param_dict) 717 718 cache = param_dict.get('cache') 719 check_cache_option(cache) 720 721 return method(self, *args, **kwargs) 722 723 return new_method 724 725 726def check_vocdataset(method): 727 """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset).""" 728 729 @wraps(method) 730 def new_method(self, *args, **kwargs): 731 _, param_dict = parse_user_args(method, *args, **kwargs) 732 733 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 734 nreq_param_bool = ['shuffle', 'decode'] 735 nreq_param_dict = ['class_indexing'] 736 737 dataset_dir = param_dict.get('dataset_dir') 738 check_dir(dataset_dir) 739 740 task = param_dict.get('task') 741 type_check(task, (str,), "task") 742 743 usage = param_dict.get('usage') 744 type_check(usage, (str,), "usage") 745 dataset_dir = os.path.realpath(dataset_dir) 746 747 if task == "Segmentation": 748 imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt") 749 if param_dict.get('class_indexing') is not None: 750 raise ValueError("class_indexing is not supported in Segmentation task.") 751 elif task == "Detection": 752 imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt") 753 else: 754 raise ValueError("Invalid task : " + task + ".") 755 756 decrypt = param_dict.get('decrypt') 757 if decrypt is not None and not callable(decrypt): 758 raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt))) 759 760 check_file(imagesets_file) 761 762 validate_dataset_param_value(nreq_param_int, param_dict, int) 763 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 764 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 765 check_sampler_shuffle_shard_options(param_dict) 766 767 cache = param_dict.get('cache') 768 check_cache_option(cache) 769 770 return method(self, *args, **kwargs) 771 772 return new_method 773 774 775def check_cocodataset(method): 776 """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset).""" 777 778 @wraps(method) 779 def new_method(self, *args, **kwargs): 780 _, param_dict = parse_user_args(method, *args, **kwargs) 781 782 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 783 nreq_param_bool = ['shuffle', 'decode'] 784 785 dataset_dir = param_dict.get('dataset_dir') 786 check_dir(dataset_dir) 787 788 annotation_file = param_dict.get('annotation_file') 789 check_file(annotation_file) 790 791 task = param_dict.get('task') 792 type_check(task, (str,), "task") 793 794 if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint', 'Captioning'}: 795 raise ValueError("Invalid task type: " + task + ".") 796 797 decrypt = param_dict.get('decrypt') 798 if decrypt is not None and not callable(decrypt): 799 raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt))) 800 801 validate_dataset_param_value(nreq_param_int, param_dict, int) 802 803 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 804 805 sampler = param_dict.get('sampler') 806 if sampler is not None and isinstance(sampler, samplers.PKSampler): 807 raise ValueError("CocoDataset doesn't support PKSampler.") 808 check_sampler_shuffle_shard_options(param_dict) 809 810 cache = param_dict.get('cache') 811 check_cache_option(cache) 812 813 return method(self, *args, **kwargs) 814 815 return new_method 816 817 818def check_celebadataset(method): 819 """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset).""" 820 821 @wraps(method) 822 def new_method(self, *args, **kwargs): 823 _, param_dict = parse_user_args(method, *args, **kwargs) 824 825 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 826 nreq_param_bool = ['shuffle', 'decode'] 827 nreq_param_list = ['extensions'] 828 nreq_param_str = ['dataset_type'] 829 830 dataset_dir = param_dict.get('dataset_dir') 831 832 check_dir(dataset_dir) 833 834 decrypt = param_dict.get('decrypt') 835 if decrypt is not None and not callable(decrypt): 836 raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt))) 837 838 validate_dataset_param_value(nreq_param_int, param_dict, int) 839 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 840 validate_dataset_param_value(nreq_param_list, param_dict, list) 841 validate_dataset_param_value(nreq_param_str, param_dict, str) 842 843 usage = param_dict.get('usage') 844 if usage is not None and usage not in ('all', 'train', 'valid', 'test'): 845 raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.") 846 847 check_sampler_shuffle_shard_options(param_dict) 848 849 sampler = param_dict.get('sampler') 850 if sampler is not None and isinstance(sampler, samplers.PKSampler): 851 raise ValueError("CelebADataset doesn't support PKSampler.") 852 853 cache = param_dict.get('cache') 854 check_cache_option(cache) 855 856 return method(self, *args, **kwargs) 857 858 return new_method 859 860 861def check_libri_tts_dataset(method): 862 """A wrapper that wraps a parameter checker around the original Dataset(LibriTTSDataset).""" 863 864 @wraps(method) 865 def new_method(self, *args, **kwargs): 866 _, param_dict = parse_user_args(method, *args, **kwargs) 867 868 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 869 nreq_param_bool = ['shuffle'] 870 871 dataset_dir = param_dict.get('dataset_dir') 872 check_dir(dataset_dir) 873 874 usage = param_dict.get('usage') 875 if usage is not None: 876 check_valid_str(usage, ["dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", 877 "train-clean-360", "train-other-500", "all"], "usage") 878 validate_dataset_param_value(nreq_param_int, param_dict, int) 879 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 880 881 check_sampler_shuffle_shard_options(param_dict) 882 cache = param_dict.get('cache') 883 check_cache_option(cache) 884 885 return method(self, *args, **kwargs) 886 887 return new_method 888 889 890def check_lj_speech_dataset(method): 891 """A wrapper that wraps a parameter checker around the original Dataset(LJSpeechDataset).""" 892 893 @wraps(method) 894 def new_method(self, *args, **kwargs): 895 _, param_dict = parse_user_args(method, *args, **kwargs) 896 897 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 898 nreq_param_bool = ['shuffle'] 899 900 dataset_dir = param_dict.get('dataset_dir') 901 check_dir(dataset_dir) 902 903 validate_dataset_param_value(nreq_param_int, param_dict, int) 904 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 905 906 check_sampler_shuffle_shard_options(param_dict) 907 908 cache = param_dict.get('cache') 909 check_cache_option(cache) 910 911 return method(self, *args, **kwargs) 912 913 return new_method 914 915 916def check_lfw_dataset(method): 917 """A wrapper that wraps a parameter checker around the original Dataset(LFWDataset).""" 918 919 @wraps(method) 920 def new_method(self, *args, **kwargs): 921 _, param_dict = parse_user_args(method, *args, **kwargs) 922 923 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 924 nreq_param_bool = ['shuffle', 'decode'] 925 926 dataset_dir = param_dict.get('dataset_dir') 927 check_dir(dataset_dir) 928 929 task = param_dict.get('task') 930 if task is not None: 931 check_valid_str(task, ["people", "pairs"], "task") 932 933 usage = param_dict.get('usage') 934 if usage is not None: 935 check_valid_str(usage, ["10fold", "train", "test", "all"], "usage") 936 937 image_set = param_dict.get('image_set') 938 if image_set is not None: 939 check_valid_str(image_set, ["original", "funneled", "deepfunneled"], "image_set") 940 941 validate_dataset_param_value(nreq_param_int, param_dict, int) 942 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 943 check_sampler_shuffle_shard_options(param_dict) 944 945 cache = param_dict.get('cache') 946 check_cache_option(cache) 947 948 return method(self, *args, **kwargs) 949 950 return new_method 951 952 953def check_save(method): 954 """A wrapper that wraps a parameter checker around the saved operation.""" 955 956 @wraps(method) 957 def new_method(self, *args, **kwargs): 958 _, param_dict = parse_user_args(method, *args, **kwargs) 959 960 nreq_param_int = ['num_files'] 961 nreq_param_str = ['file_name', 'file_type'] 962 validate_dataset_param_value(nreq_param_int, param_dict, int) 963 if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): 964 raise ValueError("num_files should between 0 and 1000.") 965 validate_dataset_param_value(nreq_param_str, param_dict, str) 966 if param_dict.get('file_type') != 'mindrecord': 967 raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type'))) 968 return method(self, *args, **kwargs) 969 970 return new_method 971 972 973def check_tuple_iterator(method): 974 """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" 975 976 @wraps(method) 977 def new_method(self, *args, **kwargs): 978 [columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs) 979 nreq_param_bool = ['output_numpy'] 980 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 981 if num_epochs is not None: 982 type_check(num_epochs, (int,), "num_epochs") 983 check_value(num_epochs, [-1, INT32_MAX], "num_epochs") 984 985 if columns is not None: 986 check_columns(columns, "column_names") 987 988 return method(self, *args, **kwargs) 989 990 return new_method 991 992 993def check_dict_iterator(method): 994 """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" 995 996 @wraps(method) 997 def new_method(self, *args, **kwargs): 998 [num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs) 999 nreq_param_bool = ['output_numpy'] 1000 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1001 if num_epochs is not None: 1002 type_check(num_epochs, (int,), "num_epochs") 1003 check_value(num_epochs, [-1, INT32_MAX], "num_epochs") 1004 1005 return method(self, *args, **kwargs) 1006 1007 return new_method 1008 1009 1010def check_minddataset(method): 1011 """A wrapper that wraps a parameter checker around the original Dataset(MindDataset).""" 1012 1013 @wraps(method) 1014 def new_method(self, *args, **kwargs): 1015 _, param_dict = parse_user_args(method, *args, **kwargs) 1016 1017 nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded'] 1018 nreq_param_list = ['columns_list'] 1019 nreq_param_dict = ['padded_sample'] 1020 1021 dataset_file = param_dict.get('dataset_files') 1022 if isinstance(dataset_file, list): 1023 if len(dataset_file) > 4096: 1024 logger.warning("The number of MindRecord files greater than 4096" 1025 "may cause slow dataset initialization.") 1026 for f in dataset_file: 1027 check_file(f) 1028 else: 1029 check_file(dataset_file) 1030 1031 validate_dataset_param_value(nreq_param_int, param_dict, int) 1032 validate_dataset_param_value(nreq_param_list, param_dict, list) 1033 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 1034 1035 check_sampler_shuffle_shard_options(param_dict) 1036 1037 check_padding_options(param_dict) 1038 return method(self, *args, **kwargs) 1039 1040 return new_method 1041 1042 1043def check_source_function(source): 1044 """Get used variable and source document in given function.""" 1045 # check whether source is an instanced object of user defined class 1046 from types import FunctionType 1047 var = tuple() 1048 source_doc = "" 1049 if isinstance(source, FunctionType): 1050 try: 1051 var = ins.getclosurevars(source) 1052 source_doc = ins.getsource(source) 1053 except OSError: 1054 return "" 1055 else: 1056 try: 1057 source_attr = source.__class__.__dict__.keys() 1058 if '__init__' in source_attr: 1059 var = var + ins.getclosurevars(source.__class__.__init__) 1060 source_doc = source_doc + ins.getsource(source.__class__.__init__) 1061 if '__getitem__' in source_attr: 1062 var = var + ins.getclosurevars(source.__class__.__getitem__) 1063 source_doc = source_doc + ins.getsource(source.__class__.__getitem__) 1064 elif '__next__' in source_attr: 1065 var = var + ins.getclosurevars(source.__class__.__next__) 1066 source_doc = source_doc + ins.getsource(source.__class__.__next__) 1067 except (TypeError, OSError): 1068 # case: like input is LambdaType or GeneratorType, it will go to else branch, and unable to run normally 1069 pass 1070 return str(var) + source_doc 1071 1072 1073def check_generatordataset(method): 1074 """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset).""" 1075 1076 @wraps(method) 1077 def new_method(self, *args, **kwargs): 1078 _, param_dict = parse_user_args(method, *args, **kwargs) 1079 1080 source = param_dict.get('source') 1081 1082 if not callable(source): 1083 try: 1084 iter(source) 1085 except TypeError: 1086 raise TypeError("Input `source` function of GeneratorDataset should be callable, iterable or random" 1087 " accessible, commonly it should implement one of the method like yield, __getitem__ or" 1088 " __next__(__iter__).") 1089 1090 # check used variable and function document whether contain computing operator 1091 check_doc = check_source_function(source) 1092 check_list = ['mindspore.nn', 'mindspore.ops', 'mindspore.numpy', 'mindspore.compression'] 1093 for item in check_list: 1094 if item in check_doc: 1095 setattr(self, 'operator_mixed', True) 1096 break 1097 1098 column_names = param_dict.get('column_names') 1099 if column_names is not None: 1100 check_columns(column_names, "column_names") 1101 schema = param_dict.get('schema') 1102 if column_names is None and schema is None: 1103 raise ValueError("Neither columns_names nor schema are provided.") 1104 1105 if schema is not None: 1106 if not isinstance(schema, (datasets.Schema, str)): 1107 raise ValueError("schema should be a path to schema file or a schema object.") 1108 1109 # check optional argument 1110 nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"] 1111 validate_dataset_param_value(nreq_param_int, param_dict, int) 1112 nreq_param_list = ["column_types"] 1113 validate_dataset_param_value(nreq_param_list, param_dict, list) 1114 nreq_param_bool = ["shuffle", "python_multiprocessing"] 1115 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1116 1117 check_value(param_dict.get("max_rowsize"), [-1, INT32_MAX], "max_rowsize") 1118 1119 num_shards = param_dict.get("num_shards") 1120 shard_id = param_dict.get("shard_id") 1121 check_dataset_num_shards_shard_id(num_shards, shard_id) 1122 1123 sampler = param_dict.get("sampler") 1124 if sampler is not None: 1125 if isinstance(sampler, samplers.PKSampler): 1126 raise ValueError("GeneratorDataset doesn't support PKSampler.") 1127 if not isinstance(sampler, samplers.BuiltinSampler): 1128 try: 1129 iter(sampler) 1130 except TypeError: 1131 raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.") 1132 1133 if sampler is not None and not hasattr(source, "__getitem__"): 1134 raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.") 1135 if num_shards is not None and not hasattr(source, "__getitem__"): 1136 raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.") 1137 1138 return method(self, *args, **kwargs) 1139 1140 return new_method 1141 1142 1143def check_random_dataset(method): 1144 """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset).""" 1145 1146 @wraps(method) 1147 def new_method(self, *args, **kwargs): 1148 _, param_dict = parse_user_args(method, *args, **kwargs) 1149 1150 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows'] 1151 nreq_param_bool = ['shuffle'] 1152 nreq_param_list = ['columns_list'] 1153 1154 validate_dataset_param_value(nreq_param_int, param_dict, int) 1155 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1156 validate_dataset_param_value(nreq_param_list, param_dict, list) 1157 1158 check_sampler_shuffle_shard_options(param_dict) 1159 1160 cache = param_dict.get('cache') 1161 check_cache_option(cache) 1162 1163 return method(self, *args, **kwargs) 1164 1165 return new_method 1166 1167 1168def check_rendered_sst2_dataset(method): 1169 """A wrapper that wraps a parameter checker around the original Dataset(RenderedSST2Dataset).""" 1170 1171 @wraps(method) 1172 def new_method(self, *args, **kwargs): 1173 _, param_dict = parse_user_args(method, *args, **kwargs) 1174 1175 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1176 nreq_param_bool = ['shuffle', 'decode'] 1177 1178 dataset_dir = param_dict.get('dataset_dir') 1179 usage = param_dict.get('usage') 1180 check_dir(dataset_dir) 1181 if usage is not None: 1182 check_valid_str(usage, ['val', 'all', 'train', 'test']) 1183 1184 validate_dataset_param_value(nreq_param_int, param_dict, int) 1185 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1186 check_sampler_shuffle_shard_options(param_dict) 1187 1188 cache = param_dict.get('cache') 1189 check_cache_option(cache) 1190 1191 return method(self, *args, **kwargs) 1192 1193 return new_method 1194 1195 1196def check_pad_info(key, val): 1197 """check the key and value pair of pad_info in batch""" 1198 type_check(key, (str,), "key in pad_info") 1199 1200 if val is not None: 1201 if len(val) != 2: 1202 raise ValueError("value of pad_info should be a tuple of size 2.") 1203 type_check(val, (tuple,), "value in pad_info") 1204 1205 if val[0] is not None: 1206 type_check(val[0], (list,), "shape in pad_info") 1207 1208 for dim in val[0]: 1209 if dim is not None: 1210 check_pos_int32(dim, "dim of shape in pad_info") 1211 if val[1] is not None: 1212 type_check(val[1], (int, float, str, bytes), "pad_value") 1213 1214 1215def check_bucket_batch_by_length(method): 1216 """check the input arguments of bucket_batch_by_length.""" 1217 1218 @wraps(method) 1219 def new_method(self, *args, **kwargs): 1220 [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info, 1221 pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs) 1222 1223 nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] 1224 1225 type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list) 1226 1227 nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder'] 1228 type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) 1229 1230 # check column_names: must be list of string. 1231 check_columns(column_names, "column_names") 1232 1233 if element_length_function is None and len(column_names) != 1: 1234 raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") 1235 1236 if element_length_function is not None and not callable(element_length_function): 1237 raise TypeError("element_length_function object is not callable.") 1238 1239 # check bucket_boundaries: must be list of int, positive and strictly increasing 1240 if not bucket_boundaries: 1241 raise ValueError("bucket_boundaries cannot be empty.") 1242 1243 all_int = all(isinstance(item, int) for item in bucket_boundaries) 1244 if not all_int: 1245 raise TypeError("bucket_boundaries should be a list of int.") 1246 1247 all_non_negative = all(item > 0 for item in bucket_boundaries) 1248 if not all_non_negative: 1249 raise ValueError("bucket_boundaries must only contain positive numbers.") 1250 1251 for i in range(len(bucket_boundaries) - 1): 1252 if not bucket_boundaries[i + 1] > bucket_boundaries[i]: 1253 raise ValueError("bucket_boundaries should be strictly increasing.") 1254 1255 # check bucket_batch_sizes: must be list of int and positive 1256 if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: 1257 raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") 1258 1259 all_int = all(isinstance(item, int) for item in bucket_batch_sizes) 1260 if not all_int: 1261 raise TypeError("bucket_batch_sizes should be a list of int.") 1262 1263 all_non_negative = all(item > 0 for item in bucket_batch_sizes) 1264 if not all_non_negative: 1265 raise ValueError("bucket_batch_sizes should be a list of positive numbers.") 1266 1267 if pad_info is not None: 1268 type_check(pad_info, (dict,), "pad_info") 1269 1270 for k, v in pad_info.items(): 1271 check_pad_info(k, v) 1272 1273 return method(self, *args, **kwargs) 1274 1275 return new_method 1276 1277 1278def get_batch_kwargs_from_dict(param_dict): 1279 """get batch operation kwargs parameters.""" 1280 if param_dict is not None: 1281 per_batch_map = param_dict.get("per_batch_map", None) 1282 input_columns = param_dict.get("input_columns", None) 1283 output_columns = param_dict.get("output_columns", None) 1284 python_multiprocessing = param_dict.get("python_multiprocessing", False) 1285 max_rowsize = param_dict.get("max_rowsize", 16) 1286 return per_batch_map, input_columns, output_columns, python_multiprocessing, max_rowsize 1287 1288 1289def check_batch(method): 1290 """check the input arguments of batch.""" 1291 1292 @wraps(method) 1293 def new_method(self, *args, **kwargs): 1294 [batch_size, drop_remainder, num_parallel_workers, param_dict], _ = parse_user_args(method, *args, **kwargs) 1295 1296 (per_batch_map, input_columns, output_columns, python_multiprocessing, max_rowsize) = \ 1297 get_batch_kwargs_from_dict(param_dict) 1298 1299 if not (isinstance(batch_size, int) or (callable(batch_size))): 1300 raise TypeError("batch_size should either be an int or a callable.") 1301 1302 if callable(batch_size): 1303 sig = ins.signature(batch_size) 1304 if len(sig.parameters) != 1: 1305 raise ValueError("callable batch_size should take one parameter (BatchInfo).") 1306 else: 1307 check_pos_int32(int(batch_size), "batch_size") 1308 1309 if num_parallel_workers is not None: 1310 check_num_parallel_workers(num_parallel_workers) 1311 type_check(drop_remainder, (bool,), "drop_remainder") 1312 1313 check_max_rowsize(max_rowsize) 1314 1315 if (input_columns is not None) and (per_batch_map is None): 1316 # input_columns must be None when per_batch_map is not set 1317 raise ValueError("input_columns can be specified only when per_batch_map is set.") 1318 1319 if input_columns is not None: 1320 check_columns(input_columns, "input_columns") 1321 if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): 1322 raise ValueError("The signature of per_batch_map should match with input columns.") 1323 1324 if output_columns is not None: 1325 check_columns(output_columns, "output_columns") 1326 1327 if python_multiprocessing is not None: 1328 type_check(python_multiprocessing, (bool,), "python_multiprocessing") 1329 1330 return method(self, *args, **kwargs) 1331 1332 return new_method 1333 1334 1335def check_padded_batch(method): 1336 """check the input arguments of padded_batch.""" 1337 1338 @wraps(method) 1339 def new_method(self, *args, **kwargs): 1340 [batch_size, drop_remainder, num_parallel_workers, pad_info], _ = parse_user_args(method, *args, **kwargs) 1341 1342 if not (isinstance(batch_size, int) or (callable(batch_size))): 1343 raise TypeError("batch_size should either be an int or a callable.") 1344 1345 if callable(batch_size): 1346 sig = ins.signature(batch_size) 1347 if len(sig.parameters) != 1: 1348 raise ValueError("callable batch_size should take one parameter (BatchInfo).") 1349 else: 1350 check_pos_int32(int(batch_size), "batch_size") 1351 1352 if num_parallel_workers is not None: 1353 check_num_parallel_workers(num_parallel_workers) 1354 type_check(drop_remainder, (bool,), "drop_remainder") 1355 1356 if pad_info is not None: 1357 type_check(pad_info, (dict,), "pad_info") 1358 for k, v in pad_info.items(): 1359 check_pad_info(k, v) 1360 1361 return method(self, *args, **kwargs) 1362 1363 return new_method 1364 1365 1366def check_sync_wait(method): 1367 """check the input arguments of sync_wait.""" 1368 1369 @wraps(method) 1370 def new_method(self, *args, **kwargs): 1371 [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs) 1372 1373 type_check(condition_name, (str,), "condition_name") 1374 type_check(num_batch, (int,), "num_batch") 1375 1376 return method(self, *args, **kwargs) 1377 1378 return new_method 1379 1380 1381def check_shuffle(method): 1382 """check the input arguments of shuffle.""" 1383 1384 @wraps(method) 1385 def new_method(self, *args, **kwargs): 1386 [buffer_size], _ = parse_user_args(method, *args, **kwargs) 1387 1388 type_check(buffer_size, (int,), "buffer_size") 1389 1390 check_value(buffer_size, [2, INT32_MAX], "buffer_size") 1391 1392 return method(self, *args, **kwargs) 1393 1394 return new_method 1395 1396 1397def get_map_kwargs_from_dict(param_dict): 1398 """get map operation kwargs parameters.""" 1399 if param_dict is not None: 1400 python_multiprocessing = param_dict.get("python_multiprocessing", False) 1401 max_rowsize = param_dict.get("max_rowsize", 16) 1402 cache = param_dict.get("cache", None) 1403 callbacks = param_dict.get("callbacks", None) 1404 offload = param_dict.get("offload", None) 1405 return python_multiprocessing, max_rowsize, cache, callbacks, offload 1406 1407 1408def check_max_rowsize(max_rowsize): 1409 """check the max_rowsize""" 1410 type_check(max_rowsize, (int, list), "max_rowsize") 1411 if isinstance(max_rowsize, int): 1412 type_check(max_rowsize, (int,), "max_rowsize") 1413 check_value(max_rowsize, [-1, INT32_MAX], "max_rowsize") 1414 elif isinstance(max_rowsize, list) and len(max_rowsize) == 2: 1415 for index, value in enumerate(max_rowsize): 1416 type_check(value, (int,), "max_rowsize[{}]".format(index)) 1417 check_value(value, [-1, INT32_MAX], "max_rowsizei[{}]".format(index)) 1418 else: 1419 raise TypeError("max_rowsize should be a single integer or a list[in_rowsize, out_rowsize] of length 2.") 1420 1421 1422def check_map(method): 1423 """check the input arguments of map.""" 1424 1425 @wraps(method) 1426 def new_method(self, *args, **kwargs): 1427 from mindspore.dataset.callback import DSCallback 1428 [operations, input_columns, output_columns, column_order, num_parallel_workers, param_dict], _ = \ 1429 parse_user_args(method, *args, **kwargs) 1430 1431 if column_order is not None: 1432 raise ValueError("The parameter 'column_order' had been deleted in map operation. " 1433 "Please use '.project' operation instead.\n" 1434 ">> # Usage of old api:\n" 1435 ">> dataset = dataset.map(operations=PyFunc,\n" 1436 ">> input_columns=[\"column_a\"],\n" 1437 ">> output_columns=[\"column_b\", \"column_c\"],\n" 1438 ">> column_order=[\"column_b\", \"column_c\"])\n" 1439 ">> # Usage of new api:\n" 1440 ">> dataset = dataset.map(operations=PyFunc,\n" 1441 ">> input_columns=[\"column_a\"],\n" 1442 ">> output_columns=[\"column_b\", \"column_c\"])\n" 1443 ">> dataset = dataset.project([\"column_b\", \"column_c\"])") 1444 1445 (python_multiprocessing, max_rowsize, cache, callbacks, offload) = get_map_kwargs_from_dict(param_dict) 1446 1447 # check whether network computing operator exist in input operations(python function) 1448 # check used variable and function document whether contain computing operator 1449 from types import FunctionType 1450 if isinstance(operations, FunctionType): 1451 try: 1452 var = ins.getclosurevars(operations) 1453 operations_doc = ins.getsource(operations) 1454 check_list = ['mindspore.nn', 'mindspore.ops', 'mindspore.numpy', 'mindspore.compression'] 1455 check_doc = str(var) + operations_doc 1456 for item in check_list: 1457 if item in check_doc: 1458 setattr(self, 'operator_mixed', True) 1459 break 1460 except OSError: 1461 pass 1462 1463 operations = operations if isinstance(operations, list) else [operations] 1464 # import nn and ops locally for type check 1465 from mindspore import nn, ops 1466 for item in operations: 1467 if isinstance(item, (nn.Cell, ops.Primitive)): 1468 raise ValueError("Input operations should not contain network computing operator like in " 1469 "mindspore.nn or mindspore.ops, got operation: ", str(item)) 1470 1471 nreq_param_columns = ['input_columns', 'output_columns'] 1472 1473 if num_parallel_workers is not None: 1474 check_num_parallel_workers(num_parallel_workers) 1475 type_check(python_multiprocessing, (bool,), "python_multiprocessing") 1476 check_cache_option(cache) 1477 check_max_rowsize(max_rowsize) 1478 if offload is not None: 1479 type_check(offload, (bool,), "offload") 1480 1481 if callbacks is not None: 1482 if isinstance(callbacks, (list, tuple)): 1483 type_check_list(callbacks, (DSCallback,), "callbacks") 1484 else: 1485 type_check(callbacks, (DSCallback,), "callbacks") 1486 1487 for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): 1488 if param is not None: 1489 check_columns(param, param_name) 1490 if callbacks is not None: 1491 type_check(callbacks, (list, DSCallback), "callbacks") 1492 1493 return method(self, *args, **kwargs) 1494 1495 return new_method 1496 1497 1498def check_filter(method): 1499 """"check the input arguments of filter.""" 1500 1501 @wraps(method) 1502 def new_method(self, *args, **kwargs): 1503 [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) 1504 if not callable(predicate): 1505 raise TypeError("Predicate should be a Python function or a callable Python object.") 1506 1507 if num_parallel_workers is not None: 1508 check_num_parallel_workers(num_parallel_workers) 1509 1510 if input_columns is not None: 1511 check_columns(input_columns, "input_columns") 1512 1513 return method(self, *args, **kwargs) 1514 1515 return new_method 1516 1517 1518def check_repeat(method): 1519 """check the input arguments of repeat.""" 1520 1521 @wraps(method) 1522 def new_method(self, *args, **kwargs): 1523 [count], _ = parse_user_args(method, *args, **kwargs) 1524 1525 type_check(count, (int, type(None)), "repeat") 1526 if isinstance(count, int): 1527 if (count <= 0 and count != -1) or count > INT32_MAX: 1528 raise ValueError("count should be either -1 or positive integer, range[1, INT32_MAX].") 1529 return method(self, *args, **kwargs) 1530 1531 return new_method 1532 1533 1534def check_skip(method): 1535 """check the input arguments of skip.""" 1536 1537 @wraps(method) 1538 def new_method(self, *args, **kwargs): 1539 [count], _ = parse_user_args(method, *args, **kwargs) 1540 1541 type_check(count, (int,), "count") 1542 check_value(count, (0, INT32_MAX), "count") 1543 1544 return method(self, *args, **kwargs) 1545 1546 return new_method 1547 1548 1549def check_take(method): 1550 """check the input arguments of take.""" 1551 1552 @wraps(method) 1553 def new_method(self, *args, **kwargs): 1554 [count], _ = parse_user_args(method, *args, **kwargs) 1555 type_check(count, (int,), "count") 1556 if (count <= 0 and count != -1) or count > INT32_MAX: 1557 raise ValueError("count should be either -1 or within the required interval of ({}, {}], got {}." 1558 .format(0, INT32_MAX, count)) 1559 1560 return method(self, *args, **kwargs) 1561 1562 return new_method 1563 1564 1565def check_positive_int32(method): 1566 """check whether the input argument is positive and int, only works for functions with one input.""" 1567 1568 @wraps(method) 1569 def new_method(self, *args, **kwargs): 1570 [count], param_dict = parse_user_args(method, *args, **kwargs) 1571 para_name = None 1572 for key in list(param_dict.keys()): 1573 if key not in ['self', 'cls']: 1574 para_name = key 1575 # Need to get default value of param 1576 if count is not None: 1577 check_pos_int32(count, para_name) 1578 1579 return method(self, *args, **kwargs) 1580 1581 return new_method 1582 1583 1584def check_device_send(method): 1585 """check the input argument of device_que.""" 1586 1587 @wraps(method) 1588 def new_method(self, *args, **kwargs): 1589 [send_epoch_end, create_data_info_queue, queue_name], _ = parse_user_args(method, *args, **kwargs) 1590 type_check(send_epoch_end, (bool,), "send_epoch_end") 1591 type_check(create_data_info_queue, (bool,), "create_data_info_queue") 1592 type_check(queue_name, (str,), "queue_name") 1593 1594 return method(self, *args, **kwargs) 1595 1596 return new_method 1597 1598 1599def check_total_batch(total_batch): 1600 check_int32(total_batch, "total_batch") 1601 1602 1603def check_zip(method): 1604 """check the input arguments of zip.""" 1605 1606 @wraps(method) 1607 def new_method(*args, **kwargs): 1608 [ds], _ = parse_user_args(method, *args, **kwargs) 1609 type_check(ds, (tuple,), "datasets") 1610 1611 return method(*args, **kwargs) 1612 1613 return new_method 1614 1615 1616def check_zip_dataset(method): 1617 """check the input arguments of zip method in `Dataset` .""" 1618 1619 @wraps(method) 1620 def new_method(self, *args, **kwargs): 1621 [ds], _ = parse_user_args(method, *args, **kwargs) 1622 type_check(ds, (tuple, datasets.Dataset), "datasets") 1623 1624 return method(self, *args, **kwargs) 1625 1626 return new_method 1627 1628 1629def check_concat(method): 1630 """check the input arguments of concat method in `Dataset` .""" 1631 1632 @wraps(method) 1633 def new_method(self, *args, **kwargs): 1634 [ds], _ = parse_user_args(method, *args, **kwargs) 1635 type_check(ds, (list, datasets.Dataset), "datasets") 1636 if isinstance(ds, list): 1637 type_check_list(ds, (datasets.Dataset,), "dataset") 1638 return method(self, *args, **kwargs) 1639 1640 return new_method 1641 1642 1643def check_rename(method): 1644 """check the input arguments of rename.""" 1645 1646 @wraps(method) 1647 def new_method(self, *args, **kwargs): 1648 values, _ = parse_user_args(method, *args, **kwargs) 1649 1650 req_param_columns = ['input_columns', 'output_columns'] 1651 for param_name, param in zip(req_param_columns, values): 1652 check_columns(param, param_name) 1653 1654 input_size, output_size = 1, 1 1655 input_columns, output_columns = values 1656 if isinstance(input_columns, list): 1657 input_size = len(input_columns) 1658 if isinstance(output_columns, list): 1659 output_size = len(output_columns) 1660 if input_size != output_size: 1661 raise ValueError("Number of column in input_columns and output_columns is not equal.") 1662 1663 return method(self, *args, **kwargs) 1664 1665 return new_method 1666 1667 1668def check_output_shape(method): 1669 """check the input arguments of output_shape.""" 1670 1671 @wraps(method) 1672 def new_method(self, *args, **kwargs): 1673 _, param_dict = parse_user_args(method, *args, **kwargs) 1674 estimate = param_dict.get('estimate') 1675 type_check(estimate, (bool,), "estimate") 1676 1677 return method(self, *args, **kwargs) 1678 1679 return new_method 1680 1681 1682def check_project(method): 1683 """check the input arguments of project.""" 1684 1685 @wraps(method) 1686 def new_method(self, *args, **kwargs): 1687 [columns], _ = parse_user_args(method, *args, **kwargs) 1688 check_columns(columns, 'columns') 1689 1690 return method(self, *args, **kwargs) 1691 1692 return new_method 1693 1694 1695def check_schema(method): 1696 """check the input arguments of Schema.__init__.""" 1697 1698 @wraps(method) 1699 def new_method(self, *args, **kwargs): 1700 [schema_file], _ = parse_user_args(method, *args, **kwargs) 1701 1702 if schema_file is not None: 1703 check_file(schema_file) 1704 1705 return method(self, *args, **kwargs) 1706 1707 return new_method 1708 1709 1710def check_add_column(method): 1711 """check the input arguments of add_column.""" 1712 1713 @wraps(method) 1714 def new_method(self, *args, **kwargs): 1715 [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs) 1716 1717 type_check(name, (str,), "name") 1718 1719 if not name: 1720 raise TypeError("Expected non-empty string for column name.") 1721 1722 if de_type is not None: 1723 if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): 1724 raise TypeError("Unknown column type: {}.".format(de_type)) 1725 else: 1726 raise TypeError("Expected non-empty string for de_type.") 1727 1728 if shape is not None: 1729 type_check(shape, (list,), "shape") 1730 type_check_list(shape, (int,), "shape") 1731 1732 return method(self, *args, **kwargs) 1733 1734 return new_method 1735 1736 1737def check_cluedataset(method): 1738 """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset).""" 1739 1740 @wraps(method) 1741 def new_method(self, *args, **kwargs): 1742 _, param_dict = parse_user_args(method, *args, **kwargs) 1743 1744 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1745 1746 dataset_files = param_dict.get('dataset_files') 1747 type_check(dataset_files, (str, list), "dataset files") 1748 if not dataset_files: 1749 raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.") 1750 1751 # check task 1752 task_param = param_dict.get('task') 1753 if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']: 1754 raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.") 1755 1756 # check usage 1757 usage_param = param_dict.get('usage') 1758 if usage_param not in ['train', 'test', 'eval']: 1759 raise ValueError("usage should be 'train', 'test' or 'eval'.") 1760 1761 validate_dataset_param_value(nreq_param_int, param_dict, int) 1762 check_sampler_shuffle_shard_options(param_dict) 1763 1764 cache = param_dict.get('cache') 1765 check_cache_option(cache) 1766 1767 return method(self, *args, **kwargs) 1768 1769 return new_method 1770 1771 1772def check_csvdataset(method): 1773 """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset).""" 1774 1775 @wraps(method) 1776 def new_method(self, *args, **kwargs): 1777 _, param_dict = parse_user_args(method, *args, **kwargs) 1778 1779 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1780 1781 # check dataset_files; required argument 1782 dataset_files = param_dict.get('dataset_files') 1783 type_check(dataset_files, (str, list), "dataset files") 1784 if not dataset_files: 1785 raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.") 1786 1787 # check field_delim 1788 field_delim = param_dict.get('field_delim') 1789 if field_delim is not None: 1790 type_check(field_delim, (str,), 'field delim') 1791 if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: 1792 raise ValueError("field_delim is invalid.") 1793 1794 # check column_defaults 1795 column_defaults = param_dict.get('column_defaults') 1796 if column_defaults is not None: 1797 if not isinstance(column_defaults, list): 1798 raise TypeError("column_defaults should be type of list.") 1799 for item in column_defaults: 1800 if not isinstance(item, (str, int, float)): 1801 raise TypeError("column type in column_defaults is invalid.") 1802 1803 # check column_names: must be list of string. 1804 column_names = param_dict.get("column_names") 1805 if column_names is not None: 1806 all_string = all(isinstance(item, str) for item in column_names) 1807 if not all_string: 1808 raise TypeError("column_names should be a list of str.") 1809 1810 validate_dataset_param_value(nreq_param_int, param_dict, int) 1811 check_sampler_shuffle_shard_options(param_dict) 1812 1813 cache = param_dict.get('cache') 1814 check_cache_option(cache) 1815 1816 return method(self, *args, **kwargs) 1817 1818 return new_method 1819 1820 1821def check_flowers102dataset(method): 1822 """A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset).""" 1823 1824 @wraps(method) 1825 def new_method(self, *args, **kwargs): 1826 _, param_dict = parse_user_args(method, *args, **kwargs) 1827 1828 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1829 nreq_param_bool = ['shuffle', 'decode'] 1830 1831 dataset_dir = param_dict.get('dataset_dir') 1832 check_dir(dataset_dir) 1833 1834 check_dir(os.path.join(dataset_dir, "jpg")) 1835 1836 check_file(os.path.join(dataset_dir, "imagelabels.mat")) 1837 check_file(os.path.join(dataset_dir, "setid.mat")) 1838 1839 usage = param_dict.get('usage') 1840 if usage is not None: 1841 check_valid_str(usage, ["train", "valid", "test", "all"], "usage") 1842 1843 task = param_dict.get('task') 1844 if task is not None: 1845 check_valid_str(task, ["Classification", "Segmentation"], "task") 1846 if task == "Segmentation": 1847 check_dir(os.path.join(dataset_dir, "segmim")) 1848 1849 validate_dataset_param_value(nreq_param_int, param_dict, int) 1850 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1851 1852 check_sampler_shuffle_shard_options(param_dict) 1853 1854 return method(self, *args, **kwargs) 1855 1856 return new_method 1857 1858 1859def check_textfiledataset(method): 1860 """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset).""" 1861 1862 @wraps(method) 1863 def new_method(self, *args, **kwargs): 1864 _, param_dict = parse_user_args(method, *args, **kwargs) 1865 1866 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1867 1868 dataset_files = param_dict.get('dataset_files') 1869 type_check(dataset_files, (str, list), "dataset files") 1870 if not dataset_files: 1871 raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.") 1872 1873 validate_dataset_param_value(nreq_param_int, param_dict, int) 1874 check_sampler_shuffle_shard_options(param_dict) 1875 1876 cache = param_dict.get('cache') 1877 check_cache_option(cache) 1878 1879 return method(self, *args, **kwargs) 1880 1881 return new_method 1882 1883 1884def check_penn_treebank_dataset(method): 1885 """A wrapper that wraps a parameter checker around the original Dataset(PennTreebankDataset).""" 1886 1887 @wraps(method) 1888 def new_method(self, *args, **kwargs): 1889 _, param_dict = parse_user_args(method, *args, **kwargs) 1890 1891 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1892 1893 # check dataset_dir; required argument 1894 dataset_dir = param_dict.get('dataset_dir') 1895 check_dir(dataset_dir) 1896 1897 # check usage 1898 usage = param_dict.get('usage') 1899 if usage is not None: 1900 check_valid_str(usage, ["train", "valid", "test", "all"], "usage") 1901 1902 validate_dataset_param_value(nreq_param_int, param_dict, int) 1903 check_sampler_shuffle_shard_options(param_dict) 1904 1905 cache = param_dict.get('cache') 1906 check_cache_option(cache) 1907 1908 return method(self, *args, **kwargs) 1909 1910 return new_method 1911 1912 1913def check_split(method): 1914 """check the input arguments of split.""" 1915 1916 @wraps(method) 1917 def new_method(self, *args, **kwargs): 1918 [sizes, randomize], _ = parse_user_args(method, *args, **kwargs) 1919 1920 type_check(sizes, (list,), "sizes") 1921 type_check(randomize, (bool,), "randomize") 1922 1923 # check sizes: must be list of float or list of int 1924 if not sizes: 1925 raise ValueError("sizes cannot be empty.") 1926 1927 all_int = all(isinstance(item, int) for item in sizes) 1928 all_float = all(isinstance(item, float) for item in sizes) 1929 1930 if not (all_int or all_float): 1931 raise ValueError("sizes should be list of int or list of float.") 1932 1933 if all_int: 1934 all_positive = all(item > 0 for item in sizes) 1935 if not all_positive: 1936 raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.") 1937 1938 if all_float: 1939 all_valid_percentages = all(0 < item <= 1 for item in sizes) 1940 if not all_valid_percentages: 1941 raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].") 1942 1943 epsilon = 0.00001 1944 if not abs(sum(sizes) - 1) < epsilon: 1945 raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.") 1946 1947 return method(self, *args, **kwargs) 1948 1949 return new_method 1950 1951 1952def check_hostname(hostname): 1953 if not hostname or len(hostname) > 255: 1954 return False 1955 if hostname[-1] == ".": 1956 hostname = hostname[:-1] # strip exactly one dot from the right, if present 1957 allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE) 1958 return all(allowed.match(x) for x in hostname.split(".")) 1959 1960 1961def check_numpyslicesdataset(method): 1962 """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset).""" 1963 1964 @wraps(method) 1965 def new_method(self, *args, **kwargs): 1966 _, param_dict = parse_user_args(method, *args, **kwargs) 1967 1968 data = param_dict.get("data") 1969 column_names = param_dict.get("column_names") 1970 type_check(data, (list, tuple, dict, np.ndarray), "data") 1971 if data is None or len(data) == 0: # pylint: disable=len-as-condition 1972 raise ValueError("Argument data cannot be empty") 1973 if isinstance(data, tuple): 1974 type_check(data[0], (list, np.ndarray), "data[0]") 1975 1976 # check column_names 1977 if column_names is not None: 1978 check_columns(column_names, "column_names") 1979 1980 # check num of input column in column_names 1981 column_num = 1 if isinstance(column_names, str) else len(column_names) 1982 if isinstance(data, dict): 1983 data_column = len(list(data.keys())) 1984 if column_num != data_column: 1985 raise ValueError("Num of input column names is {0}, but required is {1}." 1986 .format(column_num, data_column)) 1987 1988 elif isinstance(data, tuple): 1989 if column_num != len(data): 1990 raise ValueError("Num of input column names is {0}, but required is {1}." 1991 .format(column_num, len(data))) 1992 else: 1993 if column_num != 1: 1994 raise ValueError("Num of input column names is {0}, but required is {1} as data is list." 1995 .format(column_num, 1)) 1996 1997 return method(self, *args, **kwargs) 1998 1999 return new_method 2000 2001 2002def check_paddeddataset(method): 2003 """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset).""" 2004 2005 @wraps(method) 2006 def new_method(self, *args, **kwargs): 2007 _, param_dict = parse_user_args(method, *args, **kwargs) 2008 2009 padded_samples = param_dict.get("padded_samples") 2010 if not padded_samples: 2011 raise ValueError("padded_samples cannot be empty.") 2012 type_check(padded_samples, (list,), "padded_samples") 2013 type_check(padded_samples[0], (dict,), "padded_element") 2014 return method(self, *args, **kwargs) 2015 2016 return new_method 2017 2018 2019def check_cache_option(cache): 2020 """Sanity check for cache parameter""" 2021 if cache is not None: 2022 type_check(cache, (cache_client.DatasetCache,), "cache") 2023 2024 2025def check_to_device_send(method): 2026 """Check the input arguments of send function for TransferDataset.""" 2027 2028 @wraps(method) 2029 def new_method(self, *args, **kwargs): 2030 [num_epochs], _ = parse_user_args(method, *args, **kwargs) 2031 2032 if num_epochs is not None: 2033 type_check(num_epochs, (int,), "num_epochs") 2034 check_value(num_epochs, [-1, INT32_MAX], "num_epochs") 2035 2036 return method(self, *args, **kwargs) 2037 2038 return new_method 2039 2040 2041def check_emnist_dataset(method): 2042 """A wrapper that wraps a parameter checker emnist dataset""" 2043 2044 @wraps(method) 2045 def new_method(self, *args, **kwargs): 2046 _, param_dict = parse_user_args(method, *args, **kwargs) 2047 2048 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2049 nreq_param_bool = ['shuffle'] 2050 2051 validate_dataset_param_value(nreq_param_int, param_dict, int) 2052 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2053 2054 dataset_dir = param_dict.get('dataset_dir') 2055 check_dir(dataset_dir) 2056 2057 name = param_dict.get('name') 2058 check_valid_str(name, ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"], "name") 2059 2060 usage = param_dict.get('usage') 2061 if usage is not None: 2062 check_valid_str(usage, ["train", "test", "all"], "usage") 2063 2064 check_sampler_shuffle_shard_options(param_dict) 2065 2066 cache = param_dict.get('cache') 2067 check_cache_option(cache) 2068 2069 return method(self, *args, **kwargs) 2070 2071 return new_method 2072 2073 2074def check_flickr_dataset(method): 2075 """A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k).""" 2076 2077 @wraps(method) 2078 def new_method(self, *args, **kwargs): 2079 _, param_dict = parse_user_args(method, *args, **kwargs) 2080 2081 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2082 nreq_param_bool = ['shuffle', 'decode'] 2083 2084 dataset_dir = param_dict.get('dataset_dir') 2085 annotation_file = param_dict.get('annotation_file') 2086 check_dir(dataset_dir) 2087 check_file(annotation_file) 2088 2089 validate_dataset_param_value(nreq_param_int, param_dict, int) 2090 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2091 2092 check_sampler_shuffle_shard_options(param_dict) 2093 2094 cache = param_dict.get('cache') 2095 check_cache_option(cache) 2096 2097 return method(self, *args, **kwargs) 2098 2099 return new_method 2100 2101 2102def check_food101_dataset(method): 2103 """A wrapper that wraps a parameter checker around the Food101Dataset.""" 2104 2105 @wraps(method) 2106 def new_method(self, *args, **kwargs): 2107 _, param_dict = parse_user_args(method, *args, **kwargs) 2108 2109 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2110 nreq_param_bool = ['decode', 'shuffle'] 2111 2112 dataset_dir = param_dict.get('dataset_dir') 2113 check_dir(dataset_dir) 2114 2115 usage = param_dict.get('usage') 2116 if usage is not None: 2117 check_valid_str(usage, ["train", "test", "all"], "usage") 2118 2119 validate_dataset_param_value(nreq_param_int, param_dict, int) 2120 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2121 2122 check_sampler_shuffle_shard_options(param_dict) 2123 2124 cache = param_dict.get('cache') 2125 check_cache_option(cache) 2126 2127 return method(self, *args, **kwargs) 2128 2129 return new_method 2130 2131 2132def check_sb_dataset(method): 2133 """A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset.""" 2134 2135 @wraps(method) 2136 def new_method(self, *args, **kwargs): 2137 _, param_dict = parse_user_args(method, *args, **kwargs) 2138 2139 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2140 nreq_param_bool = ['shuffle', 'decode'] 2141 2142 dataset_dir = param_dict.get('dataset_dir') 2143 check_dir(dataset_dir) 2144 2145 usage = param_dict.get('usage') 2146 if usage is not None: 2147 check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage") 2148 2149 task = param_dict.get('task') 2150 if task is not None: 2151 check_valid_str(task, ["Boundaries", "Segmentation"], "task") 2152 2153 validate_dataset_param_value(nreq_param_int, param_dict, int) 2154 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2155 2156 check_sampler_shuffle_shard_options(param_dict) 2157 2158 return method(self, *args, **kwargs) 2159 2160 return new_method 2161 2162 2163def check_speech_commands_dataset(method): 2164 """A wrapper that wraps a parameter checker around the original Dataset(SpeechCommandsDataset).""" 2165 2166 @wraps(method) 2167 def new_method(self, *args, **kwargs): 2168 _, param_dict = parse_user_args(method, *args, **kwargs) 2169 2170 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2171 nreq_param_bool = ['shuffle'] 2172 2173 dataset_dir = param_dict.get('dataset_dir') 2174 check_dir(dataset_dir) 2175 2176 usage = param_dict.get('usage') 2177 if usage is not None: 2178 check_valid_str(usage, ["train", "test", "valid", "all"], "usage") 2179 2180 validate_dataset_param_value(nreq_param_int, param_dict, int) 2181 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2182 2183 check_sampler_shuffle_shard_options(param_dict) 2184 2185 cache = param_dict.get('cache') 2186 check_cache_option(cache) 2187 2188 return method(self, *args, **kwargs) 2189 2190 return new_method 2191 2192 2193def check_squad_dataset(method): 2194 """A wrapper that wraps a parameter checker around the original Dataset(SQuADDataset).""" 2195 2196 @wraps(method) 2197 def new_method(self, *args, **kwargs): 2198 _, param_dict = parse_user_args(method, *args, **kwargs) 2199 2200 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2201 2202 dataset_dir = param_dict.get('dataset_dir') 2203 check_dir(dataset_dir) 2204 2205 # check usage 2206 usage = param_dict.get('usage') 2207 if usage is not None: 2208 check_valid_str(usage, ['train', 'dev', 'all'], "usage") 2209 2210 validate_dataset_param_value(nreq_param_int, param_dict, int) 2211 check_sampler_shuffle_shard_options(param_dict) 2212 2213 cache = param_dict.get('cache') 2214 check_cache_option(cache) 2215 2216 return method(self, *args, **kwargs) 2217 2218 return new_method 2219 2220 2221def check_cityscapes_dataset(method): 2222 """A wrapper that wraps a parameter checker around the original CityScapesDataset.""" 2223 2224 @wraps(method) 2225 def new_method(self, *args, **kwargs): 2226 _, param_dict = parse_user_args(method, *args, **kwargs) 2227 2228 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2229 nreq_param_bool = ['shuffle', 'decode'] 2230 2231 dataset_dir = param_dict.get('dataset_dir') 2232 check_dir(dataset_dir) 2233 2234 task = param_dict.get('task') 2235 check_valid_str(task, ["instance", "semantic", "polygon", "color"], "task") 2236 2237 quality_mode = param_dict.get('quality_mode') 2238 check_valid_str(quality_mode, ["fine", "coarse"], "quality_mode") 2239 2240 usage = param_dict.get('usage') 2241 if quality_mode == "fine": 2242 valid_strings = ["train", "test", "val", "all"] 2243 else: 2244 valid_strings = ["train", "train_extra", "val", "all"] 2245 check_valid_str(usage, valid_strings, "usage") 2246 2247 validate_dataset_param_value(nreq_param_int, param_dict, int) 2248 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2249 2250 check_sampler_shuffle_shard_options(param_dict) 2251 2252 return method(self, *args, **kwargs) 2253 2254 return new_method 2255 2256 2257def check_div2k_dataset(method): 2258 """A wrapper that wraps a parameter checker around the original DIV2KDataset.""" 2259 2260 @wraps(method) 2261 def new_method(self, *args, **kwargs): 2262 _, param_dict = parse_user_args(method, *args, **kwargs) 2263 2264 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2265 nreq_param_bool = ['shuffle', 'decode'] 2266 2267 dataset_dir = param_dict.get('dataset_dir') 2268 check_dir(dataset_dir) 2269 2270 usage = param_dict.get('usage') 2271 check_valid_str(usage, ['train', 'valid', 'all'], "usage") 2272 2273 downgrade = param_dict.get('downgrade') 2274 check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade') 2275 2276 validate_dataset_param_value(['scale'], param_dict, int) 2277 scale = param_dict.get('scale') 2278 scale_values = [2, 3, 4, 8] 2279 if scale not in scale_values: 2280 raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values))) 2281 2282 if scale == 8 and downgrade != "bicubic": 2283 raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.") 2284 2285 downgrade_2018 = ["mild", "difficult", "wild"] 2286 if downgrade in downgrade_2018 and scale != 4: 2287 raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade)) 2288 2289 validate_dataset_param_value(nreq_param_int, param_dict, int) 2290 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2291 2292 check_sampler_shuffle_shard_options(param_dict) 2293 2294 return method(self, *args, **kwargs) 2295 2296 return new_method 2297 2298 2299def check_fake_image_dataset(method): 2300 """A wrapper that wraps a parameter checker around the original Dataset(FakeImageDataset).""" 2301 2302 @wraps(method) 2303 def new_method(self, *args, **kwargs): 2304 _, param_dict = parse_user_args(method, *args, **kwargs) 2305 2306 nreq_param_int = ['num_images', 'num_classes', 'base_seed', 'num_samples', 2307 'num_parallel_workers', 'num_shards', 'shard_id'] 2308 nreq_param_bool = ['shuffle'] 2309 2310 validate_dataset_param_value(nreq_param_int, param_dict, int) 2311 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2312 2313 num_images = param_dict.get("num_images") 2314 check_pos_int32(num_images, "num_images") 2315 2316 image_size = param_dict.get("image_size") 2317 type_check(image_size, (list, tuple), "image_size") 2318 if len(image_size) != 3: 2319 raise ValueError("image_size should be a list or tuple of length 3, but got {0}".format(len(image_size))) 2320 for i, value in enumerate(image_size): 2321 check_pos_int32(value, "image_size[{0}]".format(i)) 2322 2323 num_classes = param_dict.get("num_classes") 2324 check_pos_int32(num_classes, "num_classes") 2325 2326 check_sampler_shuffle_shard_options(param_dict) 2327 2328 cache = param_dict.get('cache') 2329 check_cache_option(cache) 2330 2331 return method(self, *args, **kwargs) 2332 2333 return new_method 2334 2335 2336def check_ag_news_dataset(method): 2337 """A wrapper that wraps a parameter checker around the original Dataset(AGNewsDataset).""" 2338 2339 @wraps(method) 2340 def new_method(self, *args, **kwargs): 2341 _, param_dict = parse_user_args(method, *args, **kwargs) 2342 2343 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2344 2345 # check dataset_files; required argument 2346 dataset_dir = param_dict.get('dataset_dir') 2347 check_dir(dataset_dir) 2348 2349 # check usage 2350 usage = param_dict.get('usage') 2351 if usage is not None: 2352 check_valid_str(usage, ["train", "test", "all"], "usage") 2353 2354 validate_dataset_param_value(nreq_param_int, param_dict, int) 2355 check_sampler_shuffle_shard_options(param_dict) 2356 2357 cache = param_dict.get('cache') 2358 check_cache_option(cache) 2359 2360 return method(self, *args, **kwargs) 2361 2362 return new_method 2363 2364 2365def check_dbpedia_dataset(method): 2366 """A wrapper that wraps a parameter checker around the original DBpediaDataset.""" 2367 2368 @wraps(method) 2369 def new_method(self, *args, **kwargs): 2370 _, param_dict = parse_user_args(method, *args, **kwargs) 2371 2372 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2373 2374 dataset_dir = param_dict.get('dataset_dir') 2375 check_dir(dataset_dir) 2376 2377 usage = param_dict.get('usage') 2378 if usage is not None: 2379 check_valid_str(usage, ["train", "test", "all"], "usage") 2380 2381 validate_dataset_param_value(nreq_param_int, param_dict, int) 2382 2383 check_sampler_shuffle_shard_options(param_dict) 2384 2385 cache = param_dict.get('cache') 2386 check_cache_option(cache) 2387 2388 return method(self, *args, **kwargs) 2389 2390 return new_method 2391 2392 2393def check_wider_face_dataset(method): 2394 """A wrapper that wraps a parameter checker around the WIDERFaceDataset.""" 2395 2396 @wraps(method) 2397 def new_method(self, *args, **kwargs): 2398 _, param_dict = parse_user_args(method, *args, **kwargs) 2399 2400 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2401 nreq_param_bool = ['decode', 'shuffle'] 2402 2403 dataset_dir = param_dict.get('dataset_dir') 2404 check_dir(dataset_dir) 2405 2406 usage = param_dict.get('usage') 2407 if usage is not None: 2408 check_valid_str(usage, ["train", "test", "valid", "all"], "usage") 2409 2410 validate_dataset_param_value(nreq_param_int, param_dict, int) 2411 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2412 2413 check_sampler_shuffle_shard_options(param_dict) 2414 2415 cache = param_dict.get('cache') 2416 check_cache_option(cache) 2417 2418 return method(self, *args, **kwargs) 2419 2420 return new_method 2421 2422 2423def check_yelp_review_dataset(method): 2424 """A wrapper that wraps a parameter checker around the original Dataset(YelpReviewDataset).""" 2425 2426 @wraps(method) 2427 def new_method(self, *args, **kwargs): 2428 _, param_dict = parse_user_args(method, *args, **kwargs) 2429 2430 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2431 2432 dataset_dir = param_dict.get('dataset_dir') 2433 check_dir(dataset_dir) 2434 2435 # check usage 2436 usage = param_dict.get('usage') 2437 if usage is not None: 2438 check_valid_str(usage, ["train", "test", "all"], "usage") 2439 2440 validate_dataset_param_value(nreq_param_int, param_dict, int) 2441 check_sampler_shuffle_shard_options(param_dict) 2442 2443 cache = param_dict.get('cache') 2444 check_cache_option(cache) 2445 2446 return method(self, *args, **kwargs) 2447 2448 return new_method 2449 2450 2451def check_yes_no_dataset(method): 2452 """A wrapper that wraps a parameter checker around the original Dataset(YesNoDataset).""" 2453 2454 @wraps(method) 2455 def new_method(self, *args, **kwargs): 2456 _, param_dict = parse_user_args(method, *args, **kwargs) 2457 2458 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2459 nreq_param_bool = ['shuffle'] 2460 2461 dataset_dir = param_dict.get('dataset_dir') 2462 check_dir(dataset_dir) 2463 2464 validate_dataset_param_value(nreq_param_int, param_dict, int) 2465 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2466 2467 check_sampler_shuffle_shard_options(param_dict) 2468 2469 cache = param_dict.get('cache') 2470 check_cache_option(cache) 2471 2472 return method(self, *args, **kwargs) 2473 2474 return new_method 2475 2476 2477def check_tedlium_dataset(method): 2478 """Wrapper method to check the parameters of TedliumDataset.""" 2479 2480 @wraps(method) 2481 def new_method(self, *args, **kwargs): 2482 _, param_dict = parse_user_args(method, *args, **kwargs) 2483 2484 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2485 nreq_param_bool = ['shuffle'] 2486 2487 release = param_dict.get('release') 2488 check_valid_str(release, ["release1", "release2", "release3"], "release") 2489 2490 dataset_dir = param_dict.get('dataset_dir') 2491 check_dir(dataset_dir) 2492 2493 usage = param_dict.get('usage') 2494 if usage is not None: 2495 if release in ["release1", "release2"]: 2496 check_valid_str(usage, ["train", "test", "dev", "all"], "usage") 2497 else: 2498 check_valid_str(usage, ["all"], "usage") 2499 2500 validate_dataset_param_value(nreq_param_int, param_dict, int) 2501 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2502 2503 check_sampler_shuffle_shard_options(param_dict) 2504 2505 cache = param_dict.get('cache') 2506 check_cache_option(cache) 2507 2508 return method(self, *args, **kwargs) 2509 2510 return new_method 2511 2512 2513def check_svhn_dataset(method): 2514 """A wrapper that wraps a parameter checker around the original Dataset(SVHNDataset).""" 2515 2516 @wraps(method) 2517 def new_method(self, *args, **kwargs): 2518 _, param_dict = parse_user_args(method, *args, **kwargs) 2519 dataset_dir = param_dict.get('dataset_dir') 2520 check_dir(dataset_dir) 2521 2522 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2523 nreq_param_bool = ['shuffle'] 2524 2525 usage = param_dict.get('usage') 2526 if usage is not None: 2527 check_valid_str(usage, ["train", "test", "extra", "all"], "usage") 2528 if usage == "all": 2529 for _usage in ["train", "test", "extra"]: 2530 check_file(os.path.join(dataset_dir, _usage + "_32x32.mat")) 2531 else: 2532 check_file(os.path.join(dataset_dir, usage + "_32x32.mat")) 2533 2534 validate_dataset_param_value(nreq_param_int, param_dict, int) 2535 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2536 2537 check_sampler_shuffle_shard_options(param_dict) 2538 2539 return method(self, *args, **kwargs) 2540 2541 return new_method 2542 2543 2544def check_sst2_dataset(method): 2545 """A wrapper that wraps a parameter checker around the original SST2 Dataset.""" 2546 2547 @wraps(method) 2548 def new_method(self, *args, **kwargs): 2549 _, param_dict = parse_user_args(method, *args, **kwargs) 2550 2551 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2552 2553 dataset_dir = param_dict.get('dataset_dir') 2554 check_dir(dataset_dir) 2555 2556 usage = param_dict.get('usage') 2557 if usage is not None: 2558 check_valid_str(usage, ["train", "test", "dev"], "usage") 2559 2560 validate_dataset_param_value(nreq_param_int, param_dict, int) 2561 2562 check_sampler_shuffle_shard_options(param_dict) 2563 2564 cache = param_dict.get('cache') 2565 check_cache_option(cache) 2566 2567 return method(self, *args, **kwargs) 2568 2569 return new_method 2570 2571 2572def check_stl10_dataset(method): 2573 """A wrapper that wraps a parameter checker around the original Dataset(STL10Dataset).""" 2574 2575 @wraps(method) 2576 def new_method(self, *args, **kwargs): 2577 _, param_dict = parse_user_args(method, *args, **kwargs) 2578 2579 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2580 nreq_param_bool = ['shuffle'] 2581 2582 dataset_dir = param_dict.get('dataset_dir') 2583 check_dir(dataset_dir) 2584 2585 usage = param_dict.get('usage') 2586 if usage is not None: 2587 check_valid_str(usage, ["train", "test", "unlabeled", "train+unlabeled", "all"], "usage") 2588 if usage == "all": 2589 for _usage in ["train", "test", "unlabeled"]: 2590 check_file(os.path.join(dataset_dir, _usage + "_X.bin")) 2591 if _usage == "unlabeled": 2592 continue 2593 else: 2594 check_file(os.path.join(dataset_dir, _usage + "_y.bin")) 2595 elif usage == "train+unlabeled": 2596 check_file(os.path.join(dataset_dir, "train_X.bin")) 2597 check_file(os.path.join(dataset_dir, "train_y.bin")) 2598 check_file(os.path.join(dataset_dir, "unlabeled_X.bin")) 2599 elif usage == "unlabeled": 2600 check_file(os.path.join(dataset_dir, "unlabeled_X.bin")) 2601 else: 2602 check_file(os.path.join(dataset_dir, usage + "_X.bin")) 2603 check_file(os.path.join(dataset_dir, usage + "_y.bin")) 2604 2605 validate_dataset_param_value(nreq_param_int, param_dict, int) 2606 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2607 2608 check_sampler_shuffle_shard_options(param_dict) 2609 2610 cache = param_dict.get('cache') 2611 check_cache_option(cache) 2612 2613 return method(self, *args, **kwargs) 2614 2615 return new_method 2616 2617 2618def check_sun397_dataset(method): 2619 """A wrapper that wraps a parameter checker around the original Dataset(SUN397Dataset).""" 2620 2621 @wraps(method) 2622 def new_method(self, *args, **kwargs): 2623 _, param_dict = parse_user_args(method, *args, **kwargs) 2624 2625 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2626 nreq_param_bool = ['shuffle', 'decode'] 2627 2628 dataset_dir = param_dict.get('dataset_dir') 2629 check_dir(dataset_dir) 2630 2631 validate_dataset_param_value(nreq_param_int, param_dict, int) 2632 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2633 check_sampler_shuffle_shard_options(param_dict) 2634 2635 cache = param_dict.get('cache') 2636 check_cache_option(cache) 2637 2638 return method(self, *args, **kwargs) 2639 2640 return new_method 2641 2642 2643def check_yahoo_answers_dataset(method): 2644 """A wrapper that wraps a parameter checker around the original YahooAnswers Dataset.""" 2645 2646 @wraps(method) 2647 def new_method(self, *args, **kwargs): 2648 _, param_dict = parse_user_args(method, *args, **kwargs) 2649 2650 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2651 2652 dataset_dir = param_dict.get('dataset_dir') 2653 check_dir(dataset_dir) 2654 2655 usage = param_dict.get('usage') 2656 if usage is not None: 2657 check_valid_str(usage, ["train", "test", "all"], "usage") 2658 2659 validate_dataset_param_value(nreq_param_int, param_dict, int) 2660 2661 check_sampler_shuffle_shard_options(param_dict) 2662 2663 cache = param_dict.get('cache') 2664 check_cache_option(cache) 2665 2666 return method(self, *args, **kwargs) 2667 2668 return new_method 2669 2670 2671def check_conll2000_dataset(method): 2672 """ A wrapper that wraps a parameter checker around the original Dataset(CoNLL2000Dataset).""" 2673 2674 @wraps(method) 2675 def new_method(self, *args, **kwargs): 2676 _, param_dict = parse_user_args(method, *args, **kwargs) 2677 2678 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2679 2680 # check dataset_dir 2681 dataset_dir = param_dict.get('dataset_dir') 2682 check_dir(dataset_dir) 2683 2684 # check usage 2685 usage = param_dict.get('usage') 2686 if usage is not None: 2687 check_valid_str(usage, ["train", "test", "all"], "usage") 2688 2689 validate_dataset_param_value(nreq_param_int, param_dict, int) 2690 check_sampler_shuffle_shard_options(param_dict) 2691 2692 cache = param_dict.get('cache') 2693 check_cache_option(cache) 2694 2695 return method(self, *args, **kwargs) 2696 2697 return new_method 2698 2699 2700def check_amazon_review_dataset(method): 2701 """A wrapper that wraps a parameter checker around the original Dataset(AmazonReviewDataset).""" 2702 2703 @wraps(method) 2704 def new_method(self, *args, **kwargs): 2705 _, param_dict = parse_user_args(method, *args, **kwargs) 2706 2707 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2708 2709 # check dataset_files 2710 dataset_dir = param_dict.get('dataset_dir') 2711 check_dir(dataset_dir) 2712 2713 # check usage 2714 usage = param_dict.get('usage') 2715 if usage is not None: 2716 check_valid_str(usage, ["train", "test", "all"], "usage") 2717 2718 validate_dataset_param_value(nreq_param_int, param_dict, int) 2719 check_sampler_shuffle_shard_options(param_dict) 2720 2721 cache = param_dict.get('cache') 2722 check_cache_option(cache) 2723 2724 return method(self, *args, **kwargs) 2725 2726 return new_method 2727 2728 2729def check_semeion_dataset(method): 2730 """Wrapper method to check the parameters of SemeionDataset.""" 2731 2732 @wraps(method) 2733 def new_method(self, *args, **kwargs): 2734 _, param_dict = parse_user_args(method, *args, **kwargs) 2735 2736 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2737 nreq_param_bool = ['shuffle'] 2738 2739 dataset_dir = param_dict.get('dataset_dir') 2740 check_dir(dataset_dir) 2741 2742 validate_dataset_param_value(nreq_param_int, param_dict, int) 2743 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2744 2745 check_sampler_shuffle_shard_options(param_dict) 2746 2747 cache = param_dict.get('cache') 2748 check_cache_option(cache) 2749 2750 return method(self, *args, **kwargs) 2751 2752 return new_method 2753 2754 2755def check_wiki_text_dataset(method): 2756 """A wrapper that wraps a parameter checker around the original Dataset(WikiTextDataset).""" 2757 2758 @wraps(method) 2759 def new_method(self, *args, **kwargs): 2760 _, param_dict = parse_user_args(method, *args, **kwargs) 2761 2762 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2763 2764 # check dataset_dir 2765 dataset_dir = param_dict.get('dataset_dir') 2766 check_dir(dataset_dir) 2767 2768 # check usage 2769 usage = param_dict.get('usage') 2770 if usage is not None: 2771 check_valid_str(usage, ["train", "valid", "test", "all"], "usage") 2772 2773 validate_dataset_param_value(nreq_param_int, param_dict, int) 2774 check_sampler_shuffle_shard_options(param_dict) 2775 2776 cache = param_dict.get('cache') 2777 check_cache_option(cache) 2778 2779 return method(self, *args, **kwargs) 2780 2781 return new_method 2782 2783 2784def check_en_wik9_dataset(method): 2785 """Wrapper method to check the parameters of EnWik9 dataset.""" 2786 2787 @wraps(method) 2788 def new_method(self, *args, **kwargs): 2789 _, param_dict = parse_user_args(method, *args, **kwargs) 2790 2791 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2792 dataset_dir = param_dict.get('dataset_dir') 2793 check_dir(dataset_dir) 2794 2795 validate_dataset_param_value(nreq_param_int, param_dict, int) 2796 check_sampler_shuffle_shard_options(param_dict) 2797 2798 cache = param_dict.get('cache') 2799 check_cache_option(cache) 2800 2801 return method(self, *args, **kwargs) 2802 2803 return new_method 2804 2805 2806def check_multi30k_dataset(method): 2807 """A wrapper that wraps a parameter checker around the original Dataset (Multi30kDataset).""" 2808 2809 @wraps(method) 2810 def new_method(self, *args, **kwargs): 2811 _, param_dict = parse_user_args(method, *args, **kwargs) 2812 2813 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 2814 nreq_param_bool = ['shuffle', 'decode'] 2815 2816 dataset_dir = param_dict.get('dataset_dir') 2817 check_dir(dataset_dir) 2818 2819 usage = param_dict.get('usage') 2820 if usage is not None: 2821 check_valid_str(usage, ["train", "test", "valid", "all"], "usage") 2822 2823 language_pair = param_dict.get('language_pair') 2824 support_language_pair = [['en', 'de'], ['de', 'en'], ('en', 'de'), ('de', 'en')] 2825 if language_pair is not None: 2826 type_check(language_pair, (list, tuple), "language_pair") 2827 if len(language_pair) != 2: 2828 raise ValueError( 2829 "language_pair should be a list or tuple of length 2, but got {0}".format(len(language_pair))) 2830 if language_pair not in support_language_pair: 2831 raise ValueError( 2832 "language_pair can only be ['en', 'de'] or ['en', 'de'], but got {0}".format(language_pair)) 2833 2834 validate_dataset_param_value(nreq_param_int, param_dict, int) 2835 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2836 2837 check_sampler_shuffle_shard_options(param_dict) 2838 2839 return method(self, *args, **kwargs) 2840 2841 return new_method 2842 2843 2844def check_obsminddataset(method): 2845 """A wrapper that wraps a parameter checker around the original Dataset(OBSMindDataset).""" 2846 2847 @wraps(method) 2848 def new_method(self, *args, **kwargs): 2849 _, param_dict = parse_user_args(method, *args, **kwargs) 2850 2851 nreq_param_int = ['num_shards', 'shard_id'] 2852 nreq_param_list = ['columns_list'] 2853 nreq_param_bool = ['shard_equal_rows'] 2854 nreq_param_str = ['server', 'ak', 'sk', 'sync_obs_path'] 2855 2856 dataset_files = param_dict.get('dataset_files') 2857 type_check(dataset_files, (list,), "dataset_files") 2858 for dataset_file in dataset_files: 2859 if not isinstance(dataset_file, str): 2860 raise TypeError("Item of dataset files is not of type [{}], but got {}.".format(type(''), 2861 type(dataset_file))) 2862 validate_dataset_param_value(nreq_param_int, param_dict, int) 2863 validate_dataset_param_value(nreq_param_list, param_dict, list) 2864 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 2865 validate_dataset_param_value(nreq_param_str, param_dict, str) 2866 2867 server = param_dict.get('server') 2868 if not server.startswith(('http://', 'https://')): 2869 raise ValueError("server should be a str that starts with http:// or https://, but got {}.".format(server)) 2870 2871 check_sampler_shuffle_shard_options(param_dict) 2872 2873 return method(self, *args, **kwargs) 2874 2875 return new_method 2876