1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""utils for operator""" 17 18from mindspore.common.tensor import Tensor 19from ..._checkparam import Validator as validator 20from ..._checkparam import Rel 21from ...common import dtype as mstype 22from ..primitive import constexpr 23 24 25def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type=""): 26 """ 27 Doing broadcast between tensor x and tensor y. 28 29 Args: 30 x_shape (list): The shape of tensor x. 31 y_shape (list): The shape of tensor y. 32 prim_name (str): Primitive name. 33 34 Returns: 35 List, the shape that broadcast between tensor x and tensor y. 36 37 Raises: 38 ValueError: If tensor x and tensor y are not equal and couldn't broadcast. 39 40 Examples: 41 >>> x_shape = [1, 2, 3] 42 >>> y_shape = [1, 2] 43 >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape) 44 """ 45 if x_shape == y_shape: 46 return x_shape 47 x_len = len(x_shape) 48 y_len = len(y_shape) 49 length = x_len if x_len < y_len else y_len 50 broadcast_shape_back = [] 51 52 for i in range(-length, 0): 53 if x_shape[i] == 1: 54 broadcast_shape_back.append(y_shape[i]) 55 elif y_shape[i] == 1: 56 broadcast_shape_back.append(x_shape[i]) 57 elif x_shape[i] == y_shape[i]: 58 broadcast_shape_back.append(x_shape[i]) 59 elif x_shape[i] == -1 or y_shape[i] == -1: 60 broadcast_shape_back.append(-1) 61 else: 62 if shape_type == "min_shape": 63 broadcast_shape_back.append(max(x_shape[i], y_shape[i])) 64 elif shape_type == "max_shape": 65 broadcast_shape_back.append(min(x_shape[i], y_shape[i])) 66 else: 67 raise ValueError(f"For '{prim_name}', x_shape and y_shape are supposed to broadcast, " 68 f"where broadcast means that " 69 f"x_shape[i] = 1 or -1 or y_shape[i] = 1 or -1 or x_shape[i] = y_shape[i], " 70 f"but now x_shape and y_shape can not broadcast, " 71 f"got i: {i}, x_shape: {x_shape}, y_shape: {y_shape}.") 72 73 broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] 74 broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back 75 return broadcast_shape 76 77 78def get_concat_offset(x_shp, x_type, axis, prim_name): 79 """for concat and concatoffset check args and compute offset""" 80 validator.check_value_type("shape", x_shp, [tuple, list], prim_name) 81 validator.check_positive_int(len(x_shp), "input_x rank", prim_name) 82 validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) 83 validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) 84 rank_base = len(x_shp[0]) 85 validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name) 86 if axis < 0: 87 axis = axis + rank_base 88 all_shp = x_shp[0][axis] 89 offset = [0] 90 for i in range(1, len(x_shp)): 91 v = x_shp[i] 92 validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) 93 validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name) 94 for j in range(rank_base): 95 if j != axis and v[j] != x_shp[0][j]: 96 raise ValueError(f"The shape of the two input elements of the Concat operator do not match:" 97 f"shape[0] = {x_shp[0]} and shape[1] = {x_shp[1]}.") 98 offset.append(all_shp) 99 if all_shp == -1 or v[axis] == -1: 100 all_shp = -1 101 else: 102 all_shp += v[axis] 103 return offset, all_shp, axis 104 105 106@constexpr 107def range_op(start, limit, delta, dtype): 108 """helper function to get tensor in specified range.""" 109 output_tensor = Tensor(list(range(start, limit, delta)), dtype) 110 return output_tensor 111 112 113@constexpr 114def get_1d_shape(in_shape): 115 """helper function to get 1d shape.""" 116 out_shape = 1 117 for i in in_shape: 118 out_shape *= i 119 return (out_shape,) 120 121 122@constexpr 123def generate_shape_index(out_shape, indices_shape, axis): 124 out_rank = len(out_shape) 125 ind_rank = len(indices_shape) 126 if axis < 0: 127 axis += out_rank - ind_rank + 1 128 perm_part1 = tuple(range(axis, axis + ind_rank)) 129 index = tuple(range(out_rank)) 130 perm = perm_part1 + index[:axis] + index[axis + ind_rank:] 131 return perm 132