• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16General Validators.
17"""
18import inspect
19from multiprocessing import cpu_count
20import os
21import numpy as np
22
23import mindspore._c_dataengine as cde
24
25# POS_INT_MIN is used to limit values from starting from 0
26POS_INT_MIN = 1
27UINT8_MAX = 255
28UINT8_MIN = 0
29UINT32_MAX = 4294967295
30UINT32_MIN = 0
31UINT64_MAX = 18446744073709551615
32UINT64_MIN = 0
33INT32_MAX = 2147483647
34INT32_MIN = -2147483648
35INT64_MAX = 9223372036854775807
36INT64_MIN = -9223372036854775808
37FLOAT_MAX_INTEGER = 16777216
38FLOAT_MIN_INTEGER = -16777216
39DOUBLE_MAX_INTEGER = 9007199254740992
40DOUBLE_MIN_INTEGER = -9007199254740992
41
42valid_detype = [
43    "bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
44    "uint32", "uint64", "float16", "float32", "float64", "string"
45]
46
47
48def is_iterable(obj):
49    """
50    Helper function to check if object is iterable.
51
52    Args:
53        obj (any): object to check if iterable
54
55    Returns:
56        bool, true if object iteratable
57    """
58    try:
59        iter(obj)
60    except TypeError:
61        return False
62    return True
63
64
65def pad_arg_name(arg_name):
66    """
67    Appends a space to the arg_name (if not empty)
68
69    :param arg_name: the input string
70    :return: the padded string
71    """
72    if arg_name != "":
73        arg_name = arg_name + " "
74    return arg_name
75
76
77def check_value(value, valid_range, arg_name="", left_open_interval=False, right_open_interval=False):
78    """
79    Validates a value is within a desired range with left and right interval open or close.
80
81    :param value: the value to be validated.
82    :param valid_range: the desired range.
83    :param arg_name: name of the variable to be validated.
84    :param left_open_interval: True for left interval open and False for close.
85    :param right_open_interval: True for right interval open and False for close.
86    :return: Exception: when the validation fails, nothing otherwise.
87    """
88    arg_name = pad_arg_name(arg_name)
89    if not left_open_interval and not right_open_interval:
90        if value < valid_range[0] or value > valid_range[1]:
91            raise ValueError(
92                "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0],
93                                                                                     valid_range[1]))
94    elif left_open_interval and not right_open_interval:
95        if value <= valid_range[0] or value > valid_range[1]:
96            raise ValueError(
97                "Input {0}is not within the required interval of ({1}, {2}].".format(arg_name, valid_range[0],
98                                                                                     valid_range[1]))
99    elif not left_open_interval and right_open_interval:
100        if value < valid_range[0] or value >= valid_range[1]:
101            raise ValueError(
102                "Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0],
103                                                                                     valid_range[1]))
104    else:
105        if value <= valid_range[0] or value >= valid_range[1]:
106            raise ValueError(
107                "Input {0}is not within the required interval of ({1}, {2}).".format(arg_name, valid_range[0],
108                                                                                     valid_range[1]))
109
110
111def check_value_cutoff(value, valid_range, arg_name=""):
112    """
113    Validates a value is within a desired range [inclusive, exclusive).
114
115    :param value: the value to be validated
116    :param valid_range: the desired range
117    :param arg_name: arg_name: arg_name: name of the variable to be validated
118    :return: Exception: when the validation fails, nothing otherwise.
119    """
120    check_value(value, valid_range, arg_name, False, True)
121
122
123def check_value_ratio(value, valid_range, arg_name=""):
124    """
125    Validates a value is within a desired range (exclusive, inclusive].
126
127    :param value: the value to be validated
128    :param valid_range: the desired range
129    :param arg_name: arg_name: name of the variable to be validated
130    :return: Exception: when the validation fails, nothing otherwise.
131    """
132    check_value(value, valid_range, arg_name, True, False)
133
134
135def check_value_normalize_std(value, valid_range, arg_name=""):
136    """
137    Validates a value is within a desired range (exclusive, inclusive].
138
139    :param value: the value to be validated
140    :param valid_range: the desired range
141    :param arg_name: arg_name: name of the variable to be validated
142    :return: Exception: when the validation fails, nothing otherwise.
143    """
144    check_value(value, valid_range, arg_name, True, False)
145
146
147def check_range(values, valid_range, arg_name=""):
148    """
149    Validates the boundaries a range are within a desired range [inclusive, inclusive].
150
151    :param values: the two values to be validated
152    :param valid_range: the desired range
153    :param arg_name: arg_name: name of the variable to be validated
154    :return: Exception: when the validation fails, nothing otherwise.
155    """
156    arg_name = pad_arg_name(arg_name)
157    if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]:
158        raise ValueError(
159            "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0],
160                                                                                 valid_range[1]))
161
162
163def check_positive(value, arg_name=""):
164    """
165    Validates the value of a variable is positive.
166
167    :param value: the value of the variable
168    :param arg_name: name of the variable to be validated
169    :return: Exception: when the validation fails, nothing otherwise.
170    """
171    arg_name = pad_arg_name(arg_name)
172    if value <= 0:
173        raise ValueError("Input {0}must be greater than 0.".format(arg_name))
174
175
176def check_int32_not_zero(value, arg_name=""):
177    arg_name = pad_arg_name(arg_name)
178    type_check(value, (int,), arg_name)
179    if value < INT32_MIN or value > INT32_MAX or value == 0:
180        raise ValueError(
181            "Input {0}is not within the required interval of [-2147483648, 0) and (0, 2147483647].".format(arg_name))
182
183
184def check_odd(value, arg_name=""):
185    arg_name = pad_arg_name(arg_name)
186    if value % 2 != 1:
187        raise ValueError(
188            "Input {0}is not an odd value.".format(arg_name))
189
190
191def check_2tuple(value, arg_name=""):
192    """
193    Validates a variable is a tuple with two entries.
194
195    :param value: the value of the variable
196    :param arg_name: name of the variable to be validated
197    :return: Exception: when the validation fails, nothing otherwise.
198    """
199    if not (isinstance(value, tuple) and len(value) == 2):
200        raise ValueError("Value {0} needs to be a 2-tuple.".format(arg_name))
201
202
203def check_int32(value, arg_name=""):
204    """
205    Validates the value of a variable is within the range of int32.
206
207    :param value: the value of the variable
208    :param arg_name: name of the variable to be validated
209    :return: Exception: when the validation fails, nothing otherwise.
210    """
211    type_check(value, (int,), arg_name)
212    check_value(value, [INT32_MIN, INT32_MAX], arg_name)
213
214
215def check_uint8(value, arg_name=""):
216    """
217    Validates the value of a variable is within the range of uint8.
218
219    :param value: the value of the variable
220    :param arg_name: name of the variable to be validated
221    :return: Exception: when the validation fails, nothing otherwise.
222    """
223    type_check(value, (int,), arg_name)
224    check_value(value, [UINT8_MIN, UINT8_MAX])
225
226
227def check_uint32(value, arg_name=""):
228    """
229    Validates the value of a variable is within the range of uint32.
230
231    :param value: the value of the variable
232    :param arg_name: name of the variable to be validated
233    :return: Exception: when the validation fails, nothing otherwise.
234    """
235    type_check(value, (int,), arg_name)
236    check_value(value, [UINT32_MIN, UINT32_MAX])
237
238
239def check_pos_uint32(value, arg_name=""):
240    """
241    Validates the value of a variable is within the range of positive uint32.
242
243    :param value: the value of the variable
244    :param arg_name: name of the variable to be validated
245    :return: Exception: when the validation fails, nothing otherwise.
246    """
247    type_check(value, (int,), arg_name)
248    check_value(value, [POS_INT_MIN, UINT32_MAX])
249
250
251def check_pos_int32(value, arg_name=""):
252    """
253    Validates the value of a variable is within the range of int32.
254
255    :param value: the value of the variable
256    :param arg_name: name of the variable to be validated
257    :return: Exception: when the validation fails, nothing otherwise.
258    """
259    type_check(value, (int,), arg_name)
260    check_value(value, [POS_INT_MIN, INT32_MAX], arg_name)
261
262
263def check_uint64(value, arg_name=""):
264    """
265    Validates the value of a variable is within the range of uint64.
266
267    :param value: the value of the variable
268    :param arg_name: name of the variable to be validated
269    :return: Exception: when the validation fails, nothing otherwise.
270    """
271    type_check(value, (int,), arg_name)
272    check_value(value, [UINT64_MIN, UINT64_MAX])
273
274
275def check_pos_int64(value, arg_name=""):
276    """
277    Validates the value of a variable is within the range of int64.
278
279    :param value: the value of the variable
280    :param arg_name: name of the variable to be validated
281    :return: Exception: when the validation fails, nothing otherwise.
282    """
283    type_check(value, (int,), arg_name)
284    check_value(value, [POS_INT_MIN, INT64_MAX])
285
286
287def check_non_negative_int32(value, arg_name=""):
288    """
289    Validates the value of a variable is within the range of non negative int32.
290
291    :param value: the value of the variable.
292    :param arg_name: name of the variable to be validated.
293    :return: Exception: when the validation fails, nothing otherwise.
294    """
295    check_value(value, [UINT32_MIN, INT32_MAX], arg_name)
296
297
298def check_float32(value, arg_name=""):
299    """
300    Validates the value of a variable is within the range of float32.
301
302    :param value: the value of the variable
303    :param arg_name: name of the variable to be validated
304    :return: Exception: when the validation fails, nothing otherwise.
305    """
306    check_value(value, [FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER], arg_name)
307
308
309def check_float64(value, arg_name=""):
310    """
311    Validates the value of a variable is within the range of float64.
312
313    :param value: the value of the variable
314    :param arg_name: name of the variable to be validated
315    :return: Exception: when the validation fails, nothing otherwise.
316    """
317    check_value(value, [DOUBLE_MIN_INTEGER, DOUBLE_MAX_INTEGER], arg_name)
318
319
320def check_pos_float32(value, arg_name=""):
321    """
322    Validates the value of a variable is within the range of positive float32.
323
324    :param value: the value of the variable
325    :param arg_name: name of the variable to be validated
326    :return: Exception: when the validation fails, nothing otherwise.
327    """
328    check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name, True)
329
330
331def check_pos_float64(value, arg_name=""):
332    """
333    Validates the value of a variable is within the range of positive float64.
334
335    :param value: the value of the variable
336    :param arg_name: name of the variable to be validated
337    :return: Exception: when the validation fails, nothing otherwise.
338    """
339    check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name, True)
340
341
342def check_non_negative_float32(value, arg_name=""):
343    """
344    Validates the value of a variable is within the range of non negative float32.
345
346    :param value: the value of the variable
347    :param arg_name: name of the variable to be validated
348    :return: Exception: when the validation fails, nothing otherwise.
349    """
350    check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name)
351
352
353def check_non_negative_float64(value, arg_name=""):
354    """
355    Validates the value of a variable is within the range of non negative float64.
356
357    :param value: the value of the variable
358    :param arg_name: name of the variable to be validated
359    :return: Exception: when the validation fails, nothing otherwise.
360    """
361    check_value(value, [UINT32_MIN, DOUBLE_MAX_INTEGER], arg_name)
362
363
364def check_float32_not_zero(value, arg_name=""):
365    arg_name = pad_arg_name(arg_name)
366    type_check(value, (int,), arg_name)
367    if value < FLOAT_MIN_INTEGER or value > FLOAT_MAX_INTEGER or value == 0:
368        raise ValueError(
369            "Input {0}is not within the required interval of [-16777216, 0) and (0, 16777216].".format(arg_name))
370
371
372def check_valid_detype(type_):
373    """
374    Validates if a type is a DE Type.
375
376    :param type_: the type_ to be validated
377    :return: Exception: when the type is not a DE type, True otherwise.
378    """
379    if type_ not in valid_detype:
380        raise TypeError("Unknown column type.")
381    return True
382
383
384def check_valid_str(value, valid_strings, arg_name=""):
385    """
386    Validates the content stored in a string.
387
388    :param value: the value to be validated
389    :param valid_strings: a list/set of valid strings
390    :param arg_name: name of the variable to be validated
391    :return: Exception: when the type is not a DE type, nothing otherwise.
392    """
393    type_check(value, (str,), arg_name)
394    if value not in valid_strings:
395        raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings)))
396
397
398def check_columns(columns, name):
399    """
400    Validate strings in column_names.
401
402    Args:
403        columns (list): list of column_names.
404        name (str): name of columns.
405
406    Returns:
407        Exception: when the value is not correct, otherwise nothing.
408    """
409    type_check(columns, (list, str), name)
410    if isinstance(columns, str):
411        if not columns:
412            raise ValueError("{0} should not be an empty str.".format(name))
413    elif isinstance(columns, list):
414        if not columns:
415            raise ValueError("{0} should not be empty.".format(name))
416        for i, column_name in enumerate(columns):
417            if not column_name:
418                raise ValueError("{0}[{1}] should not be empty.".format(name, i))
419
420        col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
421        type_check_list(columns, (str,), col_names)
422        if len(set(columns)) != len(columns):
423            raise ValueError("Every column name should not be same with others in column_names.")
424
425
426def parse_user_args(method, *args, **kwargs):
427    """
428    Parse user arguments in a function.
429
430    Args:
431        method (method): a callable function.
432        args: user passed args.
433        kwargs: user passed kwargs.
434
435    Returns:
436        user_filled_args (list): values of what the user passed in for the arguments.
437        ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
438    """
439    sig = inspect.signature(method)
440    if 'self' in sig.parameters or 'cls' in sig.parameters:
441        ba = sig.bind(method, *args, **kwargs)
442        ba.apply_defaults()
443        params = list(sig.parameters.keys())[1:]
444    else:
445        ba = sig.bind(*args, **kwargs)
446        ba.apply_defaults()
447        params = list(sig.parameters.keys())
448
449    user_filled_args = [ba.arguments.get(arg_value) for arg_value in params]
450    return user_filled_args, ba.arguments
451
452
453def type_check_list(args, types, arg_names):
454    """
455    Check the type of each parameter in the list.
456
457    Args:
458        args (Union[list, tuple]): a list or tuple of any variable.
459        types (tuple): tuple of all valid types for arg.
460        arg_names (Union[list, tuple of str]): the names of args.
461
462    Returns:
463        Exception: when the type is not correct, otherwise nothing.
464    """
465    type_check(args, (list, tuple,), arg_names)
466    if len(args) != len(arg_names) and not isinstance(arg_names, str):
467        raise ValueError("List of arguments is not the same length as argument_names.")
468    if isinstance(arg_names, str):
469        arg_names = ["{0}[{1}]".format(arg_names, i) for i in range(len(args))]
470    for arg, arg_name in zip(args, arg_names):
471        type_check(arg, types, arg_name)
472
473
474def type_check(arg, types, arg_name):
475    """
476    Check the type of the parameter.
477
478    Args:
479        arg (Any) : any variable.
480        types (tuple): tuple of all valid types for arg.
481        arg_name (str): the name of arg.
482
483    Returns:
484        Exception: when the validation fails, otherwise nothing.
485    """
486    # handle special case of booleans being a subclass of ints
487    print_value = '\"\"' if repr(arg) == repr('') else arg
488
489    if int in types and bool not in types:
490        if isinstance(arg, bool):
491            raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value,
492                                                                                                  types, type(arg)))
493    if not isinstance(arg, types):
494        raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value,
495                                                                                              list(types), type(arg)))
496
497
498def check_filename(path):
499    """
500    check the filename in the path.
501
502    Args:
503        path (str): the path.
504
505    Returns:
506        Exception: when error.
507    """
508    if not isinstance(path, str):
509        raise TypeError("path: {} is not string".format(path))
510    filename = os.path.basename(os.path.realpath(path))
511    forbidden_symbols = set(r'\/:*?"<>|`&\';')
512
513    if set(filename) & forbidden_symbols:
514        raise ValueError(r"filename should not contain \/:*?\"<>|`&;\'")
515
516    if filename.startswith(' ') or filename.endswith(' '):
517        raise ValueError("filename should not start/end with space.")
518
519
520def check_dir(dataset_dir):
521    """
522    Validates if the argument is a directory.
523
524    :param dataset_dir: string containing directory path
525    :return: Exception: when the validation fails, nothing otherwise.
526    """
527    type_check(dataset_dir, (str,), "dataset_dir")
528    if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK):
529        raise ValueError("The folder {} does not exist or is not a directory or permission denied!".format(dataset_dir))
530
531
532def check_list_same_size(list1, list2, list1_name="", list2_name=""):
533    """
534    Validates the two lists as the same size.
535
536    :param list1: the first list to be validated
537    :param list2: the secend list to be validated
538    :param list1_name: name of the list1
539    :param list2_name: name of the list2
540    :return: Exception: when the two list no same size, nothing otherwise.
541    """
542    if len(list1) != len(list2):
543        raise ValueError("The size of {0} should be the same as that of {1}.".format(list1_name, list2_name))
544
545
546def check_file(dataset_file):
547    """
548    Validates if the argument is a valid file name.
549
550    :param dataset_file: string containing file path
551    :return: Exception: when the validation fails, nothing otherwise.
552    """
553    check_filename(dataset_file)
554    dataset_file = os.path.realpath(dataset_file)
555    if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK):
556        raise ValueError("The file {} does not exist or permission denied!".format(dataset_file))
557
558
559def check_sampler_shuffle_shard_options(param_dict):
560    """
561    Check for valid shuffle, sampler, num_shards, and shard_id inputs.
562    Args:
563        param_dict (dict): param_dict.
564
565    Returns:
566        Exception: ValueError or RuntimeError if error.
567    """
568    shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
569    num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
570    num_samples = param_dict.get('num_samples')
571
572    if sampler is not None:
573        if shuffle is not None:
574            raise RuntimeError("sampler and shuffle cannot be specified at the same time.")
575        if num_shards is not None or shard_id is not None:
576            raise RuntimeError("sampler and sharding cannot be specified at the same time.")
577        if num_samples is not None:
578            raise RuntimeError("sampler and num_samples cannot be specified at the same time.")
579
580    if num_shards is not None:
581        check_pos_int32(num_shards, "num_shards")
582        if shard_id is None:
583            raise RuntimeError("num_shards is specified and currently requires shard_id as well.")
584        check_value(shard_id, [0, num_shards - 1], "shard_id")
585
586    if num_shards is None and shard_id is not None:
587        raise RuntimeError("shard_id is specified but num_shards is not.")
588
589
590def check_padding_options(param_dict):
591    """
592    Check for valid padded_sample and num_padded of padded samples.
593
594    Args:
595        param_dict (dict): param_dict.
596
597    Returns:
598        Exception: ValueError or RuntimeError if error.
599    """
600
601    columns_list = param_dict.get('columns_list')
602    padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded')
603    if padded_sample is not None:
604        if num_padded is None:
605            raise RuntimeError("padded_sample is specified and requires num_padded as well.")
606        if num_padded < 0:
607            raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded))
608        if columns_list is None:
609            raise RuntimeError("padded_sample is specified and requires columns_list as well.")
610        for column in columns_list:
611            if column not in padded_sample:
612                raise ValueError("padded_sample cannot match columns_list.")
613    if padded_sample is None and num_padded is not None:
614        raise RuntimeError("num_padded is specified but padded_sample is not.")
615
616
617def check_num_parallel_workers(value):
618    """
619    Validates the value for num_parallel_workers.
620.
621    :param value: an integer corresponding to the number of parallel workers
622    :return: Exception: when the validation fails, nothing otherwise.
623    """
624    type_check(value, (int,), "num_parallel_workers")
625    if value < 1 or value > cpu_count():
626        raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count()))
627
628
629def check_num_samples(value):
630    """
631    Validates number of samples are valid.
632.
633    :param value: an integer corresponding to the number of samples.
634    :return: Exception: when the validation fails, nothing otherwise.
635    """
636    type_check(value, (int,), "num_samples")
637    if value < 0 or value > INT64_MAX:
638        raise ValueError(
639            "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX))
640
641
642def validate_dataset_param_value(param_list, param_dict, param_type):
643    """
644
645    :param param_list: a list of parameter names.
646    :param param_dict: a dcitionary containing parameter names and their values.
647    :param param_type: a tuple containing type of parameters.
648    :return: Exception: when the validation fails, nothing otherwise.
649    """
650    for param_name in param_list:
651        if param_dict.get(param_name) is not None:
652            if param_name == 'num_parallel_workers':
653                check_num_parallel_workers(param_dict.get(param_name))
654            if param_name == 'num_samples':
655                check_num_samples(param_dict.get(param_name))
656            else:
657                type_check(param_dict.get(param_name), (param_type,), param_name)
658
659
660def check_gnn_list_of_pair_or_ndarray(param, param_name):
661    """
662    Check if the input parameter is a list of tuple or numpy.ndarray.
663
664    Args:
665        param (Union[list[tuple], nd.ndarray]): param.
666        param_name (str): param_name.
667
668    Returns:
669        Exception: TypeError if error.
670    """
671    type_check(param, (list, np.ndarray), param_name)
672    if isinstance(param, list):
673        param_names = ["node_list[{0}]".format(i) for i in range(len(param))]
674        type_check_list(param, (tuple,), param_names)
675        for idx, pair in enumerate(param):
676            if not len(pair) == 2:
677                raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format(
678                    param_names[idx], len(pair)))
679            column_names = ["node_list[{0}], number #{1} element".format(idx, i+1) for i in range(len(pair))]
680            type_check_list(pair, (int,), column_names)
681    elif isinstance(param, np.ndarray):
682        if param.ndim != 2:
683            raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim))
684        if param.shape[1] != 2:
685            raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format(
686                param_name, param.shape[1]))
687        if not param.dtype == np.int32:
688            raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
689                param_name, param.dtype))
690
691
692def check_gnn_list_or_ndarray(param, param_name):
693    """
694    Check if the input parameter is list or numpy.ndarray.
695
696    Args:
697        param (Union[list, nd.ndarray]): param.
698        param_name (str): param_name.
699
700    Returns:
701        Exception: TypeError if error.
702    """
703
704    type_check(param, (list, np.ndarray), param_name)
705    if isinstance(param, list):
706        param_names = ["param_{0}".format(i) for i in range(len(param))]
707        type_check_list(param, (int,), param_names)
708
709    elif isinstance(param, np.ndarray):
710        if not param.dtype == np.int32:
711            raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
712                param_name, param.dtype))
713
714
715def check_tensor_op(param, param_name):
716    """check whether param is a tensor op or a callable Python function"""
717    if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
718        raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
719
720
721def check_c_tensor_op(param, param_name):
722    """check whether param is a tensor op or a callable Python function but not a py_transform"""
723    if callable(param) and str(param).find("py_transform") >= 0:
724        raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name))
725    if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
726        raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
727
728
729def replace_none(value, default):
730    """ replaces None with a default value."""
731    return value if value is not None else default
732