• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""array Operations."""
16from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
17from mindspore.common import dtype as mstype
18from mindspore.common._register_for_tensor import tensor_operator_registry
19from mindspore._checkparam import Validator as validator
20from mindspore._checkparam import Rel
21from mindspore.ops.primitive import constexpr
22from mindspore.ops import functional as F
23from .. import operations as P
24
25
26@constexpr
27def _check_is_int(arg_value, arg_name, op_name):
28    arg_value = validator.check_is_int(arg_value, arg_name, op_name)
29    return arg_value
30
31
32@constexpr
33def _check_positive_int(arg_value, arg_name, op_name):
34    arg_value = validator.check_positive_int(arg_value, arg_name, op_name)
35    return arg_value
36
37
38@constexpr
39def _check_axis_range(arg_value, limit, arg_name, op_name):
40    arg_value = validator.check_int_range(arg_value, -limit, limit, Rel.INC_LEFT, arg_name, op_name)
41    return arg_value
42
43
44@constexpr
45def _cal_repeat_dims(x_rank, rep, expand_axis):
46    rep_dims = [1] * (x_rank + 1)
47    rep_dims[expand_axis] = rep
48    return tuple(rep_dims)
49
50
51@constexpr
52def _cal_reshape(x_shape, rep, axis):
53    x_reshape = list(x_shape)
54    x_reshape[axis] *= rep
55    return tuple(x_reshape)
56
57
58def repeat_elements(x, rep, axis=0):
59    """
60    Repeat elements of a tensor along an axis, like np.repeat.
61
62    Args:
63        x (Tensor): The tensor to repeat values for. Must be of type: float16,
64            float32, int8, uint8, int16, int32, or int64.
65        rep (int): The number of times to repeat, must be positive, required.
66        axis (int): The axis along which to repeat, default 0.
67
68    Outputs:
69        One tensor with values repeated along the specified axis. If x has shape
70        (s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ...,
71        si * rep, ..., sn). The output type will be the same as the type of `x`.
72
73    Supported Platforms:
74        ``Ascend`` ``GPU`` ``CPU``
75
76    Examples:
77        >>> # case 1 : repeat on axis 0
78        >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
79        >>> output = ops.repeat_elements(x, rep = 2, axis = 0)
80        >>> print(output)
81        [[0 1 2]
82         [0 1 2]
83         [3 4 5]
84         [3 4 5]]
85        >>> # case 2 : repeat on axis 1
86        >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
87        >>> output = ops.repeat_elements(x, rep = 2, axis = 1)
88        >>> print(output)
89        [[0 0 1 1 2 2]
90         [3 3 4 4 5 5]]
91    """
92    const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
93    rep = _check_positive_int(rep, "rep", "repeat_elements")
94    axis = _check_is_int(axis, "axis", "repeat_elements")
95
96    shape_op = P.Shape()
97    rank_op = P.Rank()
98    tile_op = P.Tile()
99    expand_dims_op = P.ExpandDims()
100    reshape_op = P.Reshape()
101
102    x_rank = rank_op(x)
103    axis = _check_axis_range(axis, x_rank, "axis", "repeat_elements")
104
105    expand_axis = axis + 1
106    x_expand = expand_dims_op(x, expand_axis)
107    rep_dims = _cal_repeat_dims(x_rank, rep, expand_axis)
108    x_expand = tile_op(x_expand, rep_dims)
109    x_shape = shape_op(x)
110    x_reshape = _cal_reshape(x_shape, rep, axis)
111    x_rep = reshape_op(x_expand, x_reshape)
112
113    return x_rep
114
115tensor_operator_registry.register('repeat_elements', repeat_elements)
116
117
118@constexpr
119def _check_sequence_mask_input_len(input_shape, prim_name=None):
120    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
121    if not input_shape:
122        raise ValueError(f"{msg_prefix} input_shape should be greater than 0, but got {input_shape}.")
123    # broadcast only supports 7d shape
124    shape_size = len(input_shape)
125    if shape_size >= 7:
126        raise ValueError(f"{msg_prefix} dimension of input_shape should be less than 7, but got {shape_size}d.")
127
128
129def sequence_mask(lengths, maxlen=None, prim_name='sequence_mask'):
130    """
131    Returns a mask tensor representing the first N positions of each cell.
132
133    If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
134    [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
135
136    Inputs:
137        - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
138          less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
139          Must be type int32 or int64.
140        - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
141          type as elements in `lengths`. Default is None.
142
143    Outputs:
144        One mask tensor of shape lengths.shape + (maxlen,).
145
146    Raises:
147        TypeError: If `lengths` is not a Tensor.
148        TypeError: If `maxlen` is not an int.
149        TypeError: If dtype of `lengths` is neither int32 nor int64.
150
151    Supported Platforms:
152        ``GPU``
153
154    Examples:
155        >>> # case 1: When maxlen is assigned
156        >>> x = Tensor(np.array([1, 2, 3, 4]))
157        >>> output = ops.sequence_mask(x, 5)
158        >>> print(output)
159        [[ True False False False False]
160         [ True  True False False False]
161         [ True  True  True False False]
162         [ True  True  True  True False]]
163        >>> # case 2: When there is 0 in x
164        >>> x = Tensor(np.array([[1, 3], [2, 0]]))
165        >>> output = ops.sequence_mask(x, 5)
166        >>> print(output)
167        [[[ True False False False False]
168          [ True  True  True False False]]
169         [[ True  True False False False]
170          [False False False False False]]]
171        >>> # case 3: when the maxlen is not assigned
172        >>> x = Tensor(np.array([[1, 3], [2, 4]]))
173        >>> output = ops.sequence_mask(x)
174        >>> print(output)
175        [[[ True False False False ]
176          [ True  True  True False ]]
177         [[ True  True False False ]
178          [ True  True  True  True ]]]
179    """
180
181    argmax_op = P.ArgMaxWithValue()
182    reshape_op = P.Reshape()
183    range_op = P.Range()
184    expand_op = P.ExpandDims()
185    cast_op = P.Cast()
186    shape_op = P.Shape()
187    to_tensor_op = P.ScalarToArray()
188
189    const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths')
190    _check_sequence_mask_input_len(shape_op(lengths), prim_name)
191
192    if maxlen is None:
193        flatten_data = reshape_op(lengths, (-1,))
194        flatten_data = cast_op(flatten_data, mstype.float32)
195        _, value = argmax_op(flatten_data)
196        maxlen = cast_op(value, mstype.int32)
197    else:
198        maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask")
199        maxlen = to_tensor_op(maxlen)
200
201    range_vector = range_op(to_tensor_op(0), maxlen
202                            , to_tensor_op(1))
203    mask = expand_op(lengths, -1)
204    result = range_vector < mask
205    return result
206