• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""internal graph-compatible utility functions"""
16from __future__ import absolute_import
17
18import math
19from itertools import accumulate
20import operator
21
22import mindspore.context as context
23from mindspore.ops import functional as F
24from mindspore.ops.primitive import constexpr
25from mindspore.ops.primitive import _primexpr
26from mindspore.common import dtype as mstype
27from mindspore.common import Tensor
28from mindspore._c_expression import Tensor as Tensor_
29from mindspore._c_expression import typing
30from mindspore import _checkparam as validator
31
32from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
33
34_check_axis_type = constexpr(validator.check_axis_type)
35
36
37@_primexpr
38def _check_shape(shape):
39    """check the shape param to match the numpy style"""
40    # convert tensor to int/list, use followed if statements to do further conversions
41    def _check(shape):
42        for s in shape:
43            if not isinstance(s, int):
44                raise TypeError("each entry in shape should be int.")
45            if s < 0:
46                raise ValueError("each entry in shape should no less than 0.")
47
48    if not isinstance(shape, (int, tuple, list, Tensor, typing.Tuple, typing.List)):
49        raise TypeError(f"only int, tuple, list and tensor are allowed for shape, but got {type(shape)}")
50    if isinstance(shape, Tensor):
51        shape = shape.asnumpy().tolist()
52        # this may return an int, don't change the next if to elif
53    if isinstance(shape, int):
54        shape = (shape,)
55    elif isinstance(shape, (list, typing.List)):
56        shape = tuple(shape)
57    _check(shape)
58    return shape
59
60
61@constexpr
62def _check_dtype(dtype):
63    """check the input dtype and make conversions"""
64    # convert the string dtype to mstype.dtype
65    if isinstance(dtype, str):
66        dtype = dtype.lower()
67        dtype = dtype_map.get(dtype, "")
68    elif isinstance(dtype, type):
69        if dtype is int:
70            dtype = mstype.int32
71        elif dtype is float:
72            dtype = mstype.float32
73        else:
74            dtype = mstype.pytype_to_dtype(dtype)
75    if dtype not in dtype_tuple:
76        raise TypeError(f"only {all_types} are allowed for dtype, but got {type(dtype)}")
77    return dtype
78
79
80@_primexpr
81def _is_shape_empty(shp):
82    """Check whether shape contains zero"""
83    if F.is_sequence_shape_unknown(shp):
84        return False
85    if isinstance(shp, int):
86        return shp == 0
87    if isinstance(shp, (tuple, list)):
88        return 0 in shp
89    return F.shape_mul(shp) == 0
90
91
92@_primexpr
93def _check_start_normalize(start, ndim):
94    """check and normalize start argument for rollaxis."""
95    def _check():
96        if start < -ndim or start > ndim:
97            raise ValueError(f"For rollaxis, start {start} is out of bounds."
98                             f"Ranging from {-ndim} to {ndim} is allowed.")
99    _check()
100    if start < 0:
101        start = start + ndim
102    return start
103
104
105@_primexpr
106def _check_axes_range(axes, ndim):
107    """
108    Check axes type and normalize the negative axes.
109
110    Args:
111        axes: Axes of the tensor.
112        ndim (int): The number of dimensions of the tensor.
113
114    Return:
115        Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
116
117    Raises:
118        TypeError: If the axes are not integer, tuple(int) or list(int).
119        ValueError: If duplicate axes exists or some axis is out of bounds.
120    """
121    _check_axis_type(axes, True, True, True)
122    if isinstance(axes, (list, tuple)):
123        _check_element_int(axes)
124    axes = _canonicalize_axis(axes, ndim)
125    return axes
126
127
128@constexpr
129def _get_device():
130    """Get the current device (`GPU`, `CPU`, `Ascend`)"""
131    return context.get_context('device_target')
132
133
134@_primexpr
135def _infer_out_shape(*shapes):
136    """
137    Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
138    """
139    def _check():
140        if any(item not in (1, max_size) for item in items):
141            raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
142
143    shape_out = list()
144    max_len = max([len(it) for it in shapes])
145    for i in range(max_len):
146        items = [
147            it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
148        max_size = 0 if 0 in items else max(items)
149        _check()
150        shape_out.append(max_size)
151    return tuple(shape_out)
152
153
154@_primexpr
155def _can_broadcast(*shapes):
156    """
157    Returns Ture if shapes can broadcast, False if they cannot.
158    """
159    max_len = max([len(it) for it in shapes])
160    for i in range(max_len):
161        items = [
162            it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
163        max_size = 0 if 0 in items else max(items)
164        if any(item not in (1, max_size) for item in items):
165            return False
166    return True
167
168
169@_primexpr
170def _check_axis_in_range(axis, ndim):
171    """Checks axes are with the bounds of ndim"""
172    def _check():
173        if not isinstance(axis, int):
174            raise TypeError(f'axes should be integers, not {type(axis)}')
175        if not -ndim <= axis < ndim:
176            raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
177    _check()
178    return axis - axis // ndim * ndim
179
180
181@_primexpr
182def _check_axis_valid(axes, ndim):
183    """
184    Checks axes are valid given ndim, and returns axes that can be passed
185    to the built-in operator (non-negative, int or tuple)
186    """
187    def _check(axes):
188        if any(axes.count(el) > 1 for el in axes):
189            raise ValueError('duplicate value in "axis"')
190
191    if axes is None:
192        axes = F.make_range(ndim)
193        return axes
194    if isinstance(axes, (tuple, list)):
195        axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
196        _check(axes)
197        return axes
198    return (_check_axis_in_range(axes, ndim),)
199
200
201@_primexpr
202def _check_shape_aligned(shape1, shape2):
203    """Checks shape1 and shape2 are valid shapes to perform inner product"""
204    if shape1[-1] != shape2[-1]:
205        raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
206
207
208@_primexpr
209def _tile_size(shape, out_shape, ndim):
210    """Returns tile_size such that shape*tile_size = out_shape"""
211    size = [1]*ndim
212    for idx, (i, j) in enumerate(zip(shape, out_shape)):
213        if i != j:
214            size[idx] = j
215    return tuple(size)
216
217
218@constexpr
219def _raise_type_error(info, param=None):
220    """
221    Raise TypeError in both graph/pynative mode
222
223    Args:
224        info(str): info string to display
225        param(python obj): any object that can be recognized by graph mode. If is
226            not None, then param's type information will be extracted and displayed.
227            Default is None.
228    """
229    if param is None:
230        raise TypeError(info)
231    raise TypeError(info + f"{type(param)}")
232
233
234@_primexpr
235def _raise_value_error(info, param=None):
236    """
237    Raise TypeError in both graph/pynative mode
238
239    Args:
240        info(str): info string to display
241        param(python obj): any object that can be recognized by graph mode. If is
242            not None, then param's value information will be extracted and displayed.
243            Default is None.
244    """
245    if param is None:
246        raise ValueError(info)
247    raise ValueError(info + f"{param}")
248
249
250@constexpr
251def _raise_runtime_error(info, param=None):
252    """
253    Raise RuntimeError in both graph/pynative mode
254
255    Args:
256        info(str): info string to display
257        param(python obj): any object that can be recognized by graph mode. If is
258            not None, then param's value information will be extracted and displayed.
259            Default is None.
260    """
261    if param is None:
262        raise RuntimeError(info)
263    raise RuntimeError(info + f"{param}")
264
265
266@constexpr
267def _raise_unimplemented_error(info, param=None):
268    """
269    Raise NotImplementedError in both graph/pynative mode
270
271    Args:
272        info(str): info string to display
273        param(python obj): any object that can be recognized by graph mode. If is
274            not None, then param's value information will be extracted and displayed.
275            Default is None.
276    """
277    if param is None:
278        raise NotImplementedError(info)
279    raise NotImplementedError(info + f"{param}")
280
281
282@_primexpr
283def _empty(dtype, shape):
284    """Returns an uninitialized array with dtype and shape."""
285    return Tensor_(dtype, shape)
286
287
288@constexpr
289def _promote(dtype1, dtype2):
290    if dtype1 == dtype2:
291        return dtype1
292    if (dtype1, dtype2) in promotion_rule:
293        return promotion_rule.get((dtype1, dtype2))
294    res = promotion_rule.get((dtype2, dtype1))
295    if res is None:
296        raise TypeError(f"input dtype: {dtype1}, {dtype2} are not all mindspore datatypes.")
297    return res
298
299
300@constexpr
301def _promote_for_trigonometric(dtype):
302    res_dtype = rule_for_trigonometric.get(dtype)
303    if res_dtype is None:
304        raise TypeError(f"input dtype: {dtype} is not a mindspore datatype.")
305    return res_dtype
306
307
308@_primexpr
309def _max(*args):
310    """Returns the maximum value."""
311    return max(*args)
312
313
314@_primexpr
315def _min(*args):
316    """"Returns the minimum value."""
317    return min(*args)
318
319
320@constexpr
321def _abs(arg):
322    """Returns the absolute value."""
323    return abs(arg)
324
325
326@constexpr
327def _check_same_type(dtype1, dtype2):
328    return dtype1 == dtype2
329
330
331@constexpr
332def _check_is_float(dtype):
333    """Returns whether dtype is float16 or float32."""
334    return dtype in (mstype.float16, mstype.float32)
335
336
337@constexpr
338def _check_is_int(dtype):
339    return isinstance(dtype, typing.Int)
340
341
342@_primexpr
343def _canonicalize_axis(axis, ndim):
344    """
345    Check axes are within the number of dimensions of tensor x and normalize the negative axes.
346    Args:
347        axis (Union[int, tuple(int), list(int)]): Axes of the tensor.
348        ndim (int): The number of dimensions of the tensor.
349    Return:
350        Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple.
351    """
352    def _check():
353        if not all(axis.count(el) <= 1 for el in axis):
354            raise ValueError(f"duplicate axes in {axis}.")
355
356    if isinstance(axis, int):
357        axis = [axis]
358    for ax in axis:
359        _check_axis_in_range(ax, ndim)
360
361    def canonicalizer(ax):
362        return ax + ndim if ax < 0 else ax
363
364    def _sort_axis(ax):
365        def merge(left, right):
366            result = ()
367            i = j = 0
368            while i < len(left) and j < len(right):
369                if left[i] <= right[j]:
370                    result += (left[i],)
371                    i += 1
372                else:
373                    result += (right[j],)
374                    j += 1
375
376            return result + left[i:] + right[j:]
377
378        if len(ax) <= 1:
379            return ax
380
381        middle = len(ax) // 2
382        left = _sort_axis(ax[:middle])
383        right = _sort_axis(ax[middle:])
384
385        return merge(left, right)
386
387    axis = tuple([canonicalizer(axis) for axis in axis])
388    _check()
389    return tuple(_sort_axis(axis)) if len(axis) > 1 else axis[0]
390
391
392@constexpr
393def _broadcast_tuples(tup1, tup2):
394    """
395    Broadcast two 1D tuples to the same length, if inputs are ints, convert to
396    tuples first.
397    """
398    def _check():
399        if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
400            raise TypeError("input shift and axis must be tuple or list or int.")
401        if len(tup1) == len(tup2) or len(tup1) == 1 or len(tup2) == 1:
402            return
403        raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
404
405    tup1 = (tup1,) if isinstance(tup1, int) else tup1
406    tup2 = (tup2,) if isinstance(tup2, int) else tup2
407    _check()
408    if len(tup1) == len(tup2):
409        return tup1, tup2
410    if len(tup1) == 1:
411        tup1 *= len(tup2)
412    else: # case: len(tup2) == 1
413        tup2 *= len(tup1)
414    return tup1, tup2
415
416
417@_primexpr
418def _expanded_shape(ndim, axis_size, axis):
419    """
420    Returns a shape with size = 1 for all dimensions
421    except at axis.
422    """
423    return tuple([axis_size if i == axis else 1 for i in range(ndim)])
424
425
426@_primexpr
427def _add_unit_axes(shape, ndim, append=False):
428    """
429    Prepends shape with 1s so that it has the number of dimensions ndim.
430    If append is set to True, returns shape appended with 1s instead.
431    """
432    if isinstance(shape, int):
433        shape = (shape,)
434    ndim_diff = ndim - len(shape)
435    if ndim_diff > 0:
436        if append:
437            shape = list(shape) + [1] * ndim_diff
438        else:
439            shape = [1] * ndim_diff + list(shape)
440    return tuple(shape)
441
442
443@constexpr
444def  _check_element_int(lst):
445    """
446    Check whether each element in `lst` is an integer.
447    """
448    for item in lst:
449        if not isinstance(item, int):
450            raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.")
451    return True
452
453
454@constexpr
455def _type_convert(force, obj):
456    """
457    Convert type of `obj` to `force`.
458    """
459    return force(obj)
460
461
462@_primexpr
463def _list_comprehensions(obj, item=None, return_tuple=False, make_none=False):
464    """
465    Generates a new list/tuple by list comprehension.
466
467    Args:
468        obj (Union[int, list, tuple]):
469            If integer, it will be the length of the returned tuple/list.
470        item: The value to be filled. Default: ``None``.
471            If None, the values in the new list/tuple are the same as obj
472            or range(obj) when obj is integer.
473        return_tuple(bool): If true, returns tuple, else returns list.
474
475    Returns:
476        List or tuple.
477    """
478    res = []
479    lst = obj
480    if isinstance(obj, int):
481        lst = range(obj)
482    if make_none:
483        res = [None for _ in lst]
484    elif item is None:
485        res = list(lst)
486    else:
487        res = [item for _ in lst]
488    if return_tuple:
489        return tuple(res)
490    return res
491
492
493@_primexpr
494def _tuple_setitem(tup, idx, value):
495    """
496    Returns a tuple with specified `idx` set to `value`.
497    """
498    tup = list(tup)
499    tup[idx] = value
500    return tuple(tup)
501
502
503@constexpr
504def _iota(dtype, num, increasing=True):
505    """Creates a 1-D tensor with value: [0,1,...num-1] and dtype."""
506    # Change to P.Linspace when the kernel is implemented on CPU.
507    if num <= 0:
508        raise ValueError("zero shape Tensor is not currently supported.")
509    if increasing:
510        return Tensor(list(range(int(num))), dtype)
511    return Tensor(list(range(int(num)-1, -1, -1)), dtype)
512
513
514@constexpr
515def _ceil(number):
516    """Ceils the number in graph mode."""
517    return math.ceil(number)
518
519
520@_primexpr
521def _seq_prod(seq1, seq2):
522    """Returns the element-wise product of seq1 and seq2."""
523    return tuple(map(lambda x, y: x*y, seq1, seq2))
524
525
526@constexpr
527def _make_tensor(val, dtype):
528    """Returns the tensor with value `val` and dtype `dtype`."""
529    return Tensor(val, dtype)
530
531
532@_primexpr
533def _tuple_slice(tup, start, end):
534    """get sliced tuple from start and end."""
535    return tup[start:end]
536
537
538@constexpr
539def _isscalar(x):
540    """Returns True if x is a scalar type"""
541    return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
542                          typing.Bool, typing.String))
543
544
545@constexpr
546def _cumprod(x):
547    return tuple(accumulate(x, operator.mul))
548
549
550@constexpr
551def _in(x, y):
552    return x in y
553
554
555@constexpr
556def _callable_const(x):
557    """Returns true if x is a function in graph mode."""
558    return isinstance(x, typing.Function)
559
560
561@constexpr
562def _check_is_inf(x, negative=False):
563    if not negative:
564        return x == float('inf')
565    return x == float('-inf')
566