• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License foNtest_resr the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""
17Built-in validators.
18"""
19import inspect as ins
20import os
21import re
22from functools import wraps
23
24import numpy as np
25from mindspore._c_expression import typing
26from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
27    INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
28    validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
29    check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str
30
31from . import datasets
32from . import samplers
33from . import cache_client
34
35
36def check_imagefolderdataset(method):
37    """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset)."""
38
39    @wraps(method)
40    def new_method(self, *args, **kwargs):
41        _, param_dict = parse_user_args(method, *args, **kwargs)
42
43        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
44        nreq_param_bool = ['shuffle', 'decode']
45        nreq_param_list = ['extensions']
46        nreq_param_dict = ['class_indexing']
47
48        dataset_dir = param_dict.get('dataset_dir')
49        check_dir(dataset_dir)
50
51        validate_dataset_param_value(nreq_param_int, param_dict, int)
52        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
53        validate_dataset_param_value(nreq_param_list, param_dict, list)
54        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
55        check_sampler_shuffle_shard_options(param_dict)
56
57        cache = param_dict.get('cache')
58        check_cache_option(cache)
59
60        return method(self, *args, **kwargs)
61
62    return new_method
63
64
65def check_mnist_cifar_dataset(method):
66    """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
67
68    @wraps(method)
69    def new_method(self, *args, **kwargs):
70        _, param_dict = parse_user_args(method, *args, **kwargs)
71
72        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
73        nreq_param_bool = ['shuffle']
74
75        dataset_dir = param_dict.get('dataset_dir')
76        check_dir(dataset_dir)
77
78        usage = param_dict.get('usage')
79        if usage is not None:
80            check_valid_str(usage, ["train", "test", "all"], "usage")
81
82        validate_dataset_param_value(nreq_param_int, param_dict, int)
83        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
84
85        check_sampler_shuffle_shard_options(param_dict)
86
87        cache = param_dict.get('cache')
88        check_cache_option(cache)
89
90        return method(self, *args, **kwargs)
91
92    return new_method
93
94
95def check_manifestdataset(method):
96    """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
97
98    @wraps(method)
99    def new_method(self, *args, **kwargs):
100        _, param_dict = parse_user_args(method, *args, **kwargs)
101
102        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
103        nreq_param_bool = ['shuffle', 'decode']
104        nreq_param_str = ['usage']
105        nreq_param_dict = ['class_indexing']
106
107        dataset_file = param_dict.get('dataset_file')
108        check_file(dataset_file)
109
110        validate_dataset_param_value(nreq_param_int, param_dict, int)
111        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
112        validate_dataset_param_value(nreq_param_str, param_dict, str)
113        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
114
115        check_sampler_shuffle_shard_options(param_dict)
116
117        cache = param_dict.get('cache')
118        check_cache_option(cache)
119
120        return method(self, *args, **kwargs)
121
122    return new_method
123
124
125def check_sbu_dataset(method):
126    """A wrapper that wraps a parameter checker around the original Dataset(SBUDataset)."""
127
128    @wraps(method)
129    def new_method(self, *args, **kwargs):
130        _, param_dict = parse_user_args(method, *args, **kwargs)
131
132        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
133        nreq_param_bool = ['shuffle', 'decode']
134
135        dataset_dir = param_dict.get('dataset_dir')
136        check_dir(dataset_dir)
137
138        check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt"))
139        check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt"))
140        check_dir(os.path.join(dataset_dir, "sbu_images"))
141
142        validate_dataset_param_value(nreq_param_int, param_dict, int)
143        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
144
145        check_sampler_shuffle_shard_options(param_dict)
146
147        cache = param_dict.get('cache')
148        check_cache_option(cache)
149
150        return method(self, *args, **kwargs)
151
152    return new_method
153
154
155def check_tfrecorddataset(method):
156    """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
157
158    @wraps(method)
159    def new_method(self, *args, **kwargs):
160        _, param_dict = parse_user_args(method, *args, **kwargs)
161
162        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
163        nreq_param_list = ['columns_list']
164        nreq_param_bool = ['shard_equal_rows']
165
166        dataset_files = param_dict.get('dataset_files')
167        if not isinstance(dataset_files, (str, list)):
168            raise TypeError("dataset_files should be type str or a list of strings.")
169
170        validate_dataset_param_value(nreq_param_int, param_dict, int)
171        validate_dataset_param_value(nreq_param_list, param_dict, list)
172        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
173
174        check_sampler_shuffle_shard_options(param_dict)
175
176        cache = param_dict.get('cache')
177        check_cache_option(cache)
178
179        return method(self, *args, **kwargs)
180
181    return new_method
182
183
184def check_usps_dataset(method):
185    """A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""
186
187    @wraps(method)
188    def new_method(self, *args, **kwargs):
189        _, param_dict = parse_user_args(method, *args, **kwargs)
190
191        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
192
193        dataset_dir = param_dict.get('dataset_dir')
194        check_dir(dataset_dir)
195
196        usage = param_dict.get('usage')
197        if usage is not None:
198            check_valid_str(usage, ["train", "test", "all"], "usage")
199
200        validate_dataset_param_value(nreq_param_int, param_dict, int)
201        check_sampler_shuffle_shard_options(param_dict)
202
203        cache = param_dict.get('cache')
204        check_cache_option(cache)
205
206        return method(self, *args, **kwargs)
207
208    return new_method
209
210
211def check_vocdataset(method):
212    """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
213
214    @wraps(method)
215    def new_method(self, *args, **kwargs):
216        _, param_dict = parse_user_args(method, *args, **kwargs)
217
218        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
219        nreq_param_bool = ['shuffle', 'decode']
220        nreq_param_dict = ['class_indexing']
221
222        dataset_dir = param_dict.get('dataset_dir')
223        check_dir(dataset_dir)
224
225        task = param_dict.get('task')
226        type_check(task, (str,), "task")
227
228        usage = param_dict.get('usage')
229        type_check(usage, (str,), "usage")
230        dataset_dir = os.path.realpath(dataset_dir)
231
232        if task == "Segmentation":
233            imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
234            if param_dict.get('class_indexing') is not None:
235                raise ValueError("class_indexing is not supported in Segmentation task.")
236        elif task == "Detection":
237            imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
238        else:
239            raise ValueError("Invalid task : " + task + ".")
240
241        check_file(imagesets_file)
242
243        validate_dataset_param_value(nreq_param_int, param_dict, int)
244        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
245        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
246        check_sampler_shuffle_shard_options(param_dict)
247
248        cache = param_dict.get('cache')
249        check_cache_option(cache)
250
251        return method(self, *args, **kwargs)
252
253    return new_method
254
255
256def check_cocodataset(method):
257    """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
258
259    @wraps(method)
260    def new_method(self, *args, **kwargs):
261        _, param_dict = parse_user_args(method, *args, **kwargs)
262
263        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
264        nreq_param_bool = ['shuffle', 'decode']
265
266        dataset_dir = param_dict.get('dataset_dir')
267        check_dir(dataset_dir)
268
269        annotation_file = param_dict.get('annotation_file')
270        check_file(annotation_file)
271
272        task = param_dict.get('task')
273        type_check(task, (str,), "task")
274
275        if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
276            raise ValueError("Invalid task type: " + task + ".")
277
278        validate_dataset_param_value(nreq_param_int, param_dict, int)
279
280        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
281
282        sampler = param_dict.get('sampler')
283        if sampler is not None and isinstance(sampler, samplers.PKSampler):
284            raise ValueError("CocoDataset doesn't support PKSampler.")
285        check_sampler_shuffle_shard_options(param_dict)
286
287        cache = param_dict.get('cache')
288        check_cache_option(cache)
289
290        return method(self, *args, **kwargs)
291
292    return new_method
293
294
295def check_celebadataset(method):
296    """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
297
298    @wraps(method)
299    def new_method(self, *args, **kwargs):
300        _, param_dict = parse_user_args(method, *args, **kwargs)
301
302        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
303        nreq_param_bool = ['shuffle', 'decode']
304        nreq_param_list = ['extensions']
305        nreq_param_str = ['dataset_type']
306
307        dataset_dir = param_dict.get('dataset_dir')
308
309        check_dir(dataset_dir)
310
311        validate_dataset_param_value(nreq_param_int, param_dict, int)
312        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
313        validate_dataset_param_value(nreq_param_list, param_dict, list)
314        validate_dataset_param_value(nreq_param_str, param_dict, str)
315
316        usage = param_dict.get('usage')
317        if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
318            raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.")
319
320        check_sampler_shuffle_shard_options(param_dict)
321
322        sampler = param_dict.get('sampler')
323        if sampler is not None and isinstance(sampler, samplers.PKSampler):
324            raise ValueError("CelebADataset doesn't support PKSampler.")
325
326        cache = param_dict.get('cache')
327        check_cache_option(cache)
328
329        return method(self, *args, **kwargs)
330
331    return new_method
332
333
334def check_save(method):
335    """A wrapper that wraps a parameter checker around the saved operator."""
336
337    @wraps(method)
338    def new_method(self, *args, **kwargs):
339        _, param_dict = parse_user_args(method, *args, **kwargs)
340
341        nreq_param_int = ['num_files']
342        nreq_param_str = ['file_name', 'file_type']
343        validate_dataset_param_value(nreq_param_int, param_dict, int)
344        if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
345            raise ValueError("num_files should between 0 and 1000.")
346        validate_dataset_param_value(nreq_param_str, param_dict, str)
347        if param_dict.get('file_type') != 'mindrecord':
348            raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
349        return method(self, *args, **kwargs)
350
351    return new_method
352
353
354def check_tuple_iterator(method):
355    """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
356
357    @wraps(method)
358    def new_method(self, *args, **kwargs):
359        [columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
360        nreq_param_bool = ['output_numpy']
361        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
362        if num_epochs is not None:
363            type_check(num_epochs, (int,), "num_epochs")
364            check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
365
366        if columns is not None:
367            check_columns(columns, "column_names")
368
369        return method(self, *args, **kwargs)
370
371    return new_method
372
373
374def check_dict_iterator(method):
375    """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
376
377    @wraps(method)
378    def new_method(self, *args, **kwargs):
379        [num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs)
380        nreq_param_bool = ['output_numpy']
381        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
382        if num_epochs is not None:
383            type_check(num_epochs, (int,), "num_epochs")
384            check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
385
386        return method(self, *args, **kwargs)
387
388    return new_method
389
390
391def check_minddataset(method):
392    """A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
393
394    @wraps(method)
395    def new_method(self, *args, **kwargs):
396        _, param_dict = parse_user_args(method, *args, **kwargs)
397
398        nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
399        nreq_param_list = ['columns_list']
400        nreq_param_dict = ['padded_sample']
401
402        dataset_file = param_dict.get('dataset_file')
403        if isinstance(dataset_file, list):
404            if len(dataset_file) > 4096:
405                raise ValueError("length of dataset_file should less than or equal to {}.".format(4096))
406            for f in dataset_file:
407                check_file(f)
408        else:
409            check_file(dataset_file)
410
411        validate_dataset_param_value(nreq_param_int, param_dict, int)
412        validate_dataset_param_value(nreq_param_list, param_dict, list)
413        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
414
415        check_sampler_shuffle_shard_options(param_dict)
416
417        check_padding_options(param_dict)
418        return method(self, *args, **kwargs)
419
420    return new_method
421
422
423def check_generatordataset(method):
424    """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
425
426    @wraps(method)
427    def new_method(self, *args, **kwargs):
428        _, param_dict = parse_user_args(method, *args, **kwargs)
429
430        source = param_dict.get('source')
431
432        if not callable(source):
433            try:
434                iter(source)
435            except TypeError:
436                raise TypeError("Input `source` function of GeneratorDataset should be callable, iterable or random"
437                                " accessible, commonly it should implement one of the method like yield, __getitem__ or"
438                                " __next__(__iter__).")
439
440        column_names = param_dict.get('column_names')
441        if column_names is not None:
442            check_columns(column_names, "column_names")
443        schema = param_dict.get('schema')
444        if column_names is None and schema is None:
445            raise ValueError("Neither columns_names nor schema are provided.")
446
447        if schema is not None:
448            if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
449                raise ValueError("schema should be a path to schema file or a schema object.")
450
451        # check optional argument
452        nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
453        validate_dataset_param_value(nreq_param_int, param_dict, int)
454        nreq_param_list = ["column_types"]
455        validate_dataset_param_value(nreq_param_list, param_dict, list)
456        nreq_param_bool = ["shuffle"]
457        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
458
459        num_shards = param_dict.get("num_shards")
460        shard_id = param_dict.get("shard_id")
461        if (num_shards is None) != (shard_id is None):
462            # These two parameters appear together.
463            raise ValueError("num_shards and shard_id need to be passed in together.")
464        if num_shards is not None:
465            check_pos_int32(num_shards, "num_shards")
466            if shard_id >= num_shards:
467                raise ValueError("shard_id should be less than num_shards.")
468
469        sampler = param_dict.get("sampler")
470        if sampler is not None:
471            if isinstance(sampler, samplers.PKSampler):
472                raise ValueError("GeneratorDataset doesn't support PKSampler.")
473            if not isinstance(sampler, samplers.BuiltinSampler):
474                try:
475                    iter(sampler)
476                except TypeError:
477                    raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.")
478
479        if sampler is not None and not hasattr(source, "__getitem__"):
480            raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.")
481        if num_shards is not None and not hasattr(source, "__getitem__"):
482            raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.")
483
484        return method(self, *args, **kwargs)
485
486    return new_method
487
488
489def check_random_dataset(method):
490    """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
491
492    @wraps(method)
493    def new_method(self, *args, **kwargs):
494        _, param_dict = parse_user_args(method, *args, **kwargs)
495
496        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
497        nreq_param_bool = ['shuffle']
498        nreq_param_list = ['columns_list']
499
500        validate_dataset_param_value(nreq_param_int, param_dict, int)
501        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
502        validate_dataset_param_value(nreq_param_list, param_dict, list)
503
504        check_sampler_shuffle_shard_options(param_dict)
505
506        cache = param_dict.get('cache')
507        check_cache_option(cache)
508
509        return method(self, *args, **kwargs)
510
511    return new_method
512
513
514def check_pad_info(key, val):
515    """check the key and value pair of pad_info in batch"""
516    type_check(key, (str,), "key in pad_info")
517
518    if val is not None:
519        if len(val) != 2:
520            raise ValueError("value of pad_info should be a tuple of size 2.")
521        type_check(val, (tuple,), "value in pad_info")
522
523        if val[0] is not None:
524            type_check(val[0], (list,), "shape in pad_info")
525
526            for dim in val[0]:
527                if dim is not None:
528                    check_pos_int32(dim, "dim of shape in pad_info")
529        if val[1] is not None:
530            type_check(val[1], (int, float, str, bytes), "pad_value")
531
532
533def check_bucket_batch_by_length(method):
534    """check the input arguments of bucket_batch_by_length."""
535
536    @wraps(method)
537    def new_method(self, *args, **kwargs):
538        [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
539         pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
540
541        nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
542
543        type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
544
545        nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
546        type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
547
548        # check column_names: must be list of string.
549        check_columns(column_names, "column_names")
550
551        if element_length_function is None and len(column_names) != 1:
552            raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
553
554        if element_length_function is not None and not callable(element_length_function):
555            raise TypeError("element_length_function object is not callable.")
556
557        # check bucket_boundaries: must be list of int, positive and strictly increasing
558        if not bucket_boundaries:
559            raise ValueError("bucket_boundaries cannot be empty.")
560
561        all_int = all(isinstance(item, int) for item in bucket_boundaries)
562        if not all_int:
563            raise TypeError("bucket_boundaries should be a list of int.")
564
565        all_non_negative = all(item > 0 for item in bucket_boundaries)
566        if not all_non_negative:
567            raise ValueError("bucket_boundaries must only contain positive numbers.")
568
569        for i in range(len(bucket_boundaries) - 1):
570            if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
571                raise ValueError("bucket_boundaries should be strictly increasing.")
572
573        # check bucket_batch_sizes: must be list of int and positive
574        if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
575            raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
576
577        all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
578        if not all_int:
579            raise TypeError("bucket_batch_sizes should be a list of int.")
580
581        all_non_negative = all(item > 0 for item in bucket_batch_sizes)
582        if not all_non_negative:
583            raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
584
585        if pad_info is not None:
586            type_check(pad_info, (dict,), "pad_info")
587
588            for k, v in pad_info.items():
589                check_pad_info(k, v)
590
591        return method(self, *args, **kwargs)
592
593    return new_method
594
595
596def check_batch(method):
597    """check the input arguments of batch."""
598
599    @wraps(method)
600    def new_method(self, *args, **kwargs):
601        [batch_size, drop_remainder, num_parallel_workers, per_batch_map,
602         input_columns, output_columns, column_order, pad_info,
603         python_multiprocessing, max_rowsize], param_dict = parse_user_args(method, *args, **kwargs)
604
605        if not (isinstance(batch_size, int) or (callable(batch_size))):
606            raise TypeError("batch_size should either be an int or a callable.")
607
608        if callable(batch_size):
609            sig = ins.signature(batch_size)
610            if len(sig.parameters) != 1:
611                raise ValueError("callable batch_size should take one parameter (BatchInfo).")
612        else:
613            check_pos_int32(int(batch_size), "batch_size")
614
615        if num_parallel_workers is not None:
616            check_num_parallel_workers(num_parallel_workers)
617        type_check(drop_remainder, (bool,), "drop_remainder")
618        type_check(max_rowsize, (int,), "max_rowsize")
619
620        if (pad_info is not None) and (per_batch_map is not None):
621            raise ValueError("pad_info and per_batch_map can't both be set.")
622
623        if pad_info is not None:
624            type_check(param_dict["pad_info"], (dict,), "pad_info")
625            for k, v in param_dict.get('pad_info').items():
626                check_pad_info(k, v)
627
628        if (per_batch_map is None) != (input_columns is None):
629            # These two parameters appear together.
630            raise ValueError("per_batch_map and input_columns need to be passed in together.")
631
632        if input_columns is not None:
633            check_columns(input_columns, "input_columns")
634            if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
635                raise ValueError("The signature of per_batch_map should match with input columns.")
636
637        if output_columns is not None:
638            check_columns(output_columns, "output_columns")
639
640        if column_order is not None:
641            check_columns(column_order, "column_order")
642
643        if python_multiprocessing is not None:
644            type_check(python_multiprocessing, (bool,), "python_multiprocessing")
645
646        return method(self, *args, **kwargs)
647
648    return new_method
649
650
651def check_sync_wait(method):
652    """check the input arguments of sync_wait."""
653
654    @wraps(method)
655    def new_method(self, *args, **kwargs):
656        [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
657
658        type_check(condition_name, (str,), "condition_name")
659        type_check(num_batch, (int,), "num_batch")
660
661        return method(self, *args, **kwargs)
662
663    return new_method
664
665
666def check_shuffle(method):
667    """check the input arguments of shuffle."""
668
669    @wraps(method)
670    def new_method(self, *args, **kwargs):
671        [buffer_size], _ = parse_user_args(method, *args, **kwargs)
672
673        type_check(buffer_size, (int,), "buffer_size")
674
675        check_value(buffer_size, [2, INT32_MAX], "buffer_size")
676
677        return method(self, *args, **kwargs)
678
679    return new_method
680
681
682def check_map(method):
683    """check the input arguments of map."""
684
685    @wraps(method)
686    def new_method(self, *args, **kwargs):
687        from mindspore.dataset.callback import DSCallback
688        [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
689         callbacks, max_rowsize], _ = \
690            parse_user_args(method, *args, **kwargs)
691
692        nreq_param_columns = ['input_columns', 'output_columns', 'column_order']
693
694        if column_order is not None:
695            type_check(column_order, (list,), "column_order")
696        if num_parallel_workers is not None:
697            check_num_parallel_workers(num_parallel_workers)
698        type_check(python_multiprocessing, (bool,), "python_multiprocessing")
699        check_cache_option(cache)
700        type_check(max_rowsize, (int,), "max_rowsize")
701
702        if callbacks is not None:
703            if isinstance(callbacks, (list, tuple)):
704                type_check_list(callbacks, (DSCallback,), "callbacks")
705            else:
706                type_check(callbacks, (DSCallback,), "callbacks")
707
708        for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]):
709            if param is not None:
710                check_columns(param, param_name)
711        if callbacks is not None:
712            type_check(callbacks, (list, DSCallback), "callbacks")
713
714        return method(self, *args, **kwargs)
715
716    return new_method
717
718
719def check_filter(method):
720    """"check the input arguments of filter."""
721
722    @wraps(method)
723    def new_method(self, *args, **kwargs):
724        [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
725        if not callable(predicate):
726            raise TypeError("Predicate should be a Python function or a callable Python object.")
727
728        if num_parallel_workers is not None:
729            check_num_parallel_workers(num_parallel_workers)
730
731        if input_columns is not None:
732            check_columns(input_columns, "input_columns")
733
734        return method(self, *args, **kwargs)
735
736    return new_method
737
738
739def check_repeat(method):
740    """check the input arguments of repeat."""
741
742    @wraps(method)
743    def new_method(self, *args, **kwargs):
744        [count], _ = parse_user_args(method, *args, **kwargs)
745
746        type_check(count, (int, type(None)), "repeat")
747        if isinstance(count, int):
748            if (count <= 0 and count != -1) or count > INT32_MAX:
749                raise ValueError("count should be either -1 or positive integer, range[1, INT32_MAX].")
750        return method(self, *args, **kwargs)
751
752    return new_method
753
754
755def check_skip(method):
756    """check the input arguments of skip."""
757
758    @wraps(method)
759    def new_method(self, *args, **kwargs):
760        [count], _ = parse_user_args(method, *args, **kwargs)
761
762        type_check(count, (int,), "count")
763        check_value(count, (0, INT32_MAX), "count")
764
765        return method(self, *args, **kwargs)
766
767    return new_method
768
769
770def check_take(method):
771    """check the input arguments of take."""
772
773    @wraps(method)
774    def new_method(self, *args, **kwargs):
775        [count], _ = parse_user_args(method, *args, **kwargs)
776        type_check(count, (int,), "count")
777        if (count <= 0 and count != -1) or count > INT32_MAX:
778            raise ValueError("count should be either -1 or within the required interval of ({}, {}], got {}."
779                             .format(0, INT32_MAX, count))
780
781        return method(self, *args, **kwargs)
782
783    return new_method
784
785
786def check_positive_int32(method):
787    """check whether the input argument is positive and int, only works for functions with one input."""
788
789    @wraps(method)
790    def new_method(self, *args, **kwargs):
791        [count], param_dict = parse_user_args(method, *args, **kwargs)
792        para_name = None
793        for key in list(param_dict.keys()):
794            if key not in ['self', 'cls']:
795                para_name = key
796        # Need to get default value of param
797        if count is not None:
798            check_pos_int32(count, para_name)
799
800        return method(self, *args, **kwargs)
801
802    return new_method
803
804
805def check_device_send(method):
806    """check the input argument for to_device and device_que."""
807
808    @wraps(method)
809    def new_method(self, *args, **kwargs):
810        [send_epoch_end, create_data_info_queue], _ = parse_user_args(method, *args, **kwargs)
811        type_check(send_epoch_end, (bool,), "send_epoch_end")
812        type_check(create_data_info_queue, (bool,), "create_data_info_queue")
813
814        return method(self, *args, **kwargs)
815
816    return new_method
817
818
819def check_zip(method):
820    """check the input arguments of zip."""
821
822    @wraps(method)
823    def new_method(*args, **kwargs):
824        [ds], _ = parse_user_args(method, *args, **kwargs)
825        type_check(ds, (tuple,), "datasets")
826
827        return method(*args, **kwargs)
828
829    return new_method
830
831
832def check_zip_dataset(method):
833    """check the input arguments of zip method in `Dataset`."""
834
835    @wraps(method)
836    def new_method(self, *args, **kwargs):
837        [ds], _ = parse_user_args(method, *args, **kwargs)
838        type_check(ds, (tuple, datasets.Dataset), "datasets")
839
840        return method(self, *args, **kwargs)
841
842    return new_method
843
844
845def check_concat(method):
846    """check the input arguments of concat method in `Dataset`."""
847
848    @wraps(method)
849    def new_method(self, *args, **kwargs):
850        [ds], _ = parse_user_args(method, *args, **kwargs)
851        type_check(ds, (list, datasets.Dataset), "datasets")
852        if isinstance(ds, list):
853            type_check_list(ds, (datasets.Dataset,), "dataset")
854        return method(self, *args, **kwargs)
855
856    return new_method
857
858
859def check_rename(method):
860    """check the input arguments of rename."""
861
862    @wraps(method)
863    def new_method(self, *args, **kwargs):
864        values, _ = parse_user_args(method, *args, **kwargs)
865
866        req_param_columns = ['input_columns', 'output_columns']
867        for param_name, param in zip(req_param_columns, values):
868            check_columns(param, param_name)
869
870        input_size, output_size = 1, 1
871        input_columns, output_columns = values
872        if isinstance(input_columns, list):
873            input_size = len(input_columns)
874        if isinstance(output_columns, list):
875            output_size = len(output_columns)
876        if input_size != output_size:
877            raise ValueError("Number of column in input_columns and output_columns is not equal.")
878
879        return method(self, *args, **kwargs)
880
881    return new_method
882
883
884def check_project(method):
885    """check the input arguments of project."""
886
887    @wraps(method)
888    def new_method(self, *args, **kwargs):
889        [columns], _ = parse_user_args(method, *args, **kwargs)
890        check_columns(columns, 'columns')
891
892        return method(self, *args, **kwargs)
893
894    return new_method
895
896
897def check_schema(method):
898    """check the input arguments of Schema.__init__."""
899
900    @wraps(method)
901    def new_method(self, *args, **kwargs):
902        [schema_file], _ = parse_user_args(method, *args, **kwargs)
903
904        if schema_file is not None:
905            check_file(schema_file)
906
907        return method(self, *args, **kwargs)
908
909    return new_method
910
911
912def check_add_column(method):
913    """check the input arguments of add_column."""
914
915    @wraps(method)
916    def new_method(self, *args, **kwargs):
917        [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
918
919        type_check(name, (str,), "name")
920
921        if not name:
922            raise TypeError("Expected non-empty string for column name.")
923
924        if de_type is not None:
925            if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
926                raise TypeError("Unknown column type: {}.".format(de_type))
927        else:
928            raise TypeError("Expected non-empty string for de_type.")
929
930        if shape is not None:
931            type_check(shape, (list,), "shape")
932            type_check_list(shape, (int,), "shape")
933
934        return method(self, *args, **kwargs)
935
936    return new_method
937
938
939def check_cluedataset(method):
940    """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
941
942    @wraps(method)
943    def new_method(self, *args, **kwargs):
944        _, param_dict = parse_user_args(method, *args, **kwargs)
945
946        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
947
948        dataset_files = param_dict.get('dataset_files')
949        type_check(dataset_files, (str, list), "dataset files")
950
951        # check task
952        task_param = param_dict.get('task')
953        if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
954            raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.")
955
956        # check usage
957        usage_param = param_dict.get('usage')
958        if usage_param not in ['train', 'test', 'eval']:
959            raise ValueError("usage should be 'train', 'test' or 'eval'.")
960
961        validate_dataset_param_value(nreq_param_int, param_dict, int)
962        check_sampler_shuffle_shard_options(param_dict)
963
964        cache = param_dict.get('cache')
965        check_cache_option(cache)
966
967        return method(self, *args, **kwargs)
968
969    return new_method
970
971
972def check_csvdataset(method):
973    """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
974
975    @wraps(method)
976    def new_method(self, *args, **kwargs):
977        _, param_dict = parse_user_args(method, *args, **kwargs)
978
979        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
980
981        # check dataset_files; required argument
982        dataset_files = param_dict.get('dataset_files')
983        type_check(dataset_files, (str, list), "dataset files")
984
985        # check field_delim
986        field_delim = param_dict.get('field_delim')
987        if field_delim is not None:
988            type_check(field_delim, (str,), 'field delim')
989            if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
990                raise ValueError("field_delim is invalid.")
991
992        # check column_defaults
993        column_defaults = param_dict.get('column_defaults')
994        if column_defaults is not None:
995            if not isinstance(column_defaults, list):
996                raise TypeError("column_defaults should be type of list.")
997            for item in column_defaults:
998                if not isinstance(item, (str, int, float)):
999                    raise TypeError("column type in column_defaults is invalid.")
1000
1001        # check column_names: must be list of string.
1002        column_names = param_dict.get("column_names")
1003        if column_names is not None:
1004            all_string = all(isinstance(item, str) for item in column_names)
1005            if not all_string:
1006                raise TypeError("column_names should be a list of str.")
1007
1008        validate_dataset_param_value(nreq_param_int, param_dict, int)
1009        check_sampler_shuffle_shard_options(param_dict)
1010
1011        cache = param_dict.get('cache')
1012        check_cache_option(cache)
1013
1014        return method(self, *args, **kwargs)
1015
1016    return new_method
1017
1018
1019def check_flowers102dataset(method):
1020    """A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset)."""
1021
1022    @wraps(method)
1023    def new_method(self, *args, **kwargs):
1024        _, param_dict = parse_user_args(method, *args, **kwargs)
1025
1026        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1027        nreq_param_bool = ['shuffle', 'decode']
1028
1029        dataset_dir = param_dict.get('dataset_dir')
1030        check_dir(dataset_dir)
1031
1032        check_dir(os.path.join(dataset_dir, "jpg"))
1033
1034        check_file(os.path.join(dataset_dir, "imagelabels.mat"))
1035        check_file(os.path.join(dataset_dir, "setid.mat"))
1036
1037        usage = param_dict.get('usage')
1038        if usage is not None:
1039            check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
1040
1041        task = param_dict.get('task')
1042        if task is not None:
1043            check_valid_str(task, ["Classification", "Segmentation"], "task")
1044        if task == "Segmentation":
1045            check_dir(os.path.join(dataset_dir, "segmim"))
1046
1047        validate_dataset_param_value(nreq_param_int, param_dict, int)
1048        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1049
1050        check_sampler_shuffle_shard_options(param_dict)
1051
1052        return method(self, *args, **kwargs)
1053
1054    return new_method
1055
1056
1057def check_textfiledataset(method):
1058    """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
1059
1060    @wraps(method)
1061    def new_method(self, *args, **kwargs):
1062        _, param_dict = parse_user_args(method, *args, **kwargs)
1063
1064        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1065
1066        dataset_files = param_dict.get('dataset_files')
1067        type_check(dataset_files, (str, list), "dataset files")
1068        validate_dataset_param_value(nreq_param_int, param_dict, int)
1069        check_sampler_shuffle_shard_options(param_dict)
1070
1071        cache = param_dict.get('cache')
1072        check_cache_option(cache)
1073
1074        return method(self, *args, **kwargs)
1075
1076    return new_method
1077
1078
1079def check_split(method):
1080    """check the input arguments of split."""
1081
1082    @wraps(method)
1083    def new_method(self, *args, **kwargs):
1084        [sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
1085
1086        type_check(sizes, (list,), "sizes")
1087        type_check(randomize, (bool,), "randomize")
1088
1089        # check sizes: must be list of float or list of int
1090        if not sizes:
1091            raise ValueError("sizes cannot be empty.")
1092
1093        all_int = all(isinstance(item, int) for item in sizes)
1094        all_float = all(isinstance(item, float) for item in sizes)
1095
1096        if not (all_int or all_float):
1097            raise ValueError("sizes should be list of int or list of float.")
1098
1099        if all_int:
1100            all_positive = all(item > 0 for item in sizes)
1101            if not all_positive:
1102                raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
1103
1104        if all_float:
1105            all_valid_percentages = all(0 < item <= 1 for item in sizes)
1106            if not all_valid_percentages:
1107                raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
1108
1109            epsilon = 0.00001
1110            if not abs(sum(sizes) - 1) < epsilon:
1111                raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
1112
1113        return method(self, *args, **kwargs)
1114
1115    return new_method
1116
1117
1118def check_hostname(hostname):
1119    if not hostname or len(hostname) > 255:
1120        return False
1121    if hostname[-1] == ".":
1122        hostname = hostname[:-1]  # strip exactly one dot from the right, if present
1123    allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
1124    return all(allowed.match(x) for x in hostname.split("."))
1125
1126
1127def check_gnn_graphdata(method):
1128    """check the input arguments of graphdata."""
1129
1130    @wraps(method)
1131    def new_method(self, *args, **kwargs):
1132        [dataset_file, num_parallel_workers, working_mode, hostname,
1133         port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
1134        check_file(dataset_file)
1135        if num_parallel_workers is not None:
1136            check_num_parallel_workers(num_parallel_workers)
1137        type_check(hostname, (str,), "hostname")
1138        if check_hostname(hostname) is False:
1139            raise ValueError("The hostname is illegal")
1140        type_check(working_mode, (str,), "working_mode")
1141        if working_mode not in {'local', 'client', 'server'}:
1142            raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.")
1143        type_check(port, (int,), "port")
1144        check_value(port, (1024, 65535), "port")
1145        type_check(num_client, (int,), "num_client")
1146        check_value(num_client, (1, 255), "num_client")
1147        type_check(auto_shutdown, (bool,), "auto_shutdown")
1148        return method(self, *args, **kwargs)
1149
1150    return new_method
1151
1152
1153def check_gnn_get_all_nodes(method):
1154    """A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
1155
1156    @wraps(method)
1157    def new_method(self, *args, **kwargs):
1158        [node_type], _ = parse_user_args(method, *args, **kwargs)
1159        type_check(node_type, (int,), "node_type")
1160
1161        return method(self, *args, **kwargs)
1162
1163    return new_method
1164
1165
1166def check_gnn_get_all_edges(method):
1167    """A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""
1168
1169    @wraps(method)
1170    def new_method(self, *args, **kwargs):
1171        [edge_type], _ = parse_user_args(method, *args, **kwargs)
1172        type_check(edge_type, (int,), "edge_type")
1173
1174        return method(self, *args, **kwargs)
1175
1176    return new_method
1177
1178
1179def check_gnn_get_nodes_from_edges(method):
1180    """A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""
1181
1182    @wraps(method)
1183    def new_method(self, *args, **kwargs):
1184        [edge_list], _ = parse_user_args(method, *args, **kwargs)
1185        check_gnn_list_or_ndarray(edge_list, "edge_list")
1186
1187        return method(self, *args, **kwargs)
1188
1189    return new_method
1190
1191
1192def check_gnn_get_edges_from_nodes(method):
1193    """A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function."""
1194
1195    @wraps(method)
1196    def new_method(self, *args, **kwargs):
1197        [node_list], _ = parse_user_args(method, *args, **kwargs)
1198        check_gnn_list_of_pair_or_ndarray(node_list, "node_list")
1199
1200        return method(self, *args, **kwargs)
1201
1202    return new_method
1203
1204
1205def check_gnn_get_all_neighbors(method):
1206    """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
1207
1208    @wraps(method)
1209    def new_method(self, *args, **kwargs):
1210        [node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs)
1211
1212        check_gnn_list_or_ndarray(node_list, 'node_list')
1213        type_check(neighbour_type, (int,), "neighbour_type")
1214
1215        return method(self, *args, **kwargs)
1216
1217    return new_method
1218
1219
1220def check_gnn_get_sampled_neighbors(method):
1221    """A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""
1222
1223    @wraps(method)
1224    def new_method(self, *args, **kwargs):
1225        [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs)
1226
1227        check_gnn_list_or_ndarray(node_list, 'node_list')
1228
1229        check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
1230        if not neighbor_nums or len(neighbor_nums) > 6:
1231            raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
1232                'neighbor_nums', len(neighbor_nums)))
1233
1234        check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
1235        if not neighbor_types or len(neighbor_types) > 6:
1236            raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
1237                'neighbor_types', len(neighbor_types)))
1238
1239        if len(neighbor_nums) != len(neighbor_types):
1240            raise ValueError(
1241                "The number of members of neighbor_nums and neighbor_types is inconsistent.")
1242
1243        return method(self, *args, **kwargs)
1244
1245    return new_method
1246
1247
1248def check_gnn_get_neg_sampled_neighbors(method):
1249    """A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""
1250
1251    @wraps(method)
1252    def new_method(self, *args, **kwargs):
1253        [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs)
1254
1255        check_gnn_list_or_ndarray(node_list, 'node_list')
1256        type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
1257        type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
1258
1259        return method(self, *args, **kwargs)
1260
1261    return new_method
1262
1263
1264def check_gnn_random_walk(method):
1265    """A wrapper that wraps a parameter checker around the GNN `random_walk` function."""
1266
1267    @wraps(method)
1268    def new_method(self, *args, **kwargs):
1269        [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args,
1270                                                                                                       **kwargs)
1271        check_gnn_list_or_ndarray(target_nodes, 'target_nodes')
1272        check_gnn_list_or_ndarray(meta_path, 'meta_path')
1273        type_check(step_home_param, (float,), "step_home_param")
1274        type_check(step_away_param, (float,), "step_away_param")
1275        type_check(default_node, (int,), "default_node")
1276        check_value(default_node, (-1, INT32_MAX), "default_node")
1277
1278        return method(self, *args, **kwargs)
1279
1280    return new_method
1281
1282
1283def check_aligned_list(param, param_name, member_type):
1284    """Check whether the structure of each member of the list is the same."""
1285
1286    type_check(param, (list,), "param")
1287    if not param:
1288        raise TypeError(
1289            "Parameter {0} or its members are empty".format(param_name))
1290    member_have_list = None
1291    list_len = None
1292    for member in param:
1293        if isinstance(member, list):
1294            check_aligned_list(member, param_name, member_type)
1295
1296            if member_have_list not in (None, True):
1297                raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
1298                    param_name))
1299            if list_len is not None and len(member) != list_len:
1300                raise TypeError("The size of each member of parameter {0} is inconsistent.".format(
1301                    param_name))
1302            member_have_list = True
1303            list_len = len(member)
1304        else:
1305            type_check(member, (member_type,), param_name)
1306            if member_have_list not in (None, False):
1307                raise TypeError("The type of each member of the parameter {0} is inconsistent.".format(
1308                    param_name))
1309            member_have_list = False
1310
1311
1312def check_gnn_get_node_feature(method):
1313    """A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""
1314
1315    @wraps(method)
1316    def new_method(self, *args, **kwargs):
1317        [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
1318
1319        type_check(node_list, (list, np.ndarray), "node_list")
1320        if isinstance(node_list, list):
1321            check_aligned_list(node_list, 'node_list', int)
1322        elif isinstance(node_list, np.ndarray):
1323            if not node_list.dtype == np.int32:
1324                raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
1325                    node_list, node_list.dtype))
1326
1327        check_gnn_list_or_ndarray(feature_types, 'feature_types')
1328
1329        return method(self, *args, **kwargs)
1330
1331    return new_method
1332
1333
1334def check_gnn_get_edge_feature(method):
1335    """A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""
1336
1337    @wraps(method)
1338    def new_method(self, *args, **kwargs):
1339        [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
1340
1341        type_check(edge_list, (list, np.ndarray), "edge_list")
1342        if isinstance(edge_list, list):
1343            check_aligned_list(edge_list, 'edge_list', int)
1344        elif isinstance(edge_list, np.ndarray):
1345            if not edge_list.dtype == np.int32:
1346                raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
1347                    edge_list, edge_list.dtype))
1348
1349        check_gnn_list_or_ndarray(feature_types, 'feature_types')
1350
1351        return method(self, *args, **kwargs)
1352
1353    return new_method
1354
1355
1356def check_numpyslicesdataset(method):
1357    """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
1358
1359    @wraps(method)
1360    def new_method(self, *args, **kwargs):
1361        _, param_dict = parse_user_args(method, *args, **kwargs)
1362
1363        data = param_dict.get("data")
1364        column_names = param_dict.get("column_names")
1365        type_check(data, (list, tuple, dict, np.ndarray), "data")
1366        if data is None or len(data) == 0:  # pylint: disable=len-as-condition
1367            raise ValueError("Argument data cannot be empty")
1368        if isinstance(data, tuple):
1369            type_check(data[0], (list, np.ndarray), "data[0]")
1370
1371        # check column_names
1372        if column_names is not None:
1373            check_columns(column_names, "column_names")
1374
1375            # check num of input column in column_names
1376            column_num = 1 if isinstance(column_names, str) else len(column_names)
1377            if isinstance(data, dict):
1378                data_column = len(list(data.keys()))
1379                if column_num != data_column:
1380                    raise ValueError("Num of input column names is {0}, but required is {1}."
1381                                     .format(column_num, data_column))
1382
1383            elif isinstance(data, tuple):
1384                if column_num != len(data):
1385                    raise ValueError("Num of input column names is {0}, but required is {1}."
1386                                     .format(column_num, len(data)))
1387            else:
1388                if column_num != 1:
1389                    raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
1390                                     .format(column_num, 1))
1391
1392        return method(self, *args, **kwargs)
1393
1394    return new_method
1395
1396
1397def check_paddeddataset(method):
1398    """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
1399
1400    @wraps(method)
1401    def new_method(self, *args, **kwargs):
1402        _, param_dict = parse_user_args(method, *args, **kwargs)
1403
1404        padded_samples = param_dict.get("padded_samples")
1405        if not padded_samples:
1406            raise ValueError("padded_samples cannot be empty.")
1407        type_check(padded_samples, (list,), "padded_samples")
1408        type_check(padded_samples[0], (dict,), "padded_element")
1409        return method(self, *args, **kwargs)
1410
1411    return new_method
1412
1413
1414def check_cache_option(cache):
1415    """Sanity check for cache parameter"""
1416    if cache is not None:
1417        type_check(cache, (cache_client.DatasetCache,), "cache")
1418
1419
1420def check_to_device_send(method):
1421    """Check the input arguments of send function for TransferDataset."""
1422
1423    @wraps(method)
1424    def new_method(self, *args, **kwargs):
1425        [num_epochs], _ = parse_user_args(method, *args, **kwargs)
1426
1427        if num_epochs is not None:
1428            type_check(num_epochs, (int,), "num_epochs")
1429            check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
1430
1431        return method(self, *args, **kwargs)
1432
1433    return new_method
1434
1435
1436def check_flickr_dataset(method):
1437    """A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k)."""
1438
1439    @wraps(method)
1440    def new_method(self, *args, **kwargs):
1441        _, param_dict = parse_user_args(method, *args, **kwargs)
1442
1443        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1444        nreq_param_bool = ['shuffle', 'decode']
1445
1446        dataset_dir = param_dict.get('dataset_dir')
1447        annotation_file = param_dict.get('annotation_file')
1448        check_dir(dataset_dir)
1449        check_file(annotation_file)
1450
1451        validate_dataset_param_value(nreq_param_int, param_dict, int)
1452        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1453
1454        check_sampler_shuffle_shard_options(param_dict)
1455
1456        cache = param_dict.get('cache')
1457        check_cache_option(cache)
1458
1459        return method(self, *args, **kwargs)
1460
1461    return new_method
1462
1463
1464def check_sb_dataset(method):
1465    """A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
1466
1467    @wraps(method)
1468    def new_method(self, *args, **kwargs):
1469        _, param_dict = parse_user_args(method, *args, **kwargs)
1470
1471        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1472        nreq_param_bool = ['shuffle', 'decode']
1473
1474        dataset_dir = param_dict.get('dataset_dir')
1475        check_dir(dataset_dir)
1476
1477        usage = param_dict.get('usage')
1478        if usage is not None:
1479            check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage")
1480
1481        task = param_dict.get('task')
1482        if task is not None:
1483            check_valid_str(task, ["Boundaries", "Segmentation"], "task")
1484
1485        validate_dataset_param_value(nreq_param_int, param_dict, int)
1486        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1487
1488        check_sampler_shuffle_shard_options(param_dict)
1489
1490        return method(self, *args, **kwargs)
1491
1492    return new_method
1493
1494
1495def check_cityscapes_dataset(method):
1496    """A wrapper that wraps a parameter checker around the original CityScapesDataset."""
1497
1498    @wraps(method)
1499    def new_method(self, *args, **kwargs):
1500        _, param_dict = parse_user_args(method, *args, **kwargs)
1501
1502        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1503        nreq_param_bool = ['shuffle', 'decode']
1504
1505        dataset_dir = param_dict.get('dataset_dir')
1506        check_dir(dataset_dir)
1507
1508        task = param_dict.get('task')
1509        check_valid_str(task, ["instance", "semantic", "polygon", "color"], "task")
1510
1511        quality_mode = param_dict.get('quality_mode')
1512        check_valid_str(quality_mode, ["fine", "coarse"], "quality_mode")
1513
1514        usage = param_dict.get('usage')
1515        if quality_mode == "fine":
1516            valid_strings = ["train", "test", "val", "all"]
1517        else:
1518            valid_strings = ["train", "train_extra", "val", "all"]
1519        check_valid_str(usage, valid_strings, "usage")
1520
1521        validate_dataset_param_value(nreq_param_int, param_dict, int)
1522        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1523
1524        check_sampler_shuffle_shard_options(param_dict)
1525
1526        return method(self, *args, **kwargs)
1527
1528    return new_method
1529
1530
1531def check_div2k_dataset(method):
1532    """A wrapper that wraps a parameter checker around the original DIV2KDataset."""
1533
1534    @wraps(method)
1535    def new_method(self, *args, **kwargs):
1536        _, param_dict = parse_user_args(method, *args, **kwargs)
1537
1538        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1539        nreq_param_bool = ['shuffle', 'decode']
1540
1541        dataset_dir = param_dict.get('dataset_dir')
1542        check_dir(dataset_dir)
1543
1544        usage = param_dict.get('usage')
1545        check_valid_str(usage, ['train', 'valid', 'all'], "usage")
1546
1547        downgrade = param_dict.get('downgrade')
1548        check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade')
1549
1550        validate_dataset_param_value(['scale'], param_dict, int)
1551        scale = param_dict.get('scale')
1552        scale_values = [2, 3, 4, 8]
1553        if scale not in scale_values:
1554            raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values)))
1555
1556        if scale == 8 and downgrade != "bicubic":
1557            raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.")
1558
1559        downgrade_2018 = ["mild", "difficult", "wild"]
1560        if downgrade in downgrade_2018 and scale != 4:
1561            raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade))
1562
1563        validate_dataset_param_value(nreq_param_int, param_dict, int)
1564        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1565
1566        check_sampler_shuffle_shard_options(param_dict)
1567
1568        return method(self, *args, **kwargs)
1569
1570    return new_method
1571