• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""internal graph-compatible utility functions"""
16import math
17from itertools import zip_longest, accumulate
18from collections import deque
19import operator
20
21import mindspore.context as context
22from ..ops import functional as F
23from ..ops.primitive import constexpr
24from ..common import dtype as mstype
25from ..common import Tensor
26from .._c_expression import Tensor as Tensor_
27from .._c_expression import typing
28from .._checkparam import Validator as validator
29
30from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
31
32
33_check_axis_type = constexpr(validator.check_axis_type)
34
35
36@constexpr
37def _check_shape(shape):
38    """check the shape param to match the numpy style"""
39    if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)):
40        raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}")
41    if isinstance(shape, int):
42        shape = (shape,)
43    if isinstance(shape, (list, typing.List)):
44        shape = tuple(shape)
45    for s in shape:
46        if not isinstance(s, int):
47            raise TypeError("each entry in shape should be int.")
48        if s < 0:
49            raise ValueError("each entry in shape should no less than 0.")
50    return shape
51
52
53@constexpr
54def _check_dtype(dtype):
55    """check the input dtype and make conversions"""
56    # convert the string dtype to mstype.dtype
57    if isinstance(dtype, str):
58        dtype = dtype.lower()
59        dtype = dtype_map[dtype]
60    elif isinstance(dtype, type):
61        if dtype is int:
62            dtype = mstype.int32
63        elif dtype is float:
64            dtype = mstype.float32
65        else:
66            dtype = mstype.pytype_to_dtype(dtype)
67    if dtype not in dtype_tuple:
68        raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}")
69    return dtype
70
71
72@constexpr
73def _is_shape_empty(shp):
74    """Check whether shape contains zero"""
75    if isinstance(shp, int):
76        return shp == 0
77    return F.shape_mul(shp) == 0
78
79
80@constexpr
81def _check_start_normalize(start, ndim):
82    """check and normalize start argument for rollaxis."""
83    if start < -ndim or start > ndim:
84        raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
85    if start < 0:
86        start = start + ndim
87    return start
88
89
90@constexpr
91def _check_axes_range(axes, ndim):
92    """
93    Check axes type and normalize the negative axes.
94
95    Args:
96        axes: Axes of the tensor.
97        ndim (int): The number of dimensions of the tensor.
98
99    Return:
100        Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
101
102    Raises:
103        TypeError: If the axes are not integer, tuple(int) or list(int).
104        ValueError: If duplicate axes exists or some axis is out of bounds.
105    """
106    _check_axis_type(axes, True, True, True)
107    if isinstance(axes, (list, tuple)):
108        _check_element_int(axes)
109    axes = _canonicalize_axis(axes, ndim)
110    return axes
111
112
113@constexpr
114def _get_device():
115    """Get the current device (`GPU`, `CPU`, `Ascend`)"""
116    return context.get_context('device_target')
117
118
119@constexpr
120def _infer_out_shape(*shapes):
121    """
122    Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
123    """
124    shape_out = deque()
125    reversed_shapes = map(reversed, shapes)
126    for items in zip_longest(*reversed_shapes, fillvalue=1):
127        max_size = 0 if 0 in items else max(items)
128        if any(item not in (1, max_size) for item in items):
129            raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
130        shape_out.appendleft(max_size)
131    return tuple(shape_out)
132
133
134@constexpr
135def _can_broadcast(*shapes):
136    """
137    Returns Ture if shapes can broadcast, False if they cannot.
138    """
139    try:
140        _infer_out_shape(*shapes)
141    except ValueError:
142        return False
143    finally:
144        pass
145    return True
146
147
148@constexpr
149def _check_axis_in_range(axis, ndim):
150    """Checks axes are with the bounds of ndim"""
151    if not isinstance(axis, int):
152        raise TypeError(f'axes should be integers, not {type(axis)}')
153    if not -ndim <= axis < ndim:
154        raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
155    return axis % ndim
156
157
158@constexpr
159def _check_axis_valid(axes, ndim):
160    """
161    Checks axes are valid given ndim, and returns axes that can be passed
162    to the built-in operator (non-negative, int or tuple)
163    """
164    if axes is None:
165        axes = F.make_range(ndim)
166        return axes
167    if isinstance(axes, (tuple, list)):
168        axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
169        if any(axes.count(el) > 1 for el in axes):
170            raise ValueError('duplicate value in "axis"')
171        return axes
172    return (_check_axis_in_range(axes, ndim),)
173
174
175@constexpr
176def _check_shape_aligned(shape1, shape2):
177    """Checks shape1 and shape2 are valid shapes to perform inner product"""
178    if shape1[-1] != shape2[-1]:
179        raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
180
181
182@constexpr
183def _tile_size(shape, out_shape, ndim):
184    """Returns tile_size such that shape*tile_size = out_shape"""
185    size = [1]*ndim
186    for idx, (i, j) in enumerate(zip(shape, out_shape)):
187        if i != j:
188            size[idx] = j
189    return tuple(size)
190
191
192@constexpr
193def _raise_type_error(info, param=None):
194    """
195    Raise TypeError in both graph/pynative mode
196
197    Args:
198        info(str): info string to display
199        param(python obj): any object that can be recognized by graph mode. If is
200            not None, then param's type information will be extracted and displayed.
201            Default is None.
202    """
203    if param is None:
204        raise TypeError(info)
205    raise TypeError(info + f"{type(param)}")
206
207
208@constexpr
209def _raise_value_error(info, param=None):
210    """
211    Raise TypeError in both graph/pynative mode
212
213    Args:
214        info(str): info string to display
215        param(python obj): any object that can be recognized by graph mode. If is
216            not None, then param's value information will be extracted and displayed.
217            Default is None.
218    """
219    if param is None:
220        raise ValueError(info)
221    raise ValueError(info + f"{param}")
222
223
224@constexpr
225def _raise_runtime_error(info, param=None):
226    """
227    Raise RuntimeError in both graph/pynative mode
228
229    Args:
230        info(str): info string to display
231        param(python obj): any object that can be recognized by graph mode. If is
232            not None, then param's value information will be extracted and displayed.
233            Default is None.
234    """
235    if param is None:
236        raise RuntimeError(info)
237    raise RuntimeError(info + f"{param}")
238
239
240@constexpr
241def _raise_unimplemented_error(info, param=None):
242    """
243    Raise NotImplementedError in both graph/pynative mode
244
245    Args:
246        info(str): info string to display
247        param(python obj): any object that can be recognized by graph mode. If is
248            not None, then param's value information will be extracted and displayed.
249            Default is None.
250    """
251    if param is None:
252        raise NotImplementedError(info)
253    raise NotImplementedError(info + f"{param}")
254
255
256@constexpr
257def _empty(dtype, shape):
258    """Returns an uninitialized array with dtype and shape."""
259    return Tensor_(dtype, shape)
260
261
262@constexpr
263def _promote(dtype1, dtype2):
264    if dtype1 == dtype2:
265        return dtype1
266    if (dtype1, dtype2) in promotion_rule:
267        return promotion_rule[dtype1, dtype2]
268    return promotion_rule[dtype2, dtype1]
269
270
271@constexpr
272def _promote_for_trigonometric(dtype):
273    return rule_for_trigonometric[dtype]
274
275
276@constexpr
277def _max(*args):
278    """Returns the maximum value."""
279    return max(*args)
280
281
282@constexpr
283def _min(*args):
284    """"Returns the minimum value."""
285    return min(*args)
286
287
288@constexpr
289def _abs(arg):
290    """Returns the absolute value."""
291    return abs(arg)
292
293
294@constexpr
295def _check_same_type(dtype1, dtype2):
296    return dtype1 == dtype2
297
298
299@constexpr
300def _check_is_float(dtype):
301    """Returns whether dtype is float16 or float32."""
302    return dtype in (mstype.float16, mstype.float32)
303
304
305@constexpr
306def _check_is_int(dtype):
307    return isinstance(dtype, typing.Int)
308
309
310@constexpr
311def _canonicalize_axis(axis, ndim):
312    """
313    Check axes are within the number of dimensions of tensor x and normalize the negative axes.
314    Args:
315        axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
316        ndim (int): The number of dimensions of the tensor.
317    Return:
318        Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
319    """
320    if isinstance(axis, int):
321        axis = [axis]
322    for ax in axis:
323        _check_axis_in_range(ax, ndim)
324
325    def canonicalizer(ax):
326        return ax + ndim if ax < 0 else ax
327
328    axis = tuple([canonicalizer(axis) for axis in axis])
329    if all(axis.count(el) <= 1 for el in axis):
330        return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
331    raise ValueError(f"duplicate axes in {axis}.")
332
333
334@constexpr
335def _broadcast_tuples(tup1, tup2):
336    """
337    Broadcast two 1D tuples to the same length, if inputs are ints, convert to
338    tuples first.
339    """
340    tup1 = (tup1,) if isinstance(tup1, int) else tup1
341    tup2 = (tup2,) if isinstance(tup2, int) else tup2
342    if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
343        raise TypeError("input shift and axis must be tuple or list or int.")
344    if len(tup1) == len(tup2):
345        return tup1, tup2
346    if len(tup1) == 1:
347        tup1 *= len(tup2)
348    elif len(tup2) == 1:
349        tup2 *= len(tup1)
350    else:
351        raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
352    return tup1, tup2
353
354
355@constexpr
356def _expanded_shape(ndim, axis_size, axis):
357    """
358    Returns a shape with size = 1 for all dimensions
359    except at axis.
360    """
361    return tuple([axis_size if i == axis else 1 for i in range(ndim)])
362
363
364@constexpr
365def _add_unit_axes(shape, ndim, append=False):
366    """
367    Prepends shape with 1s so that it has the number of dimensions ndim.
368    If append is set to True, returns shape appended with 1s instead.
369    """
370    if isinstance(shape, int):
371        shape = (shape,)
372    ndim_diff = ndim - len(shape)
373    if ndim_diff > 0:
374        if append:
375            shape = [i for i in shape] + [1]*ndim_diff
376        else:
377            shape = [1]*ndim_diff + [i for i in shape]
378    return tuple(shape)
379
380
381@constexpr
382def  _check_element_int(lst):
383    """
384    Check whether each element in `lst` is an integer.
385    """
386    for item in lst:
387        if not isinstance(item, int):
388            raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.")
389    return True
390
391
392@constexpr
393def _type_convert(force, obj):
394    """
395    Convert type of `obj` to `force`.
396    """
397    return force(obj)
398
399
400@constexpr
401def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
402    """
403    Generates a new list/tuple by list comprehension.
404
405    Args:
406        obj (Union[int, list, tuple]):
407            If integer, it will be the length of the returned tuple/list.
408        item: The value to be filled. Default: None.
409            If None, the values in the new list/tuple are the same as obj
410            or range(obj) when obj is integer.
411        return_tuple(bool): If true, returns tuple, else returns list.
412
413    Returns:
414        List or tuple.
415    """
416    res = []
417    lst = obj
418    if isinstance(obj, int):
419        lst = range(obj)
420    if make_none:
421        res = [None for _ in lst]
422    elif item is None:
423        res = [i for i in lst]
424    else:
425        res = [item for i in lst]
426    if return_tuple:
427        return tuple(res)
428    return res
429
430
431@constexpr
432def _tuple_setitem(tup, idx, value):
433    """
434    Returns a tuple with specified `idx` set to `value`.
435    """
436    tup = list(tup)
437    tup[idx] = value
438    return tuple(tup)
439
440
441@constexpr
442def _iota(dtype, num, increasing=True):
443    """Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
444    # Change to P.Linspace when the kernel is implemented on CPU.
445    if num <= 0:
446        raise ValueError("zero shape Tensor is not currently supported.")
447    if increasing:
448        return Tensor(list(range(int(num))), dtype)
449    return Tensor(list(range(int(num)-1, -1, -1)), dtype)
450
451
452@constexpr
453def _ceil(number):
454    """Ceils the number in graph mode."""
455    return math.ceil(number)
456
457
458@constexpr
459def _seq_prod(seq1, seq2):
460    """Returns the element-wise product of seq1 and seq2."""
461    return tuple(map(lambda x, y: x*y, seq1, seq2))
462
463
464@constexpr
465def _make_tensor(val, dtype):
466    """Returns the tensor with value `val` and dtype `dtype`."""
467    return Tensor(val, dtype)
468
469
470@constexpr
471def _tuple_slice(tup, start, end):
472    """get sliced tuple from start and end."""
473    return tup[start:end]
474
475
476@constexpr
477def _isscalar(x):
478    """Returns True if x is a scalar type"""
479    return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
480                          typing.Bool, typing.String))
481
482
483@constexpr
484def _cumprod(x):
485    return tuple(accumulate(x, operator.mul))
486
487
488@constexpr
489def _in(x, y):
490    return x in y
491
492
493@constexpr
494def _callable_const(x):
495    """Returns true if x is a function in graph mode."""
496    return isinstance(x, typing.Function)
497
498
499@constexpr
500def _check_is_inf(x, negative=False):
501    if not negative:
502        return x == float('inf')
503    return x == float('-inf')
504