1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""utils for operator""" 17from __future__ import absolute_import 18 19from mindspore.common.tensor import Tensor 20from mindspore import _checkparam as validator 21from mindspore.common import dtype as mstype 22from mindspore.ops.primitive import _primexpr 23from mindspore.common._utils import is_dim_unknown 24 25 26def get_broadcast_shape(x_shape, y_shape, prim_name, arg_name1="x", arg_name2="y"): 27 """ 28 Doing broadcast between tensor x and tensor y. 29 30 Args: 31 x_shape (list): The shape of tensor x. 32 y_shape (list): The shape of tensor y. 33 prim_name (str): Primitive name. 34 arg_name1 (str): The arg name of x_shape. 35 arg_name2 (str): The arg name of y_shape. 36 37 Returns: 38 List, the shape that broadcast between tensor x and tensor y. 39 40 Raises: 41 ValueError: If tensor x and tensor y are not equal and couldn't broadcast. 42 43 Examples: 44 >>> x_shape = [1, 2, 3] 45 >>> y_shape = [1, 2] 46 >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape) 47 """ 48 if x_shape == y_shape: 49 return x_shape 50 x_len = len(x_shape) 51 y_len = len(y_shape) 52 length = x_len if x_len < y_len else y_len 53 broadcast_shape_back = [] 54 if is_dim_unknown(x_shape) or is_dim_unknown(y_shape): 55 return [-2] 56 57 for i in range(-length, 0): 58 if x_shape[i] == 1: 59 broadcast_shape_back.append(y_shape[i]) 60 elif y_shape[i] == 1: 61 broadcast_shape_back.append(x_shape[i]) 62 elif x_shape[i] == y_shape[i]: 63 broadcast_shape_back.append(x_shape[i]) 64 elif (x_shape[i] == -1 and abs(y_shape[i]) != 1) or \ 65 (y_shape[i] == -1 and abs(x_shape[i]) != 1): 66 broadcast_shape_back.append(max(x_shape[i], y_shape[i])) 67 elif x_shape[i] == -1 or y_shape[i] == -1: 68 broadcast_shape_back.append(-1) 69 else: 70 raise ValueError(f"For '{prim_name}', {arg_name1}.shape and {arg_name2}.shape need to " 71 f"broadcast. The value of {arg_name1}.shape[{i}] or {arg_name2}.shape[{i}]" 72 f" must be 1 or -1 when they are not the same, " 73 f"but got {arg_name1}.shape = {x_shape} " 74 f"and {arg_name2}.shape = {y_shape}.") 75 76 broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] 77 broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back 78 return broadcast_shape 79 80 81def dim_not_equal(dim1, dim2): 82 """Compare dim in shape""" 83 return dim1 != dim2 and dim1 >= 0 and dim2 >= 0 84 85 86def get_concat_offset(x_shp, x_type, axis, prim_name): 87 """for concat and concatoffset check args and compute offset""" 88 validator.check_value_type("shape", x_shp, [tuple, list], prim_name) 89 validator.check_positive_int(len(x_shp), "input_x rank", prim_name) 90 validator.check_subclass("shape0", x_type[0], mstype.tensor_type, prim_name) 91 validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) 92 rank_base = len(x_shp[0]) 93 for i in range(1, len(x_shp)): 94 validator.check('len of x_shp[%d]' % i, len(x_shp[i]), 'len of x_shp[0]', 95 len(x_shp[0]), validator.EQ, prim_name) 96 validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], validator.EQ, prim_name) 97 98 validator.check_int_range(axis, -rank_base, rank_base - 1, validator.INC_BOTH, 'axis', prim_name) 99 if axis < 0: 100 axis = axis + rank_base 101 all_shp = x_shp[0][axis] 102 offset = [0] 103 for i in range(1, len(x_shp)): 104 v = x_shp[i] 105 for j in range(rank_base): 106 if j != axis and dim_not_equal(v[j], x_shp[0][j]): 107 raise ValueError(f"The shape of the two input elements of the Concat operator do not match:" 108 f"shape[0] = {x_shp[0]} and shape[{i}] = {x_shp[i]}.") 109 offset.append(all_shp) 110 if all_shp == -1 or v[axis] == -1: 111 all_shp = -1 112 else: 113 all_shp += v[axis] 114 return offset, all_shp, axis 115 116 117@_primexpr 118def range_op(start, limit, delta, dtype): 119 """helper function to get tensor in specified range.""" 120 output_tensor = Tensor(list(range(start, limit, delta)), dtype) 121 return output_tensor 122 123 124@_primexpr 125def get_1d_shape(in_shape): 126 """helper function to get 1d shape.""" 127 out_shape = 1 128 for i in in_shape: 129 out_shape *= i 130 return (out_shape,) 131 132 133@_primexpr 134def generate_shape_index(out_shape, indices_shape, axis, batch_dims=0): 135 out_rank = len(out_shape) 136 ind_rank = len(indices_shape) 137 if axis < 0: 138 axis += out_rank - ind_rank + 1 139 perm_part1 = tuple(range(axis, axis + ind_rank - batch_dims)) 140 index = tuple(range(out_rank)) 141 perm = index[:batch_dims] + perm_part1 + index[batch_dims:axis] + index[axis + ind_rank - batch_dims:] 142 return perm 143 144 145def ms_arrange(x): 146 out = [i for i in range(x)] 147 return out 148