1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Validators for TensorOps. 16""" 17from functools import wraps 18import inspect 19import numpy as np 20 21from mindspore._c_expression import typing 22 23from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \ 24 check_tensor_op, type_check_list 25 26# POS_INT_MIN is used to limit values from starting from 0 27POS_INT_MIN = 1 28UINT8_MAX = 255 29UINT8_MIN = 0 30UINT32_MAX = 4294967295 31UINT32_MIN = 0 32UINT64_MAX = 18446744073709551615 33UINT64_MIN = 0 34INT32_MAX = 2147483647 35INT32_MIN = -2147483648 36INT64_MAX = 9223372036854775807 37INT64_MIN = -9223372036854775808 38FLOAT_MAX_INTEGER = 16777216 39FLOAT_MIN_INTEGER = -16777216 40DOUBLE_MAX_INTEGER = 9007199254740992 41DOUBLE_MIN_INTEGER = -9007199254740992 42 43 44def check_fill_value(method): 45 """Wrapper method to check the parameters of fill_value.""" 46 47 @wraps(method) 48 def new_method(self, *args, **kwargs): 49 [fill_value], _ = parse_user_args(method, *args, **kwargs) 50 type_check(fill_value, (str, float, bool, int, bytes), "fill_value") 51 52 return method(self, *args, **kwargs) 53 54 return new_method 55 56 57def check_one_hot_op(method): 58 """Wrapper method to check the parameters of one_hot_op.""" 59 60 @wraps(method) 61 def new_method(self, *args, **kwargs): 62 [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs) 63 type_check(smoothing_rate, (int, float), "smoothing_rate") 64 type_check(num_classes, (int,), "num_classes") 65 check_positive(num_classes) 66 67 if smoothing_rate is not None: 68 check_value(smoothing_rate, [0., 1.], "smoothing_rate") 69 70 return method(self, *args, **kwargs) 71 72 return new_method 73 74 75def check_num_classes(method): 76 """Wrapper method to check the parameters of number of classes.""" 77 78 @wraps(method) 79 def new_method(self, *args, **kwargs): 80 [num_classes], _ = parse_user_args(method, *args, **kwargs) 81 82 type_check(num_classes, (int,), "num_classes") 83 check_positive(num_classes) 84 85 return method(self, *args, **kwargs) 86 87 return new_method 88 89 90def check_ms_type(method): 91 """Wrapper method to check the parameters of data type.""" 92 93 @wraps(method) 94 def new_method(self, *args, **kwargs): 95 [data_type], _ = parse_user_args(method, *args, **kwargs) 96 97 type_check(data_type, (typing.Type,), "data_type") 98 99 return method(self, *args, **kwargs) 100 101 return new_method 102 103 104def check_slice_option(method): 105 """Wrapper method to check the parameters of SliceOption.""" 106 107 @wraps(method) 108 def new_method(self, *args, **kwargs): 109 [slice_option], _ = parse_user_args(method, *args, **kwargs) 110 from .c_transforms import _SliceOption 111 if slice_option is not None: 112 type_check(slice_option, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice_option") 113 114 if isinstance(slice_option, list): 115 type_check_list(slice_option, (int,), "slice_option") 116 117 return method(self, *args, **kwargs) 118 119 return new_method 120 121 122def check_slice_op(method): 123 """Wrapper method to check the parameters of slice.""" 124 125 @wraps(method) 126 def new_method(self, *args, **kwargs): 127 [slice_op], _ = parse_user_args(method, *args, **kwargs) 128 129 for s in slice_op: 130 from .c_transforms import _SliceOption 131 if s is not None: 132 type_check(s, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice") 133 if isinstance(s, list) and s: 134 if isinstance(s[0], int): 135 type_check_list(s, (int,), "slice") 136 137 return method(self, *args, **kwargs) 138 139 return new_method 140 141 142def check_mask_op(method): 143 """Wrapper method to check the parameters of mask.""" 144 145 @wraps(method) 146 def new_method(self, *args, **kwargs): 147 [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) 148 149 from .c_transforms import Relational 150 type_check(operator, (Relational,), "operator") 151 type_check(constant, (str, float, bool, int, bytes), "constant") 152 type_check(dtype, (typing.Type,), "dtype") 153 154 return method(self, *args, **kwargs) 155 156 return new_method 157 158 159def check_pad_end(method): 160 """Wrapper method to check the parameters of PadEnd.""" 161 162 @wraps(method) 163 def new_method(self, *args, **kwargs): 164 165 [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs) 166 167 if pad_value is not None: 168 type_check(pad_value, (str, float, bool, int, bytes), "pad_value") 169 type_check(pad_shape, (list,), "pad_shape") 170 171 for dim in pad_shape: 172 if dim is not None: 173 if isinstance(dim, int): 174 check_pos_int64(dim) 175 else: 176 raise TypeError("a value in the list is not an integer.") 177 178 return method(self, *args, **kwargs) 179 180 return new_method 181 182 183def check_concat_type(method): 184 """Wrapper method to check the parameters of concatenation op.""" 185 186 @wraps(method) 187 def new_method(self, *args, **kwargs): 188 189 [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs) 190 191 if axis is not None: 192 type_check(axis, (int,), "axis") 193 if axis not in (0, -1): 194 raise ValueError("only 1D concatenation supported.") 195 196 if prepend is not None: 197 type_check(prepend, (np.ndarray,), "prepend") 198 if len(prepend.shape) != 1: 199 raise ValueError("can only prepend 1D arrays.") 200 201 if append is not None: 202 type_check(append, (np.ndarray,), "append") 203 if len(append.shape) != 1: 204 raise ValueError("can only append 1D arrays.") 205 206 return method(self, *args, **kwargs) 207 208 return new_method 209 210 211def check_random_transform_ops(method): 212 """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose.""" 213 214 @wraps(method) 215 def new_method(self, *args, **kwargs): 216 arg_list, _ = parse_user_args(method, *args, **kwargs) 217 type_check(arg_list[0], (list,), "op_list") 218 if not arg_list[0]: 219 raise ValueError("op_list can not be empty.") 220 for ind, op in enumerate(arg_list[0]): 221 check_tensor_op(op, "op_list[{0}]".format(ind)) 222 if len(arg_list) == 2: # random apply takes an additional arg 223 type_check(arg_list[1], (float, int), "prob") 224 check_value(arg_list[1], (0, 1), "prob") 225 return method(self, *args, **kwargs) 226 227 return new_method 228 229 230def check_compose_list(method): 231 """Wrapper method to check the transform list of Python Compose.""" 232 233 @wraps(method) 234 def new_method(self, *args, **kwargs): 235 [transforms], _ = parse_user_args(method, *args, **kwargs) 236 237 type_check(transforms, (list,), transforms) 238 if not transforms: 239 raise ValueError("transforms list is empty.") 240 for i, transform in enumerate(transforms): 241 if not callable(transform): 242 raise ValueError("transforms[{}] is not callable.".format(i)) 243 return method(self, *args, **kwargs) 244 245 return new_method 246 247 248def check_compose_call(method): 249 """Wrapper method to check the transform list of Compose.""" 250 251 @wraps(method) 252 def new_method(self, *args, **kwargs): 253 sig = inspect.signature(method) 254 ba = sig.bind_partial(method, *args, **kwargs) 255 img = ba.arguments.get("args") 256 if img is None: 257 raise TypeError( 258 "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).") 259 return method(self, *args, **kwargs) 260 261 return new_method 262 263 264def check_random_apply(method): 265 """Wrapper method to check the parameters of random apply.""" 266 267 @wraps(method) 268 def new_method(self, *args, **kwargs): 269 [transforms, prob], _ = parse_user_args(method, *args, **kwargs) 270 type_check(transforms, (list,), "transforms") 271 272 for i, transform in enumerate(transforms): 273 if str(transform).find("c_transform") >= 0: 274 raise ValueError( 275 "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \ 276 .format(i)) 277 278 if prob is not None: 279 type_check(prob, (float, int,), "prob") 280 check_value(prob, [0., 1.], "prob") 281 282 return method(self, *args, **kwargs) 283 284 return new_method 285 286 287def check_transforms_list(method): 288 """Wrapper method to check the parameters of transform list.""" 289 290 @wraps(method) 291 def new_method(self, *args, **kwargs): 292 [transforms], _ = parse_user_args(method, *args, **kwargs) 293 294 type_check(transforms, (list,), "transforms") 295 for i, transform in enumerate(transforms): 296 if str(transform).find("c_transform") >= 0: 297 raise ValueError( 298 "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \ 299 .format(i)) 300 return method(self, *args, **kwargs) 301 302 return new_method 303 304 305def check_plugin(method): 306 """Wrapper method to check the parameters of plugin.""" 307 308 @wraps(method) 309 def new_method(self, *args, **kwargs): 310 [lib_path, func_name, user_args], _ = parse_user_args(method, *args, **kwargs) 311 312 type_check(lib_path, (str,), "lib_path") 313 type_check(func_name, (str,), "func_name") 314 if user_args is not None: 315 type_check(user_args, (str,), "user_args") 316 317 return method(self, *args, **kwargs) 318 319 return new_method 320