1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""internal utility functions""" 16import types 17 18from ..common import Tensor 19from ..ops import functional as F 20from ..common import dtype as mstype 21 22from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert, \ 23 _tuple_setitem, _callable_const, _check_is_float, _get_device 24 25 26def _deep_list(array_like): 27 """convert nested tuple/list mixtures to pure nested list""" 28 if isinstance(array_like, (list, tuple)): 29 return list(map(_deep_list, array_like)) 30 return array_like 31 32 33def _deep_tensor_to_nparray(array_like): 34 """ 35 convert a nested list of tensor to nested list of np_array. 36 37 Args: 38 array_like(list(tensor)): In any format of nested lists that may contain 39 tensors. 40 41 Returns: 42 array_like(list(np_array)): Formatted array that can be directly processed 43 by numpy.array(), with all tensor elements converted to numpy_array. 44 """ 45 # Recursively check whether each element is a tensor or not, if is tensor, 46 # convert it to a numpy array in place 47 if isinstance(array_like, Tensor): 48 return array_like.asnumpy() 49 50 if isinstance(array_like, list): 51 for idx, value in enumerate(array_like): 52 array_like[idx] = _deep_tensor_to_nparray(value) 53 54 return array_like 55 56 57def _check_input_for_asarray(array_like): 58 """check whether array_like argument is a valid type for np.asarray conversion""" 59 if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)): 60 _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like) 61 62 63def _is_scalar(shape): 64 """check whether input shape is a scalar""" 65 return F.shape_mul(shape) == 1 66 67 68def _convert_list_tensor_to_tuple_tensor(list_of_tensor): 69 """Convert a list of tensor to a tuple of tensor""" 70 if isinstance(list_of_tensor, list): 71 tuple_of_tensor = () 72 for tensor in list_of_tensor: 73 tuple_of_tensor += (tensor,) 74 return tuple_of_tensor 75 return list_of_tensor 76 77 78def _expand(x, ndim, axis=0): 79 """Expand x to ndim from axis, which can be 0 or -1.""" 80 shape = _add_unit_axes(F.shape(x), ndim, axis == -1) 81 return F.reshape(x, shape) 82 83 84def _broadcast_to(x, shape_cur, shape_to, ndim_to): 85 """Broadcasts x from shape_cur to shape_to.""" 86 size = _tile_size(shape_cur, shape_to, ndim_to) 87 return F.tile(x, size) 88 89 90def _broadcast_to_shape(x, shape): 91 """Broadcasts x from current shape to shape""" 92 ndim_to = len(shape) 93 x = _expand(x, ndim_to) 94 return _broadcast_to(x, F.shape(x), shape, ndim_to) 95 96 97def _get_size(x, axis=None): 98 """Get the number of elements along the given axis of tensor x.""" 99 if axis is None or F.tuple_len(axis) == 0: 100 axis = F.make_range(x.ndim) 101 nums = 1 102 for ax in axis: 103 nums *= x.shape[ax] 104 return nums 105 106 107def _check_input_tensor(*tensors): 108 for tensor in tensors: 109 if not isinstance(tensor, Tensor): 110 _raise_type_error('expect Tensor, but got ', F.typeof(tensor)) 111 return True 112 113 114def _convert_64_to_32(tensor): 115 """Convert tensor with float64/int64 types to float32/int32.""" 116 if tensor.dtype == mstype.float64: 117 return tensor.astype("float32") 118 if tensor.dtype == mstype.int64: 119 return tensor.astype("int32") 120 return tensor 121 122 123def _to_tensor(*args): 124 """Returns each input as Tensor""" 125 res = () 126 for arg in args: 127 if isinstance(arg, (int, float, bool, list, tuple)): 128 arg = _convert_64_to_32(_type_convert(Tensor, arg)) 129 elif not isinstance(arg, Tensor): 130 _raise_type_error("Expect input to be array like.") 131 res += (arg,) 132 if len(res) == 1: 133 return res[0] 134 return res 135 136 137def _get_dtype_from_scalar(*input_numbers): 138 """ 139 Get the final dtype from series of input numbers, compared with F.typeof, we 140 return int32/float32 for python int/float instead. 141 """ 142 bool_flag = True 143 int_flag = True 144 for number in input_numbers: 145 if number is not None: 146 if not isinstance(number, bool): 147 bool_flag = False 148 if not isinstance(number, int): 149 int_flag = False 150 if bool_flag: 151 return mstype.bool_ 152 if int_flag: 153 return mstype.int32 154 return mstype.float32 155 156 157def _convert_bool_to_int(tensor): 158 """Convert tensor with bool type to int32.""" 159 if tensor.dtype == mstype.bool_: 160 return tensor.astype("int32") 161 return tensor 162 163 164def _slice_along_axis(f, axis, slice_start, slice_end): 165 """ 166 Slice a tensor along a given axis 167 168 Args: 169 f (Tensor): Input Tensor. 170 axis (int): Specified axis. 171 slice_start (int): The start of the slice. 172 slice_end (int): The end of the slice. 173 174 Returns: 175 Sliced tensor. 176 """ 177 index_start = (0,) * f.ndim 178 index_end = f.shape 179 slice_size = slice_end - slice_start 180 index_start = _tuple_setitem(index_start, axis, slice_start) 181 index_end = _tuple_setitem(index_end, axis, slice_size) 182 return F.tensor_slice(f, index_start, index_end) 183 184 185def _to_tensor_origin_dtype(*args): 186 """Returns each input as Tensor and remains original dtype.""" 187 res = [] 188 for arg in args: 189 if isinstance(arg, (int, float, bool, list, tuple)): 190 arg = _type_convert(Tensor, arg) 191 elif not isinstance(arg, Tensor): 192 _raise_type_error("Expect input to be array like.") 193 res.append(arg) 194 if len(res) == 1: 195 return res[0] 196 return res 197 198 199def _callable(tensor, obj): 200 """Returns True if `obj` is a function.""" 201 if F.isconstant(tensor): 202 return isinstance(obj, types.FunctionType) 203 return _callable_const(F.typeof(obj)) 204 205 206def _isnan(x): 207 if _get_device() == 'Ascend' or not _check_is_float(F.dtype(x)): 208 return F.fill(mstype.bool_, F.shape(x), False) 209 return F.isnan(x) 210