1# Copyright 2019-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Validators for data processing operations. 16""" 17from functools import wraps 18import inspect 19import numpy as np 20 21from mindspore._c_expression import typing 22from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \ 23 check_tensor_op, type_check_list, deprecator_factory 24 25# POS_INT_MIN is used to limit values from starting from 0 26POS_INT_MIN = 1 27UINT8_MAX = 255 28UINT8_MIN = 0 29UINT32_MAX = 4294967295 30UINT32_MIN = 0 31UINT64_MAX = 18446744073709551615 32UINT64_MIN = 0 33INT32_MAX = 2147483647 34INT32_MIN = -2147483648 35INT64_MAX = 9223372036854775807 36INT64_MIN = -9223372036854775808 37FLOAT_MAX_INTEGER = 16777216 38FLOAT_MIN_INTEGER = -16777216 39DOUBLE_MAX_INTEGER = 9007199254740992 40DOUBLE_MIN_INTEGER = -9007199254740992 41 42 43def check_fill_value(method): 44 """Wrapper method to check the parameters of fill_value.""" 45 46 @wraps(method) 47 def new_method(self, *args, **kwargs): 48 [fill_value], _ = parse_user_args(method, *args, **kwargs) 49 type_check(fill_value, (str, float, bool, int, bytes), "fill_value") 50 51 return method(self, *args, **kwargs) 52 53 return new_method 54 55 56def check_one_hot_op(method): 57 """Wrapper method to check the parameters of one_hot_op.""" 58 59 @wraps(method) 60 def new_method(self, *args, **kwargs): 61 [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs) 62 type_check(smoothing_rate, (int, float), "smoothing_rate") 63 type_check(num_classes, (int,), "num_classes") 64 check_positive(num_classes) 65 66 if smoothing_rate is not None: 67 check_value(smoothing_rate, [0., 1.], "smoothing_rate") 68 69 return method(self, *args, **kwargs) 70 71 return new_method 72 73 74def check_num_classes(method): 75 """Wrapper method to check the parameters of number of classes.""" 76 77 @wraps(method) 78 def new_method(self, *args, **kwargs): 79 [num_classes], _ = parse_user_args(method, *args, **kwargs) 80 81 type_check(num_classes, (int,), "num_classes") 82 check_positive(num_classes) 83 84 return method(self, *args, **kwargs) 85 86 return new_method 87 88 89def check_ms_type(method): 90 """Wrapper method to check the parameters of data type.""" 91 92 @wraps(method) 93 def new_method(self, *args, **kwargs): 94 [data_type], _ = parse_user_args(method, *args, **kwargs) 95 96 type_check(data_type, (typing.Type,), "data_type") 97 98 return method(self, *args, **kwargs) 99 100 return new_method 101 102 103def check_slice_option(method): 104 """Wrapper method to check the parameters of SliceOption.""" 105 106 @wraps(method) 107 def new_method(self, *args, **kwargs): 108 [slice_option], _ = parse_user_args(method, *args, **kwargs) 109 from .transforms import _SliceOption 110 if slice_option is not None: 111 type_check(slice_option, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice_option") 112 113 if isinstance(slice_option, list): 114 type_check_list(slice_option, (int,), "slice_option") 115 116 return method(self, *args, **kwargs) 117 118 return new_method 119 120 121def check_slice_op(method): 122 """Wrapper method to check the parameters of slice.""" 123 124 @wraps(method) 125 def new_method(self, *args, **kwargs): 126 [slice_op], _ = parse_user_args(method, *args, **kwargs) 127 128 for s in slice_op: 129 from .transforms import _SliceOption 130 if s is not None: 131 type_check(s, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice") 132 if isinstance(s, list) and s: 133 if isinstance(s[0], int): 134 type_check_list(s, (int,), "slice") 135 136 return method(self, *args, **kwargs) 137 138 return new_method 139 140 141def check_mask_op(method): 142 """Wrapper method to check the parameters of mask.""" 143 144 @wraps(method) 145 def new_method(self, *args, **kwargs): 146 [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) 147 148 from .c_transforms import Relational 149 type_check(operator, (Relational,), "operator") 150 type_check(constant, (str, float, bool, int, bytes), "constant") 151 type_check(dtype, (typing.Type,), "dtype") 152 153 return method(self, *args, **kwargs) 154 155 return new_method 156 157 158def check_mask_op_new(method): 159 """Wrapper method to check the parameters of mask.""" 160 161 @wraps(method) 162 def new_method(self, *args, **kwargs): 163 [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) 164 165 from .transforms import Relational 166 type_check(operator, (Relational,), "operator") 167 type_check(constant, (str, float, bool, int, bytes), "constant") 168 type_check(dtype, (typing.Type,), "dtype") 169 170 return method(self, *args, **kwargs) 171 172 return new_method 173 174 175def check_pad_end(method): 176 """Wrapper method to check the parameters of PadEnd.""" 177 178 @wraps(method) 179 def new_method(self, *args, **kwargs): 180 181 [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs) 182 183 if pad_value is not None: 184 type_check(pad_value, (str, float, bool, int, bytes), "pad_value") 185 type_check(pad_shape, (list,), "pad_shape") 186 187 for dim in pad_shape: 188 if dim is not None: 189 if isinstance(dim, int): 190 check_pos_int64(dim) 191 else: 192 raise TypeError("a value in the list is not an integer.") 193 194 return method(self, *args, **kwargs) 195 196 return new_method 197 198 199def check_concat_type(method): 200 """Wrapper method to check the parameters of concatenation op.""" 201 202 @wraps(method) 203 def new_method(self, *args, **kwargs): 204 205 [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs) 206 207 if axis is not None: 208 type_check(axis, (int,), "axis") 209 if axis not in (0, -1): 210 raise ValueError("only 1D concatenation supported.") 211 212 if prepend is not None: 213 type_check(prepend, (np.ndarray,), "prepend") 214 if len(prepend.shape) != 1: 215 raise ValueError("can only prepend 1D arrays.") 216 217 if append is not None: 218 type_check(append, (np.ndarray,), "append") 219 if len(append.shape) != 1: 220 raise ValueError("can only append 1D arrays.") 221 222 return method(self, *args, **kwargs) 223 224 return new_method 225 226 227def check_random_transform_ops(method): 228 """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose.""" 229 230 @wraps(method) 231 def new_method(self, *args, **kwargs): 232 arg_list, _ = parse_user_args(method, *args, **kwargs) 233 type_check(arg_list[0], (list,), "transforms list") 234 if not arg_list[0]: 235 raise ValueError("transforms list can not be empty.") 236 for ind, op in enumerate(arg_list[0]): 237 check_tensor_op(op, "transforms[{0}]".format(ind)) 238 check_transform_op_type(ind, op) 239 if len(arg_list) == 2: # random apply takes an additional arg 240 type_check(arg_list[1], (float, int), "prob") 241 check_value(arg_list[1], (0, 1), "prob") 242 return method(self, *args, **kwargs) 243 244 return new_method 245 246 247def check_transform_op_type(ind, op): 248 """Check the operation.""" 249 # c_vision.HWC2CHW error 250 # py_vision.HWC2CHW error 251 if type(op) == type: # pylint: disable=unidiomatic-typecheck 252 raise ValueError("op_list[{}] should be a dataset processing operation instance, " 253 "but got: {}. It may be missing parentheses for instantiation.".format(ind, op)) 254 255 256def check_compose_list(method): 257 """Wrapper method to check the transform list of Python Compose.""" 258 259 @wraps(method) 260 def new_method(self, *args, **kwargs): 261 [transforms], _ = parse_user_args(method, *args, **kwargs) 262 263 type_check(transforms, (list,), transforms) 264 if not transforms: 265 raise ValueError("transforms list is empty.") 266 for i, transform in enumerate(transforms): 267 if not callable(transform): 268 raise ValueError("transforms[{}] is not callable.".format(i)) 269 check_transform_op_type(i, transform) 270 return method(self, *args, **kwargs) 271 272 return new_method 273 274 275def check_compose_call(method): 276 """Wrapper method to check the transform list of Compose.""" 277 278 @wraps(method) 279 def new_method(self, *args, **kwargs): 280 sig = inspect.signature(method) 281 ba = sig.bind_partial(method, *args, **kwargs) 282 img = ba.arguments.get("args") 283 if img is None: 284 raise TypeError( 285 "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).") 286 return method(self, *args, **kwargs) 287 288 return new_method 289 290 291def check_random_apply(method): 292 """Wrapper method to check the parameters of random apply.""" 293 294 @wraps(method) 295 def new_method(self, *args, **kwargs): 296 [transforms, prob], _ = parse_user_args(method, *args, **kwargs) 297 type_check(transforms, (list,), "transforms") 298 299 for i, transform in enumerate(transforms): 300 if str(transform).find("c_transform") >= 0: 301 raise ValueError( 302 "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \ 303 .format(i)) 304 check_transform_op_type(i, transform) 305 306 if prob is not None: 307 type_check(prob, (float, int,), "prob") 308 check_value(prob, [0., 1.], "prob") 309 310 return method(self, *args, **kwargs) 311 312 return new_method 313 314 315def check_transforms_list(method): 316 """Wrapper method to check the parameters of transform list.""" 317 318 @wraps(method) 319 def new_method(self, *args, **kwargs): 320 [transforms], _ = parse_user_args(method, *args, **kwargs) 321 322 type_check(transforms, (list,), "transforms") 323 for i, transform in enumerate(transforms): 324 if str(transform).find("c_transform") >= 0: 325 raise ValueError( 326 "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \ 327 .format(i)) 328 check_transform_op_type(i, transform) 329 return method(self, *args, **kwargs) 330 331 return new_method 332 333 334def check_plugin(method): 335 """Wrapper method to check the parameters of plugin.""" 336 337 @wraps(method) 338 def new_method(self, *args, **kwargs): 339 [lib_path, func_name, user_args], _ = parse_user_args(method, *args, **kwargs) 340 341 type_check(lib_path, (str,), "lib_path") 342 type_check(func_name, (str,), "func_name") 343 if user_args is not None: 344 type_check(user_args, (str,), "user_args") 345 346 return method(self, *args, **kwargs) 347 348 return new_method 349 350 351def invalidate_callable(method): 352 """Wrapper method to invalidate cached callable_op_ used in eager mode. \ 353 This decorator must be added to any method which modifies the state of transform.""" 354 355 @wraps(method) 356 def new_method(self, *args, **kwargs): 357 self.callable_op_ = None 358 return method(self, *args, **kwargs) 359 360 return new_method 361 362 363def check_type_cast(method): 364 """Wrapper method to check the parameters of TypeCast.""" 365 366 @wraps(method) 367 def new_method(self, *args, **kwargs): 368 [data_type], _ = parse_user_args(method, *args, **kwargs) 369 370 # Check if data_type is mindspore.dtype 371 if isinstance(data_type, (typing.Type,)): 372 return method(self, *args, **kwargs) 373 374 # Special case: Check if data_type is None (which is invalid) 375 if data_type is None: 376 # Use type_check to raise error with descriptive error message 377 type_check(data_type, (typing.Type, np.dtype,), "data_type") 378 379 try: 380 # Check if data_type can be converted to numpy type 381 _ = np.dtype(data_type) 382 except (TypeError, ValueError): 383 # Use type_check to raise error with descriptive error message 384 type_check(data_type, (typing.Type, np.dtype,), "data_type") 385 386 return method(self, *args, **kwargs) 387 388 return new_method 389 390 391def deprecated_c_transforms(substitute_name=None, substitute_module=None): 392 """Decorator for version 1.8 deprecation warning for legacy mindspore.dataset.transforms.c_transforms operation. 393 394 Args: 395 substitute_name (str, optional): The substitute name for deprecated operation. 396 substitute_module (str, optional): The substitute module for deprecated operation. 397 """ 398 return deprecator_factory("1.8", "mindspore.dataset.transforms.c_transforms", "mindspore.dataset.transforms", 399 substitute_name, substitute_module) 400 401 402def deprecated_py_transforms(substitute_name=None, substitute_module=None): 403 """Decorator for version 1.8 deprecation warning for legacy mindspore.dataset.transforms.py_transforms operation. 404 405 Args: 406 substitute_name (str, optional): The substitute name for deprecated operation. 407 substitute_module (str, optional): The substitute module for deprecated operation. 408 """ 409 return deprecator_factory("1.8", "mindspore.dataset.transforms.py_transforms", "mindspore.dataset.transforms", 410 substitute_name, substitute_module) 411