• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Validators for data processing operations.
16"""
17from functools import wraps
18import inspect
19import numpy as np
20
21from mindspore._c_expression import typing
22from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
23    check_tensor_op, type_check_list, deprecator_factory
24
25# POS_INT_MIN is used to limit values from starting from 0
26POS_INT_MIN = 1
27UINT8_MAX = 255
28UINT8_MIN = 0
29UINT32_MAX = 4294967295
30UINT32_MIN = 0
31UINT64_MAX = 18446744073709551615
32UINT64_MIN = 0
33INT32_MAX = 2147483647
34INT32_MIN = -2147483648
35INT64_MAX = 9223372036854775807
36INT64_MIN = -9223372036854775808
37FLOAT_MAX_INTEGER = 16777216
38FLOAT_MIN_INTEGER = -16777216
39DOUBLE_MAX_INTEGER = 9007199254740992
40DOUBLE_MIN_INTEGER = -9007199254740992
41
42
43def check_fill_value(method):
44    """Wrapper method to check the parameters of fill_value."""
45
46    @wraps(method)
47    def new_method(self, *args, **kwargs):
48        [fill_value], _ = parse_user_args(method, *args, **kwargs)
49        type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
50
51        return method(self, *args, **kwargs)
52
53    return new_method
54
55
56def check_one_hot_op(method):
57    """Wrapper method to check the parameters of one_hot_op."""
58
59    @wraps(method)
60    def new_method(self, *args, **kwargs):
61        [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
62        type_check(smoothing_rate, (int, float), "smoothing_rate")
63        type_check(num_classes, (int,), "num_classes")
64        check_positive(num_classes)
65
66        if smoothing_rate is not None:
67            check_value(smoothing_rate, [0., 1.], "smoothing_rate")
68
69        return method(self, *args, **kwargs)
70
71    return new_method
72
73
74def check_num_classes(method):
75    """Wrapper method to check the parameters of number of classes."""
76
77    @wraps(method)
78    def new_method(self, *args, **kwargs):
79        [num_classes], _ = parse_user_args(method, *args, **kwargs)
80
81        type_check(num_classes, (int,), "num_classes")
82        check_positive(num_classes)
83
84        return method(self, *args, **kwargs)
85
86    return new_method
87
88
89def check_ms_type(method):
90    """Wrapper method to check the parameters of data type."""
91
92    @wraps(method)
93    def new_method(self, *args, **kwargs):
94        [data_type], _ = parse_user_args(method, *args, **kwargs)
95
96        type_check(data_type, (typing.Type,), "data_type")
97
98        return method(self, *args, **kwargs)
99
100    return new_method
101
102
103def check_slice_option(method):
104    """Wrapper method to check the parameters of SliceOption."""
105
106    @wraps(method)
107    def new_method(self, *args, **kwargs):
108        [slice_option], _ = parse_user_args(method, *args, **kwargs)
109        from .transforms import _SliceOption
110        if slice_option is not None:
111            type_check(slice_option, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice_option")
112
113            if isinstance(slice_option, list):
114                type_check_list(slice_option, (int,), "slice_option")
115
116        return method(self, *args, **kwargs)
117
118    return new_method
119
120
121def check_slice_op(method):
122    """Wrapper method to check the parameters of slice."""
123
124    @wraps(method)
125    def new_method(self, *args, **kwargs):
126        [slice_op], _ = parse_user_args(method, *args, **kwargs)
127
128        for s in slice_op:
129            from .transforms import _SliceOption
130            if s is not None:
131                type_check(s, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice")
132                if isinstance(s, list) and s:
133                    if isinstance(s[0], int):
134                        type_check_list(s, (int,), "slice")
135
136        return method(self, *args, **kwargs)
137
138    return new_method
139
140
141def check_mask_op(method):
142    """Wrapper method to check the parameters of mask."""
143
144    @wraps(method)
145    def new_method(self, *args, **kwargs):
146        [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
147
148        from .c_transforms import Relational
149        type_check(operator, (Relational,), "operator")
150        type_check(constant, (str, float, bool, int, bytes), "constant")
151        type_check(dtype, (typing.Type,), "dtype")
152
153        return method(self, *args, **kwargs)
154
155    return new_method
156
157
158def check_mask_op_new(method):
159    """Wrapper method to check the parameters of mask."""
160
161    @wraps(method)
162    def new_method(self, *args, **kwargs):
163        [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
164
165        from .transforms import Relational
166        type_check(operator, (Relational,), "operator")
167        type_check(constant, (str, float, bool, int, bytes), "constant")
168        type_check(dtype, (typing.Type,), "dtype")
169
170        return method(self, *args, **kwargs)
171
172    return new_method
173
174
175def check_pad_end(method):
176    """Wrapper method to check the parameters of PadEnd."""
177
178    @wraps(method)
179    def new_method(self, *args, **kwargs):
180
181        [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
182
183        if pad_value is not None:
184            type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
185        type_check(pad_shape, (list,), "pad_shape")
186
187        for dim in pad_shape:
188            if dim is not None:
189                if isinstance(dim, int):
190                    check_pos_int64(dim)
191                else:
192                    raise TypeError("a value in the list is not an integer.")
193
194        return method(self, *args, **kwargs)
195
196    return new_method
197
198
199def check_concat_type(method):
200    """Wrapper method to check the parameters of concatenation op."""
201
202    @wraps(method)
203    def new_method(self, *args, **kwargs):
204
205        [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
206
207        if axis is not None:
208            type_check(axis, (int,), "axis")
209            if axis not in (0, -1):
210                raise ValueError("only 1D concatenation supported.")
211
212        if prepend is not None:
213            type_check(prepend, (np.ndarray,), "prepend")
214            if len(prepend.shape) != 1:
215                raise ValueError("can only prepend 1D arrays.")
216
217        if append is not None:
218            type_check(append, (np.ndarray,), "append")
219            if len(append.shape) != 1:
220                raise ValueError("can only append 1D arrays.")
221
222        return method(self, *args, **kwargs)
223
224    return new_method
225
226
227def check_random_transform_ops(method):
228    """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
229
230    @wraps(method)
231    def new_method(self, *args, **kwargs):
232        arg_list, _ = parse_user_args(method, *args, **kwargs)
233        type_check(arg_list[0], (list,), "transforms list")
234        if not arg_list[0]:
235            raise ValueError("transforms list can not be empty.")
236        for ind, op in enumerate(arg_list[0]):
237            check_tensor_op(op, "transforms[{0}]".format(ind))
238            check_transform_op_type(ind, op)
239        if len(arg_list) == 2:  # random apply takes an additional arg
240            type_check(arg_list[1], (float, int), "prob")
241            check_value(arg_list[1], (0, 1), "prob")
242        return method(self, *args, **kwargs)
243
244    return new_method
245
246
247def check_transform_op_type(ind, op):
248    """Check the operation."""
249    # c_vision.HWC2CHW error
250    # py_vision.HWC2CHW error
251    if type(op) == type:  # pylint: disable=unidiomatic-typecheck
252        raise ValueError("op_list[{}] should be a dataset processing operation instance, "
253                         "but got: {}. It may be missing parentheses for instantiation.".format(ind, op))
254
255
256def check_compose_list(method):
257    """Wrapper method to check the transform list of Python Compose."""
258
259    @wraps(method)
260    def new_method(self, *args, **kwargs):
261        [transforms], _ = parse_user_args(method, *args, **kwargs)
262
263        type_check(transforms, (list,), transforms)
264        if not transforms:
265            raise ValueError("transforms list is empty.")
266        for i, transform in enumerate(transforms):
267            if not callable(transform):
268                raise ValueError("transforms[{}] is not callable.".format(i))
269            check_transform_op_type(i, transform)
270        return method(self, *args, **kwargs)
271
272    return new_method
273
274
275def check_compose_call(method):
276    """Wrapper method to check the transform list of Compose."""
277
278    @wraps(method)
279    def new_method(self, *args, **kwargs):
280        sig = inspect.signature(method)
281        ba = sig.bind_partial(method, *args, **kwargs)
282        img = ba.arguments.get("args")
283        if img is None:
284            raise TypeError(
285                "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).")
286        return method(self, *args, **kwargs)
287
288    return new_method
289
290
291def check_random_apply(method):
292    """Wrapper method to check the parameters of random apply."""
293
294    @wraps(method)
295    def new_method(self, *args, **kwargs):
296        [transforms, prob], _ = parse_user_args(method, *args, **kwargs)
297        type_check(transforms, (list,), "transforms")
298
299        for i, transform in enumerate(transforms):
300            if str(transform).find("c_transform") >= 0:
301                raise ValueError(
302                    "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
303                        .format(i))
304            check_transform_op_type(i, transform)
305
306        if prob is not None:
307            type_check(prob, (float, int,), "prob")
308            check_value(prob, [0., 1.], "prob")
309
310        return method(self, *args, **kwargs)
311
312    return new_method
313
314
315def check_transforms_list(method):
316    """Wrapper method to check the parameters of transform list."""
317
318    @wraps(method)
319    def new_method(self, *args, **kwargs):
320        [transforms], _ = parse_user_args(method, *args, **kwargs)
321
322        type_check(transforms, (list,), "transforms")
323        for i, transform in enumerate(transforms):
324            if str(transform).find("c_transform") >= 0:
325                raise ValueError(
326                    "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
327                        .format(i))
328            check_transform_op_type(i, transform)
329        return method(self, *args, **kwargs)
330
331    return new_method
332
333
334def check_plugin(method):
335    """Wrapper method to check the parameters of plugin."""
336
337    @wraps(method)
338    def new_method(self, *args, **kwargs):
339        [lib_path, func_name, user_args], _ = parse_user_args(method, *args, **kwargs)
340
341        type_check(lib_path, (str,), "lib_path")
342        type_check(func_name, (str,), "func_name")
343        if user_args is not None:
344            type_check(user_args, (str,), "user_args")
345
346        return method(self, *args, **kwargs)
347
348    return new_method
349
350
351def invalidate_callable(method):
352    """Wrapper method to invalidate cached callable_op_ used in eager mode. \
353    This decorator must be added to any method which modifies the state of transform."""
354
355    @wraps(method)
356    def new_method(self, *args, **kwargs):
357        self.callable_op_ = None
358        return method(self, *args, **kwargs)
359
360    return new_method
361
362
363def check_type_cast(method):
364    """Wrapper method to check the parameters of TypeCast."""
365
366    @wraps(method)
367    def new_method(self, *args, **kwargs):
368        [data_type], _ = parse_user_args(method, *args, **kwargs)
369
370        # Check if data_type is mindspore.dtype
371        if isinstance(data_type, (typing.Type,)):
372            return method(self, *args, **kwargs)
373
374        # Special case: Check if data_type is None (which is invalid)
375        if data_type is None:
376            # Use type_check to raise error with descriptive error message
377            type_check(data_type, (typing.Type, np.dtype,), "data_type")
378
379        try:
380            # Check if data_type can be converted to numpy type
381            _ = np.dtype(data_type)
382        except (TypeError, ValueError):
383            # Use type_check to raise error with descriptive error message
384            type_check(data_type, (typing.Type, np.dtype,), "data_type")
385
386        return method(self, *args, **kwargs)
387
388    return new_method
389
390
391def deprecated_c_transforms(substitute_name=None, substitute_module=None):
392    """Decorator for version 1.8 deprecation warning for legacy mindspore.dataset.transforms.c_transforms operation.
393
394    Args:
395        substitute_name (str, optional): The substitute name for deprecated operation.
396        substitute_module (str, optional): The substitute module for deprecated operation.
397    """
398    return deprecator_factory("1.8", "mindspore.dataset.transforms.c_transforms", "mindspore.dataset.transforms",
399                              substitute_name, substitute_module)
400
401
402def deprecated_py_transforms(substitute_name=None, substitute_module=None):
403    """Decorator for version 1.8 deprecation warning for legacy mindspore.dataset.transforms.py_transforms operation.
404
405    Args:
406        substitute_name (str, optional): The substitute name for deprecated operation.
407        substitute_module (str, optional): The substitute module for deprecated operation.
408    """
409    return deprecator_factory("1.8", "mindspore.dataset.transforms.py_transforms", "mindspore.dataset.transforms",
410                              substitute_name, substitute_module)
411