• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License foNtest_resr the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""
17Built-in validators.
18"""
19import inspect as ins
20import os
21from functools import wraps
22import numpy as np
23
24from mindspore._c_expression import typing
25from mindspore import log as logger
26from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
27    INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
28    validate_dataset_param_value, check_padding_options, \
29    check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id, \
30    check_valid_list_tuple, check_int32
31
32from . import datasets
33from . import samplers
34from . import cache_client
35
36
37def check_cmu_arctic_dataset(method):
38    """A wrapper that wraps a parameter checker around the original CMUArcticDataset."""
39
40    @wraps(method)
41    def new_method(self, *args, **kwargs):
42        _, param_dict = parse_user_args(method, *args, **kwargs)
43
44        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
45        nreq_param_bool = ['shuffle']
46
47        dataset_dir = param_dict.get('dataset_dir')
48        check_dir(dataset_dir)
49
50        name = param_dict.get('name')
51        if name is not None:
52            check_valid_str(name, ['aew', 'ahw', 'aup', 'awb', 'axb', 'bdl', 'clb', 'eey',
53                                   'fem', 'gka', 'jmk', 'ksp', 'ljm', 'lnh', 'rms', 'rxr', 'slp', 'slt'], "name")
54
55        validate_dataset_param_value(nreq_param_int, param_dict, int)
56        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
57
58        check_sampler_shuffle_shard_options(param_dict)
59
60        cache = param_dict.get('cache')
61        check_cache_option(cache)
62
63        return method(self, *args, **kwargs)
64
65    return new_method
66
67
68def check_gtzan_dataset(method):
69    """A wrapper that wraps a parameter checker around the original GTZANDataset."""
70
71    @wraps(method)
72    def new_method(self, *args, **kwargs):
73        _, param_dict = parse_user_args(method, *args, **kwargs)
74
75        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
76        nreq_param_bool = ['shuffle']
77
78        dataset_dir = param_dict.get('dataset_dir')
79        check_dir(dataset_dir)
80
81        usage = param_dict.get('usage')
82        if usage is not None:
83            check_valid_str(usage, ['train', 'valid', 'test', 'all'], "usage")
84
85        validate_dataset_param_value(nreq_param_int, param_dict, int)
86        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
87
88        check_sampler_shuffle_shard_options(param_dict)
89
90        cache = param_dict.get('cache')
91        check_cache_option(cache)
92
93        return method(self, *args, **kwargs)
94
95    return new_method
96
97
98def check_imagefolderdataset(method):
99    """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset)."""
100
101    @wraps(method)
102    def new_method(self, *args, **kwargs):
103        _, param_dict = parse_user_args(method, *args, **kwargs)
104
105        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
106        nreq_param_bool = ['shuffle', 'decode']
107        nreq_param_list = ['extensions']
108        nreq_param_dict = ['class_indexing']
109
110        dataset_dir = param_dict.get('dataset_dir')
111        check_dir(dataset_dir)
112
113        decrypt = param_dict.get('decrypt')
114        if decrypt is not None and not callable(decrypt):
115            raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
116
117        validate_dataset_param_value(nreq_param_int, param_dict, int)
118        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
119        validate_dataset_param_value(nreq_param_list, param_dict, list)
120        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
121        check_sampler_shuffle_shard_options(param_dict)
122
123        cache = param_dict.get('cache')
124        check_cache_option(cache)
125
126        return method(self, *args, **kwargs)
127
128    return new_method
129
130
131def check_imdb_dataset(method):
132    """A wrapper that wraps a parameter checker around the original IMDBDataset."""
133
134    @wraps(method)
135    def new_method(self, *args, **kwargs):
136        _, param_dict = parse_user_args(method, *args, **kwargs)
137
138        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
139        nreq_param_bool = ['shuffle']
140
141        dataset_dir = param_dict.get('dataset_dir')
142        check_dir(dataset_dir)
143
144        validate_dataset_param_value(nreq_param_int, param_dict, int)
145        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
146        check_sampler_shuffle_shard_options(param_dict)
147
148        cache = param_dict.get('cache')
149        check_cache_option(cache)
150
151        usage = param_dict.get('usage')
152        if usage is not None:
153            check_valid_str(usage, ["train", "test", "all"], "usage")
154
155        return method(self, *args, **kwargs)
156
157    return new_method
158
159
160def check_iwslt2016_dataset(method):
161    """A wrapper that wraps a parameter checker around the original Dataset(IWSLT2016dataset)."""
162
163    @wraps(method)
164    def new_method(self, *args, **kwargs):
165        _, param_dict = parse_user_args(method, *args, **kwargs)
166
167        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
168
169        dataset_dir = param_dict.get('dataset_dir')
170        check_dir(dataset_dir)
171
172        # check usage
173        usage = param_dict.get('usage')
174        if usage is not None:
175            check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
176
177        support_language_pair = [
178            ['en', 'ar'], ['en', 'ar'], ['en', 'de'], ['en', 'fr'], ['en', 'cs'], ['ar', 'en'], ['fr', 'en'],
179            ['de', 'en'], ['cs', 'en']
180        ]
181        support_language_pair_tuple = (
182            ('en', 'ar'), ('en', 'ar'), ('en', 'de'), ('en', 'fr'), ('en', 'cs'), ('ar', 'en'), ('fr', 'en'),
183            ('de', 'en'), ('cs', 'en')
184        )
185        support_set_type = ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"]
186        # check language_pair
187        language_pair = param_dict.get('language_pair')
188        if language_pair is not None:
189            if isinstance(language_pair, (list,)):
190                check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair")
191            elif isinstance(language_pair, (tuple,)):
192                check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair")
193            else:
194                raise TypeError("language_pair should be a type list or tuple of length 2.")
195
196        # check valid_set
197        valid_set = param_dict.get('valid_set')
198        if valid_set is not None:
199            check_valid_str(valid_set, support_set_type, "valid_set")
200
201        # check test_set
202        test_set = param_dict.get('test_set')
203        if test_set is not None:
204            check_valid_str(test_set, support_set_type, "test_set")
205
206        validate_dataset_param_value(nreq_param_int, param_dict, int)
207        check_sampler_shuffle_shard_options(param_dict)
208
209        cache = param_dict.get('cache')
210        check_cache_option(cache)
211
212        return method(self, *args, **kwargs)
213
214    return new_method
215
216
217def check_iwslt2017_dataset(method):
218    """A wrapper that wraps a parameter checker around the original Dataset(IWSLT2017dataset)."""
219
220    @wraps(method)
221    def new_method(self, *args, **kwargs):
222        _, param_dict = parse_user_args(method, *args, **kwargs)
223
224        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
225
226        dataset_dir = param_dict.get('dataset_dir')
227        check_dir(dataset_dir)
228
229        # check usage
230        usage = param_dict.get('usage')
231        if usage is not None:
232            check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
233
234        support_language_pair = [
235            ['en', 'nl'], ['en', 'de'], ['en', 'it'], ['en', 'ro'], ['ro', 'de'], ['ro', 'en'], ['ro', 'nl'],
236            ['ro', 'it'], ['de', 'ro'], ['de', 'en'], ['de', 'nl'], ['de', 'it'], ['it', 'en'], ['it', 'nl'],
237            ['it', 'de'], ['it', 'ro'], ['nl', 'de'], ['nl', 'en'], ['nl', 'it'], ['nl', 'ro']
238        ]
239        support_language_pair_tuple = (
240            ('en', 'nl'), ('en', 'de'), ('en', 'it'), ('en', 'ro'), ('ro', 'de'), ('ro', 'en'), ('ro', 'nl'),
241            ('ro', 'it'), ('de', 'ro'), ('de', 'en'), ('de', 'nl'), ('de', 'it'), ('it', 'en'), ('it', 'nl'),
242            ('it', 'de'), ('it', 'ro'), ('nl', 'de'), ('nl', 'en'), ('nl', 'it'), ('nl', 'ro')
243        )
244        # check language_pair
245        language_pair = param_dict.get('language_pair')
246        if language_pair is not None:
247            if isinstance(language_pair, (list,)):
248                check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair")
249            elif isinstance(language_pair, (tuple,)):
250                check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair")
251            else:
252                raise TypeError("language_pair should be a type list or tuple of length 2.")
253
254        validate_dataset_param_value(nreq_param_int, param_dict, int)
255        check_sampler_shuffle_shard_options(param_dict)
256
257        cache = param_dict.get('cache')
258        check_cache_option(cache)
259
260        return method(self, *args, **kwargs)
261
262    return new_method
263
264
265def check_kittidataset(method):
266    """A wrapper that wraps a parameter checker around the original Dataset(KITTIDataset)."""
267
268    @wraps(method)
269    def new_method(self, *args, **kwargs):
270        _, param_dict = parse_user_args(method, *args, **kwargs)
271
272        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
273        nreq_param_bool = ['shuffle', 'decode']
274
275        dataset_dir = param_dict.get('dataset_dir')
276        check_dir(dataset_dir)
277
278        usage = param_dict.get('usage')
279        if usage is not None:
280            check_valid_str(usage, ["train", "test"], "usage")
281
282        validate_dataset_param_value(nreq_param_int, param_dict, int)
283        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
284        check_sampler_shuffle_shard_options(param_dict)
285
286        cache = param_dict.get('cache')
287        check_cache_option(cache)
288
289        return method(self, *args, **kwargs)
290
291    return new_method
292
293
294def check_lsun_dataset(method):
295    """A wrapper that wraps a parameter checker around the original Dataset(LSUNDataset)."""
296
297    @wraps(method)
298    def new_method(self, *args, **kwargs):
299        _, param_dict = parse_user_args(method, *args, **kwargs)
300
301        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
302        nreq_param_bool = ['shuffle', 'decode']
303        nreq_param_list = ['classes']
304
305        dataset_dir = param_dict.get('dataset_dir')
306        check_dir(dataset_dir)
307
308        usage = param_dict.get('usage')
309        if usage is not None:
310            check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
311
312        validate_dataset_param_value(nreq_param_int, param_dict, int)
313        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
314        validate_dataset_param_value(nreq_param_list, param_dict, list)
315
316        categories = [
317            'bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen',
318            'living_room', 'restaurant', 'tower'
319        ]
320        classes = param_dict.get('classes')
321        if classes is not None:
322            for class_name in classes:
323                check_valid_str(class_name, categories, "classes")
324
325        check_sampler_shuffle_shard_options(param_dict)
326
327        cache = param_dict.get('cache')
328        check_cache_option(cache)
329
330        return method(self, *args, **kwargs)
331
332    return new_method
333
334
335def check_mnist_cifar_dataset(method):
336    """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
337
338    @wraps(method)
339    def new_method(self, *args, **kwargs):
340        _, param_dict = parse_user_args(method, *args, **kwargs)
341
342        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
343        nreq_param_bool = ['shuffle']
344
345        dataset_dir = param_dict.get('dataset_dir')
346        check_dir(dataset_dir)
347
348        usage = param_dict.get('usage')
349        if usage is not None:
350            check_valid_str(usage, ["train", "test", "all"], "usage")
351
352        validate_dataset_param_value(nreq_param_int, param_dict, int)
353        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
354
355        check_sampler_shuffle_shard_options(param_dict)
356
357        cache = param_dict.get('cache')
358        check_cache_option(cache)
359
360        return method(self, *args, **kwargs)
361
362    return new_method
363
364
365def check_omniglotdataset(method):
366    """A wrapper that wraps a parameter checker around the original Dataset(OmniglotDataset)."""
367
368    @wraps(method)
369    def new_method(self, *args, **kwargs):
370        _, param_dict = parse_user_args(method, *args, **kwargs)
371
372        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
373        nreq_param_bool = ['shuffle', 'background', 'decode']
374        dataset_dir = param_dict.get('dataset_dir')
375        check_dir(dataset_dir)
376
377        validate_dataset_param_value(nreq_param_int, param_dict, int)
378        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
379        check_sampler_shuffle_shard_options(param_dict)
380
381        cache = param_dict.get('cache')
382        check_cache_option(cache)
383
384        return method(self, *args, **kwargs)
385
386    return new_method
387
388
389def check_photo_tour_dataset(method):
390    """A wrapper that wraps a parameter checker around the original Dataset(PhotoTourDataset)."""
391
392    @wraps(method)
393    def new_method(self, *args, **kwargs):
394        _, param_dict = parse_user_args(method, *args, **kwargs)
395
396        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
397        nreq_param_bool = ['shuffle']
398
399        dataset_dir = param_dict.get('dataset_dir')
400        check_dir(dataset_dir)
401
402        usage = param_dict.get('usage')
403        if usage is not None:
404            check_valid_str(usage, ["train", "test"], "usage")
405        name = param_dict.get('name')
406        check_valid_str(name, ["notredame", "yosemite", "liberty", "notredame_harris",
407                               "yosemite_harris", "liberty_harris"], "name")
408        validate_dataset_param_value(nreq_param_int, param_dict, int)
409        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
410
411        check_sampler_shuffle_shard_options(param_dict)
412        cache = param_dict.get('cache')
413        check_cache_option(cache)
414        return method(self, *args, **kwargs)
415
416    return new_method
417
418
419def check_places365_dataset(method):
420    """A wrapper that wraps a parameter checker around the original Dataset(Places365Dataset)."""
421
422    @wraps(method)
423    def new_method(self, *args, **kwargs):
424        _, param_dict = parse_user_args(method, *args, **kwargs)
425
426        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
427        nreq_param_bool = ['shuffle', 'small', 'decode']
428
429        dataset_dir = param_dict.get('dataset_dir')
430        check_dir(dataset_dir)
431
432        usage = param_dict.get('usage')
433        if usage is not None:
434            check_valid_str(usage, ["train-standard", "train-challenge", "val"], "usage")
435
436        validate_dataset_param_value(nreq_param_int, param_dict, int)
437        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
438
439        check_sampler_shuffle_shard_options(param_dict)
440
441        cache = param_dict.get('cache')
442        check_cache_option(cache)
443
444        return method(self, *args, **kwargs)
445
446    return new_method
447
448
449def check_qmnist_dataset(method):
450    """A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset)."""
451
452    @wraps(method)
453    def new_method(self, *args, **kwargs):
454        _, param_dict = parse_user_args(method, *args, **kwargs)
455
456        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
457        nreq_param_bool = ['shuffle', 'compat']
458
459        dataset_dir = param_dict.get('dataset_dir')
460        check_dir(dataset_dir)
461
462        usage = param_dict.get('usage')
463        if usage is not None:
464            check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage")
465
466        validate_dataset_param_value(nreq_param_int, param_dict, int)
467        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
468
469        check_sampler_shuffle_shard_options(param_dict)
470
471        cache = param_dict.get('cache')
472        check_cache_option(cache)
473
474        return method(self, *args, **kwargs)
475
476    return new_method
477
478
479def check_manifestdataset(method):
480    """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
481
482    @wraps(method)
483    def new_method(self, *args, **kwargs):
484        _, param_dict = parse_user_args(method, *args, **kwargs)
485
486        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
487        nreq_param_bool = ['shuffle', 'decode']
488        nreq_param_str = ['usage']
489        nreq_param_dict = ['class_indexing']
490
491        dataset_file = param_dict.get('dataset_file')
492        check_file(dataset_file)
493
494        validate_dataset_param_value(nreq_param_int, param_dict, int)
495        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
496        validate_dataset_param_value(nreq_param_str, param_dict, str)
497        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
498
499        check_sampler_shuffle_shard_options(param_dict)
500
501        cache = param_dict.get('cache')
502        check_cache_option(cache)
503
504        return method(self, *args, **kwargs)
505
506    return new_method
507
508
509def check_sbu_dataset(method):
510    """A wrapper that wraps a parameter checker around the original Dataset(SBUDataset)."""
511
512    @wraps(method)
513    def new_method(self, *args, **kwargs):
514        _, param_dict = parse_user_args(method, *args, **kwargs)
515
516        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
517        nreq_param_bool = ['shuffle', 'decode']
518
519        dataset_dir = param_dict.get('dataset_dir')
520        check_dir(dataset_dir)
521
522        check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt"))
523        check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt"))
524        check_dir(os.path.join(dataset_dir, "sbu_images"))
525
526        validate_dataset_param_value(nreq_param_int, param_dict, int)
527        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
528
529        check_sampler_shuffle_shard_options(param_dict)
530
531        cache = param_dict.get('cache')
532        check_cache_option(cache)
533
534        return method(self, *args, **kwargs)
535
536    return new_method
537
538
539def check_sogou_news_dataset(method):
540    """A wrapper that wraps a parameter checker around the original Dataset(SogouNewsDataset)."""
541
542    @wraps(method)
543    def new_method(self, *args, **kwargs):
544        _, param_dict = parse_user_args(method, *args, **kwargs)
545
546        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
547
548        dataset_dir = param_dict.get('dataset_dir')
549        check_dir(dataset_dir)
550
551        usage = param_dict.get('usage')
552        if usage is not None:
553            check_valid_str(usage, ["train", "test", "all"], "usage")
554
555        validate_dataset_param_value(nreq_param_int, param_dict, int)
556        check_sampler_shuffle_shard_options(param_dict)
557
558        cache = param_dict.get('cache')
559        check_cache_option(cache)
560
561        return method(self, *args, **kwargs)
562
563    return new_method
564
565
566def check_tfrecorddataset(method):
567    """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
568
569    @wraps(method)
570    def new_method(self, *args, **kwargs):
571        _, param_dict = parse_user_args(method, *args, **kwargs)
572
573        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
574        nreq_param_list = ['columns_list']
575        nreq_param_bool = ['shard_equal_rows']
576
577        dataset_files = param_dict.get('dataset_files')
578        if not isinstance(dataset_files, (str, list)):
579            raise TypeError("dataset_files should be type str or a list of strings.")
580        if not dataset_files:
581            raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
582
583        validate_dataset_param_value(nreq_param_int, param_dict, int)
584        validate_dataset_param_value(nreq_param_list, param_dict, list)
585        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
586
587        compression_type = param_dict.get('compression_type')
588        if compression_type is not None and compression_type not in ['', 'ZLIB', 'GZIP']:
589            raise ValueError("Input compression_type can only be either '' (no compression), 'ZLIB', or 'GZIP', " +
590                             "but got '" + str(compression_type) + "'.")
591        if compression_type is not None and compression_type in ['ZLIB', 'GZIP'] and \
592            param_dict.get('num_samples') is not None:
593            if param_dict.get('num_shards') is not None and ((isinstance(dataset_files, str) and \
594                param_dict.get('num_shards') > 1) or (isinstance(dataset_files, list) and \
595                len(dataset_files) < param_dict.get('num_shards'))):
596                num_files = len(dataset_files) if isinstance(dataset_files, list) else 1
597                act_num_shard = param_dict.get('num_shards') if param_dict.get('num_shards') is not None else 1
598                raise ValueError("When compression_type is provided, the number of dataset files cannot be less " +
599                                 "than num_shards, but the actual number of files is " + str(num_files) +
600                                 " and actual num_shards is " + str(act_num_shard) + ".")
601            if param_dict.get('shard_equal_rows') is None or not param_dict.get('shard_equal_rows'):
602                logger.warning("If compression_type is set, shard_equal_rows will be ignored.")
603
604        check_sampler_shuffle_shard_options(param_dict)
605
606        cache = param_dict.get('cache')
607        check_cache_option(cache)
608
609        return method(self, *args, **kwargs)
610
611    return new_method
612
613
614def check_udpos_dataset(method):
615    """A wrapper that wraps a parameter checker around the original Dataset(UDPOSDataset)."""
616
617    @wraps(method)
618    def new_method(self, *args, **kwargs):
619        _, param_dict = parse_user_args(method, *args, **kwargs)
620
621        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
622
623        # check dataset_dir; required argument
624        dataset_dir = param_dict.get('dataset_dir')
625        check_dir(dataset_dir)
626
627        # check usage
628        usage = param_dict.get('usage')
629        if usage is not None:
630            check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
631
632        validate_dataset_param_value(nreq_param_int, param_dict, int)
633        check_sampler_shuffle_shard_options(param_dict)
634
635        cache = param_dict.get('cache')
636        check_cache_option(cache)
637
638        return method(self, *args, **kwargs)
639
640    return new_method
641
642
643def check_usps_dataset(method):
644    """A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""
645
646    @wraps(method)
647    def new_method(self, *args, **kwargs):
648        _, param_dict = parse_user_args(method, *args, **kwargs)
649
650        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
651
652        dataset_dir = param_dict.get('dataset_dir')
653        check_dir(dataset_dir)
654
655        usage = param_dict.get('usage')
656        if usage is not None:
657            check_valid_str(usage, ["train", "test", "all"], "usage")
658
659        validate_dataset_param_value(nreq_param_int, param_dict, int)
660        check_sampler_shuffle_shard_options(param_dict)
661
662        cache = param_dict.get('cache')
663        check_cache_option(cache)
664
665        return method(self, *args, **kwargs)
666
667    return new_method
668
669
670def check_caltech101_dataset(method):
671    """A wrapper that wraps a parameter checker around the original Dataset(Caltech101Dataset)."""
672
673    @wraps(method)
674    def new_method(self, *args, **kwargs):
675        _, param_dict = parse_user_args(method, *args, **kwargs)
676
677        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
678        nreq_param_bool = ['shuffle', 'decode']
679        nreq_param_str = ['target_type']
680
681        dataset_dir = param_dict.get('dataset_dir')
682        check_dir(dataset_dir)
683
684        target_type = param_dict.get('target_type')
685        if target_type is not None:
686            check_valid_str(target_type, ["category", "annotation", "all"], "target_type")
687
688        validate_dataset_param_value(nreq_param_int, param_dict, int)
689        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
690        validate_dataset_param_value(nreq_param_str, param_dict, str)
691        check_sampler_shuffle_shard_options(param_dict)
692
693        cache = param_dict.get('cache')
694        check_cache_option(cache)
695
696        return method(self, *args, **kwargs)
697
698    return new_method
699
700
701def check_caltech256_dataset(method):
702    """A wrapper that wraps a parameter checker around the original Dataset(Caltech256Dataset)."""
703
704    @wraps(method)
705    def new_method(self, *args, **kwargs):
706        _, param_dict = parse_user_args(method, *args, **kwargs)
707
708        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
709        nreq_param_bool = ['shuffle', 'decode']
710
711        dataset_dir = param_dict.get('dataset_dir')
712        check_dir(dataset_dir)
713
714        validate_dataset_param_value(nreq_param_int, param_dict, int)
715        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
716        check_sampler_shuffle_shard_options(param_dict)
717
718        cache = param_dict.get('cache')
719        check_cache_option(cache)
720
721        return method(self, *args, **kwargs)
722
723    return new_method
724
725
726def check_vocdataset(method):
727    """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
728
729    @wraps(method)
730    def new_method(self, *args, **kwargs):
731        _, param_dict = parse_user_args(method, *args, **kwargs)
732
733        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
734        nreq_param_bool = ['shuffle', 'decode']
735        nreq_param_dict = ['class_indexing']
736
737        dataset_dir = param_dict.get('dataset_dir')
738        check_dir(dataset_dir)
739
740        task = param_dict.get('task')
741        type_check(task, (str,), "task")
742
743        usage = param_dict.get('usage')
744        type_check(usage, (str,), "usage")
745        dataset_dir = os.path.realpath(dataset_dir)
746
747        if task == "Segmentation":
748            imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
749            if param_dict.get('class_indexing') is not None:
750                raise ValueError("class_indexing is not supported in Segmentation task.")
751        elif task == "Detection":
752            imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
753        else:
754            raise ValueError("Invalid task : " + task + ".")
755
756        decrypt = param_dict.get('decrypt')
757        if decrypt is not None and not callable(decrypt):
758            raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
759
760        check_file(imagesets_file)
761
762        validate_dataset_param_value(nreq_param_int, param_dict, int)
763        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
764        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
765        check_sampler_shuffle_shard_options(param_dict)
766
767        cache = param_dict.get('cache')
768        check_cache_option(cache)
769
770        return method(self, *args, **kwargs)
771
772    return new_method
773
774
775def check_cocodataset(method):
776    """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
777
778    @wraps(method)
779    def new_method(self, *args, **kwargs):
780        _, param_dict = parse_user_args(method, *args, **kwargs)
781
782        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
783        nreq_param_bool = ['shuffle', 'decode']
784
785        dataset_dir = param_dict.get('dataset_dir')
786        check_dir(dataset_dir)
787
788        annotation_file = param_dict.get('annotation_file')
789        check_file(annotation_file)
790
791        task = param_dict.get('task')
792        type_check(task, (str,), "task")
793
794        if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint', 'Captioning'}:
795            raise ValueError("Invalid task type: " + task + ".")
796
797        decrypt = param_dict.get('decrypt')
798        if decrypt is not None and not callable(decrypt):
799            raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
800
801        validate_dataset_param_value(nreq_param_int, param_dict, int)
802
803        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
804
805        sampler = param_dict.get('sampler')
806        if sampler is not None and isinstance(sampler, samplers.PKSampler):
807            raise ValueError("CocoDataset doesn't support PKSampler.")
808        check_sampler_shuffle_shard_options(param_dict)
809
810        cache = param_dict.get('cache')
811        check_cache_option(cache)
812
813        return method(self, *args, **kwargs)
814
815    return new_method
816
817
818def check_celebadataset(method):
819    """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
820
821    @wraps(method)
822    def new_method(self, *args, **kwargs):
823        _, param_dict = parse_user_args(method, *args, **kwargs)
824
825        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
826        nreq_param_bool = ['shuffle', 'decode']
827        nreq_param_list = ['extensions']
828        nreq_param_str = ['dataset_type']
829
830        dataset_dir = param_dict.get('dataset_dir')
831
832        check_dir(dataset_dir)
833
834        decrypt = param_dict.get('decrypt')
835        if decrypt is not None and not callable(decrypt):
836            raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
837
838        validate_dataset_param_value(nreq_param_int, param_dict, int)
839        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
840        validate_dataset_param_value(nreq_param_list, param_dict, list)
841        validate_dataset_param_value(nreq_param_str, param_dict, str)
842
843        usage = param_dict.get('usage')
844        if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
845            raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.")
846
847        check_sampler_shuffle_shard_options(param_dict)
848
849        sampler = param_dict.get('sampler')
850        if sampler is not None and isinstance(sampler, samplers.PKSampler):
851            raise ValueError("CelebADataset doesn't support PKSampler.")
852
853        cache = param_dict.get('cache')
854        check_cache_option(cache)
855
856        return method(self, *args, **kwargs)
857
858    return new_method
859
860
861def check_libri_tts_dataset(method):
862    """A wrapper that wraps a parameter checker around the original Dataset(LibriTTSDataset)."""
863
864    @wraps(method)
865    def new_method(self, *args, **kwargs):
866        _, param_dict = parse_user_args(method, *args, **kwargs)
867
868        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
869        nreq_param_bool = ['shuffle']
870
871        dataset_dir = param_dict.get('dataset_dir')
872        check_dir(dataset_dir)
873
874        usage = param_dict.get('usage')
875        if usage is not None:
876            check_valid_str(usage, ["dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100",
877                                    "train-clean-360", "train-other-500", "all"], "usage")
878        validate_dataset_param_value(nreq_param_int, param_dict, int)
879        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
880
881        check_sampler_shuffle_shard_options(param_dict)
882        cache = param_dict.get('cache')
883        check_cache_option(cache)
884
885        return method(self, *args, **kwargs)
886
887    return new_method
888
889
890def check_lj_speech_dataset(method):
891    """A wrapper that wraps a parameter checker around the original Dataset(LJSpeechDataset)."""
892
893    @wraps(method)
894    def new_method(self, *args, **kwargs):
895        _, param_dict = parse_user_args(method, *args, **kwargs)
896
897        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
898        nreq_param_bool = ['shuffle']
899
900        dataset_dir = param_dict.get('dataset_dir')
901        check_dir(dataset_dir)
902
903        validate_dataset_param_value(nreq_param_int, param_dict, int)
904        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
905
906        check_sampler_shuffle_shard_options(param_dict)
907
908        cache = param_dict.get('cache')
909        check_cache_option(cache)
910
911        return method(self, *args, **kwargs)
912
913    return new_method
914
915
916def check_lfw_dataset(method):
917    """A wrapper that wraps a parameter checker around the original Dataset(LFWDataset)."""
918
919    @wraps(method)
920    def new_method(self, *args, **kwargs):
921        _, param_dict = parse_user_args(method, *args, **kwargs)
922
923        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
924        nreq_param_bool = ['shuffle', 'decode']
925
926        dataset_dir = param_dict.get('dataset_dir')
927        check_dir(dataset_dir)
928
929        task = param_dict.get('task')
930        if task is not None:
931            check_valid_str(task, ["people", "pairs"], "task")
932
933        usage = param_dict.get('usage')
934        if usage is not None:
935            check_valid_str(usage, ["10fold", "train", "test", "all"], "usage")
936
937        image_set = param_dict.get('image_set')
938        if image_set is not None:
939            check_valid_str(image_set, ["original", "funneled", "deepfunneled"], "image_set")
940
941        validate_dataset_param_value(nreq_param_int, param_dict, int)
942        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
943        check_sampler_shuffle_shard_options(param_dict)
944
945        cache = param_dict.get('cache')
946        check_cache_option(cache)
947
948        return method(self, *args, **kwargs)
949
950    return new_method
951
952
953def check_save(method):
954    """A wrapper that wraps a parameter checker around the saved operation."""
955
956    @wraps(method)
957    def new_method(self, *args, **kwargs):
958        _, param_dict = parse_user_args(method, *args, **kwargs)
959
960        nreq_param_int = ['num_files']
961        nreq_param_str = ['file_name', 'file_type']
962        validate_dataset_param_value(nreq_param_int, param_dict, int)
963        if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
964            raise ValueError("num_files should between 0 and 1000.")
965        validate_dataset_param_value(nreq_param_str, param_dict, str)
966        if param_dict.get('file_type') != 'mindrecord':
967            raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
968        return method(self, *args, **kwargs)
969
970    return new_method
971
972
973def check_tuple_iterator(method):
974    """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
975
976    @wraps(method)
977    def new_method(self, *args, **kwargs):
978        [columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
979        nreq_param_bool = ['output_numpy']
980        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
981        if num_epochs is not None:
982            type_check(num_epochs, (int,), "num_epochs")
983            check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
984
985        if columns is not None:
986            check_columns(columns, "column_names")
987
988        return method(self, *args, **kwargs)
989
990    return new_method
991
992
993def check_dict_iterator(method):
994    """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
995
996    @wraps(method)
997    def new_method(self, *args, **kwargs):
998        [num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
999        nreq_param_bool = ['output_numpy']
1000        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1001        if num_epochs is not None:
1002            type_check(num_epochs, (int,), "num_epochs")
1003            check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
1004
1005        return method(self, *args, **kwargs)
1006
1007    return new_method
1008
1009
1010def check_minddataset(method):
1011    """A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
1012
1013    @wraps(method)
1014    def new_method(self, *args, **kwargs):
1015        _, param_dict = parse_user_args(method, *args, **kwargs)
1016
1017        nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
1018        nreq_param_list = ['columns_list']
1019        nreq_param_dict = ['padded_sample']
1020
1021        dataset_file = param_dict.get('dataset_files')
1022        if isinstance(dataset_file, list):
1023            if len(dataset_file) > 4096:
1024                logger.warning("The number of MindRecord files greater than 4096"
1025                               "may cause slow dataset initialization.")
1026            for f in dataset_file:
1027                check_file(f)
1028        else:
1029            check_file(dataset_file)
1030
1031        validate_dataset_param_value(nreq_param_int, param_dict, int)
1032        validate_dataset_param_value(nreq_param_list, param_dict, list)
1033        validate_dataset_param_value(nreq_param_dict, param_dict, dict)
1034
1035        check_sampler_shuffle_shard_options(param_dict)
1036
1037        check_padding_options(param_dict)
1038        return method(self, *args, **kwargs)
1039
1040    return new_method
1041
1042
1043def check_source_function(source):
1044    """Get used variable and source document in given function."""
1045    # check whether source is an instanced object of user defined class
1046    from types import FunctionType
1047    var = tuple()
1048    source_doc = ""
1049    if isinstance(source, FunctionType):
1050        try:
1051            var = ins.getclosurevars(source)
1052            source_doc = ins.getsource(source)
1053        except OSError:
1054            return ""
1055    else:
1056        try:
1057            source_attr = source.__class__.__dict__.keys()
1058            if '__init__' in source_attr:
1059                var = var + ins.getclosurevars(source.__class__.__init__)
1060                source_doc = source_doc + ins.getsource(source.__class__.__init__)
1061            if '__getitem__' in source_attr:
1062                var = var + ins.getclosurevars(source.__class__.__getitem__)
1063                source_doc = source_doc + ins.getsource(source.__class__.__getitem__)
1064            elif '__next__' in source_attr:
1065                var = var + ins.getclosurevars(source.__class__.__next__)
1066                source_doc = source_doc + ins.getsource(source.__class__.__next__)
1067        except (TypeError, OSError):
1068            # case: like input is LambdaType or GeneratorType, it will go to else branch, and unable to run normally
1069            pass
1070    return str(var) + source_doc
1071
1072
1073def check_generatordataset(method):
1074    """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
1075
1076    @wraps(method)
1077    def new_method(self, *args, **kwargs):
1078        _, param_dict = parse_user_args(method, *args, **kwargs)
1079
1080        source = param_dict.get('source')
1081
1082        if not callable(source):
1083            try:
1084                iter(source)
1085            except TypeError:
1086                raise TypeError("Input `source` function of GeneratorDataset should be callable, iterable or random"
1087                                " accessible, commonly it should implement one of the method like yield, __getitem__ or"
1088                                " __next__(__iter__).")
1089
1090        # check used variable and function document whether contain computing operator
1091        check_doc = check_source_function(source)
1092        check_list = ['mindspore.nn', 'mindspore.ops', 'mindspore.numpy', 'mindspore.compression']
1093        for item in check_list:
1094            if item in check_doc:
1095                setattr(self, 'operator_mixed', True)
1096                break
1097
1098        column_names = param_dict.get('column_names')
1099        if column_names is not None:
1100            check_columns(column_names, "column_names")
1101        schema = param_dict.get('schema')
1102        if column_names is None and schema is None:
1103            raise ValueError("Neither columns_names nor schema are provided.")
1104
1105        if schema is not None:
1106            if not isinstance(schema, (datasets.Schema, str)):
1107                raise ValueError("schema should be a path to schema file or a schema object.")
1108
1109        # check optional argument
1110        nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
1111        validate_dataset_param_value(nreq_param_int, param_dict, int)
1112        nreq_param_list = ["column_types"]
1113        validate_dataset_param_value(nreq_param_list, param_dict, list)
1114        nreq_param_bool = ["shuffle", "python_multiprocessing"]
1115        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1116
1117        check_value(param_dict.get("max_rowsize"), [-1, INT32_MAX], "max_rowsize")
1118
1119        num_shards = param_dict.get("num_shards")
1120        shard_id = param_dict.get("shard_id")
1121        check_dataset_num_shards_shard_id(num_shards, shard_id)
1122
1123        sampler = param_dict.get("sampler")
1124        if sampler is not None:
1125            if isinstance(sampler, samplers.PKSampler):
1126                raise ValueError("GeneratorDataset doesn't support PKSampler.")
1127            if not isinstance(sampler, samplers.BuiltinSampler):
1128                try:
1129                    iter(sampler)
1130                except TypeError:
1131                    raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.")
1132
1133        if sampler is not None and not hasattr(source, "__getitem__"):
1134            raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.")
1135        if num_shards is not None and not hasattr(source, "__getitem__"):
1136            raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.")
1137
1138        return method(self, *args, **kwargs)
1139
1140    return new_method
1141
1142
1143def check_random_dataset(method):
1144    """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
1145
1146    @wraps(method)
1147    def new_method(self, *args, **kwargs):
1148        _, param_dict = parse_user_args(method, *args, **kwargs)
1149
1150        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
1151        nreq_param_bool = ['shuffle']
1152        nreq_param_list = ['columns_list']
1153
1154        validate_dataset_param_value(nreq_param_int, param_dict, int)
1155        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1156        validate_dataset_param_value(nreq_param_list, param_dict, list)
1157
1158        check_sampler_shuffle_shard_options(param_dict)
1159
1160        cache = param_dict.get('cache')
1161        check_cache_option(cache)
1162
1163        return method(self, *args, **kwargs)
1164
1165    return new_method
1166
1167
1168def check_rendered_sst2_dataset(method):
1169    """A wrapper that wraps a parameter checker around the original Dataset(RenderedSST2Dataset)."""
1170
1171    @wraps(method)
1172    def new_method(self, *args, **kwargs):
1173        _, param_dict = parse_user_args(method, *args, **kwargs)
1174
1175        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1176        nreq_param_bool = ['shuffle', 'decode']
1177
1178        dataset_dir = param_dict.get('dataset_dir')
1179        usage = param_dict.get('usage')
1180        check_dir(dataset_dir)
1181        if usage is not None:
1182            check_valid_str(usage, ['val', 'all', 'train', 'test'])
1183
1184        validate_dataset_param_value(nreq_param_int, param_dict, int)
1185        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1186        check_sampler_shuffle_shard_options(param_dict)
1187
1188        cache = param_dict.get('cache')
1189        check_cache_option(cache)
1190
1191        return method(self, *args, **kwargs)
1192
1193    return new_method
1194
1195
1196def check_pad_info(key, val):
1197    """check the key and value pair of pad_info in batch"""
1198    type_check(key, (str,), "key in pad_info")
1199
1200    if val is not None:
1201        if len(val) != 2:
1202            raise ValueError("value of pad_info should be a tuple of size 2.")
1203        type_check(val, (tuple,), "value in pad_info")
1204
1205        if val[0] is not None:
1206            type_check(val[0], (list,), "shape in pad_info")
1207
1208            for dim in val[0]:
1209                if dim is not None:
1210                    check_pos_int32(dim, "dim of shape in pad_info")
1211        if val[1] is not None:
1212            type_check(val[1], (int, float, str, bytes), "pad_value")
1213
1214
1215def check_bucket_batch_by_length(method):
1216    """check the input arguments of bucket_batch_by_length."""
1217
1218    @wraps(method)
1219    def new_method(self, *args, **kwargs):
1220        [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
1221         pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
1222
1223        nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
1224
1225        type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
1226
1227        nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
1228        type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
1229
1230        # check column_names: must be list of string.
1231        check_columns(column_names, "column_names")
1232
1233        if element_length_function is None and len(column_names) != 1:
1234            raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
1235
1236        if element_length_function is not None and not callable(element_length_function):
1237            raise TypeError("element_length_function object is not callable.")
1238
1239        # check bucket_boundaries: must be list of int, positive and strictly increasing
1240        if not bucket_boundaries:
1241            raise ValueError("bucket_boundaries cannot be empty.")
1242
1243        all_int = all(isinstance(item, int) for item in bucket_boundaries)
1244        if not all_int:
1245            raise TypeError("bucket_boundaries should be a list of int.")
1246
1247        all_non_negative = all(item > 0 for item in bucket_boundaries)
1248        if not all_non_negative:
1249            raise ValueError("bucket_boundaries must only contain positive numbers.")
1250
1251        for i in range(len(bucket_boundaries) - 1):
1252            if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
1253                raise ValueError("bucket_boundaries should be strictly increasing.")
1254
1255        # check bucket_batch_sizes: must be list of int and positive
1256        if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
1257            raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
1258
1259        all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
1260        if not all_int:
1261            raise TypeError("bucket_batch_sizes should be a list of int.")
1262
1263        all_non_negative = all(item > 0 for item in bucket_batch_sizes)
1264        if not all_non_negative:
1265            raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
1266
1267        if pad_info is not None:
1268            type_check(pad_info, (dict,), "pad_info")
1269
1270            for k, v in pad_info.items():
1271                check_pad_info(k, v)
1272
1273        return method(self, *args, **kwargs)
1274
1275    return new_method
1276
1277
1278def get_batch_kwargs_from_dict(param_dict):
1279    """get batch operation kwargs parameters."""
1280    if param_dict is not None:
1281        per_batch_map = param_dict.get("per_batch_map", None)
1282        input_columns = param_dict.get("input_columns", None)
1283        output_columns = param_dict.get("output_columns", None)
1284        python_multiprocessing = param_dict.get("python_multiprocessing", False)
1285        max_rowsize = param_dict.get("max_rowsize", 16)
1286    return per_batch_map, input_columns, output_columns, python_multiprocessing, max_rowsize
1287
1288
1289def check_batch(method):
1290    """check the input arguments of batch."""
1291
1292    @wraps(method)
1293    def new_method(self, *args, **kwargs):
1294        [batch_size, drop_remainder, num_parallel_workers, param_dict], _ = parse_user_args(method, *args, **kwargs)
1295
1296        (per_batch_map, input_columns, output_columns, python_multiprocessing, max_rowsize) = \
1297            get_batch_kwargs_from_dict(param_dict)
1298
1299        if not (isinstance(batch_size, int) or (callable(batch_size))):
1300            raise TypeError("batch_size should either be an int or a callable.")
1301
1302        if callable(batch_size):
1303            sig = ins.signature(batch_size)
1304            if len(sig.parameters) != 1:
1305                raise ValueError("callable batch_size should take one parameter (BatchInfo).")
1306        else:
1307            check_pos_int32(int(batch_size), "batch_size")
1308
1309        if num_parallel_workers is not None:
1310            check_num_parallel_workers(num_parallel_workers)
1311        type_check(drop_remainder, (bool,), "drop_remainder")
1312
1313        check_max_rowsize(max_rowsize)
1314
1315        if (input_columns is not None) and (per_batch_map is None):
1316            # input_columns must be None when per_batch_map is not set
1317            raise ValueError("input_columns can be specified only when per_batch_map is set.")
1318
1319        if input_columns is not None:
1320            check_columns(input_columns, "input_columns")
1321            if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
1322                raise ValueError("The signature of per_batch_map should match with input columns.")
1323
1324        if output_columns is not None:
1325            check_columns(output_columns, "output_columns")
1326
1327        if python_multiprocessing is not None:
1328            type_check(python_multiprocessing, (bool,), "python_multiprocessing")
1329
1330        return method(self, *args, **kwargs)
1331
1332    return new_method
1333
1334
1335def check_padded_batch(method):
1336    """check the input arguments of padded_batch."""
1337
1338    @wraps(method)
1339    def new_method(self, *args, **kwargs):
1340        [batch_size, drop_remainder, num_parallel_workers, pad_info], _ = parse_user_args(method, *args, **kwargs)
1341
1342        if not (isinstance(batch_size, int) or (callable(batch_size))):
1343            raise TypeError("batch_size should either be an int or a callable.")
1344
1345        if callable(batch_size):
1346            sig = ins.signature(batch_size)
1347            if len(sig.parameters) != 1:
1348                raise ValueError("callable batch_size should take one parameter (BatchInfo).")
1349        else:
1350            check_pos_int32(int(batch_size), "batch_size")
1351
1352        if num_parallel_workers is not None:
1353            check_num_parallel_workers(num_parallel_workers)
1354        type_check(drop_remainder, (bool,), "drop_remainder")
1355
1356        if pad_info is not None:
1357            type_check(pad_info, (dict,), "pad_info")
1358            for k, v in pad_info.items():
1359                check_pad_info(k, v)
1360
1361        return method(self, *args, **kwargs)
1362
1363    return new_method
1364
1365
1366def check_sync_wait(method):
1367    """check the input arguments of sync_wait."""
1368
1369    @wraps(method)
1370    def new_method(self, *args, **kwargs):
1371        [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
1372
1373        type_check(condition_name, (str,), "condition_name")
1374        type_check(num_batch, (int,), "num_batch")
1375
1376        return method(self, *args, **kwargs)
1377
1378    return new_method
1379
1380
1381def check_shuffle(method):
1382    """check the input arguments of shuffle."""
1383
1384    @wraps(method)
1385    def new_method(self, *args, **kwargs):
1386        [buffer_size], _ = parse_user_args(method, *args, **kwargs)
1387
1388        type_check(buffer_size, (int,), "buffer_size")
1389
1390        check_value(buffer_size, [2, INT32_MAX], "buffer_size")
1391
1392        return method(self, *args, **kwargs)
1393
1394    return new_method
1395
1396
1397def get_map_kwargs_from_dict(param_dict):
1398    """get map operation kwargs parameters."""
1399    if param_dict is not None:
1400        python_multiprocessing = param_dict.get("python_multiprocessing", False)
1401        max_rowsize = param_dict.get("max_rowsize", 16)
1402        cache = param_dict.get("cache", None)
1403        callbacks = param_dict.get("callbacks", None)
1404        offload = param_dict.get("offload", None)
1405    return python_multiprocessing, max_rowsize, cache, callbacks, offload
1406
1407
1408def check_max_rowsize(max_rowsize):
1409    """check the max_rowsize"""
1410    type_check(max_rowsize, (int, list), "max_rowsize")
1411    if isinstance(max_rowsize, int):
1412        type_check(max_rowsize, (int,), "max_rowsize")
1413        check_value(max_rowsize, [-1, INT32_MAX], "max_rowsize")
1414    elif isinstance(max_rowsize, list) and len(max_rowsize) == 2:
1415        for index, value in enumerate(max_rowsize):
1416            type_check(value, (int,), "max_rowsize[{}]".format(index))
1417            check_value(value, [-1, INT32_MAX], "max_rowsizei[{}]".format(index))
1418    else:
1419        raise TypeError("max_rowsize should be a single integer or a list[in_rowsize, out_rowsize] of length 2.")
1420
1421
1422def check_map(method):
1423    """check the input arguments of map."""
1424
1425    @wraps(method)
1426    def new_method(self, *args, **kwargs):
1427        from mindspore.dataset.callback import DSCallback
1428        [operations, input_columns, output_columns, column_order, num_parallel_workers, param_dict], _ = \
1429            parse_user_args(method, *args, **kwargs)
1430
1431        if column_order is not None:
1432            raise ValueError("The parameter 'column_order' had been deleted in map operation. "
1433                             "Please use '.project' operation instead.\n"
1434                             ">> # Usage of old api:\n"
1435                             ">> dataset = dataset.map(operations=PyFunc,\n"
1436                             ">>                       input_columns=[\"column_a\"],\n"
1437                             ">>                       output_columns=[\"column_b\", \"column_c\"],\n"
1438                             ">>                       column_order=[\"column_b\", \"column_c\"])\n"
1439                             ">> # Usage of new api:\n"
1440                             ">> dataset = dataset.map(operations=PyFunc,\n"
1441                             ">>                       input_columns=[\"column_a\"],\n"
1442                             ">>                       output_columns=[\"column_b\", \"column_c\"])\n"
1443                             ">> dataset = dataset.project([\"column_b\", \"column_c\"])")
1444
1445        (python_multiprocessing, max_rowsize, cache, callbacks, offload) = get_map_kwargs_from_dict(param_dict)
1446
1447        # check whether network computing operator exist in input operations(python function)
1448        # check used variable and function document whether contain computing operator
1449        from types import FunctionType
1450        if isinstance(operations, FunctionType):
1451            try:
1452                var = ins.getclosurevars(operations)
1453                operations_doc = ins.getsource(operations)
1454                check_list = ['mindspore.nn', 'mindspore.ops', 'mindspore.numpy', 'mindspore.compression']
1455                check_doc = str(var) + operations_doc
1456                for item in check_list:
1457                    if item in check_doc:
1458                        setattr(self, 'operator_mixed', True)
1459                        break
1460            except OSError:
1461                pass
1462
1463        operations = operations if isinstance(operations, list) else [operations]
1464        # import nn and ops locally for type check
1465        from mindspore import nn, ops
1466        for item in operations:
1467            if isinstance(item, (nn.Cell, ops.Primitive)):
1468                raise ValueError("Input operations should not contain network computing operator like in "
1469                                 "mindspore.nn or mindspore.ops, got operation: ", str(item))
1470
1471        nreq_param_columns = ['input_columns', 'output_columns']
1472
1473        if num_parallel_workers is not None:
1474            check_num_parallel_workers(num_parallel_workers)
1475        type_check(python_multiprocessing, (bool,), "python_multiprocessing")
1476        check_cache_option(cache)
1477        check_max_rowsize(max_rowsize)
1478        if offload is not None:
1479            type_check(offload, (bool,), "offload")
1480
1481        if callbacks is not None:
1482            if isinstance(callbacks, (list, tuple)):
1483                type_check_list(callbacks, (DSCallback,), "callbacks")
1484            else:
1485                type_check(callbacks, (DSCallback,), "callbacks")
1486
1487        for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
1488            if param is not None:
1489                check_columns(param, param_name)
1490        if callbacks is not None:
1491            type_check(callbacks, (list, DSCallback), "callbacks")
1492
1493        return method(self, *args, **kwargs)
1494
1495    return new_method
1496
1497
1498def check_filter(method):
1499    """"check the input arguments of filter."""
1500
1501    @wraps(method)
1502    def new_method(self, *args, **kwargs):
1503        [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
1504        if not callable(predicate):
1505            raise TypeError("Predicate should be a Python function or a callable Python object.")
1506
1507        if num_parallel_workers is not None:
1508            check_num_parallel_workers(num_parallel_workers)
1509
1510        if input_columns is not None:
1511            check_columns(input_columns, "input_columns")
1512
1513        return method(self, *args, **kwargs)
1514
1515    return new_method
1516
1517
1518def check_repeat(method):
1519    """check the input arguments of repeat."""
1520
1521    @wraps(method)
1522    def new_method(self, *args, **kwargs):
1523        [count], _ = parse_user_args(method, *args, **kwargs)
1524
1525        type_check(count, (int, type(None)), "repeat")
1526        if isinstance(count, int):
1527            if (count <= 0 and count != -1) or count > INT32_MAX:
1528                raise ValueError("count should be either -1 or positive integer, range[1, INT32_MAX].")
1529        return method(self, *args, **kwargs)
1530
1531    return new_method
1532
1533
1534def check_skip(method):
1535    """check the input arguments of skip."""
1536
1537    @wraps(method)
1538    def new_method(self, *args, **kwargs):
1539        [count], _ = parse_user_args(method, *args, **kwargs)
1540
1541        type_check(count, (int,), "count")
1542        check_value(count, (0, INT32_MAX), "count")
1543
1544        return method(self, *args, **kwargs)
1545
1546    return new_method
1547
1548
1549def check_take(method):
1550    """check the input arguments of take."""
1551
1552    @wraps(method)
1553    def new_method(self, *args, **kwargs):
1554        [count], _ = parse_user_args(method, *args, **kwargs)
1555        type_check(count, (int,), "count")
1556        if (count <= 0 and count != -1) or count > INT32_MAX:
1557            raise ValueError("count should be either -1 or within the required interval of ({}, {}], got {}."
1558                             .format(0, INT32_MAX, count))
1559
1560        return method(self, *args, **kwargs)
1561
1562    return new_method
1563
1564
1565def check_positive_int32(method):
1566    """check whether the input argument is positive and int, only works for functions with one input."""
1567
1568    @wraps(method)
1569    def new_method(self, *args, **kwargs):
1570        [count], param_dict = parse_user_args(method, *args, **kwargs)
1571        para_name = None
1572        for key in list(param_dict.keys()):
1573            if key not in ['self', 'cls']:
1574                para_name = key
1575        # Need to get default value of param
1576        if count is not None:
1577            check_pos_int32(count, para_name)
1578
1579        return method(self, *args, **kwargs)
1580
1581    return new_method
1582
1583
1584def check_device_send(method):
1585    """check the input argument of device_que."""
1586
1587    @wraps(method)
1588    def new_method(self, *args, **kwargs):
1589        [send_epoch_end, create_data_info_queue, queue_name], _ = parse_user_args(method, *args, **kwargs)
1590        type_check(send_epoch_end, (bool,), "send_epoch_end")
1591        type_check(create_data_info_queue, (bool,), "create_data_info_queue")
1592        type_check(queue_name, (str,), "queue_name")
1593
1594        return method(self, *args, **kwargs)
1595
1596    return new_method
1597
1598
1599def check_total_batch(total_batch):
1600    check_int32(total_batch, "total_batch")
1601
1602
1603def check_zip(method):
1604    """check the input arguments of zip."""
1605
1606    @wraps(method)
1607    def new_method(*args, **kwargs):
1608        [ds], _ = parse_user_args(method, *args, **kwargs)
1609        type_check(ds, (tuple,), "datasets")
1610
1611        return method(*args, **kwargs)
1612
1613    return new_method
1614
1615
1616def check_zip_dataset(method):
1617    """check the input arguments of zip method in `Dataset` ."""
1618
1619    @wraps(method)
1620    def new_method(self, *args, **kwargs):
1621        [ds], _ = parse_user_args(method, *args, **kwargs)
1622        type_check(ds, (tuple, datasets.Dataset), "datasets")
1623
1624        return method(self, *args, **kwargs)
1625
1626    return new_method
1627
1628
1629def check_concat(method):
1630    """check the input arguments of concat method in `Dataset` ."""
1631
1632    @wraps(method)
1633    def new_method(self, *args, **kwargs):
1634        [ds], _ = parse_user_args(method, *args, **kwargs)
1635        type_check(ds, (list, datasets.Dataset), "datasets")
1636        if isinstance(ds, list):
1637            type_check_list(ds, (datasets.Dataset,), "dataset")
1638        return method(self, *args, **kwargs)
1639
1640    return new_method
1641
1642
1643def check_rename(method):
1644    """check the input arguments of rename."""
1645
1646    @wraps(method)
1647    def new_method(self, *args, **kwargs):
1648        values, _ = parse_user_args(method, *args, **kwargs)
1649
1650        req_param_columns = ['input_columns', 'output_columns']
1651        for param_name, param in zip(req_param_columns, values):
1652            check_columns(param, param_name)
1653
1654        input_size, output_size = 1, 1
1655        input_columns, output_columns = values
1656        if isinstance(input_columns, list):
1657            input_size = len(input_columns)
1658        if isinstance(output_columns, list):
1659            output_size = len(output_columns)
1660        if input_size != output_size:
1661            raise ValueError("Number of column in input_columns and output_columns is not equal.")
1662
1663        return method(self, *args, **kwargs)
1664
1665    return new_method
1666
1667
1668def check_output_shape(method):
1669    """check the input arguments of output_shape."""
1670
1671    @wraps(method)
1672    def new_method(self, *args, **kwargs):
1673        _, param_dict = parse_user_args(method, *args, **kwargs)
1674        estimate = param_dict.get('estimate')
1675        type_check(estimate, (bool,), "estimate")
1676
1677        return method(self, *args, **kwargs)
1678
1679    return new_method
1680
1681
1682def check_project(method):
1683    """check the input arguments of project."""
1684
1685    @wraps(method)
1686    def new_method(self, *args, **kwargs):
1687        [columns], _ = parse_user_args(method, *args, **kwargs)
1688        check_columns(columns, 'columns')
1689
1690        return method(self, *args, **kwargs)
1691
1692    return new_method
1693
1694
1695def check_schema(method):
1696    """check the input arguments of Schema.__init__."""
1697
1698    @wraps(method)
1699    def new_method(self, *args, **kwargs):
1700        [schema_file], _ = parse_user_args(method, *args, **kwargs)
1701
1702        if schema_file is not None:
1703            check_file(schema_file)
1704
1705        return method(self, *args, **kwargs)
1706
1707    return new_method
1708
1709
1710def check_add_column(method):
1711    """check the input arguments of add_column."""
1712
1713    @wraps(method)
1714    def new_method(self, *args, **kwargs):
1715        [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
1716
1717        type_check(name, (str,), "name")
1718
1719        if not name:
1720            raise TypeError("Expected non-empty string for column name.")
1721
1722        if de_type is not None:
1723            if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
1724                raise TypeError("Unknown column type: {}.".format(de_type))
1725        else:
1726            raise TypeError("Expected non-empty string for de_type.")
1727
1728        if shape is not None:
1729            type_check(shape, (list,), "shape")
1730            type_check_list(shape, (int,), "shape")
1731
1732        return method(self, *args, **kwargs)
1733
1734    return new_method
1735
1736
1737def check_cluedataset(method):
1738    """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
1739
1740    @wraps(method)
1741    def new_method(self, *args, **kwargs):
1742        _, param_dict = parse_user_args(method, *args, **kwargs)
1743
1744        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1745
1746        dataset_files = param_dict.get('dataset_files')
1747        type_check(dataset_files, (str, list), "dataset files")
1748        if not dataset_files:
1749            raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
1750
1751        # check task
1752        task_param = param_dict.get('task')
1753        if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
1754            raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.")
1755
1756        # check usage
1757        usage_param = param_dict.get('usage')
1758        if usage_param not in ['train', 'test', 'eval']:
1759            raise ValueError("usage should be 'train', 'test' or 'eval'.")
1760
1761        validate_dataset_param_value(nreq_param_int, param_dict, int)
1762        check_sampler_shuffle_shard_options(param_dict)
1763
1764        cache = param_dict.get('cache')
1765        check_cache_option(cache)
1766
1767        return method(self, *args, **kwargs)
1768
1769    return new_method
1770
1771
1772def check_csvdataset(method):
1773    """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
1774
1775    @wraps(method)
1776    def new_method(self, *args, **kwargs):
1777        _, param_dict = parse_user_args(method, *args, **kwargs)
1778
1779        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1780
1781        # check dataset_files; required argument
1782        dataset_files = param_dict.get('dataset_files')
1783        type_check(dataset_files, (str, list), "dataset files")
1784        if not dataset_files:
1785            raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
1786
1787        # check field_delim
1788        field_delim = param_dict.get('field_delim')
1789        if field_delim is not None:
1790            type_check(field_delim, (str,), 'field delim')
1791            if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
1792                raise ValueError("field_delim is invalid.")
1793
1794        # check column_defaults
1795        column_defaults = param_dict.get('column_defaults')
1796        if column_defaults is not None:
1797            if not isinstance(column_defaults, list):
1798                raise TypeError("column_defaults should be type of list.")
1799            for item in column_defaults:
1800                if not isinstance(item, (str, int, float)):
1801                    raise TypeError("column type in column_defaults is invalid.")
1802
1803        # check column_names: must be list of string.
1804        column_names = param_dict.get("column_names")
1805        if column_names is not None:
1806            all_string = all(isinstance(item, str) for item in column_names)
1807            if not all_string:
1808                raise TypeError("column_names should be a list of str.")
1809
1810        validate_dataset_param_value(nreq_param_int, param_dict, int)
1811        check_sampler_shuffle_shard_options(param_dict)
1812
1813        cache = param_dict.get('cache')
1814        check_cache_option(cache)
1815
1816        return method(self, *args, **kwargs)
1817
1818    return new_method
1819
1820
1821def check_flowers102dataset(method):
1822    """A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset)."""
1823
1824    @wraps(method)
1825    def new_method(self, *args, **kwargs):
1826        _, param_dict = parse_user_args(method, *args, **kwargs)
1827
1828        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1829        nreq_param_bool = ['shuffle', 'decode']
1830
1831        dataset_dir = param_dict.get('dataset_dir')
1832        check_dir(dataset_dir)
1833
1834        check_dir(os.path.join(dataset_dir, "jpg"))
1835
1836        check_file(os.path.join(dataset_dir, "imagelabels.mat"))
1837        check_file(os.path.join(dataset_dir, "setid.mat"))
1838
1839        usage = param_dict.get('usage')
1840        if usage is not None:
1841            check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
1842
1843        task = param_dict.get('task')
1844        if task is not None:
1845            check_valid_str(task, ["Classification", "Segmentation"], "task")
1846        if task == "Segmentation":
1847            check_dir(os.path.join(dataset_dir, "segmim"))
1848
1849        validate_dataset_param_value(nreq_param_int, param_dict, int)
1850        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
1851
1852        check_sampler_shuffle_shard_options(param_dict)
1853
1854        return method(self, *args, **kwargs)
1855
1856    return new_method
1857
1858
1859def check_textfiledataset(method):
1860    """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
1861
1862    @wraps(method)
1863    def new_method(self, *args, **kwargs):
1864        _, param_dict = parse_user_args(method, *args, **kwargs)
1865
1866        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1867
1868        dataset_files = param_dict.get('dataset_files')
1869        type_check(dataset_files, (str, list), "dataset files")
1870        if not dataset_files:
1871            raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
1872
1873        validate_dataset_param_value(nreq_param_int, param_dict, int)
1874        check_sampler_shuffle_shard_options(param_dict)
1875
1876        cache = param_dict.get('cache')
1877        check_cache_option(cache)
1878
1879        return method(self, *args, **kwargs)
1880
1881    return new_method
1882
1883
1884def check_penn_treebank_dataset(method):
1885    """A wrapper that wraps a parameter checker around the original Dataset(PennTreebankDataset)."""
1886
1887    @wraps(method)
1888    def new_method(self, *args, **kwargs):
1889        _, param_dict = parse_user_args(method, *args, **kwargs)
1890
1891        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
1892
1893        # check dataset_dir; required argument
1894        dataset_dir = param_dict.get('dataset_dir')
1895        check_dir(dataset_dir)
1896
1897        # check usage
1898        usage = param_dict.get('usage')
1899        if usage is not None:
1900            check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
1901
1902        validate_dataset_param_value(nreq_param_int, param_dict, int)
1903        check_sampler_shuffle_shard_options(param_dict)
1904
1905        cache = param_dict.get('cache')
1906        check_cache_option(cache)
1907
1908        return method(self, *args, **kwargs)
1909
1910    return new_method
1911
1912
1913def check_split(method):
1914    """check the input arguments of split."""
1915
1916    @wraps(method)
1917    def new_method(self, *args, **kwargs):
1918        [sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
1919
1920        type_check(sizes, (list,), "sizes")
1921        type_check(randomize, (bool,), "randomize")
1922
1923        # check sizes: must be list of float or list of int
1924        if not sizes:
1925            raise ValueError("sizes cannot be empty.")
1926
1927        all_int = all(isinstance(item, int) for item in sizes)
1928        all_float = all(isinstance(item, float) for item in sizes)
1929
1930        if not (all_int or all_float):
1931            raise ValueError("sizes should be list of int or list of float.")
1932
1933        if all_int:
1934            all_positive = all(item > 0 for item in sizes)
1935            if not all_positive:
1936                raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
1937
1938        if all_float:
1939            all_valid_percentages = all(0 < item <= 1 for item in sizes)
1940            if not all_valid_percentages:
1941                raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
1942
1943            epsilon = 0.00001
1944            if not abs(sum(sizes) - 1) < epsilon:
1945                raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
1946
1947        return method(self, *args, **kwargs)
1948
1949    return new_method
1950
1951
1952def check_hostname(hostname):
1953    if not hostname or len(hostname) > 255:
1954        return False
1955    if hostname[-1] == ".":
1956        hostname = hostname[:-1]  # strip exactly one dot from the right, if present
1957    allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
1958    return all(allowed.match(x) for x in hostname.split("."))
1959
1960
1961def check_numpyslicesdataset(method):
1962    """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
1963
1964    @wraps(method)
1965    def new_method(self, *args, **kwargs):
1966        _, param_dict = parse_user_args(method, *args, **kwargs)
1967
1968        data = param_dict.get("data")
1969        column_names = param_dict.get("column_names")
1970        type_check(data, (list, tuple, dict, np.ndarray), "data")
1971        if data is None or len(data) == 0:  # pylint: disable=len-as-condition
1972            raise ValueError("Argument data cannot be empty")
1973        if isinstance(data, tuple):
1974            type_check(data[0], (list, np.ndarray), "data[0]")
1975
1976        # check column_names
1977        if column_names is not None:
1978            check_columns(column_names, "column_names")
1979
1980            # check num of input column in column_names
1981            column_num = 1 if isinstance(column_names, str) else len(column_names)
1982            if isinstance(data, dict):
1983                data_column = len(list(data.keys()))
1984                if column_num != data_column:
1985                    raise ValueError("Num of input column names is {0}, but required is {1}."
1986                                     .format(column_num, data_column))
1987
1988            elif isinstance(data, tuple):
1989                if column_num != len(data):
1990                    raise ValueError("Num of input column names is {0}, but required is {1}."
1991                                     .format(column_num, len(data)))
1992            else:
1993                if column_num != 1:
1994                    raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
1995                                     .format(column_num, 1))
1996
1997        return method(self, *args, **kwargs)
1998
1999    return new_method
2000
2001
2002def check_paddeddataset(method):
2003    """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
2004
2005    @wraps(method)
2006    def new_method(self, *args, **kwargs):
2007        _, param_dict = parse_user_args(method, *args, **kwargs)
2008
2009        padded_samples = param_dict.get("padded_samples")
2010        if not padded_samples:
2011            raise ValueError("padded_samples cannot be empty.")
2012        type_check(padded_samples, (list,), "padded_samples")
2013        type_check(padded_samples[0], (dict,), "padded_element")
2014        return method(self, *args, **kwargs)
2015
2016    return new_method
2017
2018
2019def check_cache_option(cache):
2020    """Sanity check for cache parameter"""
2021    if cache is not None:
2022        type_check(cache, (cache_client.DatasetCache,), "cache")
2023
2024
2025def check_to_device_send(method):
2026    """Check the input arguments of send function for TransferDataset."""
2027
2028    @wraps(method)
2029    def new_method(self, *args, **kwargs):
2030        [num_epochs], _ = parse_user_args(method, *args, **kwargs)
2031
2032        if num_epochs is not None:
2033            type_check(num_epochs, (int,), "num_epochs")
2034            check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
2035
2036        return method(self, *args, **kwargs)
2037
2038    return new_method
2039
2040
2041def check_emnist_dataset(method):
2042    """A wrapper that wraps a parameter checker emnist dataset"""
2043
2044    @wraps(method)
2045    def new_method(self, *args, **kwargs):
2046        _, param_dict = parse_user_args(method, *args, **kwargs)
2047
2048        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2049        nreq_param_bool = ['shuffle']
2050
2051        validate_dataset_param_value(nreq_param_int, param_dict, int)
2052        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2053
2054        dataset_dir = param_dict.get('dataset_dir')
2055        check_dir(dataset_dir)
2056
2057        name = param_dict.get('name')
2058        check_valid_str(name, ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"], "name")
2059
2060        usage = param_dict.get('usage')
2061        if usage is not None:
2062            check_valid_str(usage, ["train", "test", "all"], "usage")
2063
2064        check_sampler_shuffle_shard_options(param_dict)
2065
2066        cache = param_dict.get('cache')
2067        check_cache_option(cache)
2068
2069        return method(self, *args, **kwargs)
2070
2071    return new_method
2072
2073
2074def check_flickr_dataset(method):
2075    """A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k)."""
2076
2077    @wraps(method)
2078    def new_method(self, *args, **kwargs):
2079        _, param_dict = parse_user_args(method, *args, **kwargs)
2080
2081        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2082        nreq_param_bool = ['shuffle', 'decode']
2083
2084        dataset_dir = param_dict.get('dataset_dir')
2085        annotation_file = param_dict.get('annotation_file')
2086        check_dir(dataset_dir)
2087        check_file(annotation_file)
2088
2089        validate_dataset_param_value(nreq_param_int, param_dict, int)
2090        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2091
2092        check_sampler_shuffle_shard_options(param_dict)
2093
2094        cache = param_dict.get('cache')
2095        check_cache_option(cache)
2096
2097        return method(self, *args, **kwargs)
2098
2099    return new_method
2100
2101
2102def check_food101_dataset(method):
2103    """A wrapper that wraps a parameter checker around the Food101Dataset."""
2104
2105    @wraps(method)
2106    def new_method(self, *args, **kwargs):
2107        _, param_dict = parse_user_args(method, *args, **kwargs)
2108
2109        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2110        nreq_param_bool = ['decode', 'shuffle']
2111
2112        dataset_dir = param_dict.get('dataset_dir')
2113        check_dir(dataset_dir)
2114
2115        usage = param_dict.get('usage')
2116        if usage is not None:
2117            check_valid_str(usage, ["train", "test", "all"], "usage")
2118
2119        validate_dataset_param_value(nreq_param_int, param_dict, int)
2120        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2121
2122        check_sampler_shuffle_shard_options(param_dict)
2123
2124        cache = param_dict.get('cache')
2125        check_cache_option(cache)
2126
2127        return method(self, *args, **kwargs)
2128
2129    return new_method
2130
2131
2132def check_sb_dataset(method):
2133    """A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
2134
2135    @wraps(method)
2136    def new_method(self, *args, **kwargs):
2137        _, param_dict = parse_user_args(method, *args, **kwargs)
2138
2139        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2140        nreq_param_bool = ['shuffle', 'decode']
2141
2142        dataset_dir = param_dict.get('dataset_dir')
2143        check_dir(dataset_dir)
2144
2145        usage = param_dict.get('usage')
2146        if usage is not None:
2147            check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage")
2148
2149        task = param_dict.get('task')
2150        if task is not None:
2151            check_valid_str(task, ["Boundaries", "Segmentation"], "task")
2152
2153        validate_dataset_param_value(nreq_param_int, param_dict, int)
2154        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2155
2156        check_sampler_shuffle_shard_options(param_dict)
2157
2158        return method(self, *args, **kwargs)
2159
2160    return new_method
2161
2162
2163def check_speech_commands_dataset(method):
2164    """A wrapper that wraps a parameter checker around the original Dataset(SpeechCommandsDataset)."""
2165
2166    @wraps(method)
2167    def new_method(self, *args, **kwargs):
2168        _, param_dict = parse_user_args(method, *args, **kwargs)
2169
2170        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2171        nreq_param_bool = ['shuffle']
2172
2173        dataset_dir = param_dict.get('dataset_dir')
2174        check_dir(dataset_dir)
2175
2176        usage = param_dict.get('usage')
2177        if usage is not None:
2178            check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
2179
2180        validate_dataset_param_value(nreq_param_int, param_dict, int)
2181        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2182
2183        check_sampler_shuffle_shard_options(param_dict)
2184
2185        cache = param_dict.get('cache')
2186        check_cache_option(cache)
2187
2188        return method(self, *args, **kwargs)
2189
2190    return new_method
2191
2192
2193def check_squad_dataset(method):
2194    """A wrapper that wraps a parameter checker around the original Dataset(SQuADDataset)."""
2195
2196    @wraps(method)
2197    def new_method(self, *args, **kwargs):
2198        _, param_dict = parse_user_args(method, *args, **kwargs)
2199
2200        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2201
2202        dataset_dir = param_dict.get('dataset_dir')
2203        check_dir(dataset_dir)
2204
2205        # check usage
2206        usage = param_dict.get('usage')
2207        if usage is not None:
2208            check_valid_str(usage, ['train', 'dev', 'all'], "usage")
2209
2210        validate_dataset_param_value(nreq_param_int, param_dict, int)
2211        check_sampler_shuffle_shard_options(param_dict)
2212
2213        cache = param_dict.get('cache')
2214        check_cache_option(cache)
2215
2216        return method(self, *args, **kwargs)
2217
2218    return new_method
2219
2220
2221def check_cityscapes_dataset(method):
2222    """A wrapper that wraps a parameter checker around the original CityScapesDataset."""
2223
2224    @wraps(method)
2225    def new_method(self, *args, **kwargs):
2226        _, param_dict = parse_user_args(method, *args, **kwargs)
2227
2228        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2229        nreq_param_bool = ['shuffle', 'decode']
2230
2231        dataset_dir = param_dict.get('dataset_dir')
2232        check_dir(dataset_dir)
2233
2234        task = param_dict.get('task')
2235        check_valid_str(task, ["instance", "semantic", "polygon", "color"], "task")
2236
2237        quality_mode = param_dict.get('quality_mode')
2238        check_valid_str(quality_mode, ["fine", "coarse"], "quality_mode")
2239
2240        usage = param_dict.get('usage')
2241        if quality_mode == "fine":
2242            valid_strings = ["train", "test", "val", "all"]
2243        else:
2244            valid_strings = ["train", "train_extra", "val", "all"]
2245        check_valid_str(usage, valid_strings, "usage")
2246
2247        validate_dataset_param_value(nreq_param_int, param_dict, int)
2248        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2249
2250        check_sampler_shuffle_shard_options(param_dict)
2251
2252        return method(self, *args, **kwargs)
2253
2254    return new_method
2255
2256
2257def check_div2k_dataset(method):
2258    """A wrapper that wraps a parameter checker around the original DIV2KDataset."""
2259
2260    @wraps(method)
2261    def new_method(self, *args, **kwargs):
2262        _, param_dict = parse_user_args(method, *args, **kwargs)
2263
2264        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2265        nreq_param_bool = ['shuffle', 'decode']
2266
2267        dataset_dir = param_dict.get('dataset_dir')
2268        check_dir(dataset_dir)
2269
2270        usage = param_dict.get('usage')
2271        check_valid_str(usage, ['train', 'valid', 'all'], "usage")
2272
2273        downgrade = param_dict.get('downgrade')
2274        check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade')
2275
2276        validate_dataset_param_value(['scale'], param_dict, int)
2277        scale = param_dict.get('scale')
2278        scale_values = [2, 3, 4, 8]
2279        if scale not in scale_values:
2280            raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values)))
2281
2282        if scale == 8 and downgrade != "bicubic":
2283            raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.")
2284
2285        downgrade_2018 = ["mild", "difficult", "wild"]
2286        if downgrade in downgrade_2018 and scale != 4:
2287            raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade))
2288
2289        validate_dataset_param_value(nreq_param_int, param_dict, int)
2290        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2291
2292        check_sampler_shuffle_shard_options(param_dict)
2293
2294        return method(self, *args, **kwargs)
2295
2296    return new_method
2297
2298
2299def check_fake_image_dataset(method):
2300    """A wrapper that wraps a parameter checker around the original Dataset(FakeImageDataset)."""
2301
2302    @wraps(method)
2303    def new_method(self, *args, **kwargs):
2304        _, param_dict = parse_user_args(method, *args, **kwargs)
2305
2306        nreq_param_int = ['num_images', 'num_classes', 'base_seed', 'num_samples',
2307                          'num_parallel_workers', 'num_shards', 'shard_id']
2308        nreq_param_bool = ['shuffle']
2309
2310        validate_dataset_param_value(nreq_param_int, param_dict, int)
2311        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2312
2313        num_images = param_dict.get("num_images")
2314        check_pos_int32(num_images, "num_images")
2315
2316        image_size = param_dict.get("image_size")
2317        type_check(image_size, (list, tuple), "image_size")
2318        if len(image_size) != 3:
2319            raise ValueError("image_size should be a list or tuple of length 3, but got {0}".format(len(image_size)))
2320        for i, value in enumerate(image_size):
2321            check_pos_int32(value, "image_size[{0}]".format(i))
2322
2323        num_classes = param_dict.get("num_classes")
2324        check_pos_int32(num_classes, "num_classes")
2325
2326        check_sampler_shuffle_shard_options(param_dict)
2327
2328        cache = param_dict.get('cache')
2329        check_cache_option(cache)
2330
2331        return method(self, *args, **kwargs)
2332
2333    return new_method
2334
2335
2336def check_ag_news_dataset(method):
2337    """A wrapper that wraps a parameter checker around the original Dataset(AGNewsDataset)."""
2338
2339    @wraps(method)
2340    def new_method(self, *args, **kwargs):
2341        _, param_dict = parse_user_args(method, *args, **kwargs)
2342
2343        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2344
2345        # check dataset_files; required argument
2346        dataset_dir = param_dict.get('dataset_dir')
2347        check_dir(dataset_dir)
2348
2349        # check usage
2350        usage = param_dict.get('usage')
2351        if usage is not None:
2352            check_valid_str(usage, ["train", "test", "all"], "usage")
2353
2354        validate_dataset_param_value(nreq_param_int, param_dict, int)
2355        check_sampler_shuffle_shard_options(param_dict)
2356
2357        cache = param_dict.get('cache')
2358        check_cache_option(cache)
2359
2360        return method(self, *args, **kwargs)
2361
2362    return new_method
2363
2364
2365def check_dbpedia_dataset(method):
2366    """A wrapper that wraps a parameter checker around the original DBpediaDataset."""
2367
2368    @wraps(method)
2369    def new_method(self, *args, **kwargs):
2370        _, param_dict = parse_user_args(method, *args, **kwargs)
2371
2372        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2373
2374        dataset_dir = param_dict.get('dataset_dir')
2375        check_dir(dataset_dir)
2376
2377        usage = param_dict.get('usage')
2378        if usage is not None:
2379            check_valid_str(usage, ["train", "test", "all"], "usage")
2380
2381        validate_dataset_param_value(nreq_param_int, param_dict, int)
2382
2383        check_sampler_shuffle_shard_options(param_dict)
2384
2385        cache = param_dict.get('cache')
2386        check_cache_option(cache)
2387
2388        return method(self, *args, **kwargs)
2389
2390    return new_method
2391
2392
2393def check_wider_face_dataset(method):
2394    """A wrapper that wraps a parameter checker around the WIDERFaceDataset."""
2395
2396    @wraps(method)
2397    def new_method(self, *args, **kwargs):
2398        _, param_dict = parse_user_args(method, *args, **kwargs)
2399
2400        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2401        nreq_param_bool = ['decode', 'shuffle']
2402
2403        dataset_dir = param_dict.get('dataset_dir')
2404        check_dir(dataset_dir)
2405
2406        usage = param_dict.get('usage')
2407        if usage is not None:
2408            check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
2409
2410        validate_dataset_param_value(nreq_param_int, param_dict, int)
2411        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2412
2413        check_sampler_shuffle_shard_options(param_dict)
2414
2415        cache = param_dict.get('cache')
2416        check_cache_option(cache)
2417
2418        return method(self, *args, **kwargs)
2419
2420    return new_method
2421
2422
2423def check_yelp_review_dataset(method):
2424    """A wrapper that wraps a parameter checker around the original Dataset(YelpReviewDataset)."""
2425
2426    @wraps(method)
2427    def new_method(self, *args, **kwargs):
2428        _, param_dict = parse_user_args(method, *args, **kwargs)
2429
2430        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2431
2432        dataset_dir = param_dict.get('dataset_dir')
2433        check_dir(dataset_dir)
2434
2435        # check usage
2436        usage = param_dict.get('usage')
2437        if usage is not None:
2438            check_valid_str(usage, ["train", "test", "all"], "usage")
2439
2440        validate_dataset_param_value(nreq_param_int, param_dict, int)
2441        check_sampler_shuffle_shard_options(param_dict)
2442
2443        cache = param_dict.get('cache')
2444        check_cache_option(cache)
2445
2446        return method(self, *args, **kwargs)
2447
2448    return new_method
2449
2450
2451def check_yes_no_dataset(method):
2452    """A wrapper that wraps a parameter checker around the original Dataset(YesNoDataset)."""
2453
2454    @wraps(method)
2455    def new_method(self, *args, **kwargs):
2456        _, param_dict = parse_user_args(method, *args, **kwargs)
2457
2458        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2459        nreq_param_bool = ['shuffle']
2460
2461        dataset_dir = param_dict.get('dataset_dir')
2462        check_dir(dataset_dir)
2463
2464        validate_dataset_param_value(nreq_param_int, param_dict, int)
2465        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2466
2467        check_sampler_shuffle_shard_options(param_dict)
2468
2469        cache = param_dict.get('cache')
2470        check_cache_option(cache)
2471
2472        return method(self, *args, **kwargs)
2473
2474    return new_method
2475
2476
2477def check_tedlium_dataset(method):
2478    """Wrapper method to check the parameters of TedliumDataset."""
2479
2480    @wraps(method)
2481    def new_method(self, *args, **kwargs):
2482        _, param_dict = parse_user_args(method, *args, **kwargs)
2483
2484        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2485        nreq_param_bool = ['shuffle']
2486
2487        release = param_dict.get('release')
2488        check_valid_str(release, ["release1", "release2", "release3"], "release")
2489
2490        dataset_dir = param_dict.get('dataset_dir')
2491        check_dir(dataset_dir)
2492
2493        usage = param_dict.get('usage')
2494        if usage is not None:
2495            if release in ["release1", "release2"]:
2496                check_valid_str(usage, ["train", "test", "dev", "all"], "usage")
2497            else:
2498                check_valid_str(usage, ["all"], "usage")
2499
2500        validate_dataset_param_value(nreq_param_int, param_dict, int)
2501        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2502
2503        check_sampler_shuffle_shard_options(param_dict)
2504
2505        cache = param_dict.get('cache')
2506        check_cache_option(cache)
2507
2508        return method(self, *args, **kwargs)
2509
2510    return new_method
2511
2512
2513def check_svhn_dataset(method):
2514    """A wrapper that wraps a parameter checker around the original Dataset(SVHNDataset)."""
2515
2516    @wraps(method)
2517    def new_method(self, *args, **kwargs):
2518        _, param_dict = parse_user_args(method, *args, **kwargs)
2519        dataset_dir = param_dict.get('dataset_dir')
2520        check_dir(dataset_dir)
2521
2522        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2523        nreq_param_bool = ['shuffle']
2524
2525        usage = param_dict.get('usage')
2526        if usage is not None:
2527            check_valid_str(usage, ["train", "test", "extra", "all"], "usage")
2528            if usage == "all":
2529                for _usage in ["train", "test", "extra"]:
2530                    check_file(os.path.join(dataset_dir, _usage + "_32x32.mat"))
2531            else:
2532                check_file(os.path.join(dataset_dir, usage + "_32x32.mat"))
2533
2534        validate_dataset_param_value(nreq_param_int, param_dict, int)
2535        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2536
2537        check_sampler_shuffle_shard_options(param_dict)
2538
2539        return method(self, *args, **kwargs)
2540
2541    return new_method
2542
2543
2544def check_sst2_dataset(method):
2545    """A wrapper that wraps a parameter checker around the original SST2 Dataset."""
2546
2547    @wraps(method)
2548    def new_method(self, *args, **kwargs):
2549        _, param_dict = parse_user_args(method, *args, **kwargs)
2550
2551        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2552
2553        dataset_dir = param_dict.get('dataset_dir')
2554        check_dir(dataset_dir)
2555
2556        usage = param_dict.get('usage')
2557        if usage is not None:
2558            check_valid_str(usage, ["train", "test", "dev"], "usage")
2559
2560        validate_dataset_param_value(nreq_param_int, param_dict, int)
2561
2562        check_sampler_shuffle_shard_options(param_dict)
2563
2564        cache = param_dict.get('cache')
2565        check_cache_option(cache)
2566
2567        return method(self, *args, **kwargs)
2568
2569    return new_method
2570
2571
2572def check_stl10_dataset(method):
2573    """A wrapper that wraps a parameter checker around the original Dataset(STL10Dataset)."""
2574
2575    @wraps(method)
2576    def new_method(self, *args, **kwargs):
2577        _, param_dict = parse_user_args(method, *args, **kwargs)
2578
2579        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2580        nreq_param_bool = ['shuffle']
2581
2582        dataset_dir = param_dict.get('dataset_dir')
2583        check_dir(dataset_dir)
2584
2585        usage = param_dict.get('usage')
2586        if usage is not None:
2587            check_valid_str(usage, ["train", "test", "unlabeled", "train+unlabeled", "all"], "usage")
2588            if usage == "all":
2589                for _usage in ["train", "test", "unlabeled"]:
2590                    check_file(os.path.join(dataset_dir, _usage + "_X.bin"))
2591                    if _usage == "unlabeled":
2592                        continue
2593                    else:
2594                        check_file(os.path.join(dataset_dir, _usage + "_y.bin"))
2595            elif usage == "train+unlabeled":
2596                check_file(os.path.join(dataset_dir, "train_X.bin"))
2597                check_file(os.path.join(dataset_dir, "train_y.bin"))
2598                check_file(os.path.join(dataset_dir, "unlabeled_X.bin"))
2599            elif usage == "unlabeled":
2600                check_file(os.path.join(dataset_dir, "unlabeled_X.bin"))
2601            else:
2602                check_file(os.path.join(dataset_dir, usage + "_X.bin"))
2603                check_file(os.path.join(dataset_dir, usage + "_y.bin"))
2604
2605        validate_dataset_param_value(nreq_param_int, param_dict, int)
2606        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2607
2608        check_sampler_shuffle_shard_options(param_dict)
2609
2610        cache = param_dict.get('cache')
2611        check_cache_option(cache)
2612
2613        return method(self, *args, **kwargs)
2614
2615    return new_method
2616
2617
2618def check_sun397_dataset(method):
2619    """A wrapper that wraps a parameter checker around the original Dataset(SUN397Dataset)."""
2620
2621    @wraps(method)
2622    def new_method(self, *args, **kwargs):
2623        _, param_dict = parse_user_args(method, *args, **kwargs)
2624
2625        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2626        nreq_param_bool = ['shuffle', 'decode']
2627
2628        dataset_dir = param_dict.get('dataset_dir')
2629        check_dir(dataset_dir)
2630
2631        validate_dataset_param_value(nreq_param_int, param_dict, int)
2632        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2633        check_sampler_shuffle_shard_options(param_dict)
2634
2635        cache = param_dict.get('cache')
2636        check_cache_option(cache)
2637
2638        return method(self, *args, **kwargs)
2639
2640    return new_method
2641
2642
2643def check_yahoo_answers_dataset(method):
2644    """A wrapper that wraps a parameter checker around the original YahooAnswers Dataset."""
2645
2646    @wraps(method)
2647    def new_method(self, *args, **kwargs):
2648        _, param_dict = parse_user_args(method, *args, **kwargs)
2649
2650        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2651
2652        dataset_dir = param_dict.get('dataset_dir')
2653        check_dir(dataset_dir)
2654
2655        usage = param_dict.get('usage')
2656        if usage is not None:
2657            check_valid_str(usage, ["train", "test", "all"], "usage")
2658
2659        validate_dataset_param_value(nreq_param_int, param_dict, int)
2660
2661        check_sampler_shuffle_shard_options(param_dict)
2662
2663        cache = param_dict.get('cache')
2664        check_cache_option(cache)
2665
2666        return method(self, *args, **kwargs)
2667
2668    return new_method
2669
2670
2671def check_conll2000_dataset(method):
2672    """ A wrapper that wraps a parameter checker around the original Dataset(CoNLL2000Dataset)."""
2673
2674    @wraps(method)
2675    def new_method(self, *args, **kwargs):
2676        _, param_dict = parse_user_args(method, *args, **kwargs)
2677
2678        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2679
2680        # check dataset_dir
2681        dataset_dir = param_dict.get('dataset_dir')
2682        check_dir(dataset_dir)
2683
2684        # check usage
2685        usage = param_dict.get('usage')
2686        if usage is not None:
2687            check_valid_str(usage, ["train", "test", "all"], "usage")
2688
2689        validate_dataset_param_value(nreq_param_int, param_dict, int)
2690        check_sampler_shuffle_shard_options(param_dict)
2691
2692        cache = param_dict.get('cache')
2693        check_cache_option(cache)
2694
2695        return method(self, *args, **kwargs)
2696
2697    return new_method
2698
2699
2700def check_amazon_review_dataset(method):
2701    """A wrapper that wraps a parameter checker around the original Dataset(AmazonReviewDataset)."""
2702
2703    @wraps(method)
2704    def new_method(self, *args, **kwargs):
2705        _, param_dict = parse_user_args(method, *args, **kwargs)
2706
2707        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2708
2709        # check dataset_files
2710        dataset_dir = param_dict.get('dataset_dir')
2711        check_dir(dataset_dir)
2712
2713        # check usage
2714        usage = param_dict.get('usage')
2715        if usage is not None:
2716            check_valid_str(usage, ["train", "test", "all"], "usage")
2717
2718        validate_dataset_param_value(nreq_param_int, param_dict, int)
2719        check_sampler_shuffle_shard_options(param_dict)
2720
2721        cache = param_dict.get('cache')
2722        check_cache_option(cache)
2723
2724        return method(self, *args, **kwargs)
2725
2726    return new_method
2727
2728
2729def check_semeion_dataset(method):
2730    """Wrapper method to check the parameters of SemeionDataset."""
2731
2732    @wraps(method)
2733    def new_method(self, *args, **kwargs):
2734        _, param_dict = parse_user_args(method, *args, **kwargs)
2735
2736        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2737        nreq_param_bool = ['shuffle']
2738
2739        dataset_dir = param_dict.get('dataset_dir')
2740        check_dir(dataset_dir)
2741
2742        validate_dataset_param_value(nreq_param_int, param_dict, int)
2743        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2744
2745        check_sampler_shuffle_shard_options(param_dict)
2746
2747        cache = param_dict.get('cache')
2748        check_cache_option(cache)
2749
2750        return method(self, *args, **kwargs)
2751
2752    return new_method
2753
2754
2755def check_wiki_text_dataset(method):
2756    """A wrapper that wraps a parameter checker around the original Dataset(WikiTextDataset)."""
2757
2758    @wraps(method)
2759    def new_method(self, *args, **kwargs):
2760        _, param_dict = parse_user_args(method, *args, **kwargs)
2761
2762        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2763
2764        # check dataset_dir
2765        dataset_dir = param_dict.get('dataset_dir')
2766        check_dir(dataset_dir)
2767
2768        # check usage
2769        usage = param_dict.get('usage')
2770        if usage is not None:
2771            check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
2772
2773        validate_dataset_param_value(nreq_param_int, param_dict, int)
2774        check_sampler_shuffle_shard_options(param_dict)
2775
2776        cache = param_dict.get('cache')
2777        check_cache_option(cache)
2778
2779        return method(self, *args, **kwargs)
2780
2781    return new_method
2782
2783
2784def check_en_wik9_dataset(method):
2785    """Wrapper method to check the parameters of EnWik9 dataset."""
2786
2787    @wraps(method)
2788    def new_method(self, *args, **kwargs):
2789        _, param_dict = parse_user_args(method, *args, **kwargs)
2790
2791        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2792        dataset_dir = param_dict.get('dataset_dir')
2793        check_dir(dataset_dir)
2794
2795        validate_dataset_param_value(nreq_param_int, param_dict, int)
2796        check_sampler_shuffle_shard_options(param_dict)
2797
2798        cache = param_dict.get('cache')
2799        check_cache_option(cache)
2800
2801        return method(self, *args, **kwargs)
2802
2803    return new_method
2804
2805
2806def check_multi30k_dataset(method):
2807    """A wrapper that wraps a parameter checker around the original Dataset (Multi30kDataset)."""
2808
2809    @wraps(method)
2810    def new_method(self, *args, **kwargs):
2811        _, param_dict = parse_user_args(method, *args, **kwargs)
2812
2813        nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
2814        nreq_param_bool = ['shuffle', 'decode']
2815
2816        dataset_dir = param_dict.get('dataset_dir')
2817        check_dir(dataset_dir)
2818
2819        usage = param_dict.get('usage')
2820        if usage is not None:
2821            check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
2822
2823        language_pair = param_dict.get('language_pair')
2824        support_language_pair = [['en', 'de'], ['de', 'en'], ('en', 'de'), ('de', 'en')]
2825        if language_pair is not None:
2826            type_check(language_pair, (list, tuple), "language_pair")
2827            if len(language_pair) != 2:
2828                raise ValueError(
2829                    "language_pair should be a list or tuple of length 2, but got {0}".format(len(language_pair)))
2830            if language_pair not in support_language_pair:
2831                raise ValueError(
2832                    "language_pair can only be ['en', 'de'] or ['en', 'de'], but got {0}".format(language_pair))
2833
2834        validate_dataset_param_value(nreq_param_int, param_dict, int)
2835        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2836
2837        check_sampler_shuffle_shard_options(param_dict)
2838
2839        return method(self, *args, **kwargs)
2840
2841    return new_method
2842
2843
2844def check_obsminddataset(method):
2845    """A wrapper that wraps a parameter checker around the original Dataset(OBSMindDataset)."""
2846
2847    @wraps(method)
2848    def new_method(self, *args, **kwargs):
2849        _, param_dict = parse_user_args(method, *args, **kwargs)
2850
2851        nreq_param_int = ['num_shards', 'shard_id']
2852        nreq_param_list = ['columns_list']
2853        nreq_param_bool = ['shard_equal_rows']
2854        nreq_param_str = ['server', 'ak', 'sk', 'sync_obs_path']
2855
2856        dataset_files = param_dict.get('dataset_files')
2857        type_check(dataset_files, (list,), "dataset_files")
2858        for dataset_file in dataset_files:
2859            if not isinstance(dataset_file, str):
2860                raise TypeError("Item of dataset files is not of type [{}], but got {}.".format(type(''),
2861                                                                                                type(dataset_file)))
2862        validate_dataset_param_value(nreq_param_int, param_dict, int)
2863        validate_dataset_param_value(nreq_param_list, param_dict, list)
2864        validate_dataset_param_value(nreq_param_bool, param_dict, bool)
2865        validate_dataset_param_value(nreq_param_str, param_dict, str)
2866
2867        server = param_dict.get('server')
2868        if not server.startswith(('http://', 'https://')):
2869            raise ValueError("server should be a str that starts with http:// or https://, but got {}.".format(server))
2870
2871        check_sampler_shuffle_shard_options(param_dict)
2872
2873        return method(self, *args, **kwargs)
2874
2875    return new_method
2876