• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""internal utility functions"""
16from __future__ import absolute_import
17
18import types
19
20from mindspore.common import Tensor
21from mindspore._c_expression import Tensor as Tensor_
22from mindspore.ops import functional as F
23from mindspore.common import dtype as mstype
24
25from mindspore.numpy.utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \
26    _tuple_setitem, _callable_const, _check_is_float, _get_device
27
28
29def _deep_list(array_like):
30    """convert nested tuple/list mixtures to pure nested list"""
31    if isinstance(array_like, (list, tuple)):
32        return list(map(_deep_list, array_like))
33    return array_like
34
35
36def _deep_tensor_to_nparray(array_like):
37    """
38    convert a nested list of tensor to nested list of np_array.
39
40    Args:
41        array_like(list(tensor)): In any format of nested lists that may contain
42        tensors.
43
44    Returns:
45        array_like(list(np_array)): Formatted array that can be directly processed
46            by numpy.array(), with all tensor elements converted to numpy_array.
47    """
48    # Recursively check whether each element is a tensor or not, if is tensor,
49    # convert it to a numpy array in place
50    if isinstance(array_like, Tensor):
51        return array_like.asnumpy()
52
53    if isinstance(array_like, list):
54        for idx, value in enumerate(array_like):
55            array_like[idx] = _deep_tensor_to_nparray(value)
56
57    return array_like
58
59
60def _check_input_for_asarray(array_like):
61    """check whether array_like argument is a valid type for np.asarray conversion"""
62    if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)):
63        _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like)
64
65
66def _is_scalar(shape):
67    """check whether input shape is a scalar"""
68    return F.shape_mul(shape) == 1
69
70
71def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
72    """Convert a list of tensor to a tuple of tensor"""
73    if isinstance(list_of_tensor, list):
74        tuple_of_tensor = ()
75        for tensor in list_of_tensor:
76            tuple_of_tensor += (tensor,)
77        return tuple_of_tensor
78    return list_of_tensor
79
80
81def _expand(x, ndim, axis=0):
82    """Expand x to ndim from axis, which can be 0 or -1."""
83    shape = _add_unit_axes(F.shape(x), ndim, axis == -1)
84    return F.reshape(x, shape)
85
86
87def _broadcast_to(x, shape_cur, shape_to, ndim_to):
88    """Broadcasts x from shape_cur to shape_to."""
89    size = _tile_size(shape_cur, shape_to, ndim_to)
90    return F.tile(x, size)
91
92
93def _broadcast_to_shape(x, shape):
94    """Broadcasts x from current shape to shape"""
95    ndim_to = len(shape)
96    x = _expand(x, ndim_to)
97    return _broadcast_to(x, F.shape(x), shape, ndim_to)
98
99
100def _get_size(x, axis=None):
101    """Get the number of elements along the given axis of tensor x."""
102    if axis is None or F.tuple_len(axis) == 0:
103        axis = F.make_range(x.ndim)
104    nums = 1
105    for ax in axis:
106        nums *= x.shape[ax]
107    return nums
108
109
110def _check_input_tensor(*tensors):
111    for tensor in tensors:
112        if not isinstance(tensor, Tensor):
113            _raise_type_error('expect Tensor, but got ', F.typeof(tensor))
114    return True
115
116
117def _convert_64_to_32(tensor):
118    """Convert tensor with float64/int64 types to float32/int32."""
119    if tensor.dtype == mstype.float64:
120        return tensor.astype("float32")
121    if tensor.dtype == mstype.int64:
122        return tensor.astype("int32")
123    return tensor
124
125
126def _to_tensor(*args):
127    """Returns each input as Tensor"""
128    res = ()
129    for arg in args:
130        if isinstance(arg, (int, float, bool, list, tuple)):
131            if isinstance(arg, (list, tuple)) and not arg:
132                arg = Tensor_(arg)
133            arg = _convert_64_to_32(_type_convert(Tensor, arg))
134        elif not isinstance(arg, Tensor):
135            _raise_type_error("Expect input to be array like.")
136        res += (arg,)
137    if len(res) == 1:
138        return res[0]
139    return res
140
141
142def _get_dtype_from_scalar(*input_numbers):
143    """
144    Get the final dtype from series of input numbers, compared with F.typeof, we
145    return int32/float32 for python int/float instead.
146    """
147    bool_flag = True
148    int_flag = True
149    for number in input_numbers:
150        if number is not None:
151            if not isinstance(number, bool):
152                bool_flag = False
153            if not isinstance(number, int):
154                int_flag = False
155    if bool_flag:
156        return mstype.bool_
157    if int_flag:
158        return mstype.int32
159    return mstype.float32
160
161
162def _convert_bool_to_int(tensor):
163    """Convert tensor with bool type to int32."""
164    if tensor.dtype == mstype.bool_:
165        return tensor.astype("int32")
166    return tensor
167
168
169def _slice_along_axis(f, axis, slice_start, slice_end):
170    """
171    Slice a tensor along a given axis
172
173    Args:
174        f (Tensor): Input Tensor.
175        axis (int): Specified axis.
176        slice_start (int): The start of the slice.
177        slice_end (int): The end of the slice.
178
179    Returns:
180        Sliced tensor.
181    """
182    index_start = (0,) * f.ndim
183    index_end = f.shape
184    slice_size = slice_end - slice_start
185    index_start = _tuple_setitem(index_start, axis, slice_start)
186    index_end = _tuple_setitem(index_end, axis, slice_size)
187    return F.tensor_slice(f, index_start, index_end)
188
189
190def _to_tensor_origin_dtype(*args):
191    """Returns each input as Tensor and remains original dtype."""
192    res = []
193    for arg in args:
194        if isinstance(arg, (int, float, bool, list, tuple)):
195            arg = _type_convert(Tensor, arg)
196        elif not isinstance(arg, Tensor):
197            _raise_type_error("Expect input to be array like.")
198        res.append(arg)
199    if len(res) == 1:
200        return res[0]
201    return res
202
203
204def _callable(tensor, obj):
205    """Returns True if `obj` is a function."""
206    if F.isconstant(tensor):
207        return isinstance(obj, types.FunctionType)
208    return _callable_const(F.typeof(obj))
209
210
211def _isnan(x):
212    if _get_device() == 'Ascend' and not _check_is_float(F.dtype(x)):
213        return F.fill(mstype.bool_, F.shape(x), False)
214    return F.isnan(x)
215