• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""utils for operator"""
17from __future__ import absolute_import
18
19from mindspore.common.tensor import Tensor
20from mindspore import _checkparam as validator
21from mindspore.common import dtype as mstype
22from mindspore.ops.primitive import _primexpr
23from mindspore.common._utils import is_dim_unknown
24
25
26def get_broadcast_shape(x_shape, y_shape, prim_name, arg_name1="x", arg_name2="y"):
27    """
28    Doing broadcast between tensor x and tensor y.
29
30    Args:
31        x_shape (list): The shape of tensor x.
32        y_shape (list): The shape of tensor y.
33        prim_name (str): Primitive name.
34        arg_name1 (str): The arg name of x_shape.
35        arg_name2 (str): The arg name of y_shape.
36
37    Returns:
38        List, the shape that broadcast between tensor x and tensor y.
39
40    Raises:
41        ValueError: If tensor x and tensor y are not equal and couldn't broadcast.
42
43    Examples:
44        >>> x_shape = [1, 2, 3]
45        >>> y_shape = [1, 2]
46        >>> broadcast_shape = get_broadcast_shape(x_shape, y_shape)
47    """
48    if x_shape == y_shape:
49        return x_shape
50    x_len = len(x_shape)
51    y_len = len(y_shape)
52    length = x_len if x_len < y_len else y_len
53    broadcast_shape_back = []
54    if is_dim_unknown(x_shape) or is_dim_unknown(y_shape):
55        return [-2]
56
57    for i in range(-length, 0):
58        if x_shape[i] == 1:
59            broadcast_shape_back.append(y_shape[i])
60        elif y_shape[i] == 1:
61            broadcast_shape_back.append(x_shape[i])
62        elif x_shape[i] == y_shape[i]:
63            broadcast_shape_back.append(x_shape[i])
64        elif (x_shape[i] == -1 and abs(y_shape[i]) != 1) or \
65             (y_shape[i] == -1 and abs(x_shape[i]) != 1):
66            broadcast_shape_back.append(max(x_shape[i], y_shape[i]))
67        elif x_shape[i] == -1 or y_shape[i] == -1:
68            broadcast_shape_back.append(-1)
69        else:
70            raise ValueError(f"For '{prim_name}', {arg_name1}.shape and {arg_name2}.shape need to "
71                             f"broadcast. The value of {arg_name1}.shape[{i}] or {arg_name2}.shape[{i}]"
72                             f" must be 1 or -1 when they are not the same, "
73                             f"but got {arg_name1}.shape = {x_shape} "
74                             f"and {arg_name2}.shape = {y_shape}.")
75
76    broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
77    broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
78    return broadcast_shape
79
80
81def dim_not_equal(dim1, dim2):
82    """Compare dim in shape"""
83    return dim1 != dim2 and dim1 >= 0 and dim2 >= 0
84
85
86def get_concat_offset(x_shp, x_type, axis, prim_name):
87    """for concat and concatoffset check args and compute offset"""
88    validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
89    validator.check_positive_int(len(x_shp), "input_x rank", prim_name)
90    validator.check_subclass("shape0", x_type[0], mstype.tensor_type, prim_name)
91    validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name)
92    rank_base = len(x_shp[0])
93    for i in range(1, len(x_shp)):
94        validator.check('len of x_shp[%d]' % i, len(x_shp[i]), 'len of x_shp[0]',
95                        len(x_shp[0]), validator.EQ, prim_name)
96        validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], validator.EQ, prim_name)
97
98    validator.check_int_range(axis, -rank_base, rank_base - 1, validator.INC_BOTH, 'axis', prim_name)
99    if axis < 0:
100        axis = axis + rank_base
101    all_shp = x_shp[0][axis]
102    offset = [0]
103    for i in range(1, len(x_shp)):
104        v = x_shp[i]
105        for j in range(rank_base):
106            if j != axis and dim_not_equal(v[j], x_shp[0][j]):
107                raise ValueError(f"The shape of the two input elements of the Concat operator do not match:"
108                                 f"shape[0] = {x_shp[0]} and shape[{i}] = {x_shp[i]}.")
109        offset.append(all_shp)
110        if all_shp == -1 or v[axis] == -1:
111            all_shp = -1
112        else:
113            all_shp += v[axis]
114    return offset, all_shp, axis
115
116
117@_primexpr
118def range_op(start, limit, delta, dtype):
119    """helper function to get tensor in specified range."""
120    output_tensor = Tensor(list(range(start, limit, delta)), dtype)
121    return output_tensor
122
123
124@_primexpr
125def get_1d_shape(in_shape):
126    """helper function to get 1d shape."""
127    out_shape = 1
128    for i in in_shape:
129        out_shape *= i
130    return (out_shape,)
131
132
133@_primexpr
134def generate_shape_index(out_shape, indices_shape, axis, batch_dims=0):
135    out_rank = len(out_shape)
136    ind_rank = len(indices_shape)
137    if axis < 0:
138        axis += out_rank - ind_rank + 1
139    perm_part1 = tuple(range(axis, axis + ind_rank - batch_dims))
140    index = tuple(range(out_rank))
141    perm = index[:batch_dims] + perm_part1 + index[batch_dims:axis] + index[axis + ind_rank - batch_dims:]
142    return perm
143
144
145def ms_arrange(x):
146    out = [i for i in range(x)]
147    return out
148