• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""basic"""
17import math
18import numpy as np
19import mindspore.common.dtype as mstype
20from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
21from mindspore.common.seed import _get_graph_seed
22from mindspore.common.tensor import Tensor
23from mindspore.common.initializer import initializer
24from mindspore.ops import operations as P
25from mindspore.ops import functional as F
26from mindspore.ops.functional import identity
27from mindspore.ops.operations import _inner_ops as inner
28from mindspore.ops.primitive import constexpr, Primitive
29from mindspore.common.parameter import Parameter
30from mindspore._extends import cell_attr_register
31from mindspore._checkparam import Rel, Validator
32from ..cell import Cell
33from .activation import get_activation
34
35__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
36           'Tril', 'Triu', 'ResizeBilinear', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Roll']
37
38
39class L1Regularizer(Cell):
40    r"""
41    Applies l1 regularization to weights.
42
43    l1 regularization makes weights sparsity
44
45    .. math::
46        \text{loss}=\lambda * \text{reduce_sum}(\text{abs}(\omega))
47
48    Note:
49        scale(regularization factor) should be a number which greater than 0
50
51    Args:
52        scale (int, float): l1 regularization factor which greater than 0.
53
54    Inputs:
55        - **weights** (Tensor) - The input of L1Regularizer with data type of float16 or float32.
56          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
57
58    Outputs:
59        Tensor, which dtype is higher precision data type between mindspore.float32 and weights dtype,
60        and Tensor shape is ()
61
62    Raises:
63        TypeError: If `scale` is neither an int nor float.
64        ValueError: If `scale` is not greater than 0.
65        ValueError: If `scale` is math.inf or math.nan.
66
67    Supported Platforms:
68        ``Ascend`` ``GPU`` ``CPU``
69
70    Examples:
71        >>> scale = 0.5
72        >>> net = nn.L1Regularizer(scale)
73        >>> weights = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
74        >>> output = net(weights)
75        >>> print(output.asnumpy())
76        5.0
77    """
78
79    def __init__(self, scale):
80        """Initialize L1Regularizer."""
81        super(L1Regularizer, self).__init__()
82        Validator.check_value_type("scale", scale, [int, float], self.cls_name)
83        if scale <= 0:
84            raise ValueError(f"For '{self.cls_name}', the 'scale' should be greater than 0, but got {scale}.")
85        if math.isinf(scale) or math.isnan(scale):
86            raise ValueError(f"For '{self.cls_name}', the 'scale' can not be INF or NAN, but got {scale}.")
87        self.abs = P.Abs()
88        self.reduce_sum = P.ReduceSum()
89        self.scale = Tensor(scale, dtype=mstype.float32)
90
91    def construct(self, weights):
92        const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights')
93        l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
94        return l1_regularization
95
96
97class Dropout(Cell):
98    r"""
99    Dropout layer for the input.
100
101    Randomly set some elements of the input tensor to zero with probability :math:`1 - keep\_prob` during training
102    using samples from a Bernoulli distribution.
103
104    The outputs are scaled by a factor of :math:`\frac{1}{keep\_prob}`    during training so
105    that the output layer remains at a similar scale. During inference, this
106    layer returns the same tensor as the `x`.
107
108    This technique is proposed in paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
109    <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ and proved to be effective to reduce
110    over-fitting and prevents neurons from co-adaptation. See more details in `Improving neural networks by
111    preventing co-adaptation of feature detectors
112    <https://arxiv.org/pdf/1207.0580.pdf>`_.
113
114    Note:
115        Each channel will be zeroed out independently on every construct call.
116
117    Args:
118        keep_prob (float): The keep rate, greater than 0 and less equal than 1. E.g. rate=0.9,
119                   dropping out 10% of input units. Default: 0.5.
120        dtype (:class:`mindspore.dtype`): Data type of `x`. Default: mindspore.float32.
121
122    Inputs:
123        - **x** (Tensor) - The input of Dropout with data type of float16 or float32.
124          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
125
126    Outputs:
127        Tensor, output tensor with the same shape as the `x`.
128
129    Raises:
130        TypeError: If `keep_prob` is not a float.
131        TypeError: If dtype of `x` is not neither float16 nor float32.
132        ValueError: If `keep_prob` is not in range (0, 1].
133        ValueError: If length of shape of `x` is less than 1.
134
135    Supported Platforms:
136        ``Ascend`` ``GPU`` ``CPU``
137
138    Examples:
139        >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
140        >>> net = nn.Dropout(keep_prob=0.8)
141        >>> net.set_train()
142        Dropout<keep_prob=0.8>
143        >>> output = net(x)
144        >>> print(output.shape)
145        (2, 2, 3)
146    """
147
148    def __init__(self, keep_prob=0.5, dtype=mstype.float32):
149        """Initialize Dropout."""
150        super(Dropout, self).__init__()
151        if keep_prob <= 0 or keep_prob > 1:
152            raise ValueError(f"For '{self.cls_name}', the 'keep_prob' should be a number in range (0, 1], "
153                             f"but got {keep_prob}.")
154        Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
155        Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
156        self.keep_prob = keep_prob
157        seed0, seed1 = _get_graph_seed(0, "dropout")
158        self.seed0 = seed0
159        self.seed1 = seed1
160        self.dropout = P.Dropout(keep_prob, seed0, seed1)
161
162    def construct(self, x):
163        if not self.training:
164            return x
165
166        if self.keep_prob == 1:
167            return x
168
169        out, _ = self.dropout(x)
170        return out
171
172    def extend_repr(self):
173        return 'keep_prob={}'.format(self.keep_prob)
174
175
176class Flatten(Cell):
177    r"""
178    Flatten layer for the input.
179
180    Flattens a tensor without changing dimension of batch size on the 0-th axis.
181
182    Inputs:
183        - **x** (Tensor) - Tensor of shape :math:`(N, \ldots)` to be flattened. The data type is Number.
184          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions
185          and the shape can't be ().
186
187    Outputs:
188        Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
189        the product of the remaining dimensions.
190
191    Raises:
192        TypeError: If `x` is not a subclass of Tensor.
193
194    Supported Platforms:
195        ``Ascend`` ``GPU`` ``CPU``
196
197    Examples:
198        >>> x = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
199        >>> net = nn.Flatten()
200        >>> output = net(x)
201        >>> print(output)
202        [[1.2 1.2 2.1 2.1]
203         [2.2 2.2 3.2 3.2]]
204        >>> print(f"before flatten the x shape is {x.shape}")
205        before flatten the x shape is  (2, 2, 2)
206        >>> print(f"after flatten the output shape is {output.shape}")
207        after flatten the output shape is (2, 4)
208    """
209
210    def __init__(self):
211        """Initialize Flatten."""
212        super(Flatten, self).__init__()
213
214    def construct(self, x):
215        return F.reshape(x, (F.shape(x)[0], -1))
216
217
218@constexpr
219def check_dense_input_shape(x, prim_name=None):
220    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
221    if len(x) < 2:
222        raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 2, but got {len(x)}.")
223
224
225class Dense(Cell):
226    r"""
227    The dense connected layer.
228
229    Applies dense connected layer for the input. This layer implements the operation as:
230
231    .. math::
232        \text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
233
234    where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
235    argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
236    data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
237    with the same data type as the :math:`X` created by the layer (only if has_bias is True).
238
239    Args:
240        in_channels (int): The number of channels in the input space.
241        out_channels (int): The number of channels in the output space.
242        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
243            is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
244        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
245            same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
246        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
247        activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer,
248            eg. 'ReLU'.Default: None.
249
250    Inputs:
251        - **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
252          to :math:`in\_channels` in `Inputs`.
253
254    Outputs:
255        Tensor of shape :math:`(*, out\_channels)`.
256
257    Raises:
258        TypeError: If `in_channels` or `out_channels` is not an int.
259        TypeError: If `has_bias` is not a bool.
260        TypeError: If `activation` is not one of str, Cell, Primitive, None.
261        ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
262                    is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
263        ValueError: If length of shape of `bias_init` is not equal to 1
264                    or shape[0] of `bias_init` is not equal to `out_channels`.
265
266    Supported Platforms:
267        ``Ascend`` ``GPU`` ``CPU``
268
269    Examples:
270        >>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
271        >>> net = nn.Dense(3, 4)
272        >>> output = net(x)
273        >>> print(output.shape)
274        (2, 4)
275    """
276
277    @cell_attr_register(attrs=['has_bias', 'activation'])
278    def __init__(self,
279                 in_channels,
280                 out_channels,
281                 weight_init='normal',
282                 bias_init='zeros',
283                 has_bias=True,
284                 activation=None):
285        """Initialize Dense."""
286        super(Dense, self).__init__()
287        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
288        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
289        self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
290        self.reshape = P.Reshape()
291        self.shape_op = P.Shape()
292
293        if isinstance(weight_init, Tensor):
294            if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
295                    weight_init.shape[1] != in_channels:
296                raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' should "
297                                 f"be equal to 2, and the first dim should be equal to 'out_channels', and the "
298                                 f"second dim should be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
299                                 f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
300        self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
301
302        self.bias = None
303        if self.has_bias:
304            if isinstance(bias_init, Tensor):
305                if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
306                    raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
307                                     f"be equal to 1, and the first dim should be equal to 'out_channels'. But got "
308                                     f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
309            self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
310            self.bias_add = P.BiasAdd()
311
312        self.matmul = P.MatMul(transpose_b=True)
313        self.activation = get_activation(activation) if isinstance(activation, str) else activation
314        if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
315            raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, but got "
316                            f"{type(activation).__name__}.")
317        self.activation_flag = self.activation is not None
318
319    def construct(self, x):
320        x_shape = self.shape_op(x)
321        check_dense_input_shape(x_shape, self.cls_name)
322        if len(x_shape) != 2:
323            x = self.reshape(x, (-1, x_shape[-1]))
324        x = self.matmul(x, self.weight)
325        if self.has_bias:
326            x = self.bias_add(x, self.bias)
327        if self.activation_flag:
328            x = self.activation(x)
329        if len(x_shape) != 2:
330            out_shape = x_shape[:-1] + (-1,)
331            x = self.reshape(x, out_shape)
332        return x
333
334    def extend_repr(self):
335        s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
336        if self.has_bias:
337            s += ', has_bias={}'.format(self.has_bias)
338        if self.activation_flag:
339            s += ', activation={}'.format(self.activation)
340        return s
341
342
343@constexpr
344def _is_equal_one(x):
345    if x is None:
346        return False
347    return bool(x.asnumpy().mean() == 1.0)
348
349
350@constexpr
351def _dtype_check(x_dtype, prim_name=None):
352    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
353    if x_dtype not in [mstype.float32, mstype.float16]:
354        raise TypeError(f"{msg_prefix} x_dtype must be float32 or float16, but got {x_dtype}.")
355
356
357@constexpr
358def _is_float_dtype(dtype):
359    if dtype in [mstype.float32, mstype.float16]:
360        return True
361    return False
362
363
364@constexpr
365def _need_reduce_all(axis):
366    if axis == ():
367        return True
368    return False
369
370
371class ClipByNorm(Cell):
372    r"""
373    Clips tensor values to a maximum :math:`L_2`-norm.
374
375    The output of this layer remains the same if the :math:`L_2`-norm of the input tensor
376    is not greater than the argument clip_norm. Otherwise the tensor will be normalized as:
377
378    .. math::
379        \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
380
381    where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
382
383    Args:
384        axis (Union[None, int, tuple(int)]): Compute the L2-norm along the Specific dimension.
385                                            Default: None, all dimensions to calculate.
386
387    Inputs:
388        - **x** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
389        - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
390          Or a tensor shape can be broadcast to input shape.
391
392    Outputs:
393        Tensor, clipped tensor with the same shape as the `x`, whose type is float32.
394
395    Raises:
396        TypeError: If `axis` is not one of None, int, tuple.
397        TypeError: If dtype of `x` is neither float32 nor float16.
398
399    Supported Platforms:
400        ``Ascend`` ``GPU`` ``CPU``
401
402    Examples:
403        >>> net = nn.ClipByNorm()
404        >>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
405        >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
406        >>> output = net(x, clip_norm)
407        >>> print(output.shape)
408        (4, 16)
409
410    """
411
412    def __init__(self, axis=None):
413        """Initialize ClipByNorm."""
414        super(ClipByNorm, self).__init__()
415        if axis is None:
416            axis = ()
417        if isinstance(axis, tuple):
418            for idx, item in enumerate(axis):
419                Validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
420        self.axis = Validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
421        self.reduce_sum = P.ReduceSum(keep_dims=True)
422        self.select_ = P.Select()
423        self.greater_ = P.Greater()
424        self.cast = P.Cast()
425        self.sqrt = P.Sqrt()
426        self.max_op = P.Maximum()
427        self.shape = P.Shape()
428        self.reshape = P.Reshape()
429        self.fill = P.Fill()
430        self.expand_dims = P.ExpandDims()
431        self.dtype = P.DType()
432
433    def construct(self, x, clip_norm):
434        mul_x = F.square(x)
435        l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32)
436        cond = self.greater_(l2sum, 0)
437        ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
438        l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
439        l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
440
441        _dtype_check(self.dtype(x), self.cls_name)
442        if _is_equal_one(clip_norm):
443            intermediate = x
444        else:
445            intermediate = x * clip_norm
446
447        max_norm = self.max_op(l2norm, clip_norm)
448        if _need_reduce_all(self.axis):
449            max_norm = self.expand_dims(max_norm, -1)
450        values_clip = self.cast(intermediate, mstype.float32) / max_norm
451        values_clip = self.reshape(values_clip, self.shape(x))
452        values_clip = identity(values_clip)
453        return values_clip
454
455
456class Norm(Cell):
457    r"""
458    Computes the norm of vectors, currently including Euclidean norm, i.e., :math:`L_2`-norm.
459
460    .. math::
461
462        norm(x) = \sqrt{\sum_{i=1}^{n} (x_i^2)}
463
464    Args:
465        axis (Union[tuple, int]): The axis over which to compute vector norms. Default: ().
466        keep_dims (bool): If true, the axis indicated in `axis` are kept with size 1. Otherwise,
467                   the dimensions in `axis` are removed from the output shape. Default: False.
468
469    Inputs:
470        - **x** (Tensor) - Tensor which is not empty. The data type should be float16 or float32.
471          :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
472
473    Outputs:
474        Tensor, output tensor with dimensions in 'axis' reduced to 1 will be returned if 'keep_dims' is True;
475        otherwise a Tensor with dimensions in 'axis' removed is returned. The data type is the same with `x`.
476
477    Raises:
478        TypeError: If `axis` is neither an int nor a tuple.
479        TypeError: If `keep_dims` is not a bool.
480
481    Supported Platforms:
482        ``Ascend`` ``GPU`` ``CPU``
483
484    Examples:
485        >>> net = nn.Norm(axis=0)
486        >>> x = Tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), mindspore.float32)
487        >>> print(x.shape)
488        (2, 4)
489        >>> output = net(x)
490        >>> print(output)
491        [4.472136 4.1231055 9.486833 6.0827627]
492        >>> print(output.shape)
493        (4,)
494        >>> net = nn.Norm(axis=0, keep_dims=True)
495        >>> x = Tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), mindspore.float32)
496        >>> print(x.shape)
497        (2, 4)
498        >>> output = net(x)
499        >>> print(output)
500        [4.472136 4.1231055 9.486833 6.0827627]
501        >>> print(output.shape)
502        (1, 4)
503        >>> net = nn.Norm(axis=1)
504        >>> x = Tensor(np.array([[4, 4, 9, 1], [2, 1, 3, 6]]), mindspore.float32)
505        >>> print(x.shape)
506        (2, 4)
507        >>> output = net(x)
508        >>> print(output)
509        [10.677078 7.071068]
510        >>> print(output.shape)
511        (2,)
512    """
513
514    def __init__(self, axis=(), keep_dims=False):
515        """Initialize Norm."""
516        super(Norm, self).__init__()
517        Validator.check_value_type("keep_dims", keep_dims, [bool], self.cls_name)
518        self.axis = axis
519        self.keep_dims = keep_dims
520        self.reduce_sum = P.ReduceSum(True)
521        self.sqrt = P.Sqrt()
522        self.squeeze = P.Squeeze(self.axis)
523
524    def construct(self, x):
525        x = self.sqrt(self.reduce_sum(F.square(x), self.axis))
526
527        if not self.keep_dims:
528            x = self.squeeze(x)
529        return x
530
531    def extend_repr(self):
532        return 'axis={}, keep_dims={}'.format(self.axis, self.keep_dims)
533
534
535class OneHot(Cell):
536    """
537    Returns a one-hot tensor.
538
539    The locations represented by indices in argument `indices` take value on_value,
540    while all other locations take value off_value.
541
542    Note:
543        If the input indices is rank :math:`N`, the output will have rank :math:`N+1`. The new
544        axis is created at dimension `axis`.
545
546    If `indices` is a scalar, the output shape will be a vector of length `depth`.
547
548    If `indices` is a vector of length `features`, the output shape will be:
549
550    .. code-block::
551
552        features * depth if axis == -1
553
554        depth * features if axis == 0
555
556    If `indices` is a matrix with shape `[batch, features]`, the output shape will be:
557
558    .. code-block::
559
560        batch * features * depth if axis == -1
561
562        batch * depth * features if axis == 1
563
564        depth * batch * features if axis == 0
565
566    Args:
567        axis (int): Features x depth if axis is -1, depth x features
568                    if axis is 0. Default: -1.
569        depth (int): A scalar defining the depth of the one hot dimension. Default: 1.
570        on_value (float): A scalar defining the value to fill in output[i][j]
571                          when indices[j] = i. Default: 1.0.
572        off_value (float): A scalar defining the value to fill in output[i][j]
573                           when indices[j] != i. Default: 0.0.
574        dtype (:class:`mindspore.dtype`): Data type of 'on_value' and 'off_value', not the
575                                          data type of indices. Default: mindspore.float32.
576
577    Inputs:
578        - **indices** (Tensor) - A tensor of indices with data type of int32 or int64.
579          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
580
581    Outputs:
582        Tensor, the one-hot tensor of data type `dtype` with dimension at `axis` expanded to `depth` and filled with
583        on_value and off_value. The dimension of the `Outputs` is equal to the dimension of the `indices` plus one.
584
585    Raises:
586        TypeError: If `axis` or `depth` is not an int.
587        TypeError: If dtype of `indices` is neither int32 nor int64.
588        ValueError: If `axis` is not in range [-1, len(indices_shape)].
589        ValueError: If `depth` is less than 0.
590
591    Supported Platforms:
592        ``Ascend`` ``GPU`` ``CPU``
593
594    Examples:
595        >>> # 1st sample: add new coordinates at axis 1
596        >>> net = nn.OneHot(depth=4, axis=1)
597        >>> indices = Tensor([[1, 3], [0, 2]], dtype=mindspore.int32)
598        >>> output = net(indices)
599        >>> print(output)
600        [[[0. 0.]
601          [1. 0.]
602          [0. 0.]
603          [0. 1.]]
604         [[1. 0.]
605          [0. 0.]
606          [0. 1.]
607          [0. 0.]]]
608        >>> # The results are shown below:
609        >>> print(output.shape)
610        (2, 4, 2)
611        >>> # 2nd sample: add new coordinates at axis 0
612        >>> net = nn.OneHot(depth=4, axis=0)
613        >>> indices = Tensor([[1, 3], [0, 2]], dtype=mindspore.int32)
614        >>> output = net(indices)
615        >>> print(output)
616        [[[0. 0.]
617          [1. 0.]]
618         [[1. 0.]
619          [0. 0.]]
620         [[0. 0.]
621          [0. 1.]]
622         [[0. 1.]
623          [0. 0.]]]
624        >>> # The results are shown below:
625        >>> print(output.shape)
626        (4, 2, 2)
627        >>> # 3rd sample: add new coordinates at the last dimension.
628        >>> net = nn.OneHot(depth=4, axis=-1)
629        >>> indices = Tensor([[1, 3], [0, 2]], dtype=mindspore.int32)
630        >>> output = net(indices)
631        >>> # The results are shown below:
632        >>> print(output)
633        [[[0. 1. 0. 0.]
634          [0. 0. 0. 1.]]
635         [[1. 0. 0. 0.]
636          [0. 0. 1. 0.]]]
637        >>> print(output.shape)
638        (2, 2, 4)
639        >>> indices = Tensor([1, 3, 0, 2], dtype=mindspore.int32)
640        >>> output = net(indices)
641        >>> print(output)
642        [[0. 1. 0. 0.]
643         [0. 0. 0. 1.]
644         [1. 0. 0. 0.]
645         [0. 0. 1. 0.]]
646        >>> print(output.shape)
647        (4, 4)
648    """
649
650    def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, dtype=mstype.float32):
651        """Initialize OneHot."""
652        super(OneHot, self).__init__()
653        self.onehot = P.OneHot(axis)
654        self.depth = depth
655        self.dtype = dtype
656        self.on_value = on_value
657        self.off_value = off_value
658
659    def construct(self, indices):
660        return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
661
662
663class Pad(Cell):
664    r"""
665    Pads the input tensor according to the paddings and mode.
666
667    Args:
668        paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
669            paddings are int type. For `D` th dimension of the `x`, paddings[D, 0] indicates how many sizes to be
670            extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to
671            be extended behind of the `D` th dimension of the input tensor. The padded size of each dimension D of the
672            output is: :math:`paddings[D, 0] + input\_x.dim\_size(D) + paddings[D, 1]`,
673            e.g.:
674
675            .. code-block::
676
677                mode = "CONSTANT".
678                paddings = [[1,1], [2,2]].
679                x = [[1,2,3], [4,5,6], [7,8,9]].
680                # The above can be seen: 1st dimension of `x` is 3, 2nd dimension of `x` is 3.
681                # Substitute into the formula to get:
682                # 1st dimension of output is paddings[0][0] + 3 + paddings[0][1] = 1 + 3 + 1 = 5.
683                # 2nd dimension of output is paddings[1][0] + 3 + paddings[1][1] = 2 + 3 + 2 = 7.
684                # So the shape of output is (5, 7).
685
686        mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
687            Default: "CONSTANT".
688
689    Inputs:
690        - **x** (Tensor) - The input tensor.
691
692    Outputs:
693        Tensor, the tensor after padding.
694
695        - If `mode` is "CONSTANT", it fills the edge with 0, regardless of the values of the `x`.
696          If the `x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
697          Outputs is [[0,0,0,0,0,0,0], [0,0,1,2,3,0,0], [0,0,4,5,6,0,0], [0,0,7,8,9,0,0], [0,0,0,0,0,0,0]].
698        - If `mode` is "REFLECT", it uses a way of symmetrical copying through the axis of symmetry to fill in.
699          If the `x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
700          Outputs is [[6,5,4,5,6,5,4], [3,2,1,2,3,2,1], [6,5,4,5,6,5,4], [9,8,7,8,9,8,7], [6,5,4,5,6,5,4]].
701        - If `mode` is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
702          according to the symmetry axis, except that it includes the symmetry axis. If the `x`
703          is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the Outputs is
704          [[2,1,1,2,3,3,2], [2,1,1,2,3,3,2], [5,4,4,5,6,6,5], [8,7,7,8,9,9,8], [8,7,7,8,9,9,8]].
705
706    Raises:
707        TypeError: If `paddings` is not a tuple.
708        ValueError: If length of `paddings` is more than 4 or its shape is not (n, 2).
709        ValueError: If `mode` is not one of 'CONSTANT', 'REFLECT', 'SYMMETRIC'.
710
711    Supported Platforms:
712        ``Ascend`` ``GPU`` ``CPU``
713
714    Examples:
715        >>> from mindspore import Tensor
716        >>> import mindspore.nn as nn
717        >>> import numpy as np
718        >>> # If `mode` is "CONSTANT"
719        >>> class Net(nn.Cell):
720        ...     def __init__(self):
721        ...         super(Net, self).__init__()
722        ...         self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="CONSTANT")
723        ...     def construct(self, x):
724        ...         return self.pad(x)
725        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.float32)
726        >>> pad = Net()
727        >>> output = pad(x)
728        >>> print(output)
729        [[0. 0. 0. 0. 0. 0. 0.]
730         [0. 0. 1. 2. 3. 0. 0.]
731         [0. 0. 4. 5. 6. 0. 0.]
732         [0. 0. 0. 0. 0. 0. 0.]]
733        >>> # Another way to call
734        >>> pad = ops.Pad(paddings=((1, 1), (2, 2)))
735        >>> # From the above code, we can see following:
736        >>> # "paddings=((1, 1), (2, 2))",
737        >>> # paddings[0][0] = 1, indicates a row of values is filled top of the input data in the 1st dimension.
738        >>> # Shown as follows:
739        >>> # [[0. 0. 0.]
740        >>> #  [1. 2. 3.]
741        >>> #  [4. 5. 6.]]
742        >>> # paddings[0][1] = 1 indicates a row of values is filled below input data in the 1st dimension.
743        >>> # Shown as follows:
744        >>> # [[0. 0. 0.]
745        >>> #  [1. 2. 3.]
746        >>> #  [4. 5. 6.]
747        >>> #  [0. 0. 0.]]
748        >>> # paddings[1][0] = 2, indicates 2 rows of values is filled in front of input data in the 2nd dimension.
749        >>> # Shown as follows:
750        >>> # [[0. 0. 0. 0. 0.]
751        >>> #  [0. 0. 1. 2. 3.]
752        >>> #  [0. 0. 4. 5. 6.]
753        >>> #  [0. 0. 0. 0. 0.]]
754        >>> # paddings[1][1] = 2, indicates 2 rows of values is filled in front of input data in the 2nd dimension.
755        >>> # Shown as follows:
756        >>> # [[0. 0. 0. 0. 0. 0. 0.]
757        >>> #  [0. 0. 1. 2. 3. 0. 0.]
758        >>> #  [0. 0. 4. 5. 6. 0. 0.]
759        >>> #  [0. 0. 0. 0. 0. 0. 0.]]
760        >>> output = pad(x)
761        >>> print(output)
762        [[0. 0. 0. 0. 0. 0. 0.]
763         [0. 0. 1. 2. 3. 0. 0.]
764         [0. 0. 4. 5. 6. 0. 0.]
765         [0. 0. 0. 0. 0. 0. 0.]]
766        >>> # if mode is "REFLECT"
767        >>> class Net(nn.Cell):
768        ...     def __init__(self):
769        ...         super(Net, self).__init__()
770        ...         self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="REFLECT")
771        ...     def construct(self, x):
772        ...         return self.pad(x)
773        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.float32)
774        >>> pad = Net()
775        >>> output = pad(x)
776        >>> print(output)
777        [[6. 5. 4. 5. 6. 5. 4.]
778         [3. 2. 1. 2. 3. 2. 1.]
779         [6. 5. 4. 5. 6. 5. 4.]
780         [3. 2. 1. 2. 3. 2. 1.]]
781        >>> # if mode is "SYMMETRIC"
782        >>> class Net(nn.Cell):
783        ...     def __init__(self):
784        ...         super(Net, self).__init__()
785        ...         self.pad = nn.Pad(paddings=((1, 1), (2, 2)), mode="SYMMETRIC")
786        ...     def construct(self, x):
787        ...         return self.pad(x)
788        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.float32)
789        >>> pad = Net()
790        >>> output = pad(x)
791        >>> print(output)
792        [[2. 1. 1. 2. 3. 3. 2.]
793         [2. 1. 1. 2. 3. 3. 2.]
794         [5. 4. 4. 5. 6. 6. 5.]
795         [5. 4. 4. 5. 6. 6. 5.]]
796    """
797
798    def __init__(self, paddings, mode="CONSTANT"):
799        """Initialize Pad."""
800        super(Pad, self).__init__()
801        self.mode = mode
802        self.paddings = paddings
803        Validator.check_string(self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], 'mode', self.cls_name)
804        if not isinstance(paddings, tuple):
805            raise TypeError(f"For '{self.cls_name}', the type of 'paddings' must be tuple, "
806                            f"but got {type(paddings).__name__}.")
807        for item in paddings:
808            if len(item) != 2:
809                raise ValueError(f"For '{self.cls_name}', the dimension of 'paddings' must be (n, 2), "
810                                 f"but got {paddings}.")
811        if len(paddings) > 4:
812            raise ValueError(f"For '{self.cls_name}', only 'paddings' up to 4 dims is supported, but got "
813                             f"{len(paddings)}.")
814        if mode == "CONSTANT":
815            self.pad = P.Pad(self.paddings)
816        else:
817            self.paddings = Tensor(np.array(self.paddings), dtype=mstype.int64)
818            self.pad = P.MirrorPad(mode=mode)
819
820    def construct(self, x):
821        if self.mode == "CONSTANT":
822            x = self.pad(x)
823        else:
824            x = self.pad(x, self.paddings)
825        return x
826
827
828@constexpr
829def bilinear(shape, size, scale, align_corners, prim_name=None):
830    """Check input and calculate shape"""
831    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
832    if not isinstance(align_corners, bool):
833        raise TypeError(f"{msg_prefix} type of 'align_corners' should be boolean, "
834                        f"but got {type(align_corners).__name__}.")
835    if size is None and scale is None:
836        raise ValueError(f"{msg_prefix} 'size' and 'scale' both none.")
837    if size is not None and scale is not None:
838        raise ValueError(f"{msg_prefix} 'size' and 'scale' both not none.")
839    if size is not None:
840        if not isinstance(size, (tuple, list)):
841            raise ValueError(f"{msg_prefix} 'size' must be tuple or list or None, but got {type(size).__name__}.")
842        Validator.check_int(len(size), 2, Rel.EQ, "size", "bilinear")
843        Validator.check_int(size[0], 1, Rel.GE, "size[0]", "bilinear")
844        Validator.check_int(size[1], 1, Rel.GE, "size[1]", "bilinear")
845        return size
846    Validator.check_int(scale, 1, Rel.GE, "scale factor", "bilinear")
847    ret = (scale * shape[2], scale * shape[3])
848    return ret
849
850
851class ResizeBilinear(Cell):
852    r"""
853    Samples the input tensor to the given size or scale_factor by using bilinear interpolate.
854
855    Inputs:
856        - **x** (Tensor) - Tensor to be resized. Input tensor must be a 4-D tensor with shape
857          :math:`(batch, channels, height, width)`, with data type of float16 or float32.
858        - **size** (Union[tuple[int], list[int]]): A tuple or list of 2 int elements
859          :math:`(new\_height, new\_width)`,the new size of the tensor.
860          One and only one of size and scale_factor can be set to None. Default: None.
861        - **scale_factor** (int): The scale factor of new size of the tensor. The value should be positive integer.
862          One and only one of size and scale_factor can be set to None. Default: None.
863        - **align_corners** (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, which exactly
864          aligns the 4 corners of images and resized images. If false, rescale by :math:`new\_height / height`.
865          Default: False.
866
867    Outputs:
868        Resized tensor.
869        If size is set, the result is 4-D tensor with shape :math:`(batch, channels, new\_height, new\_width)`,
870        and the data type is the same as `x`.
871        If scale is set, the result is 4-D tensor with shape
872        :math:`(batch, channels, scale\_factor * height, scale\_factor * width)` and the data type is the same as `x`.
873
874    Raises:
875        TypeError: If `size` is not one of tuple, list, None.
876        TypeError: If `scale_factor` is neither int nor None.
877        TypeError: If `align_corners` is not a bool.
878        TypeError: If dtype of `x` is neither float16 nor float32.
879        ValueError: If `size` and `scale_factor` are both None or not None.
880        ValueError: If length of shape of `x` is not equal to 4.
881        ValueError: If `scale_factor` is an int which is less than 0.
882        ValueError: If `size` is a list or tuple whose length is not equal to 2.
883
884    Supported Platforms:
885        ``Ascend`` ``CPU`` ``GPU``
886
887    Examples:
888        >>> x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)
889        >>> resize_bilinear = nn.ResizeBilinear()
890        >>> result = resize_bilinear(x, size=(5,5))
891        >>> print(x)
892        [[[[1. 2. 3. 4.]
893           [5. 6. 7. 8.]]]]
894        >>> print(result)
895        [[[[1.        1.8       2.6       3.4       4.       ]
896           [2.6       3.4       4.2000003 5.        5.6000004]
897           [4.2       5.0000005 5.8       6.6       7.2      ]
898           [5.        5.8       6.6       7.4       8.       ]
899           [5.        5.8       6.6       7.4000006 8.       ]]]]
900        >>> print(result.shape)
901        (1, 1, 5, 5)
902    """
903
904    def __init__(self):
905        """Initialize ResizeBilinear."""
906        super(ResizeBilinear, self).__init__()
907
908    def construct(self, x, size=None, scale_factor=None, align_corners=False):
909        shape = bilinear(x.shape, size, scale_factor, align_corners, self.cls_name)
910        resize_bilinear = P.ResizeBilinear(shape, align_corners)
911        return resize_bilinear(x)
912
913
914class Unfold(Cell):
915    r"""
916    Extracts patches from images.
917    The input tensor must be a 4-D tensor and the data format is NCHW.
918
919    Args:
920        ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
921            and the format is [1, ksize_row, ksize_col, 1].
922        strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
923            must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
924        rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
925            pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
926        padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", not case sensitive.
927            Default: "valid".
928
929            - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
930
931            - valid: Means that the taken patch area must be completely covered in the original image.
932
933    Inputs:
934        - **x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_depth, in_row, in_col] and
935          data type is number.
936
937    Outputs:
938        Tensor, a 4-D tensor whose data type is same as `x`,
939        and the shape is [out_batch, out_depth, out_row, out_col] where `out_batch` is the same as the `in_batch`.
940
941        :math:`out\_depth = ksize\_row * ksize\_col * in\_depth`
942
943        :math:`out\_row = (in\_row - (ksize\_row + (ksize\_row - 1) * (rate\_row - 1))) // stride\_row + 1`
944
945        :math:`out\_col = (in\_col - (ksize\_col + (ksize\_col - 1) * (rate\_col - 1))) // stride\_col + 1`
946
947    Raises:
948        TypeError: If `ksizes`, `strides` or `rates` is neither a tuple nor list.
949        ValueError: If shape of `ksizes`, `strides` or `rates` is not (1, x_row, x_col, 1).
950        ValueError: If the second and third element of `ksizes`, `strides` or `rates` is less than 1.
951
952    Supported Platforms:
953        ``Ascend``
954
955    Examples:
956        >>> net = Unfold(ksizes=[1, 2, 2, 1], strides=[1, 2, 2, 1], rates=[1, 2, 2, 1])
957        >>> # As stated in the above code:
958        >>> # ksize_row = 2, ksize_col = 2, rate_row = 2, rate_col = 2, stride_row = 2, stride_col = 2.
959        >>> image = Tensor(np.ones([2, 3, 6, 6]), dtype=mstype.float16)
960        >>> # in_batch = 2, in_depth = 3, in_row = 6, in_col = 6.
961        >>> # Substituting the formula to get:
962        >>> # out_batch = in_batch = 2
963        >>> # out_depth = 2 * 2 * 3 = 12
964        >>> # out_row = (6 - (2 + (2 - 1) * (2 - 1))) // 2 + 1 = 2
965        >>> # out_col = (6 - (2 + (2 - 1) * (2 - 1))) // 2 + 1 = 2
966        >>> output = net(image)
967        >>> print(output.shape)
968        (2, 12, 2, 2)
969    """
970
971    def __init__(self, ksizes, strides, rates, padding="valid"):
972        """Initialize Unfold."""
973        super(Unfold, self).__init__()
974
975        def _check_tuple_or_list(arg_name, arg_val, prim_name):
976            Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name)
977            if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
978                raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' should be [1, {arg_name}_row, "
979                                 f"{arg_name}_col, 1], but got {arg_val}.")
980            if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
981                raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in '{arg_name}s' should be "
982                                 f"an positive integer number, but got {arg_name}_row is {arg_val[1]}, "
983                                 f"{arg_name}_col is {arg_val[2]}")
984
985        _check_tuple_or_list("ksize", ksizes, self.cls_name)
986        _check_tuple_or_list("stride", strides, self.cls_name)
987        _check_tuple_or_list("rate", rates, self.cls_name)
988        ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
989        strides = strides[0], strides[3], strides[1], strides[2]
990        rates = rates[0], rates[3], rates[1], rates[2]
991        self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
992
993    def construct(self, input_x):
994        result = self.extract_image_patches(input_x)
995        return result
996
997
998@constexpr
999def tril(x_shape, x_dtype, k):
1000    Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril")
1001    Validator.check_is_int(k, "k value", "tril")
1002    mask = np.tril(np.ones(x_shape), k)
1003    return Tensor(mask, x_dtype)
1004
1005
1006class Tril(Cell):
1007    """
1008    Returns a tensor with elements above the kth diagonal zeroed.
1009
1010    Inputs:
1011        - **x** (Tensor) - The input tensor. The data type is Number.
1012          :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1013        - **k** (Int) - The index of diagonal. Default: 0
1014
1015    Outputs:
1016        Tensor, has the same shape and type as input `x`.
1017
1018    Raises:
1019        TypeError: If `k` is not an int.
1020        ValueError: If length of shape of `x` is less than 1.
1021
1022    Supported Platforms:
1023        ``Ascend`` ``GPU`` ``CPU``
1024
1025    Examples:
1026        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1027        ...                      [ 5,  6,  7,  8],
1028        ...                      [10, 11, 12, 13],
1029        ...                      [14, 15, 16, 17]]))
1030        >>> tril = nn.Tril()
1031        >>> result = tril(x)
1032        >>> print(result)
1033        [[ 1  0  0  0]
1034         [ 5  6  0  0]
1035         [10 11 12  0]
1036         [14 15 16 17]]
1037        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1038        ...                      [ 5,  6,  7,  8],
1039        ...                      [10, 11, 12, 13],
1040        ...                      [14, 15, 16, 17]]))
1041        >>> tril = nn.Tril()
1042        >>> result = tril(x, 1)
1043        >>> print(result)
1044        [[ 1  2  0  0]
1045         [ 5  6  7  0]
1046         [10 11 12 13]
1047         [14 15 16 17]]
1048        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1049        ...                      [ 5,  6,  7,  8],
1050        ...                      [10, 11, 12, 13],
1051        ...                      [14, 15, 16, 17]]))
1052        >>> tril = nn.Tril()
1053        >>> result = tril(x, 2)
1054        >>> print(result)
1055        [[ 1  2  3  0]
1056         [ 5  6  7  8]
1057         [10 11 12 13]
1058         [14 15 16 17]]
1059        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1060        ...                      [ 5,  6,  7,  8],
1061        ...                      [10, 11, 12, 13],
1062        ...                      [14, 15, 16, 17]]))
1063        >>> tril = nn.Tril()
1064        >>> result = tril(x, -1)
1065        >>> print(result)
1066        [[ 0  0  0  0]
1067         [ 5  0  0  0]
1068         [10 11  0  0]
1069         [14 15 16  0]]
1070    """
1071
1072    def __init__(self):
1073        """Initialize Tril."""
1074        super(Tril, self).__init__()
1075        self.dtype = P.DType()
1076        self.mul = P.Mul()
1077        self.cast = P.Cast()
1078
1079    def construct(self, x, k=0):
1080        assist = tril(x.shape, self.dtype(x), k)
1081        result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
1082        return self.cast(result, self.dtype(x))
1083
1084
1085@constexpr
1086def triu(x_shape, x_dtype, k):
1087    Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu")
1088    Validator.check_is_int(k, "k value", "triu")
1089    mask = np.triu(np.ones(x_shape), k)
1090    return Tensor(mask, x_dtype)
1091
1092
1093class Triu(Cell):
1094    """
1095    Returns a tensor with elements below the kth diagonal zeroed.
1096
1097    Inputs:
1098        - **x** (Tensor) - The input tensor. The data type is Number.
1099          :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1100        - **k** (Int) - The index of diagonal. Default: 0
1101
1102    Outputs:
1103        Tensor, has the same type and shape as input `x`.
1104
1105    Raises:
1106        TypeError: If `k` is not an int.
1107        ValueError: If length of shape of `x` is less than 1.
1108
1109    Supported Platforms:
1110        ``Ascend`` ``GPU`` ``CPU``
1111
1112    Examples:
1113        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1114        ...                      [ 5,  6,  7,  8],
1115        ...                      [10, 11, 12, 13],
1116        ...                      [14, 15, 16, 17]]))
1117        >>> triu = nn.Triu()
1118        >>> result = triu(x)
1119        >>> print(result)
1120        [[ 1  2  3  4]
1121         [ 0  6  7  8]
1122         [ 0  0 12 13]
1123         [ 0  0  0 17]]
1124        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1125        ...                      [ 5,  6,  7,  8],
1126        ...                      [10, 11, 12, 13],
1127        ...                      [14, 15, 16, 17]]))
1128        >>> triu = nn.Triu()
1129        >>> result = triu(x, 1)
1130        >>> print(result)
1131        [[ 0  2  3  4]
1132         [ 0  0  7  8]
1133         [ 0  0  0 13]
1134         [ 0  0  0  0]]
1135        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1136        ...                      [ 5,  6,  7,  8],
1137        ...                      [10, 11, 12, 13],
1138        ...                      [14, 15, 16, 17]]))
1139        >>> triu = nn.Triu()
1140        >>> result = triu(x, 2)
1141        >>> print(result)
1142        [[ 0  0  3  4]
1143         [ 0  0  0  8]
1144         [ 0  0  0  0]
1145         [ 0  0  0  0]]
1146        >>> x = Tensor(np.array([[ 1,  2,  3,  4],
1147        ...                      [ 5,  6,  7,  8],
1148        ...                      [10, 11, 12, 13],
1149        ...                      [14, 15, 16, 17]]))
1150        >>> triu = nn.Triu()
1151        >>> result = triu(x, -1)
1152        >>> print(result)
1153        [[ 1  2  3  4]
1154         [ 5  6  7  8]
1155         [ 0 11 12 13]
1156         [ 0  0 16 17]]
1157    """
1158
1159    def __init__(self):
1160        """Initialize Triu."""
1161        super(Triu, self).__init__()
1162        self.dtype = P.DType()
1163        self.mul = P.Mul()
1164        self.cast = P.Cast()
1165
1166    def construct(self, x, k=0):
1167        assist = triu(x.shape, self.dtype(x), k)
1168        result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
1169        return self.cast(result, self.dtype(x))
1170
1171
1172@constexpr
1173def _get_matrix_diag_assist(x_shape, x_dtype):
1174    Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")
1175    base_eye = np.eye(x_shape[-1], x_shape[-1]).reshape(-1)
1176    assist = np.tile(base_eye, x_shape[:-1]).reshape(x_shape + (x_shape[-1],))
1177    return Tensor(assist, x_dtype)
1178
1179
1180@constexpr
1181def _get_matrix_diag_part_assist(x_shape, x_dtype):
1182    Validator.check_int(len(x_shape), 2, Rel.GE, "x rank", "_get_matrix_diag_part_assist")
1183    base_eye = np.eye(x_shape[-2], x_shape[-1]).reshape(-1)
1184    assist = np.tile(base_eye, x_shape[:-2]).reshape(x_shape)
1185    return Tensor(assist, x_dtype)
1186
1187
1188class MatrixDiag(Cell):
1189    r"""
1190    Returns a batched diagonal tensor with a given batched diagonal values.
1191
1192    Assume `x` has :math:`k` dimensions :math:`[I, J, K, ..., N]`, then the output is a tensor of rank
1193    :math:`k+1` with dimensions :math:`[I, J, K, ..., N, N]` where:
1194    :math:`output[i, j, k, ..., m, n] = 1\{m=n\} * x[i, j, k, ..., n]`
1195
1196    Inputs:
1197        - **x** (Tensor) - The diagonal values. It can be one of the following data types:
1198          float32, float16, int32, int8, and uint8.
1199          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1200
1201    Outputs:
1202        Tensor, has the same type as input `x`. The shape must be x.shape + (x.shape[-1], ).
1203
1204    Raises:
1205        TypeError: If dtype of `x` is not one of float32, float16, int32, int8 or uint8.
1206
1207    Supported Platforms:
1208        ``Ascend``
1209
1210    Examples:
1211        >>> x = Tensor(np.array([1, -1]), mindspore.float32)
1212        >>> matrix_diag = nn.MatrixDiag()
1213        >>> output = matrix_diag(x)
1214        >>> print(x.shape)
1215        (2,)
1216        >>> print(output)
1217        [[ 1.  0.]
1218         [ 0. -1.]]
1219        >>> print(output.shape)
1220        (2, 2)
1221        >>> x = Tensor(np.array([[1, -1], [1, -1]]), mindspore.float32)
1222        >>> matrix_diag = nn.MatrixDiag()
1223        >>> output = matrix_diag(x)
1224        >>> print(x.shape)
1225        (2, 2)
1226        >>> print(output)
1227        [[[ 1.  0.]
1228          [ 0. -1.]]
1229         [[ 1.  0.]
1230          [ 0. -1.]]]
1231        >>> print(output.shape)
1232        (2, 2, 2)
1233        >>> x = Tensor(np.array([[1, -1, 1], [1, -1, 1]]), mindspore.float32)
1234        >>> matrix_diag = nn.MatrixDiag()
1235        >>> output = matrix_diag(x)
1236        >>> print(x.shape)
1237        (2, 3)
1238        >>> print(output)
1239        [[[ 1.  0.  0.]
1240          [ 0. -1.  0.]
1241          [ 0.  0.  1.]
1242         [[ 1.  0.  0.]
1243          [ 0. -1.  0.]
1244          [ 0.  0.  1.]]]
1245        >>> print(output.shape)
1246        (2, 3, 3)
1247    """
1248
1249    def __init__(self):
1250        """Initialize MatrixDiag."""
1251        super(MatrixDiag, self).__init__()
1252        self.matrix_diag = inner.MatrixDiag()
1253        self.dtype = P.DType()
1254
1255    def construct(self, input_x):
1256        x_shape = F.shape(input_x)
1257        x_dtype = self.dtype(input_x)
1258        assist = _get_matrix_diag_assist(x_shape, x_dtype)
1259        out_matrix_diag = self.matrix_diag(input_x, assist)
1260        return out_matrix_diag
1261
1262
1263class MatrixDiagPart(Cell):
1264    r"""
1265    Returns the batched diagonal part of a batched tensor.
1266
1267    Assume `x` has :math:`k` dimensions :math:`[I, J, K, ..., M, N]`, then the output is a tensor of rank
1268    :math:`k-1` with dimensions :math:`[I, J, K, ..., min(M, N)]` where:
1269    :math:`output[i, j, k, ..., n] = x[i, j, k, ..., n, n]`
1270
1271    Inputs:
1272        - **x** (Tensor) - The batched tensor. It can be one of the following data types:
1273          float32, float16, int32, int8, and uint8.
1274
1275    Outputs:
1276        Tensor, has the same type as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
1277
1278    Raises:
1279        TypeError: If dtype of `x` is not one of float32, float16, int32, int8 or uint8.
1280
1281    Supported Platforms:
1282        ``Ascend``
1283
1284    Examples:
1285        >>> x = Tensor([[[-1, 0], [0, 1]],
1286        ...             [[-1, 0], [0, 1]],
1287        ...             [[-1, 0], [0, 1]]], mindspore.float32)
1288        >>> matrix_diag_part = nn.MatrixDiagPart()
1289        >>> output = matrix_diag_part(x)
1290        >>> print(output)
1291        [[-1.  1.]
1292         [-1.  1.]
1293         [-1.  1.]]
1294        >>> x = Tensor([[-1, 0, 0, 1],
1295        ...             [-1, 0, 0, 1],
1296        ...             [-1, 0, 0, 1],
1297        ...             [-1, 0, 0, 1]], mindspore.float32)
1298        >>> matrix_diag_part = nn.MatrixDiagPart()
1299        >>> output = matrix_diag_part(x)
1300        >>> print(output)
1301        [-1 0 0 1]
1302    """
1303
1304    def __init__(self):
1305        """Initialize MatrixDiagPart."""
1306        super(MatrixDiagPart, self).__init__()
1307        self.matrix_diag_part = inner.MatrixDiagPart()
1308        self.dtype = P.DType()
1309
1310    def construct(self, input_x):
1311        x_shape = F.shape(input_x)
1312        x_dtype = self.dtype(input_x)
1313        assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
1314        out_matrix_diag_part = self.matrix_diag_part(input_x, assist)
1315        return out_matrix_diag_part
1316
1317
1318class MatrixSetDiag(Cell):
1319    r"""
1320    Modifies the batched diagonal part of a batched tensor.
1321
1322    Assume `x` has :math:`k+1` dimensions :math:`[I, J, K, ..., M, N]` and `diagonal` has :math:`k`
1323    dimensions :math:`[I, J, K, ..., min(M, N)]`. Then the output is a tensor of rank :math:`k+1` with dimensions
1324    :math:`[I, J, K, ..., M, N]` where:
1325
1326    .. math::
1327        output[i, j, k, ..., m, n] = diagnoal[i, j, k, ..., n]\ for\ m == n
1328
1329    .. math::
1330        output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n]\ for\ m != n
1331
1332    Inputs:
1333        - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
1334          float32, float16, int32, int8, and uint8.
1335        - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
1336
1337    Outputs:
1338        Tensor, has the same type and shape as input `x`.
1339
1340    Raises:
1341        TypeError: If dtype of `x` or `diagonal` is not one of float32, float16, int32, int8 or uint8.
1342        ValueError: If length of shape of `x` is less than 2.
1343        ValueError: If x_shape[-2] < x_shape[-1] and x_shape[:-1] != diagonal_shape.
1344        ValueError: If x_shape[-2] >= x_shape[-1] and x_shape[:-2] + x_shape[-1:] != diagonal_shape.
1345
1346    Supported Platforms:
1347        ``Ascend``
1348
1349    Examples:
1350        >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
1351        >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
1352        >>> matrix_set_diag = nn.MatrixSetDiag()
1353        >>> output = matrix_set_diag(x, diagonal)
1354        >>> print(output)
1355        [[[-1.  0.]
1356          [ 0.  2.]]
1357         [[-1.  0.]
1358          [ 0.  1.]]
1359         [[-1.  0.]
1360          [ 0.  1.]]]
1361    """
1362
1363    def __init__(self):
1364        """Initialize MatrixSetDiag."""
1365        super(MatrixSetDiag, self).__init__()
1366        self.matrix_set_diag = inner.MatrixSetDiag()
1367        self.dtype = P.DType()
1368
1369    def construct(self, input_x, diagonal):
1370        x_shape = F.shape(input_x)
1371        x_dtype = self.dtype(input_x)
1372        assist = _get_matrix_diag_part_assist(x_shape, x_dtype)
1373        out_matrix_set_diag = self.matrix_set_diag(input_x, diagonal, assist)
1374        return out_matrix_set_diag
1375
1376
1377@constexpr
1378def _check_input_dim(axis, dim, cls_name):
1379    Validator.check_int_range(axis, -dim, dim, Rel.INC_LEFT, 'axis', cls_name)
1380
1381
1382class Roll(Cell):
1383    """
1384    Rolls the elements of a tensor along an axis.
1385
1386    The elements are shifted positively (towards larger indices) by the offset of `shift` along the dimension of `axis`.
1387    Negative `shift` values will shift elements in the opposite direction. Elements that roll passed the last position
1388    will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified.
1389
1390    Args:
1391        shift (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted
1392            positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
1393            in the opposite direction.
1394        axis (Union[list(int), tuple(int), int]): Specifies the dimension indexes of shape to be rolled.
1395
1396    Inputs:
1397        - **input_x** (Tensor) - Input tensor.
1398
1399    Outputs:
1400        Tensor, has the same shape and type as `input_x`.
1401
1402    Raises:
1403        TypeError: If `shift` is not an int, a tuple or a list.
1404        TypeError: If `axis` is not an int, a tuple or a list.
1405        TypeError: If element of `shift` is not an int.
1406        TypeError: If element of `axis` is not an int.
1407        ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
1408        ValueError: If length of shape of `shift` is not equal to length of shape of `axis`.
1409
1410    Supported Platforms:
1411        ``Ascend``
1412
1413    Examples:
1414        >>> input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32))
1415        >>> op = nn.Roll(shift=2, axis=0)
1416        >>> output = op(input_x)
1417        >>> print(output)
1418        [3. 4. 0. 1. 2.]
1419        >>> input_x = Tensor(np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).astype(np.float32))
1420        >>> op = nn.Roll(shift=[1, -2], axis=[0, 1])
1421        >>> output = op(input_x)
1422        >>> print(output)
1423        [[7. 8. 9. 5. 6.]
1424         [2. 3. 4. 0. 1.]]
1425    """
1426
1427    def __init__(self, shift, axis):
1428        """Initialize Roll"""
1429        super(Roll, self).__init__()
1430        Validator.check_value_type("shift", shift, [int, tuple, list], self.cls_name)
1431        Validator.check_value_type("axis", axis, [int, tuple, list], self.cls_name)
1432        self.shape_op = P.Shape()
1433        self.shift = shift
1434        self.axis = axis
1435        self.op_list = []
1436
1437        if not isinstance(self.axis, (list, tuple)):
1438            self.op_list.append((inner.Roll(shift=self.shift, axis=0), self.axis))
1439        else:
1440            if len(self.shift) != len(self.axis):
1441                raise ValueError(f"For '{self.cls_name}', the shape of 'shift' and the shape of 'axis' must be "
1442                                 f"the same, but got the length of 'shift' {len(self.shift)} and the length of 'axis'"
1443                                 f" {len(self.axis)}.")
1444            for idx, _ in enumerate(self.axis):
1445                self.op_list.append((inner.Roll(shift=self.shift[idx], axis=0), self.axis[idx]))
1446
1447    def construct(self, input_x):
1448        dim = len(self.shape_op(input_x))
1449        for single_op_roll, single_axis in self.op_list:
1450            _check_input_dim(single_axis, dim, self.cls_name)
1451            if single_axis < 0:
1452                single_axis += dim
1453            transpose_perm = []
1454            for i in range(dim):
1455                transpose_perm.append(i)
1456            transpose_perm[0], transpose_perm[single_axis] = single_axis, 0
1457
1458            input_x = input_x.transpose(transpose_perm)
1459            input_x = single_op_roll(input_x)
1460            input_x = input_x.transpose(transpose_perm)
1461        return input_x
1462