• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Validators for TensorOps.
16"""
17from functools import wraps
18import inspect
19import numpy as np
20
21from mindspore._c_expression import typing
22
23from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
24    check_tensor_op, type_check_list
25
26# POS_INT_MIN is used to limit values from starting from 0
27POS_INT_MIN = 1
28UINT8_MAX = 255
29UINT8_MIN = 0
30UINT32_MAX = 4294967295
31UINT32_MIN = 0
32UINT64_MAX = 18446744073709551615
33UINT64_MIN = 0
34INT32_MAX = 2147483647
35INT32_MIN = -2147483648
36INT64_MAX = 9223372036854775807
37INT64_MIN = -9223372036854775808
38FLOAT_MAX_INTEGER = 16777216
39FLOAT_MIN_INTEGER = -16777216
40DOUBLE_MAX_INTEGER = 9007199254740992
41DOUBLE_MIN_INTEGER = -9007199254740992
42
43
44def check_fill_value(method):
45    """Wrapper method to check the parameters of fill_value."""
46
47    @wraps(method)
48    def new_method(self, *args, **kwargs):
49        [fill_value], _ = parse_user_args(method, *args, **kwargs)
50        type_check(fill_value, (str, float, bool, int, bytes), "fill_value")
51
52        return method(self, *args, **kwargs)
53
54    return new_method
55
56
57def check_one_hot_op(method):
58    """Wrapper method to check the parameters of one_hot_op."""
59
60    @wraps(method)
61    def new_method(self, *args, **kwargs):
62        [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs)
63        type_check(smoothing_rate, (int, float), "smoothing_rate")
64        type_check(num_classes, (int,), "num_classes")
65        check_positive(num_classes)
66
67        if smoothing_rate is not None:
68            check_value(smoothing_rate, [0., 1.], "smoothing_rate")
69
70        return method(self, *args, **kwargs)
71
72    return new_method
73
74
75def check_num_classes(method):
76    """Wrapper method to check the parameters of number of classes."""
77
78    @wraps(method)
79    def new_method(self, *args, **kwargs):
80        [num_classes], _ = parse_user_args(method, *args, **kwargs)
81
82        type_check(num_classes, (int,), "num_classes")
83        check_positive(num_classes)
84
85        return method(self, *args, **kwargs)
86
87    return new_method
88
89
90def check_ms_type(method):
91    """Wrapper method to check the parameters of data type."""
92
93    @wraps(method)
94    def new_method(self, *args, **kwargs):
95        [data_type], _ = parse_user_args(method, *args, **kwargs)
96
97        type_check(data_type, (typing.Type,), "data_type")
98
99        return method(self, *args, **kwargs)
100
101    return new_method
102
103
104def check_slice_option(method):
105    """Wrapper method to check the parameters of SliceOption."""
106
107    @wraps(method)
108    def new_method(self, *args, **kwargs):
109        [slice_option], _ = parse_user_args(method, *args, **kwargs)
110        from .c_transforms import _SliceOption
111        if slice_option is not None:
112            type_check(slice_option, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice_option")
113
114            if isinstance(slice_option, list):
115                type_check_list(slice_option, (int,), "slice_option")
116
117        return method(self, *args, **kwargs)
118
119    return new_method
120
121
122def check_slice_op(method):
123    """Wrapper method to check the parameters of slice."""
124
125    @wraps(method)
126    def new_method(self, *args, **kwargs):
127        [slice_op], _ = parse_user_args(method, *args, **kwargs)
128
129        for s in slice_op:
130            from .c_transforms import _SliceOption
131            if s is not None:
132                type_check(s, (int, list, slice, bool, type(Ellipsis), _SliceOption), "slice")
133                if isinstance(s, list) and s:
134                    if isinstance(s[0], int):
135                        type_check_list(s, (int,), "slice")
136
137        return method(self, *args, **kwargs)
138
139    return new_method
140
141
142def check_mask_op(method):
143    """Wrapper method to check the parameters of mask."""
144
145    @wraps(method)
146    def new_method(self, *args, **kwargs):
147        [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs)
148
149        from .c_transforms import Relational
150        type_check(operator, (Relational,), "operator")
151        type_check(constant, (str, float, bool, int, bytes), "constant")
152        type_check(dtype, (typing.Type,), "dtype")
153
154        return method(self, *args, **kwargs)
155
156    return new_method
157
158
159def check_pad_end(method):
160    """Wrapper method to check the parameters of PadEnd."""
161
162    @wraps(method)
163    def new_method(self, *args, **kwargs):
164
165        [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs)
166
167        if pad_value is not None:
168            type_check(pad_value, (str, float, bool, int, bytes), "pad_value")
169        type_check(pad_shape, (list,), "pad_shape")
170
171        for dim in pad_shape:
172            if dim is not None:
173                if isinstance(dim, int):
174                    check_pos_int64(dim)
175                else:
176                    raise TypeError("a value in the list is not an integer.")
177
178        return method(self, *args, **kwargs)
179
180    return new_method
181
182
183def check_concat_type(method):
184    """Wrapper method to check the parameters of concatenation op."""
185
186    @wraps(method)
187    def new_method(self, *args, **kwargs):
188
189        [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs)
190
191        if axis is not None:
192            type_check(axis, (int,), "axis")
193            if axis not in (0, -1):
194                raise ValueError("only 1D concatenation supported.")
195
196        if prepend is not None:
197            type_check(prepend, (np.ndarray,), "prepend")
198            if len(prepend.shape) != 1:
199                raise ValueError("can only prepend 1D arrays.")
200
201        if append is not None:
202            type_check(append, (np.ndarray,), "append")
203            if len(append.shape) != 1:
204                raise ValueError("can only append 1D arrays.")
205
206        return method(self, *args, **kwargs)
207
208    return new_method
209
210
211def check_random_transform_ops(method):
212    """Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
213
214    @wraps(method)
215    def new_method(self, *args, **kwargs):
216        arg_list, _ = parse_user_args(method, *args, **kwargs)
217        type_check(arg_list[0], (list,), "op_list")
218        if not arg_list[0]:
219            raise ValueError("op_list can not be empty.")
220        for ind, op in enumerate(arg_list[0]):
221            check_tensor_op(op, "op_list[{0}]".format(ind))
222        if len(arg_list) == 2:  # random apply takes an additional arg
223            type_check(arg_list[1], (float, int), "prob")
224            check_value(arg_list[1], (0, 1), "prob")
225        return method(self, *args, **kwargs)
226
227    return new_method
228
229
230def check_compose_list(method):
231    """Wrapper method to check the transform list of Python Compose."""
232
233    @wraps(method)
234    def new_method(self, *args, **kwargs):
235        [transforms], _ = parse_user_args(method, *args, **kwargs)
236
237        type_check(transforms, (list,), transforms)
238        if not transforms:
239            raise ValueError("transforms list is empty.")
240        for i, transform in enumerate(transforms):
241            if not callable(transform):
242                raise ValueError("transforms[{}] is not callable.".format(i))
243        return method(self, *args, **kwargs)
244
245    return new_method
246
247
248def check_compose_call(method):
249    """Wrapper method to check the transform list of Compose."""
250
251    @wraps(method)
252    def new_method(self, *args, **kwargs):
253        sig = inspect.signature(method)
254        ba = sig.bind_partial(method, *args, **kwargs)
255        img = ba.arguments.get("args")
256        if img is None:
257            raise TypeError(
258                "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])()).")
259        return method(self, *args, **kwargs)
260
261    return new_method
262
263
264def check_random_apply(method):
265    """Wrapper method to check the parameters of random apply."""
266
267    @wraps(method)
268    def new_method(self, *args, **kwargs):
269        [transforms, prob], _ = parse_user_args(method, *args, **kwargs)
270        type_check(transforms, (list,), "transforms")
271
272        for i, transform in enumerate(transforms):
273            if str(transform).find("c_transform") >= 0:
274                raise ValueError(
275                    "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
276                        .format(i))
277
278        if prob is not None:
279            type_check(prob, (float, int,), "prob")
280            check_value(prob, [0., 1.], "prob")
281
282        return method(self, *args, **kwargs)
283
284    return new_method
285
286
287def check_transforms_list(method):
288    """Wrapper method to check the parameters of transform list."""
289
290    @wraps(method)
291    def new_method(self, *args, **kwargs):
292        [transforms], _ = parse_user_args(method, *args, **kwargs)
293
294        type_check(transforms, (list,), "transforms")
295        for i, transform in enumerate(transforms):
296            if str(transform).find("c_transform") >= 0:
297                raise ValueError(
298                    "transforms[{}] is not a py transforms. Should not use a c transform in py transform" \
299                        .format(i))
300        return method(self, *args, **kwargs)
301
302    return new_method
303
304
305def check_plugin(method):
306    """Wrapper method to check the parameters of plugin."""
307
308    @wraps(method)
309    def new_method(self, *args, **kwargs):
310        [lib_path, func_name, user_args], _ = parse_user_args(method, *args, **kwargs)
311
312        type_check(lib_path, (str,), "lib_path")
313        type_check(func_name, (str,), "func_name")
314        if user_args is not None:
315            type_check(user_args, (str,), "user_args")
316
317        return method(self, *args, **kwargs)
318
319    return new_method
320