1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License foNtest_resr the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16""" 17Built-in validators. 18""" 19import inspect as ins 20import os 21import re 22from functools import wraps 23 24import numpy as np 25from mindspore._c_expression import typing 26from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ 27 INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ 28 validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ 29 check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str 30 31from . import datasets 32from . import samplers 33from . import cache_client 34 35 36def check_imagefolderdataset(method): 37 """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset).""" 38 39 @wraps(method) 40 def new_method(self, *args, **kwargs): 41 _, param_dict = parse_user_args(method, *args, **kwargs) 42 43 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 44 nreq_param_bool = ['shuffle', 'decode'] 45 nreq_param_list = ['extensions'] 46 nreq_param_dict = ['class_indexing'] 47 48 dataset_dir = param_dict.get('dataset_dir') 49 check_dir(dataset_dir) 50 51 validate_dataset_param_value(nreq_param_int, param_dict, int) 52 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 53 validate_dataset_param_value(nreq_param_list, param_dict, list) 54 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 55 check_sampler_shuffle_shard_options(param_dict) 56 57 cache = param_dict.get('cache') 58 check_cache_option(cache) 59 60 return method(self, *args, **kwargs) 61 62 return new_method 63 64 65def check_mnist_cifar_dataset(method): 66 """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset).""" 67 68 @wraps(method) 69 def new_method(self, *args, **kwargs): 70 _, param_dict = parse_user_args(method, *args, **kwargs) 71 72 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 73 nreq_param_bool = ['shuffle'] 74 75 dataset_dir = param_dict.get('dataset_dir') 76 check_dir(dataset_dir) 77 78 usage = param_dict.get('usage') 79 if usage is not None: 80 check_valid_str(usage, ["train", "test", "all"], "usage") 81 82 validate_dataset_param_value(nreq_param_int, param_dict, int) 83 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 84 85 check_sampler_shuffle_shard_options(param_dict) 86 87 cache = param_dict.get('cache') 88 check_cache_option(cache) 89 90 return method(self, *args, **kwargs) 91 92 return new_method 93 94 95def check_manifestdataset(method): 96 """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset).""" 97 98 @wraps(method) 99 def new_method(self, *args, **kwargs): 100 _, param_dict = parse_user_args(method, *args, **kwargs) 101 102 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 103 nreq_param_bool = ['shuffle', 'decode'] 104 nreq_param_str = ['usage'] 105 nreq_param_dict = ['class_indexing'] 106 107 dataset_file = param_dict.get('dataset_file') 108 check_file(dataset_file) 109 110 validate_dataset_param_value(nreq_param_int, param_dict, int) 111 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 112 validate_dataset_param_value(nreq_param_str, param_dict, str) 113 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 114 115 check_sampler_shuffle_shard_options(param_dict) 116 117 cache = param_dict.get('cache') 118 check_cache_option(cache) 119 120 return method(self, *args, **kwargs) 121 122 return new_method 123 124 125def check_sbu_dataset(method): 126 """A wrapper that wraps a parameter checker around the original Dataset(SBUDataset).""" 127 128 @wraps(method) 129 def new_method(self, *args, **kwargs): 130 _, param_dict = parse_user_args(method, *args, **kwargs) 131 132 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 133 nreq_param_bool = ['shuffle', 'decode'] 134 135 dataset_dir = param_dict.get('dataset_dir') 136 check_dir(dataset_dir) 137 138 check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt")) 139 check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt")) 140 check_dir(os.path.join(dataset_dir, "sbu_images")) 141 142 validate_dataset_param_value(nreq_param_int, param_dict, int) 143 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 144 145 check_sampler_shuffle_shard_options(param_dict) 146 147 cache = param_dict.get('cache') 148 check_cache_option(cache) 149 150 return method(self, *args, **kwargs) 151 152 return new_method 153 154 155def check_tfrecorddataset(method): 156 """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset).""" 157 158 @wraps(method) 159 def new_method(self, *args, **kwargs): 160 _, param_dict = parse_user_args(method, *args, **kwargs) 161 162 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 163 nreq_param_list = ['columns_list'] 164 nreq_param_bool = ['shard_equal_rows'] 165 166 dataset_files = param_dict.get('dataset_files') 167 if not isinstance(dataset_files, (str, list)): 168 raise TypeError("dataset_files should be type str or a list of strings.") 169 170 validate_dataset_param_value(nreq_param_int, param_dict, int) 171 validate_dataset_param_value(nreq_param_list, param_dict, list) 172 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 173 174 check_sampler_shuffle_shard_options(param_dict) 175 176 cache = param_dict.get('cache') 177 check_cache_option(cache) 178 179 return method(self, *args, **kwargs) 180 181 return new_method 182 183 184def check_usps_dataset(method): 185 """A wrapper that wraps a parameter checker around the original Dataset(USPSDataset).""" 186 187 @wraps(method) 188 def new_method(self, *args, **kwargs): 189 _, param_dict = parse_user_args(method, *args, **kwargs) 190 191 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 192 193 dataset_dir = param_dict.get('dataset_dir') 194 check_dir(dataset_dir) 195 196 usage = param_dict.get('usage') 197 if usage is not None: 198 check_valid_str(usage, ["train", "test", "all"], "usage") 199 200 validate_dataset_param_value(nreq_param_int, param_dict, int) 201 check_sampler_shuffle_shard_options(param_dict) 202 203 cache = param_dict.get('cache') 204 check_cache_option(cache) 205 206 return method(self, *args, **kwargs) 207 208 return new_method 209 210 211def check_vocdataset(method): 212 """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset).""" 213 214 @wraps(method) 215 def new_method(self, *args, **kwargs): 216 _, param_dict = parse_user_args(method, *args, **kwargs) 217 218 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 219 nreq_param_bool = ['shuffle', 'decode'] 220 nreq_param_dict = ['class_indexing'] 221 222 dataset_dir = param_dict.get('dataset_dir') 223 check_dir(dataset_dir) 224 225 task = param_dict.get('task') 226 type_check(task, (str,), "task") 227 228 usage = param_dict.get('usage') 229 type_check(usage, (str,), "usage") 230 dataset_dir = os.path.realpath(dataset_dir) 231 232 if task == "Segmentation": 233 imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt") 234 if param_dict.get('class_indexing') is not None: 235 raise ValueError("class_indexing is not supported in Segmentation task.") 236 elif task == "Detection": 237 imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt") 238 else: 239 raise ValueError("Invalid task : " + task + ".") 240 241 check_file(imagesets_file) 242 243 validate_dataset_param_value(nreq_param_int, param_dict, int) 244 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 245 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 246 check_sampler_shuffle_shard_options(param_dict) 247 248 cache = param_dict.get('cache') 249 check_cache_option(cache) 250 251 return method(self, *args, **kwargs) 252 253 return new_method 254 255 256def check_cocodataset(method): 257 """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset).""" 258 259 @wraps(method) 260 def new_method(self, *args, **kwargs): 261 _, param_dict = parse_user_args(method, *args, **kwargs) 262 263 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 264 nreq_param_bool = ['shuffle', 'decode'] 265 266 dataset_dir = param_dict.get('dataset_dir') 267 check_dir(dataset_dir) 268 269 annotation_file = param_dict.get('annotation_file') 270 check_file(annotation_file) 271 272 task = param_dict.get('task') 273 type_check(task, (str,), "task") 274 275 if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: 276 raise ValueError("Invalid task type: " + task + ".") 277 278 validate_dataset_param_value(nreq_param_int, param_dict, int) 279 280 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 281 282 sampler = param_dict.get('sampler') 283 if sampler is not None and isinstance(sampler, samplers.PKSampler): 284 raise ValueError("CocoDataset doesn't support PKSampler.") 285 check_sampler_shuffle_shard_options(param_dict) 286 287 cache = param_dict.get('cache') 288 check_cache_option(cache) 289 290 return method(self, *args, **kwargs) 291 292 return new_method 293 294 295def check_celebadataset(method): 296 """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset).""" 297 298 @wraps(method) 299 def new_method(self, *args, **kwargs): 300 _, param_dict = parse_user_args(method, *args, **kwargs) 301 302 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 303 nreq_param_bool = ['shuffle', 'decode'] 304 nreq_param_list = ['extensions'] 305 nreq_param_str = ['dataset_type'] 306 307 dataset_dir = param_dict.get('dataset_dir') 308 309 check_dir(dataset_dir) 310 311 validate_dataset_param_value(nreq_param_int, param_dict, int) 312 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 313 validate_dataset_param_value(nreq_param_list, param_dict, list) 314 validate_dataset_param_value(nreq_param_str, param_dict, str) 315 316 usage = param_dict.get('usage') 317 if usage is not None and usage not in ('all', 'train', 'valid', 'test'): 318 raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.") 319 320 check_sampler_shuffle_shard_options(param_dict) 321 322 sampler = param_dict.get('sampler') 323 if sampler is not None and isinstance(sampler, samplers.PKSampler): 324 raise ValueError("CelebADataset doesn't support PKSampler.") 325 326 cache = param_dict.get('cache') 327 check_cache_option(cache) 328 329 return method(self, *args, **kwargs) 330 331 return new_method 332 333 334def check_save(method): 335 """A wrapper that wraps a parameter checker around the saved operator.""" 336 337 @wraps(method) 338 def new_method(self, *args, **kwargs): 339 _, param_dict = parse_user_args(method, *args, **kwargs) 340 341 nreq_param_int = ['num_files'] 342 nreq_param_str = ['file_name', 'file_type'] 343 validate_dataset_param_value(nreq_param_int, param_dict, int) 344 if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): 345 raise ValueError("num_files should between 0 and 1000.") 346 validate_dataset_param_value(nreq_param_str, param_dict, str) 347 if param_dict.get('file_type') != 'mindrecord': 348 raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type'))) 349 return method(self, *args, **kwargs) 350 351 return new_method 352 353 354def check_tuple_iterator(method): 355 """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" 356 357 @wraps(method) 358 def new_method(self, *args, **kwargs): 359 [columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs) 360 nreq_param_bool = ['output_numpy'] 361 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 362 if num_epochs is not None: 363 type_check(num_epochs, (int,), "num_epochs") 364 check_value(num_epochs, [-1, INT32_MAX], "num_epochs") 365 366 if columns is not None: 367 check_columns(columns, "column_names") 368 369 return method(self, *args, **kwargs) 370 371 return new_method 372 373 374def check_dict_iterator(method): 375 """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" 376 377 @wraps(method) 378 def new_method(self, *args, **kwargs): 379 [num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs) 380 nreq_param_bool = ['output_numpy'] 381 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 382 if num_epochs is not None: 383 type_check(num_epochs, (int,), "num_epochs") 384 check_value(num_epochs, [-1, INT32_MAX], "num_epochs") 385 386 return method(self, *args, **kwargs) 387 388 return new_method 389 390 391def check_minddataset(method): 392 """A wrapper that wraps a parameter checker around the original Dataset(MindDataset).""" 393 394 @wraps(method) 395 def new_method(self, *args, **kwargs): 396 _, param_dict = parse_user_args(method, *args, **kwargs) 397 398 nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded'] 399 nreq_param_list = ['columns_list'] 400 nreq_param_dict = ['padded_sample'] 401 402 dataset_file = param_dict.get('dataset_file') 403 if isinstance(dataset_file, list): 404 if len(dataset_file) > 4096: 405 raise ValueError("length of dataset_file should less than or equal to {}.".format(4096)) 406 for f in dataset_file: 407 check_file(f) 408 else: 409 check_file(dataset_file) 410 411 validate_dataset_param_value(nreq_param_int, param_dict, int) 412 validate_dataset_param_value(nreq_param_list, param_dict, list) 413 validate_dataset_param_value(nreq_param_dict, param_dict, dict) 414 415 check_sampler_shuffle_shard_options(param_dict) 416 417 check_padding_options(param_dict) 418 return method(self, *args, **kwargs) 419 420 return new_method 421 422 423def check_generatordataset(method): 424 """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset).""" 425 426 @wraps(method) 427 def new_method(self, *args, **kwargs): 428 _, param_dict = parse_user_args(method, *args, **kwargs) 429 430 source = param_dict.get('source') 431 432 if not callable(source): 433 try: 434 iter(source) 435 except TypeError: 436 raise TypeError("Input `source` function of GeneratorDataset should be callable, iterable or random" 437 " accessible, commonly it should implement one of the method like yield, __getitem__ or" 438 " __next__(__iter__).") 439 440 column_names = param_dict.get('column_names') 441 if column_names is not None: 442 check_columns(column_names, "column_names") 443 schema = param_dict.get('schema') 444 if column_names is None and schema is None: 445 raise ValueError("Neither columns_names nor schema are provided.") 446 447 if schema is not None: 448 if not isinstance(schema, datasets.Schema) and not isinstance(schema, str): 449 raise ValueError("schema should be a path to schema file or a schema object.") 450 451 # check optional argument 452 nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"] 453 validate_dataset_param_value(nreq_param_int, param_dict, int) 454 nreq_param_list = ["column_types"] 455 validate_dataset_param_value(nreq_param_list, param_dict, list) 456 nreq_param_bool = ["shuffle"] 457 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 458 459 num_shards = param_dict.get("num_shards") 460 shard_id = param_dict.get("shard_id") 461 if (num_shards is None) != (shard_id is None): 462 # These two parameters appear together. 463 raise ValueError("num_shards and shard_id need to be passed in together.") 464 if num_shards is not None: 465 check_pos_int32(num_shards, "num_shards") 466 if shard_id >= num_shards: 467 raise ValueError("shard_id should be less than num_shards.") 468 469 sampler = param_dict.get("sampler") 470 if sampler is not None: 471 if isinstance(sampler, samplers.PKSampler): 472 raise ValueError("GeneratorDataset doesn't support PKSampler.") 473 if not isinstance(sampler, samplers.BuiltinSampler): 474 try: 475 iter(sampler) 476 except TypeError: 477 raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.") 478 479 if sampler is not None and not hasattr(source, "__getitem__"): 480 raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.") 481 if num_shards is not None and not hasattr(source, "__getitem__"): 482 raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.") 483 484 return method(self, *args, **kwargs) 485 486 return new_method 487 488 489def check_random_dataset(method): 490 """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset).""" 491 492 @wraps(method) 493 def new_method(self, *args, **kwargs): 494 _, param_dict = parse_user_args(method, *args, **kwargs) 495 496 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows'] 497 nreq_param_bool = ['shuffle'] 498 nreq_param_list = ['columns_list'] 499 500 validate_dataset_param_value(nreq_param_int, param_dict, int) 501 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 502 validate_dataset_param_value(nreq_param_list, param_dict, list) 503 504 check_sampler_shuffle_shard_options(param_dict) 505 506 cache = param_dict.get('cache') 507 check_cache_option(cache) 508 509 return method(self, *args, **kwargs) 510 511 return new_method 512 513 514def check_pad_info(key, val): 515 """check the key and value pair of pad_info in batch""" 516 type_check(key, (str,), "key in pad_info") 517 518 if val is not None: 519 if len(val) != 2: 520 raise ValueError("value of pad_info should be a tuple of size 2.") 521 type_check(val, (tuple,), "value in pad_info") 522 523 if val[0] is not None: 524 type_check(val[0], (list,), "shape in pad_info") 525 526 for dim in val[0]: 527 if dim is not None: 528 check_pos_int32(dim, "dim of shape in pad_info") 529 if val[1] is not None: 530 type_check(val[1], (int, float, str, bytes), "pad_value") 531 532 533def check_bucket_batch_by_length(method): 534 """check the input arguments of bucket_batch_by_length.""" 535 536 @wraps(method) 537 def new_method(self, *args, **kwargs): 538 [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info, 539 pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs) 540 541 nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] 542 543 type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list) 544 545 nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder'] 546 type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) 547 548 # check column_names: must be list of string. 549 check_columns(column_names, "column_names") 550 551 if element_length_function is None and len(column_names) != 1: 552 raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") 553 554 if element_length_function is not None and not callable(element_length_function): 555 raise TypeError("element_length_function object is not callable.") 556 557 # check bucket_boundaries: must be list of int, positive and strictly increasing 558 if not bucket_boundaries: 559 raise ValueError("bucket_boundaries cannot be empty.") 560 561 all_int = all(isinstance(item, int) for item in bucket_boundaries) 562 if not all_int: 563 raise TypeError("bucket_boundaries should be a list of int.") 564 565 all_non_negative = all(item > 0 for item in bucket_boundaries) 566 if not all_non_negative: 567 raise ValueError("bucket_boundaries must only contain positive numbers.") 568 569 for i in range(len(bucket_boundaries) - 1): 570 if not bucket_boundaries[i + 1] > bucket_boundaries[i]: 571 raise ValueError("bucket_boundaries should be strictly increasing.") 572 573 # check bucket_batch_sizes: must be list of int and positive 574 if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: 575 raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") 576 577 all_int = all(isinstance(item, int) for item in bucket_batch_sizes) 578 if not all_int: 579 raise TypeError("bucket_batch_sizes should be a list of int.") 580 581 all_non_negative = all(item > 0 for item in bucket_batch_sizes) 582 if not all_non_negative: 583 raise ValueError("bucket_batch_sizes should be a list of positive numbers.") 584 585 if pad_info is not None: 586 type_check(pad_info, (dict,), "pad_info") 587 588 for k, v in pad_info.items(): 589 check_pad_info(k, v) 590 591 return method(self, *args, **kwargs) 592 593 return new_method 594 595 596def check_batch(method): 597 """check the input arguments of batch.""" 598 599 @wraps(method) 600 def new_method(self, *args, **kwargs): 601 [batch_size, drop_remainder, num_parallel_workers, per_batch_map, 602 input_columns, output_columns, column_order, pad_info, 603 python_multiprocessing, max_rowsize], param_dict = parse_user_args(method, *args, **kwargs) 604 605 if not (isinstance(batch_size, int) or (callable(batch_size))): 606 raise TypeError("batch_size should either be an int or a callable.") 607 608 if callable(batch_size): 609 sig = ins.signature(batch_size) 610 if len(sig.parameters) != 1: 611 raise ValueError("callable batch_size should take one parameter (BatchInfo).") 612 else: 613 check_pos_int32(int(batch_size), "batch_size") 614 615 if num_parallel_workers is not None: 616 check_num_parallel_workers(num_parallel_workers) 617 type_check(drop_remainder, (bool,), "drop_remainder") 618 type_check(max_rowsize, (int,), "max_rowsize") 619 620 if (pad_info is not None) and (per_batch_map is not None): 621 raise ValueError("pad_info and per_batch_map can't both be set.") 622 623 if pad_info is not None: 624 type_check(param_dict["pad_info"], (dict,), "pad_info") 625 for k, v in param_dict.get('pad_info').items(): 626 check_pad_info(k, v) 627 628 if (per_batch_map is None) != (input_columns is None): 629 # These two parameters appear together. 630 raise ValueError("per_batch_map and input_columns need to be passed in together.") 631 632 if input_columns is not None: 633 check_columns(input_columns, "input_columns") 634 if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): 635 raise ValueError("The signature of per_batch_map should match with input columns.") 636 637 if output_columns is not None: 638 check_columns(output_columns, "output_columns") 639 640 if column_order is not None: 641 check_columns(column_order, "column_order") 642 643 if python_multiprocessing is not None: 644 type_check(python_multiprocessing, (bool,), "python_multiprocessing") 645 646 return method(self, *args, **kwargs) 647 648 return new_method 649 650 651def check_sync_wait(method): 652 """check the input arguments of sync_wait.""" 653 654 @wraps(method) 655 def new_method(self, *args, **kwargs): 656 [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs) 657 658 type_check(condition_name, (str,), "condition_name") 659 type_check(num_batch, (int,), "num_batch") 660 661 return method(self, *args, **kwargs) 662 663 return new_method 664 665 666def check_shuffle(method): 667 """check the input arguments of shuffle.""" 668 669 @wraps(method) 670 def new_method(self, *args, **kwargs): 671 [buffer_size], _ = parse_user_args(method, *args, **kwargs) 672 673 type_check(buffer_size, (int,), "buffer_size") 674 675 check_value(buffer_size, [2, INT32_MAX], "buffer_size") 676 677 return method(self, *args, **kwargs) 678 679 return new_method 680 681 682def check_map(method): 683 """check the input arguments of map.""" 684 685 @wraps(method) 686 def new_method(self, *args, **kwargs): 687 from mindspore.dataset.callback import DSCallback 688 [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache, 689 callbacks, max_rowsize], _ = \ 690 parse_user_args(method, *args, **kwargs) 691 692 nreq_param_columns = ['input_columns', 'output_columns', 'column_order'] 693 694 if column_order is not None: 695 type_check(column_order, (list,), "column_order") 696 if num_parallel_workers is not None: 697 check_num_parallel_workers(num_parallel_workers) 698 type_check(python_multiprocessing, (bool,), "python_multiprocessing") 699 check_cache_option(cache) 700 type_check(max_rowsize, (int,), "max_rowsize") 701 702 if callbacks is not None: 703 if isinstance(callbacks, (list, tuple)): 704 type_check_list(callbacks, (DSCallback,), "callbacks") 705 else: 706 type_check(callbacks, (DSCallback,), "callbacks") 707 708 for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]): 709 if param is not None: 710 check_columns(param, param_name) 711 if callbacks is not None: 712 type_check(callbacks, (list, DSCallback), "callbacks") 713 714 return method(self, *args, **kwargs) 715 716 return new_method 717 718 719def check_filter(method): 720 """"check the input arguments of filter.""" 721 722 @wraps(method) 723 def new_method(self, *args, **kwargs): 724 [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) 725 if not callable(predicate): 726 raise TypeError("Predicate should be a Python function or a callable Python object.") 727 728 if num_parallel_workers is not None: 729 check_num_parallel_workers(num_parallel_workers) 730 731 if input_columns is not None: 732 check_columns(input_columns, "input_columns") 733 734 return method(self, *args, **kwargs) 735 736 return new_method 737 738 739def check_repeat(method): 740 """check the input arguments of repeat.""" 741 742 @wraps(method) 743 def new_method(self, *args, **kwargs): 744 [count], _ = parse_user_args(method, *args, **kwargs) 745 746 type_check(count, (int, type(None)), "repeat") 747 if isinstance(count, int): 748 if (count <= 0 and count != -1) or count > INT32_MAX: 749 raise ValueError("count should be either -1 or positive integer, range[1, INT32_MAX].") 750 return method(self, *args, **kwargs) 751 752 return new_method 753 754 755def check_skip(method): 756 """check the input arguments of skip.""" 757 758 @wraps(method) 759 def new_method(self, *args, **kwargs): 760 [count], _ = parse_user_args(method, *args, **kwargs) 761 762 type_check(count, (int,), "count") 763 check_value(count, (0, INT32_MAX), "count") 764 765 return method(self, *args, **kwargs) 766 767 return new_method 768 769 770def check_take(method): 771 """check the input arguments of take.""" 772 773 @wraps(method) 774 def new_method(self, *args, **kwargs): 775 [count], _ = parse_user_args(method, *args, **kwargs) 776 type_check(count, (int,), "count") 777 if (count <= 0 and count != -1) or count > INT32_MAX: 778 raise ValueError("count should be either -1 or within the required interval of ({}, {}], got {}." 779 .format(0, INT32_MAX, count)) 780 781 return method(self, *args, **kwargs) 782 783 return new_method 784 785 786def check_positive_int32(method): 787 """check whether the input argument is positive and int, only works for functions with one input.""" 788 789 @wraps(method) 790 def new_method(self, *args, **kwargs): 791 [count], param_dict = parse_user_args(method, *args, **kwargs) 792 para_name = None 793 for key in list(param_dict.keys()): 794 if key not in ['self', 'cls']: 795 para_name = key 796 # Need to get default value of param 797 if count is not None: 798 check_pos_int32(count, para_name) 799 800 return method(self, *args, **kwargs) 801 802 return new_method 803 804 805def check_device_send(method): 806 """check the input argument for to_device and device_que.""" 807 808 @wraps(method) 809 def new_method(self, *args, **kwargs): 810 [send_epoch_end, create_data_info_queue], _ = parse_user_args(method, *args, **kwargs) 811 type_check(send_epoch_end, (bool,), "send_epoch_end") 812 type_check(create_data_info_queue, (bool,), "create_data_info_queue") 813 814 return method(self, *args, **kwargs) 815 816 return new_method 817 818 819def check_zip(method): 820 """check the input arguments of zip.""" 821 822 @wraps(method) 823 def new_method(*args, **kwargs): 824 [ds], _ = parse_user_args(method, *args, **kwargs) 825 type_check(ds, (tuple,), "datasets") 826 827 return method(*args, **kwargs) 828 829 return new_method 830 831 832def check_zip_dataset(method): 833 """check the input arguments of zip method in `Dataset`.""" 834 835 @wraps(method) 836 def new_method(self, *args, **kwargs): 837 [ds], _ = parse_user_args(method, *args, **kwargs) 838 type_check(ds, (tuple, datasets.Dataset), "datasets") 839 840 return method(self, *args, **kwargs) 841 842 return new_method 843 844 845def check_concat(method): 846 """check the input arguments of concat method in `Dataset`.""" 847 848 @wraps(method) 849 def new_method(self, *args, **kwargs): 850 [ds], _ = parse_user_args(method, *args, **kwargs) 851 type_check(ds, (list, datasets.Dataset), "datasets") 852 if isinstance(ds, list): 853 type_check_list(ds, (datasets.Dataset,), "dataset") 854 return method(self, *args, **kwargs) 855 856 return new_method 857 858 859def check_rename(method): 860 """check the input arguments of rename.""" 861 862 @wraps(method) 863 def new_method(self, *args, **kwargs): 864 values, _ = parse_user_args(method, *args, **kwargs) 865 866 req_param_columns = ['input_columns', 'output_columns'] 867 for param_name, param in zip(req_param_columns, values): 868 check_columns(param, param_name) 869 870 input_size, output_size = 1, 1 871 input_columns, output_columns = values 872 if isinstance(input_columns, list): 873 input_size = len(input_columns) 874 if isinstance(output_columns, list): 875 output_size = len(output_columns) 876 if input_size != output_size: 877 raise ValueError("Number of column in input_columns and output_columns is not equal.") 878 879 return method(self, *args, **kwargs) 880 881 return new_method 882 883 884def check_project(method): 885 """check the input arguments of project.""" 886 887 @wraps(method) 888 def new_method(self, *args, **kwargs): 889 [columns], _ = parse_user_args(method, *args, **kwargs) 890 check_columns(columns, 'columns') 891 892 return method(self, *args, **kwargs) 893 894 return new_method 895 896 897def check_schema(method): 898 """check the input arguments of Schema.__init__.""" 899 900 @wraps(method) 901 def new_method(self, *args, **kwargs): 902 [schema_file], _ = parse_user_args(method, *args, **kwargs) 903 904 if schema_file is not None: 905 check_file(schema_file) 906 907 return method(self, *args, **kwargs) 908 909 return new_method 910 911 912def check_add_column(method): 913 """check the input arguments of add_column.""" 914 915 @wraps(method) 916 def new_method(self, *args, **kwargs): 917 [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs) 918 919 type_check(name, (str,), "name") 920 921 if not name: 922 raise TypeError("Expected non-empty string for column name.") 923 924 if de_type is not None: 925 if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): 926 raise TypeError("Unknown column type: {}.".format(de_type)) 927 else: 928 raise TypeError("Expected non-empty string for de_type.") 929 930 if shape is not None: 931 type_check(shape, (list,), "shape") 932 type_check_list(shape, (int,), "shape") 933 934 return method(self, *args, **kwargs) 935 936 return new_method 937 938 939def check_cluedataset(method): 940 """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset).""" 941 942 @wraps(method) 943 def new_method(self, *args, **kwargs): 944 _, param_dict = parse_user_args(method, *args, **kwargs) 945 946 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 947 948 dataset_files = param_dict.get('dataset_files') 949 type_check(dataset_files, (str, list), "dataset files") 950 951 # check task 952 task_param = param_dict.get('task') 953 if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']: 954 raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.") 955 956 # check usage 957 usage_param = param_dict.get('usage') 958 if usage_param not in ['train', 'test', 'eval']: 959 raise ValueError("usage should be 'train', 'test' or 'eval'.") 960 961 validate_dataset_param_value(nreq_param_int, param_dict, int) 962 check_sampler_shuffle_shard_options(param_dict) 963 964 cache = param_dict.get('cache') 965 check_cache_option(cache) 966 967 return method(self, *args, **kwargs) 968 969 return new_method 970 971 972def check_csvdataset(method): 973 """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset).""" 974 975 @wraps(method) 976 def new_method(self, *args, **kwargs): 977 _, param_dict = parse_user_args(method, *args, **kwargs) 978 979 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 980 981 # check dataset_files; required argument 982 dataset_files = param_dict.get('dataset_files') 983 type_check(dataset_files, (str, list), "dataset files") 984 985 # check field_delim 986 field_delim = param_dict.get('field_delim') 987 if field_delim is not None: 988 type_check(field_delim, (str,), 'field delim') 989 if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1: 990 raise ValueError("field_delim is invalid.") 991 992 # check column_defaults 993 column_defaults = param_dict.get('column_defaults') 994 if column_defaults is not None: 995 if not isinstance(column_defaults, list): 996 raise TypeError("column_defaults should be type of list.") 997 for item in column_defaults: 998 if not isinstance(item, (str, int, float)): 999 raise TypeError("column type in column_defaults is invalid.") 1000 1001 # check column_names: must be list of string. 1002 column_names = param_dict.get("column_names") 1003 if column_names is not None: 1004 all_string = all(isinstance(item, str) for item in column_names) 1005 if not all_string: 1006 raise TypeError("column_names should be a list of str.") 1007 1008 validate_dataset_param_value(nreq_param_int, param_dict, int) 1009 check_sampler_shuffle_shard_options(param_dict) 1010 1011 cache = param_dict.get('cache') 1012 check_cache_option(cache) 1013 1014 return method(self, *args, **kwargs) 1015 1016 return new_method 1017 1018 1019def check_flowers102dataset(method): 1020 """A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset).""" 1021 1022 @wraps(method) 1023 def new_method(self, *args, **kwargs): 1024 _, param_dict = parse_user_args(method, *args, **kwargs) 1025 1026 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1027 nreq_param_bool = ['shuffle', 'decode'] 1028 1029 dataset_dir = param_dict.get('dataset_dir') 1030 check_dir(dataset_dir) 1031 1032 check_dir(os.path.join(dataset_dir, "jpg")) 1033 1034 check_file(os.path.join(dataset_dir, "imagelabels.mat")) 1035 check_file(os.path.join(dataset_dir, "setid.mat")) 1036 1037 usage = param_dict.get('usage') 1038 if usage is not None: 1039 check_valid_str(usage, ["train", "valid", "test", "all"], "usage") 1040 1041 task = param_dict.get('task') 1042 if task is not None: 1043 check_valid_str(task, ["Classification", "Segmentation"], "task") 1044 if task == "Segmentation": 1045 check_dir(os.path.join(dataset_dir, "segmim")) 1046 1047 validate_dataset_param_value(nreq_param_int, param_dict, int) 1048 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1049 1050 check_sampler_shuffle_shard_options(param_dict) 1051 1052 return method(self, *args, **kwargs) 1053 1054 return new_method 1055 1056 1057def check_textfiledataset(method): 1058 """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset).""" 1059 1060 @wraps(method) 1061 def new_method(self, *args, **kwargs): 1062 _, param_dict = parse_user_args(method, *args, **kwargs) 1063 1064 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1065 1066 dataset_files = param_dict.get('dataset_files') 1067 type_check(dataset_files, (str, list), "dataset files") 1068 validate_dataset_param_value(nreq_param_int, param_dict, int) 1069 check_sampler_shuffle_shard_options(param_dict) 1070 1071 cache = param_dict.get('cache') 1072 check_cache_option(cache) 1073 1074 return method(self, *args, **kwargs) 1075 1076 return new_method 1077 1078 1079def check_split(method): 1080 """check the input arguments of split.""" 1081 1082 @wraps(method) 1083 def new_method(self, *args, **kwargs): 1084 [sizes, randomize], _ = parse_user_args(method, *args, **kwargs) 1085 1086 type_check(sizes, (list,), "sizes") 1087 type_check(randomize, (bool,), "randomize") 1088 1089 # check sizes: must be list of float or list of int 1090 if not sizes: 1091 raise ValueError("sizes cannot be empty.") 1092 1093 all_int = all(isinstance(item, int) for item in sizes) 1094 all_float = all(isinstance(item, float) for item in sizes) 1095 1096 if not (all_int or all_float): 1097 raise ValueError("sizes should be list of int or list of float.") 1098 1099 if all_int: 1100 all_positive = all(item > 0 for item in sizes) 1101 if not all_positive: 1102 raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.") 1103 1104 if all_float: 1105 all_valid_percentages = all(0 < item <= 1 for item in sizes) 1106 if not all_valid_percentages: 1107 raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].") 1108 1109 epsilon = 0.00001 1110 if not abs(sum(sizes) - 1) < epsilon: 1111 raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.") 1112 1113 return method(self, *args, **kwargs) 1114 1115 return new_method 1116 1117 1118def check_hostname(hostname): 1119 if not hostname or len(hostname) > 255: 1120 return False 1121 if hostname[-1] == ".": 1122 hostname = hostname[:-1] # strip exactly one dot from the right, if present 1123 allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE) 1124 return all(allowed.match(x) for x in hostname.split(".")) 1125 1126 1127def check_gnn_graphdata(method): 1128 """check the input arguments of graphdata.""" 1129 1130 @wraps(method) 1131 def new_method(self, *args, **kwargs): 1132 [dataset_file, num_parallel_workers, working_mode, hostname, 1133 port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs) 1134 check_file(dataset_file) 1135 if num_parallel_workers is not None: 1136 check_num_parallel_workers(num_parallel_workers) 1137 type_check(hostname, (str,), "hostname") 1138 if check_hostname(hostname) is False: 1139 raise ValueError("The hostname is illegal") 1140 type_check(working_mode, (str,), "working_mode") 1141 if working_mode not in {'local', 'client', 'server'}: 1142 raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.") 1143 type_check(port, (int,), "port") 1144 check_value(port, (1024, 65535), "port") 1145 type_check(num_client, (int,), "num_client") 1146 check_value(num_client, (1, 255), "num_client") 1147 type_check(auto_shutdown, (bool,), "auto_shutdown") 1148 return method(self, *args, **kwargs) 1149 1150 return new_method 1151 1152 1153def check_gnn_get_all_nodes(method): 1154 """A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function.""" 1155 1156 @wraps(method) 1157 def new_method(self, *args, **kwargs): 1158 [node_type], _ = parse_user_args(method, *args, **kwargs) 1159 type_check(node_type, (int,), "node_type") 1160 1161 return method(self, *args, **kwargs) 1162 1163 return new_method 1164 1165 1166def check_gnn_get_all_edges(method): 1167 """A wrapper that wraps a parameter checker around the GNN `get_all_edges` function.""" 1168 1169 @wraps(method) 1170 def new_method(self, *args, **kwargs): 1171 [edge_type], _ = parse_user_args(method, *args, **kwargs) 1172 type_check(edge_type, (int,), "edge_type") 1173 1174 return method(self, *args, **kwargs) 1175 1176 return new_method 1177 1178 1179def check_gnn_get_nodes_from_edges(method): 1180 """A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function.""" 1181 1182 @wraps(method) 1183 def new_method(self, *args, **kwargs): 1184 [edge_list], _ = parse_user_args(method, *args, **kwargs) 1185 check_gnn_list_or_ndarray(edge_list, "edge_list") 1186 1187 return method(self, *args, **kwargs) 1188 1189 return new_method 1190 1191 1192def check_gnn_get_edges_from_nodes(method): 1193 """A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function.""" 1194 1195 @wraps(method) 1196 def new_method(self, *args, **kwargs): 1197 [node_list], _ = parse_user_args(method, *args, **kwargs) 1198 check_gnn_list_of_pair_or_ndarray(node_list, "node_list") 1199 1200 return method(self, *args, **kwargs) 1201 1202 return new_method 1203 1204 1205def check_gnn_get_all_neighbors(method): 1206 """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function.""" 1207 1208 @wraps(method) 1209 def new_method(self, *args, **kwargs): 1210 [node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs) 1211 1212 check_gnn_list_or_ndarray(node_list, 'node_list') 1213 type_check(neighbour_type, (int,), "neighbour_type") 1214 1215 return method(self, *args, **kwargs) 1216 1217 return new_method 1218 1219 1220def check_gnn_get_sampled_neighbors(method): 1221 """A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function.""" 1222 1223 @wraps(method) 1224 def new_method(self, *args, **kwargs): 1225 [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs) 1226 1227 check_gnn_list_or_ndarray(node_list, 'node_list') 1228 1229 check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') 1230 if not neighbor_nums or len(neighbor_nums) > 6: 1231 raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format( 1232 'neighbor_nums', len(neighbor_nums))) 1233 1234 check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') 1235 if not neighbor_types or len(neighbor_types) > 6: 1236 raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format( 1237 'neighbor_types', len(neighbor_types))) 1238 1239 if len(neighbor_nums) != len(neighbor_types): 1240 raise ValueError( 1241 "The number of members of neighbor_nums and neighbor_types is inconsistent.") 1242 1243 return method(self, *args, **kwargs) 1244 1245 return new_method 1246 1247 1248def check_gnn_get_neg_sampled_neighbors(method): 1249 """A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function.""" 1250 1251 @wraps(method) 1252 def new_method(self, *args, **kwargs): 1253 [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs) 1254 1255 check_gnn_list_or_ndarray(node_list, 'node_list') 1256 type_check(neg_neighbor_num, (int,), "neg_neighbor_num") 1257 type_check(neg_neighbor_type, (int,), "neg_neighbor_type") 1258 1259 return method(self, *args, **kwargs) 1260 1261 return new_method 1262 1263 1264def check_gnn_random_walk(method): 1265 """A wrapper that wraps a parameter checker around the GNN `random_walk` function.""" 1266 1267 @wraps(method) 1268 def new_method(self, *args, **kwargs): 1269 [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args, 1270 **kwargs) 1271 check_gnn_list_or_ndarray(target_nodes, 'target_nodes') 1272 check_gnn_list_or_ndarray(meta_path, 'meta_path') 1273 type_check(step_home_param, (float,), "step_home_param") 1274 type_check(step_away_param, (float,), "step_away_param") 1275 type_check(default_node, (int,), "default_node") 1276 check_value(default_node, (-1, INT32_MAX), "default_node") 1277 1278 return method(self, *args, **kwargs) 1279 1280 return new_method 1281 1282 1283def check_aligned_list(param, param_name, member_type): 1284 """Check whether the structure of each member of the list is the same.""" 1285 1286 type_check(param, (list,), "param") 1287 if not param: 1288 raise TypeError( 1289 "Parameter {0} or its members are empty".format(param_name)) 1290 member_have_list = None 1291 list_len = None 1292 for member in param: 1293 if isinstance(member, list): 1294 check_aligned_list(member, param_name, member_type) 1295 1296 if member_have_list not in (None, True): 1297 raise TypeError("The type of each member of the parameter {0} is inconsistent.".format( 1298 param_name)) 1299 if list_len is not None and len(member) != list_len: 1300 raise TypeError("The size of each member of parameter {0} is inconsistent.".format( 1301 param_name)) 1302 member_have_list = True 1303 list_len = len(member) 1304 else: 1305 type_check(member, (member_type,), param_name) 1306 if member_have_list not in (None, False): 1307 raise TypeError("The type of each member of the parameter {0} is inconsistent.".format( 1308 param_name)) 1309 member_have_list = False 1310 1311 1312def check_gnn_get_node_feature(method): 1313 """A wrapper that wraps a parameter checker around the GNN `get_node_feature` function.""" 1314 1315 @wraps(method) 1316 def new_method(self, *args, **kwargs): 1317 [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs) 1318 1319 type_check(node_list, (list, np.ndarray), "node_list") 1320 if isinstance(node_list, list): 1321 check_aligned_list(node_list, 'node_list', int) 1322 elif isinstance(node_list, np.ndarray): 1323 if not node_list.dtype == np.int32: 1324 raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( 1325 node_list, node_list.dtype)) 1326 1327 check_gnn_list_or_ndarray(feature_types, 'feature_types') 1328 1329 return method(self, *args, **kwargs) 1330 1331 return new_method 1332 1333 1334def check_gnn_get_edge_feature(method): 1335 """A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function.""" 1336 1337 @wraps(method) 1338 def new_method(self, *args, **kwargs): 1339 [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs) 1340 1341 type_check(edge_list, (list, np.ndarray), "edge_list") 1342 if isinstance(edge_list, list): 1343 check_aligned_list(edge_list, 'edge_list', int) 1344 elif isinstance(edge_list, np.ndarray): 1345 if not edge_list.dtype == np.int32: 1346 raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( 1347 edge_list, edge_list.dtype)) 1348 1349 check_gnn_list_or_ndarray(feature_types, 'feature_types') 1350 1351 return method(self, *args, **kwargs) 1352 1353 return new_method 1354 1355 1356def check_numpyslicesdataset(method): 1357 """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset).""" 1358 1359 @wraps(method) 1360 def new_method(self, *args, **kwargs): 1361 _, param_dict = parse_user_args(method, *args, **kwargs) 1362 1363 data = param_dict.get("data") 1364 column_names = param_dict.get("column_names") 1365 type_check(data, (list, tuple, dict, np.ndarray), "data") 1366 if data is None or len(data) == 0: # pylint: disable=len-as-condition 1367 raise ValueError("Argument data cannot be empty") 1368 if isinstance(data, tuple): 1369 type_check(data[0], (list, np.ndarray), "data[0]") 1370 1371 # check column_names 1372 if column_names is not None: 1373 check_columns(column_names, "column_names") 1374 1375 # check num of input column in column_names 1376 column_num = 1 if isinstance(column_names, str) else len(column_names) 1377 if isinstance(data, dict): 1378 data_column = len(list(data.keys())) 1379 if column_num != data_column: 1380 raise ValueError("Num of input column names is {0}, but required is {1}." 1381 .format(column_num, data_column)) 1382 1383 elif isinstance(data, tuple): 1384 if column_num != len(data): 1385 raise ValueError("Num of input column names is {0}, but required is {1}." 1386 .format(column_num, len(data))) 1387 else: 1388 if column_num != 1: 1389 raise ValueError("Num of input column names is {0}, but required is {1} as data is list." 1390 .format(column_num, 1)) 1391 1392 return method(self, *args, **kwargs) 1393 1394 return new_method 1395 1396 1397def check_paddeddataset(method): 1398 """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset).""" 1399 1400 @wraps(method) 1401 def new_method(self, *args, **kwargs): 1402 _, param_dict = parse_user_args(method, *args, **kwargs) 1403 1404 padded_samples = param_dict.get("padded_samples") 1405 if not padded_samples: 1406 raise ValueError("padded_samples cannot be empty.") 1407 type_check(padded_samples, (list,), "padded_samples") 1408 type_check(padded_samples[0], (dict,), "padded_element") 1409 return method(self, *args, **kwargs) 1410 1411 return new_method 1412 1413 1414def check_cache_option(cache): 1415 """Sanity check for cache parameter""" 1416 if cache is not None: 1417 type_check(cache, (cache_client.DatasetCache,), "cache") 1418 1419 1420def check_to_device_send(method): 1421 """Check the input arguments of send function for TransferDataset.""" 1422 1423 @wraps(method) 1424 def new_method(self, *args, **kwargs): 1425 [num_epochs], _ = parse_user_args(method, *args, **kwargs) 1426 1427 if num_epochs is not None: 1428 type_check(num_epochs, (int,), "num_epochs") 1429 check_value(num_epochs, [-1, INT32_MAX], "num_epochs") 1430 1431 return method(self, *args, **kwargs) 1432 1433 return new_method 1434 1435 1436def check_flickr_dataset(method): 1437 """A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k).""" 1438 1439 @wraps(method) 1440 def new_method(self, *args, **kwargs): 1441 _, param_dict = parse_user_args(method, *args, **kwargs) 1442 1443 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1444 nreq_param_bool = ['shuffle', 'decode'] 1445 1446 dataset_dir = param_dict.get('dataset_dir') 1447 annotation_file = param_dict.get('annotation_file') 1448 check_dir(dataset_dir) 1449 check_file(annotation_file) 1450 1451 validate_dataset_param_value(nreq_param_int, param_dict, int) 1452 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1453 1454 check_sampler_shuffle_shard_options(param_dict) 1455 1456 cache = param_dict.get('cache') 1457 check_cache_option(cache) 1458 1459 return method(self, *args, **kwargs) 1460 1461 return new_method 1462 1463 1464def check_sb_dataset(method): 1465 """A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset.""" 1466 1467 @wraps(method) 1468 def new_method(self, *args, **kwargs): 1469 _, param_dict = parse_user_args(method, *args, **kwargs) 1470 1471 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1472 nreq_param_bool = ['shuffle', 'decode'] 1473 1474 dataset_dir = param_dict.get('dataset_dir') 1475 check_dir(dataset_dir) 1476 1477 usage = param_dict.get('usage') 1478 if usage is not None: 1479 check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage") 1480 1481 task = param_dict.get('task') 1482 if task is not None: 1483 check_valid_str(task, ["Boundaries", "Segmentation"], "task") 1484 1485 validate_dataset_param_value(nreq_param_int, param_dict, int) 1486 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1487 1488 check_sampler_shuffle_shard_options(param_dict) 1489 1490 return method(self, *args, **kwargs) 1491 1492 return new_method 1493 1494 1495def check_cityscapes_dataset(method): 1496 """A wrapper that wraps a parameter checker around the original CityScapesDataset.""" 1497 1498 @wraps(method) 1499 def new_method(self, *args, **kwargs): 1500 _, param_dict = parse_user_args(method, *args, **kwargs) 1501 1502 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1503 nreq_param_bool = ['shuffle', 'decode'] 1504 1505 dataset_dir = param_dict.get('dataset_dir') 1506 check_dir(dataset_dir) 1507 1508 task = param_dict.get('task') 1509 check_valid_str(task, ["instance", "semantic", "polygon", "color"], "task") 1510 1511 quality_mode = param_dict.get('quality_mode') 1512 check_valid_str(quality_mode, ["fine", "coarse"], "quality_mode") 1513 1514 usage = param_dict.get('usage') 1515 if quality_mode == "fine": 1516 valid_strings = ["train", "test", "val", "all"] 1517 else: 1518 valid_strings = ["train", "train_extra", "val", "all"] 1519 check_valid_str(usage, valid_strings, "usage") 1520 1521 validate_dataset_param_value(nreq_param_int, param_dict, int) 1522 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1523 1524 check_sampler_shuffle_shard_options(param_dict) 1525 1526 return method(self, *args, **kwargs) 1527 1528 return new_method 1529 1530 1531def check_div2k_dataset(method): 1532 """A wrapper that wraps a parameter checker around the original DIV2KDataset.""" 1533 1534 @wraps(method) 1535 def new_method(self, *args, **kwargs): 1536 _, param_dict = parse_user_args(method, *args, **kwargs) 1537 1538 nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] 1539 nreq_param_bool = ['shuffle', 'decode'] 1540 1541 dataset_dir = param_dict.get('dataset_dir') 1542 check_dir(dataset_dir) 1543 1544 usage = param_dict.get('usage') 1545 check_valid_str(usage, ['train', 'valid', 'all'], "usage") 1546 1547 downgrade = param_dict.get('downgrade') 1548 check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade') 1549 1550 validate_dataset_param_value(['scale'], param_dict, int) 1551 scale = param_dict.get('scale') 1552 scale_values = [2, 3, 4, 8] 1553 if scale not in scale_values: 1554 raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values))) 1555 1556 if scale == 8 and downgrade != "bicubic": 1557 raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.") 1558 1559 downgrade_2018 = ["mild", "difficult", "wild"] 1560 if downgrade in downgrade_2018 and scale != 4: 1561 raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade)) 1562 1563 validate_dataset_param_value(nreq_param_int, param_dict, int) 1564 validate_dataset_param_value(nreq_param_bool, param_dict, bool) 1565 1566 check_sampler_shuffle_shard_options(param_dict) 1567 1568 return method(self, *args, **kwargs) 1569 1570 return new_method 1571