• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""utils for operator"""
17
18from mindspore.common.tensor import Tensor
19from ..._checkparam import Validator as validator
20from ..._checkparam import Rel
21from ...common import dtype as mstype
22from ..primitive import constexpr
23
24
25def get_broadcast_shape(x_shape, y_shape, prim_name, shape_type=""):
26    """
27    Doing broadcast between tensor x and tensor y.
28
29    Args:
30        x_shape (list): The shape of tensor x.
31        y_shape (list): The shape of tensor y.
32        prim_name (str): Primitive name.
33
34    Returns:
35        List, the shape that broadcast between tensor x and tensor y.
36
37    Raises:
38        ValueError: If tensor x and tensor y are not equal and couldn't broadcast.
39
40    Examples:
41        >>> x_shape = [1, 2, 3]
42        >>> y_shape = [1, 2]
43        >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
44    """
45    if x_shape == y_shape:
46        return x_shape
47    x_len = len(x_shape)
48    y_len = len(y_shape)
49    length = x_len if x_len < y_len else y_len
50    broadcast_shape_back = []
51
52    for i in range(-length, 0):
53        if x_shape[i] == 1:
54            broadcast_shape_back.append(y_shape[i])
55        elif y_shape[i] == 1:
56            broadcast_shape_back.append(x_shape[i])
57        elif x_shape[i] == y_shape[i]:
58            broadcast_shape_back.append(x_shape[i])
59        elif x_shape[i] == -1 or y_shape[i] == -1:
60            broadcast_shape_back.append(-1)
61        else:
62            if shape_type == "min_shape":
63                broadcast_shape_back.append(max(x_shape[i], y_shape[i]))
64            elif shape_type == "max_shape":
65                broadcast_shape_back.append(min(x_shape[i], y_shape[i]))
66            else:
67                raise ValueError(f"For '{prim_name}', x_shape and y_shape are supposed to broadcast, "
68                                 f"where broadcast means that "
69                                 f"x_shape[i] = 1 or -1 or y_shape[i] = 1 or -1 or x_shape[i] = y_shape[i], "
70                                 f"but now x_shape and y_shape can not broadcast, "
71                                 f"got i: {i}, x_shape: {x_shape}, y_shape: {y_shape}.")
72
73    broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
74    broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
75    return broadcast_shape
76
77
78def get_concat_offset(x_shp, x_type, axis, prim_name):
79    """for concat and concatoffset check args and compute offset"""
80    validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
81    validator.check_positive_int(len(x_shp), "input_x rank", prim_name)
82    validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
83    validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
84    rank_base = len(x_shp[0])
85    validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
86    if axis < 0:
87        axis = axis + rank_base
88    all_shp = x_shp[0][axis]
89    offset = [0]
90    for i in range(1, len(x_shp)):
91        v = x_shp[i]
92        validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
93        validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
94        for j in range(rank_base):
95            if j != axis and v[j] != x_shp[0][j]:
96                raise ValueError(f"The shape of the two input elements of the Concat operator do not match:"
97                                 f"shape[0] = {x_shp[0]} and shape[1] = {x_shp[1]}.")
98        offset.append(all_shp)
99        if all_shp == -1 or v[axis] == -1:
100            all_shp = -1
101        else:
102            all_shp += v[axis]
103    return offset, all_shp, axis
104
105
106@constexpr
107def range_op(start, limit, delta, dtype):
108    """helper function to get tensor in specified range."""
109    output_tensor = Tensor(list(range(start, limit, delta)), dtype)
110    return output_tensor
111
112
113@constexpr
114def get_1d_shape(in_shape):
115    """helper function to get 1d shape."""
116    out_shape = 1
117    for i in in_shape:
118        out_shape *= i
119    return (out_shape,)
120
121
122@constexpr
123def generate_shape_index(out_shape, indices_shape, axis):
124    out_rank = len(out_shape)
125    ind_rank = len(indices_shape)
126    if axis < 0:
127        axis += out_rank - ind_rank + 1
128    perm_part1 = tuple(range(axis, axis + ind_rank))
129    index = tuple(range(out_rank))
130    perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
131    return perm
132