• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16General Validators.
17"""
18from __future__ import absolute_import
19
20import inspect
21from multiprocessing import cpu_count
22import os
23
24import mindspore._c_dataengine as cde
25from mindspore import log as logger
26
27# POS_INT_MIN is used to limit values from starting from 0
28POS_INT_MIN = 1
29UINT8_MAX = 255
30UINT8_MIN = 0
31UINT32_MAX = 4294967295
32UINT32_MIN = 0
33UINT64_MAX = 18446744073709551615
34UINT64_MIN = 0
35INT32_MAX = 2147483647
36INT32_MIN = -2147483648
37INT64_MAX = 9223372036854775807
38INT64_MIN = -9223372036854775808
39FLOAT_MAX_INTEGER = 16777216
40FLOAT_MIN_INTEGER = -16777216
41DOUBLE_MAX_INTEGER = 9007199254740992
42DOUBLE_MIN_INTEGER = -9007199254740992
43
44valid_detype = [
45    "bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
46    "uint32", "uint64", "float16", "float32", "float64", "string"
47]
48
49
50def is_iterable(obj):
51    """
52    Helper function to check if object is iterable.
53
54    Args:
55        obj (any): object to check if iterable
56
57    Returns:
58        bool, true if object iterable
59    """
60    try:
61        iter(obj)
62    except TypeError:
63        return False
64    return True
65
66
67def pad_arg_name(arg_name):
68    """
69    Appends a space to the arg_name (if not empty)
70
71    :param arg_name: the input string
72    :return: the padded string
73    """
74    if arg_name != "":
75        arg_name = arg_name + " "
76    return arg_name
77
78
79def check_value(value, valid_range, arg_name="", left_open_interval=False, right_open_interval=False):
80    """
81    Validates a value is within a desired range with left and right interval open or close.
82
83    :param value: the value to be validated.
84    :param valid_range: the desired range.
85    :param arg_name: name of the variable to be validated.
86    :param left_open_interval: True for left interval open and False for close.
87    :param right_open_interval: True for right interval open and False for close.
88    :return: Exception: when the validation fails, nothing otherwise.
89    """
90    arg_name = pad_arg_name(arg_name)
91    if not left_open_interval and not right_open_interval:
92        if value < valid_range[0] or value > valid_range[1]:
93            raise ValueError(
94                "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0],
95                                                                                     valid_range[1]))
96    elif left_open_interval and not right_open_interval:
97        if value <= valid_range[0] or value > valid_range[1]:
98            raise ValueError(
99                "Input {0}is not within the required interval of ({1}, {2}].".format(arg_name, valid_range[0],
100                                                                                     valid_range[1]))
101    elif not left_open_interval and right_open_interval:
102        if value < valid_range[0] or value >= valid_range[1]:
103            raise ValueError(
104                "Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0],
105                                                                                     valid_range[1]))
106    else:
107        if value <= valid_range[0] or value >= valid_range[1]:
108            raise ValueError(
109                "Input {0}is not within the required interval of ({1}, {2}).".format(arg_name, valid_range[0],
110                                                                                     valid_range[1]))
111
112
113def check_value_cutoff(value, valid_range, arg_name=""):
114    """
115    Validates a value is within a desired range [inclusive, exclusive).
116
117    :param value: the value to be validated
118    :param valid_range: the desired range
119    :param arg_name: arg_name: arg_name: name of the variable to be validated
120    :return: Exception: when the validation fails, nothing otherwise.
121    """
122    check_value(value, valid_range, arg_name, False, True)
123
124
125def check_value_ratio(value, valid_range, arg_name=""):
126    """
127    Validates a value is within a desired range (exclusive, inclusive].
128
129    :param value: the value to be validated
130    :param valid_range: the desired range
131    :param arg_name: arg_name: name of the variable to be validated
132    :return: Exception: when the validation fails, nothing otherwise.
133    """
134    check_value(value, valid_range, arg_name, True, False)
135
136
137def check_value_normalize_std(value, valid_range, arg_name=""):
138    """
139    Validates a value is within a desired range (exclusive, inclusive].
140
141    :param value: the value to be validated
142    :param valid_range: the desired range
143    :param arg_name: arg_name: name of the variable to be validated
144    :return: Exception: when the validation fails, nothing otherwise.
145    """
146    check_value(value, valid_range, arg_name, True, False)
147
148
149def check_range(values, valid_range, arg_name=""):
150    """
151    Validates the boundaries a range are within a desired range [inclusive, inclusive].
152
153    :param values: the two values to be validated
154    :param valid_range: the desired range
155    :param arg_name: arg_name: name of the variable to be validated
156    :return: Exception: when the validation fails, nothing otherwise.
157    """
158    arg_name = pad_arg_name(arg_name)
159    if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
160        raise ValueError(
161            "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0],
162                                                                                 valid_range[1]))
163
164
165def check_positive(value, arg_name=""):
166    """
167    Validates the value of a variable is positive.
168
169    :param value: the value of the variable
170    :param arg_name: name of the variable to be validated
171    :return: Exception: when the validation fails, nothing otherwise.
172    """
173    arg_name = pad_arg_name(arg_name)
174    if value <= 0:
175        raise ValueError("Input {0}must be greater than 0.".format(arg_name))
176
177
178def check_int32_not_zero(value, arg_name=""):
179    arg_name = pad_arg_name(arg_name)
180    type_check(value, (int,), arg_name)
181    if value < INT32_MIN or value > INT32_MAX or value == 0:
182        raise ValueError(
183            "Input {0}is not within the required interval of [-2147483648, 0) and (0, 2147483647].".format(arg_name))
184
185
186def check_odd(value, arg_name=""):
187    arg_name = pad_arg_name(arg_name)
188    if value % 2 != 1:
189        raise ValueError(
190            "Input {0}is not an odd value.".format(arg_name))
191
192
193def check_2tuple(value, arg_name=""):
194    """
195    Validates a variable is a tuple with two entries.
196
197    :param value: the value of the variable
198    :param arg_name: name of the variable to be validated
199    :return: Exception: when the validation fails, nothing otherwise.
200    """
201    if not (isinstance(value, tuple) and len(value) == 2):
202        raise ValueError("Value {0} needs to be a 2-tuple.".format(arg_name))
203
204
205def check_int32(value, arg_name=""):
206    """
207    Validates the value of a variable is within the range of int32.
208
209    :param value: the value of the variable
210    :param arg_name: name of the variable to be validated
211    :return: Exception: when the validation fails, nothing otherwise.
212    """
213    type_check(value, (int,), arg_name)
214    check_value(value, [INT32_MIN, INT32_MAX], arg_name)
215
216
217def check_uint8(value, arg_name=""):
218    """
219    Validates the value of a variable is within the range of uint8.
220
221    :param value: the value of the variable
222    :param arg_name: name of the variable to be validated
223    :return: Exception: when the validation fails, nothing otherwise.
224    """
225    type_check(value, (int,), arg_name)
226    check_value(value, [UINT8_MIN, UINT8_MAX], arg_name)
227
228
229def check_uint32(value, arg_name=""):
230    """
231    Validates the value of a variable is within the range of uint32.
232
233    :param value: the value of the variable
234    :param arg_name: name of the variable to be validated
235    :return: Exception: when the validation fails, nothing otherwise.
236    """
237    type_check(value, (int,), arg_name)
238    check_value(value, [UINT32_MIN, UINT32_MAX], arg_name)
239
240
241def check_pos_uint32(value, arg_name=""):
242    """
243    Validates the value of a variable is within the range of positive uint32.
244
245    :param value: the value of the variable
246    :param arg_name: name of the variable to be validated
247    :return: Exception: when the validation fails, nothing otherwise.
248    """
249    type_check(value, (int,), arg_name)
250    check_value(value, [POS_INT_MIN, UINT32_MAX], arg_name)
251
252
253def check_pos_int32(value, arg_name=""):
254    """
255    Validates the value of a variable is within the range of int32.
256
257    :param value: the value of the variable
258    :param arg_name: name of the variable to be validated
259    :return: Exception: when the validation fails, nothing otherwise.
260    """
261    type_check(value, (int,), arg_name)
262    check_value(value, [POS_INT_MIN, INT32_MAX], arg_name)
263
264
265def check_uint64(value, arg_name=""):
266    """
267    Validates the value of a variable is within the range of uint64.
268
269    :param value: the value of the variable
270    :param arg_name: name of the variable to be validated
271    :return: Exception: when the validation fails, nothing otherwise.
272    """
273    type_check(value, (int,), arg_name)
274    check_value(value, [UINT64_MIN, UINT64_MAX], arg_name)
275
276
277def check_pos_int64(value, arg_name=""):
278    """
279    Validates the value of a variable is within the range of int64.
280
281    :param value: the value of the variable
282    :param arg_name: name of the variable to be validated
283    :return: Exception: when the validation fails, nothing otherwise.
284    """
285    type_check(value, (int,), arg_name)
286    check_value(value, [POS_INT_MIN, INT64_MAX], arg_name)
287
288
289def check_non_negative_int32(value, arg_name=""):
290    """
291    Validates the value of a variable is within the range of non negative int32.
292
293    :param value: the value of the variable.
294    :param arg_name: name of the variable to be validated.
295    :return: Exception: when the validation fails, nothing otherwise.
296    """
297    type_check(value, (int,), arg_name)
298    check_value(value, [0, INT32_MAX], arg_name)
299
300
301def check_float32(value, arg_name=""):
302    """
303    Validates the value of a variable is within the range of float32.
304
305    :param value: the value of the variable
306    :param arg_name: name of the variable to be validated
307    :return: Exception: when the validation fails, nothing otherwise.
308    """
309    type_check(value, (float, int), arg_name)
310    check_value(value, [FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER], arg_name)
311
312
313def check_float64(value, arg_name=""):
314    """
315    Validates the value of a variable is within the range of float64.
316
317    :param value: the value of the variable
318    :param arg_name: name of the variable to be validated
319    :return: Exception: when the validation fails, nothing otherwise.
320    """
321    type_check(value, (float, int), arg_name)
322    check_value(value, [DOUBLE_MIN_INTEGER, DOUBLE_MAX_INTEGER], arg_name)
323
324
325def check_pos_float32(value, arg_name=""):
326    """
327    Validates the value of a variable is within the range of positive float32.
328
329    :param value: the value of the variable
330    :param arg_name: name of the variable to be validated
331    :return: Exception: when the validation fails, nothing otherwise.
332    """
333    type_check(value, (float, int), arg_name)
334    check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name, True)
335
336
337def check_pos_float64(value, arg_name=""):
338    """
339    Validates the value of a variable is within the range of positive float64.
340
341    :param value: the value of the variable
342    :param arg_name: name of the variable to be validated
343    :return: Exception: when the validation fails, nothing otherwise.
344    """
345    type_check(value, (float, int), arg_name)
346    check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name, True)
347
348
349def check_non_negative_float32(value, arg_name=""):
350    """
351    Validates the value of a variable is within the range of non negative float32.
352
353    :param value: the value of the variable
354    :param arg_name: name of the variable to be validated
355    :return: Exception: when the validation fails, nothing otherwise.
356    """
357    type_check(value, (float, int), arg_name)
358    check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
359
360
361def check_non_negative_float64(value, arg_name=""):
362    """
363    Validates the value of a variable is within the range of non negative float64.
364
365    :param value: the value of the variable
366    :param arg_name: name of the variable to be validated
367    :return: Exception: when the validation fails, nothing otherwise.
368    """
369    type_check(value, (float, int), arg_name)
370    check_value(value, [UINT32_MIN, DOUBLE_MAX_INTEGER], arg_name)
371
372
373def check_float32_not_zero(value, arg_name=""):
374    arg_name = pad_arg_name(arg_name)
375    type_check(value, (float, int), arg_name)
376    if value < FLOAT_MIN_INTEGER or value > FLOAT_MAX_INTEGER or value == 0:
377        raise ValueError(
378            "Input {0}is not within the required interval of [-16777216, 0) and (0, 16777216].".format(arg_name))
379
380
381def check_valid_detype(type_):
382    """
383    Validates if a type is a DE Type.
384
385    :param type_: the type_ to be validated
386    :return: Exception: when the type is not a DE type, True otherwise.
387    """
388    if type_ not in valid_detype:
389        raise TypeError("Unknown column type.")
390    return True
391
392
393def check_valid_str(value, valid_strings, arg_name=""):
394    """
395    Validates the content stored in a string.
396
397    :param value: the value to be validated
398    :param valid_strings: a list/set of valid strings
399    :param arg_name: name of the variable to be validated
400    :return: Exception: when the type is not a DE type, nothing otherwise.
401    """
402    type_check(value, (str,), arg_name)
403    if value not in valid_strings:
404        raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings)))
405
406
407def check_valid_list_tuple(value, valid_list_tuple, data_type, arg_name=""):
408    """
409    Validate value in valid_list_tuple.
410
411    Args:
412        value (Union[list, tuple]): the value to be validated.
413        valid_strings (Union[list, tuple]): name of columns.
414        type (tuple): tuple of all valid types for value.
415        arg_name (str): the names of value.
416    Returns:
417        Exception: when the value is not correct, otherwise nothing.
418    """
419    valid_length = len(valid_list_tuple[0])
420    type_check(value, (list, tuple), arg_name)
421    type_check_list(value, data_type, arg_name)
422    if len(value) != valid_length:
423        raise ValueError("Input {0} is a list or tuple of length {1}.".format(arg_name, valid_length))
424    if value not in valid_list_tuple:
425        raise ValueError(
426            "Input {0}{1} is not within the valid set of {2}.".format(arg_name, value, valid_list_tuple))
427
428
429def check_columns(columns, name):
430    """
431    Validate strings in column_names.
432
433    Args:
434        columns (list): list of column_names.
435        name (str): name of columns.
436
437    Returns:
438        Exception: when the value is not correct, otherwise nothing.
439    """
440    type_check(columns, (list, str), name)
441    if isinstance(columns, str):
442        if not columns:
443            raise ValueError("{0} should not be an empty str.".format(name))
444    elif isinstance(columns, list):
445        if not columns:
446            raise ValueError("{0} should not be empty.".format(name))
447        for i, column_name in enumerate(columns):
448            if not column_name:
449                raise ValueError("{0}[{1}] should not be empty.".format(name, i))
450
451        col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
452        type_check_list(columns, (str,), col_names)
453        if len(set(columns)) != len(columns):
454            raise ValueError("Every column name should not be same with others in column_names.")
455
456
457def parse_user_args(method, *args, **kwargs):
458    """
459    Parse user arguments in a function.
460
461    Args:
462        method (method): a callable function.
463        args: user passed args.
464        kwargs: user passed kwargs.
465
466    Returns:
467        user_filled_args (list): values of what the user passed in for the arguments.
468        ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
469    """
470    sig = inspect.signature(method)
471    if 'self' in sig.parameters or 'cls' in sig.parameters:
472        ba = sig.bind(method, *args, **kwargs)
473        ba.apply_defaults()
474        params = list(sig.parameters.keys())[1:]
475    else:
476        ba = sig.bind(*args, **kwargs)
477        ba.apply_defaults()
478        params = list(sig.parameters.keys())
479
480    user_filled_args = [ba.arguments.get(arg_value) for arg_value in params]
481    return user_filled_args, ba.arguments
482
483
484def type_check_list(args, types, arg_names):
485    """
486    Check the type of each parameter in the list.
487
488    Args:
489        args (Union[list, tuple]): a list or tuple of any variable.
490        types (tuple): tuple of all valid types for arg.
491        arg_names (Union[list, tuple of str]): the names of args.
492
493    Returns:
494        Exception: when the type is not correct, otherwise nothing.
495    """
496    type_check(args, (list, tuple,), arg_names)
497    if len(args) != len(arg_names) and not isinstance(arg_names, str):
498        raise ValueError("List of arguments is not the same length as argument_names.")
499    if isinstance(arg_names, str):
500        arg_names = ["{0}[{1}]".format(arg_names, i) for i in range(len(args))]
501    for arg, arg_name in zip(args, arg_names):
502        type_check(arg, types, arg_name)
503
504
505def type_check(arg, types, arg_name):
506    """
507    Check the type of the parameter.
508
509    Args:
510        arg (Any) : any variable.
511        types (tuple): tuple of all valid types for arg.
512        arg_name (str): the name of arg.
513
514    Returns:
515        Exception: when the validation fails, otherwise nothing.
516    """
517
518    if int in types and bool not in types:
519        if isinstance(arg, bool):
520            # handle special case of booleans being a subclass of ints
521            print_value = '\"\"' if repr(arg) == repr('') else arg
522            raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value,
523                                                                                                  types, type(arg)))
524    if not isinstance(arg, types):
525        print_value = '\"\"' if repr(arg) == repr('') else arg
526        raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value,
527                                                                                              list(types), type(arg)))
528
529
530def check_filename(path):
531    """
532    check the filename in the path.
533
534    Args:
535        path (str): the path.
536
537    Returns:
538        Exception: when error.
539    """
540    if not isinstance(path, str):
541        raise TypeError("path: {} is not string".format(path))
542    filename = os.path.basename(os.path.realpath(path))
543    forbidden_symbols = set(r'\/:*?"<>|`&\';')
544
545    if set(filename) & forbidden_symbols:
546        raise ValueError(r"filename should not contain \/:*?\"<>|`&;\'")
547
548    if filename.startswith(' ') or filename.endswith(' '):
549        raise ValueError("filename should not start/end with space.")
550
551
552def check_dir(dataset_dir):
553    """
554    Validates if the argument is a directory.
555
556    :param dataset_dir: string containing directory path
557    :return: Exception: when the validation fails, nothing otherwise.
558    """
559    type_check(dataset_dir, (str,), "dataset_dir")
560    if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
561        raise ValueError("The folder {} does not exist or is not a directory or permission denied!".format(dataset_dir))
562
563
564def check_list_same_size(list1, list2, list1_name="", list2_name=""):
565    """
566    Validates the two lists as the same size.
567
568    :param list1: the first list to be validated
569    :param list2: the secend list to be validated
570    :param list1_name: name of the list1
571    :param list2_name: name of the list2
572    :return: Exception: when the two list no same size, nothing otherwise.
573    """
574    if len(list1) != len(list2):
575        raise ValueError("The size of {0} should be the same as that of {1}.".format(list1_name, list2_name))
576
577
578def check_file(dataset_file):
579    """
580    Validates if the argument is a valid file name.
581
582    :param dataset_file: string containing file path
583    :return: Exception: when the validation fails, nothing otherwise.
584    """
585    check_filename(dataset_file)
586    dataset_file = os.path.realpath(dataset_file)
587    if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
588        raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
589
590
591def check_sampler_shuffle_shard_options(param_dict):
592    """
593    Check for valid shuffle, sampler, num_shards, and shard_id inputs.
594    Args:
595        param_dict (dict): param_dict.
596
597    Returns:
598        Exception: ValueError or RuntimeError if error.
599    """
600    shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
601    num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
602    num_samples = param_dict.get('num_samples')
603
604    if sampler is not None:
605        if shuffle is not None:
606            raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
607        if num_shards is not None or shard_id is not None:
608            raise RuntimeError("sampler and sharding cannot be specified at the same time.")
609        if num_samples is not None:
610            raise RuntimeError("sampler and num_samples cannot be specified at the same time.")
611
612    if num_shards is not None:
613        check_pos_int32(num_shards, "num_shards")
614        if shard_id is None:
615            raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
616        check_value(shard_id, [0, num_shards - 1], "shard_id")
617
618    if num_shards is None and shard_id is not None:
619        raise RuntimeError("shard_id is specified but num_shards is not.")
620
621
622def check_padding_options(param_dict):
623    """
624    Check for valid padded_sample and num_padded of padded samples.
625
626    Args:
627        param_dict (dict): param_dict.
628
629    Returns:
630        Exception: ValueError or RuntimeError if error.
631    """
632
633    columns_list = param_dict.get('columns_list')
634    padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
635    if padded_sample is not None:
636        if num_padded is None:
637            raise RuntimeError("padded_sample is specified and requires num_padded as well.")
638        if num_padded < 0:
639            raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
640        if columns_list is None:
641            raise RuntimeError("padded_sample is specified and requires columns_list as well.")
642        for column in columns_list:
643            if column not in padded_sample:
644                raise ValueError("padded_sample cannot match columns_list.")
645    if padded_sample is None and num_padded is not None:
646        raise RuntimeError("num_padded is specified but padded_sample is not.")
647
648
649def check_num_parallel_workers(value):
650    """
651    Validates the value for num_parallel_workers.
652.
653    :param value: an integer corresponding to the number of parallel workers
654    :return: Exception: when the validation fails, nothing otherwise.
655    """
656    type_check(value, (int,), "num_parallel_workers")
657    if value < 1 or value > cpu_count():
658        raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
659
660
661def check_num_samples(value):
662    """
663    Validates number of samples are valid.
664.
665    :param value: an integer corresponding to the number of samples.
666    :return: Exception: when the validation fails, nothing otherwise.
667    """
668    type_check(value, (int,), "num_samples")
669    if value < 0 or value > INT64_MAX:
670        raise ValueError(
671            "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))
672
673
674def validate_dataset_param_value(param_list, param_dict, param_type):
675    """
676
677    :param param_list: a list of parameter names.
678    :param param_dict: a dcitionary containing parameter names and their values.
679    :param param_type: a tuple containing type of parameters.
680    :return: Exception: when the validation fails, nothing otherwise.
681    """
682    for param_name in param_list:
683        if param_dict.get(param_name) is not None:
684            if param_name == 'num_parallel_workers':
685                check_num_parallel_workers(param_dict.get(param_name))
686            if param_name == 'num_samples':
687                check_num_samples(param_dict.get(param_name))
688            else:
689                type_check(param_dict.get(param_name), (param_type,), param_name)
690
691
692def check_tensor_op(param, param_name):
693    """check whether param is a tensor op or a callable Python function"""
694    if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
695        raise TypeError("{0} is neither a transforms op (TensorOperation) nor a callable pyfunc.".format(param_name))
696
697
698def check_c_tensor_op(param, param_name):
699    """check whether param is a tensor op or a callable Python function but not a py_transform"""
700    if callable(param) and str(param).find("py_transform") >= 0:
701        raise TypeError("{0} is a py_transform op which is not allowed to use.".format(param_name))
702    if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
703        raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
704
705
706def replace_none(value, default):
707    """ replaces None with a default value."""
708    return value if value is not None else default
709
710
711def check_dataset_num_shards_shard_id(num_shards, shard_id):
712    if (num_shards is None) != (shard_id is None):
713        # These two parameters appear together.
714        raise ValueError("num_shards and shard_id need to be passed in together.")
715    if num_shards is not None:
716        check_pos_int32(num_shards, "num_shards")
717        if shard_id >= num_shards:
718            raise ValueError("shard_id should be less than num_shards.")
719
720
721def deprecator_factory(version, old_module, new_module, substitute_name=None, substitute_module=None):
722    """Decorator factory function for deprecated operation to log deprecation warning message.
723
724    Args:
725        version (str): Version that the operation is deprecated.
726        old_module (str): Old module for deprecated operation.
727        new_module (str): New module for deprecated operation.
728        substitute_name (str, optional): The substitute name for deprecated operation.
729        substitute_module (str, optional): The substitute module for deprecated operation.
730    """
731
732    def decorator(op):
733        def wrapper(*args, **kwargs):
734            # Get operation class name for operation class which applies decorator to __init__()
735            name = str(op).split()[1].split(".")[0]
736            # Build message
737            message = f"'{name}' from " + f"{old_module}" + f" is deprecated from version " f"{version}" + \
738                      " and will be removed in a future version."
739            message += f" Use '{substitute_name}'" if substitute_name else f" Use '{name}'"
740            message += f" from {substitute_module} instead." if substitute_module \
741                else f" from " f"{new_module}" + " instead."
742
743            # Log warning message
744            logger.warning(message)
745
746            ret = op(*args, **kwargs)
747            return ret
748
749        return wrapper
750
751    return decorator
752
753
754def check_dict(data, key_type, value_type, param_name):
755    """ check key and value type in dict."""
756    if data is not None:
757        if not isinstance(data, dict):
758            raise TypeError("{0} should be dict type, but got: {1}".format(param_name, type(data)))
759
760        for key, value in data.items():
761            if not isinstance(key, key_type):
762                raise TypeError("key '{0}' in parameter {1} should be {2} type, but got: {3}"
763                                .format(key, param_name, key_type, type(key)))
764            if not isinstance(value, value_type):
765                raise TypeError("value of '{0}' in parameter {1} should be {2} type, but got: {3}"
766                                .format(key, param_name, value_type, type(value)))
767
768
769def check_feature_shape(data, shape, param_name):
770    if isinstance(data, dict):
771        for key, value in data.items():
772            if len(value.shape) != 2 or value.shape[0] != shape:
773                raise ValueError("Shape of '{0}' in '{1}' should be of 2 dimension, and shape of first dimension "
774                                 "should be: {2}, but got: {3}.".format(key, param_name, shape, value.shape))
775