1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""internal utility functions""" 16from __future__ import absolute_import 17 18import types 19 20from mindspore.common import Tensor 21from mindspore._c_expression import Tensor as Tensor_ 22from mindspore.ops import functional as F 23from mindspore.common import dtype as mstype 24 25from mindspore.numpy.utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \ 26 _tuple_setitem, _callable_const, _check_is_float, _get_device 27 28 29def _deep_list(array_like): 30 """convert nested tuple/list mixtures to pure nested list""" 31 if isinstance(array_like, (list, tuple)): 32 return list(map(_deep_list, array_like)) 33 return array_like 34 35 36def _deep_tensor_to_nparray(array_like): 37 """ 38 convert a nested list of tensor to nested list of np_array. 39 40 Args: 41 array_like(list(tensor)): In any format of nested lists that may contain 42 tensors. 43 44 Returns: 45 array_like(list(np_array)): Formatted array that can be directly processed 46 by numpy.array(), with all tensor elements converted to numpy_array. 47 """ 48 # Recursively check whether each element is a tensor or not, if is tensor, 49 # convert it to a numpy array in place 50 if isinstance(array_like, Tensor): 51 return array_like.asnumpy() 52 53 if isinstance(array_like, list): 54 for idx, value in enumerate(array_like): 55 array_like[idx] = _deep_tensor_to_nparray(value) 56 57 return array_like 58 59 60def _check_input_for_asarray(array_like): 61 """check whether array_like argument is a valid type for np.asarray conversion""" 62 if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)): 63 _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like) 64 65 66def _is_scalar(shape): 67 """check whether input shape is a scalar""" 68 return F.shape_mul(shape) == 1 69 70 71def _convert_list_tensor_to_tuple_tensor(list_of_tensor): 72 """Convert a list of tensor to a tuple of tensor""" 73 if isinstance(list_of_tensor, list): 74 tuple_of_tensor = () 75 for tensor in list_of_tensor: 76 tuple_of_tensor += (tensor,) 77 return tuple_of_tensor 78 return list_of_tensor 79 80 81def _expand(x, ndim, axis=0): 82 """Expand x to ndim from axis, which can be 0 or -1.""" 83 shape = _add_unit_axes(F.shape(x), ndim, axis == -1) 84 return F.reshape(x, shape) 85 86 87def _broadcast_to(x, shape_cur, shape_to, ndim_to): 88 """Broadcasts x from shape_cur to shape_to.""" 89 size = _tile_size(shape_cur, shape_to, ndim_to) 90 return F.tile(x, size) 91 92 93def _broadcast_to_shape(x, shape): 94 """Broadcasts x from current shape to shape""" 95 ndim_to = len(shape) 96 x = _expand(x, ndim_to) 97 return _broadcast_to(x, F.shape(x), shape, ndim_to) 98 99 100def _get_size(x, axis=None): 101 """Get the number of elements along the given axis of tensor x.""" 102 if axis is None or F.tuple_len(axis) == 0: 103 axis = F.make_range(x.ndim) 104 nums = 1 105 for ax in axis: 106 nums *= x.shape[ax] 107 return nums 108 109 110def _check_input_tensor(*tensors): 111 for tensor in tensors: 112 if not isinstance(tensor, Tensor): 113 _raise_type_error('expect Tensor, but got ', F.typeof(tensor)) 114 return True 115 116 117def _convert_64_to_32(tensor): 118 """Convert tensor with float64/int64 types to float32/int32.""" 119 if tensor.dtype == mstype.float64: 120 return tensor.astype("float32") 121 if tensor.dtype == mstype.int64: 122 return tensor.astype("int32") 123 return tensor 124 125 126def _to_tensor(*args): 127 """Returns each input as Tensor""" 128 res = () 129 for arg in args: 130 if isinstance(arg, (int, float, bool, list, tuple)): 131 if isinstance(arg, (list, tuple)) and not arg: 132 arg = Tensor_(arg) 133 arg = _convert_64_to_32(_type_convert(Tensor, arg)) 134 elif not isinstance(arg, Tensor): 135 _raise_type_error("Expect input to be array like.") 136 res += (arg,) 137 if len(res) == 1: 138 return res[0] 139 return res 140 141 142def _get_dtype_from_scalar(*input_numbers): 143 """ 144 Get the final dtype from series of input numbers, compared with F.typeof, we 145 return int32/float32 for python int/float instead. 146 """ 147 bool_flag = True 148 int_flag = True 149 for number in input_numbers: 150 if number is not None: 151 if not isinstance(number, bool): 152 bool_flag = False 153 if not isinstance(number, int): 154 int_flag = False 155 if bool_flag: 156 return mstype.bool_ 157 if int_flag: 158 return mstype.int32 159 return mstype.float32 160 161 162def _convert_bool_to_int(tensor): 163 """Convert tensor with bool type to int32.""" 164 if tensor.dtype == mstype.bool_: 165 return tensor.astype("int32") 166 return tensor 167 168 169def _slice_along_axis(f, axis, slice_start, slice_end): 170 """ 171 Slice a tensor along a given axis 172 173 Args: 174 f (Tensor): Input Tensor. 175 axis (int): Specified axis. 176 slice_start (int): The start of the slice. 177 slice_end (int): The end of the slice. 178 179 Returns: 180 Sliced tensor. 181 """ 182 index_start = (0,) * f.ndim 183 index_end = f.shape 184 slice_size = slice_end - slice_start 185 index_start = _tuple_setitem(index_start, axis, slice_start) 186 index_end = _tuple_setitem(index_end, axis, slice_size) 187 return F.tensor_slice(f, index_start, index_end) 188 189 190def _to_tensor_origin_dtype(*args): 191 """Returns each input as Tensor and remains original dtype.""" 192 res = [] 193 for arg in args: 194 if isinstance(arg, (int, float, bool, list, tuple)): 195 arg = _type_convert(Tensor, arg) 196 elif not isinstance(arg, Tensor): 197 _raise_type_error("Expect input to be array like.") 198 res.append(arg) 199 if len(res) == 1: 200 return res[0] 201 return res 202 203 204def _callable(tensor, obj): 205 """Returns True if `obj` is a function.""" 206 if F.isconstant(tensor): 207 return isinstance(obj, types.FunctionType) 208 return _callable_const(F.typeof(obj)) 209 210 211def _isnan(x): 212 if _get_device() == 'Ascend' and not _check_is_float(F.dtype(x)): 213 return F.fill(mstype.bool_, F.shape(x), False) 214 return F.isnan(x) 215