• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Other operators."""
17import functools
18from mindspore import log as logger
19from mindspore.ops import signature as sig
20from mindspore import _checkparam as validator
21from mindspore.common import dtype as mstype
22from mindspore.ops.primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
23from mindspore.ops.operations._pyfunc_registry import add_pyfunc
24from mindspore._c_expression import typing
25from mindspore.ops._primitive_cache import _get_cache_prim
26from ..auto_generate import Assign, Identity
27
28
29class Load(PrimitiveWithCheck):
30    """
31    Load `Parameter` to a value.
32
33    Inputs:
34        - **variable** (Parameter) - The `Parameter`.
35
36    Outputs:
37        Tensor - The loaded parameter tensor value.
38    """
39    __mindspore_signature__ = (
40        sig.make_sig('variable', sig.sig_rw.RW_READ, dtype=sig.sig_dtype.T),
41        sig.make_sig('u', dtype=sig.sig_dtype.T1)
42    )
43
44    @prim_attr_register
45    def __init__(self):
46        """Initialize Load."""
47        self.init_prim_io_names(inputs=['ref', 'u'], outputs=['output'])
48
49    def __call__(self, *args):
50        return _get_cache_prim(Identity)()(args[0])
51
52    def check_dtype(self, variable):
53        if variable != mstype.type_refkey:
54            validator.check_tensors_dtypes_same_and_valid({"variable": variable}, mstype.number_type, self.name)
55
56
57class _DynamicLossScale(PrimitiveWithInfer):
58    """
59    Dynamic multi layer loss scale operator.
60
61    Inputs:
62        - **input_x** (Tensor) - Output of last operator.
63        - **loss_scale** (Tensor) - Dynamic loss scale.
64
65    Outputs:
66        Tensor - The same as `input_x`.
67    """
68    __mindspore_signature__ = (
69        sig.make_sig('input_x', dtype=sig.sig_dtype.T),
70        sig.make_sig('loss_scale', dtype=sig.sig_dtype.T)
71    )
72
73    @prim_attr_register
74    def __init__(self, layer=-1):
75        """Initialize DynamicLossScale."""
76        validator.check_value_type('layer', layer, (int,), self.name)
77        self.init_prim_io_names(inputs=['input_x', 'loss_scale'], outputs=['output'])
78
79    def infer_shape(self, input_x, loss_scale):
80        return input_x
81
82    def infer_dtype(self, input_x, loss_scale):
83        return input_x
84
85
86class BoundingBoxEncode(PrimitiveWithInfer):
87    """
88    Encodes bounding boxes locations.
89
90    This operator will calculate the offset between the predicted bounding boxes and the real bounding boxes,
91    and this offset will be used as a variable for the loss.
92
93    Args:
94        means (tuple): Means for encoding bounding boxes calculation. Default: ``(0.0, 0.0, 0.0, 0.0)`` .
95        stds (tuple): The standard deviations of deltas calculation. Default: ``(1.0, 1.0, 1.0, 1.0)`` .
96
97    Inputs:
98        - **anchor_box** (Tensor) - Anchor boxes. The shape of anchor_box must be :math:`(n, 4)`.
99        - **groundtruth_box** (Tensor) - Ground truth boxes. Which has the same shape with anchor_box.
100
101    Outputs:
102        Tensor, encoded bounding boxes. It has the same data type and shape as input `anchor_box`.
103
104    Raises:
105        TypeError: If `means` or `stds` is not a tuple.
106        TypeError: If `anchor_box` or `groundtruth_box` is not a Tensor.
107
108    Supported Platforms:
109        ``Ascend`` ``GPU`` ``CPU``
110
111    Examples:
112        >>> import mindspore
113        >>> from mindspore import Tensor, ops
114        >>> anchor_box = Tensor([[2, 2, 2, 3], [2, 2, 2, 3]], mindspore.float32)
115        >>> groundtruth_box = Tensor([[1, 2, 1, 4], [1, 2, 1, 4]], mindspore.float32)
116        >>> boundingbox_encode = ops.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
117        >>> output = boundingbox_encode(anchor_box, groundtruth_box)
118        >>> print(output)
119        [[ -1.  0.25  0.  0.40551758]
120         [ -1.  0.25  0.  0.40551758]]
121    """
122
123    @prim_attr_register
124    def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
125        """Initialize BoundingBoxEncode."""
126        validator.check_value_type('means', means, tuple, self.name)
127        validator.check_value_type('stds', stds, tuple, self.name)
128        for i, value in enumerate(means):
129            validator.check_value_type("means[%d]" % i, value, [float], self.name)
130        for i, value in enumerate(stds):
131            validator.check_value_type("stds[%d]" % i, value, [float], self.name)
132        validator.check_equal_int(len(means), 4, "means len", self.name)
133        validator.check_equal_int(len(stds), 4, "stds len", self.name)
134
135
136class BoundingBoxDecode(Primitive):
137    """
138    Decodes bounding boxes locations.
139
140    The function of the operator is to calculate the offset, and this operator converts the offset into a Bbox,
141    which is used to mark the target in the subsequent images, etc.
142
143    Args:
144        max_shape (tuple): The max size limit for decoding box calculation.
145        means (tuple): The means of deltas calculation. Default: ``(0.0, 0.0, 0.0, 0.0)`` .
146        stds (tuple): The standard deviations of deltas calculation. Default: ``(1.0, 1.0, 1.0, 1.0)`` .
147        wh_ratio_clip (float): The limit of width and height ratio for decoding box calculation. Default: ``0.016`` .
148
149    Inputs:
150        - **anchor_box** (Tensor) - Anchor boxes. The shape of `anchor_box` must be :math:`(n, 4)`.
151        - **deltas** (Tensor) - Delta of boxes. Which has the same shape with `anchor_box`.
152
153    Outputs:
154        Tensor, decoded boxes. It has the same data type and shape as `anchor_box`.
155
156    Raises:
157        TypeError: If `means`, `stds` or `max_shape` is not a tuple.
158        TypeError: If `wh_ratio_clip` is not a float.
159        TypeError: If `anchor_box` or `deltas` is not a Tensor.
160
161    Supported Platforms:
162        ``Ascend`` ``GPU`` ``CPU``
163
164    Examples:
165        >>> import mindspore
166        >>> from mindspore import Tensor, ops
167        >>> anchor_box = Tensor([[4, 1, 2, 1], [2, 2, 2, 3]], mindspore.float32)
168        >>> deltas = Tensor([[3, 1, 2, 2], [1, 2, 1, 4]], mindspore.float32)
169        >>> boundingbox_decode = ops.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0),
170        ...                                          max_shape=(768, 1280), wh_ratio_clip=0.016)
171        >>> output = boundingbox_decode(anchor_box, deltas)
172        >>> print(output)
173        [[ 4.194528  0.         0.         5.194528]
174         [ 2.1408591   0.         3.8591409  60.598152  ]]
175
176    """
177
178    @prim_attr_register
179    def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
180        """Initialize BoundingBoxDecode."""
181        validator.check_value_type('means', means, tuple, self.name)
182        validator.check_value_type('stds', stds, tuple, self.name)
183        for i, value in enumerate(means):
184            validator.check_value_type("means[%d]" % i, value, [float], self.name)
185        for i, value in enumerate(stds):
186            validator.check_value_type("stds[%d]" % i, value, [float], self.name)
187        validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
188        validator.check_equal_int(len(means), 4, "means len", self.name)
189        validator.check_equal_int(len(stds), 4, "stds len", self.name)
190        if max_shape is not None:
191            validator.check_value_type('max_shape', max_shape, [tuple], self.name)
192            validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name)
193
194
195class SampleDistortedBoundingBoxV2(Primitive):
196    r"""
197    Creates a single bounding box that is randomly distorted for an image.
198
199    It is often used for object localization and image recognition tasks.
200    In such tasks, bounding box annotations are supplied in addition to ground-truth
201    labels, and data augmentation techniques are often used to randomly distort an image
202    while preserving its content.
203
204    This function takes the `image_size`, `bounding_boxes`, and
205    a series of constraints as input, and outputs a randomly distorted localization of an
206    object (i.e., bounding box) based on these inputs.
207
208    The output is returned as 3 tensors:
209
210    The output is returned as 3 tensors:
211    `begin`, `size` and `bboxes`. The first 2 tensors can be fed directly
212    into :class:`mindspore.ops.Slice` to crop the image.
213    The latter is the generated distorted bounding box.
214
215    Args:
216        seed (int, optional): Random number seed. If either `seed` or `seed2` is set to a non-zero value,
217            the seed is to the given value. Otherwise, a random seed is uesed. Default: ``0`` .
218        seed2 (int, optional): The second seed to avoid seed collision. Default: ``0`` .
219        aspect_ratio_range (Union[list(float), tuple(float)], optional): Specifying the valild range of aspect
220            ratio of cropped area. Aspect ratio of area = area_width / area_height. The value of this
221            attribute should be positive. Default: ``(0.75, 1.33)`` .
222        area_range (Union[list(float), tuple(float)], optional): The cropped area of the image must contain a
223            fraction of the supplied image within this range. The value of this attribute should
224            be in range (0.0, 1.0]. Default: ``(0.05, 1.0)`` .
225        max_attempts (int, optional): A poditive integer specifies the number of attempts that will be made to
226            generate a cropped region of the image based on the given constraints. If the maximum number of
227            attempts is exceeded without success, the function will return the entire original image.
228            Default: ``100`` .
229        use_image_if_no_bounding_boxes (bool, optional): Controls behavior if no bounding boxes supplied.
230            If no bounding boxes supplied (`bounding_boxes` in shape :math:`(0, N, 4)` or :math:`(batch, 0, 4)`), and
231            this attribute is set True, then assume an implicit bounding box covering the
232            whole input, else if this attribute is set False, then raise an error. Default: ``False`` .
233
234    Inputs:
235        - **image_size** (Tensor) - 1-D Tensor, containing [height, width, channels]. The value of this input
236          tensor should be positive.
237        - **bounding_boxes** (Tensor) - 3-D Tensor with shape :math:`(batch, N, 4)` describing the N
238          bounding boxes associated with the image. The value of this input tensor should be in range [0.0, 1.0].
239          The data type is float32.
240        - **min_object_covered** (Tensor) - The least fraction of bounding box the croped area need to cover.
241          This parameter's value should be between 0.0 and 1.0, inclusive. If the value is 0,
242          the cropped area does not need to overlap with any of the supplied bounding boxes.
243          The data type is float32.
244
245    Outputs:
246        - **begin** (Tensor) - A 1-D Tensor, containing [offset_height, offset_width, 0]. The data type is same as
247          `image_size`.
248        - **size** (Tensor) - A 1-D Tensor, containing [target_height, target_width, -1]. The data type is same as
249          `image_size`. When the data type of `image_size` is uint8, the last value of `size`,
250          which is originally -1, will be forced to 255.
251        - **bboxes** (Tensor) - A 3-D Tensor with shape :math:`(1, 1, 4)`, containing
252          the distorted bounding box. The data type is float32.
253
254    Raises:
255        TypeError: If `image_size` is not a Tensor.
256        TypeError: If `bounding_boxes` is not a Tensor.
257        TypeError: If `min_object_covered` is not a Tensor.
258        TypeError: If `seed` or `seed2` is not an int.
259        TypeError: If `aspect_ratio_range` is not a list or a tuple with type float.
260        TypeError: If `area_range` is not a list or a tuple with type float.
261        TypeError: If `use_image_if_no_bounding_boxes` is not a bool.
262        ValueError: If the dimension of `image_size` is not 1.
263        ValueError: If the elements of `image_size` is not 3.
264        ValueError: If the dimension of `bounding_boxes` is not 3.
265        ValueError: If the elements of each bounding box in `bounding_boxes` is not 4.
266        ValueError: If the elements of `min_object_covered` is not 1.
267        ValueError: If the elements of `aspect_ratio_range` list or tuple is not 2.
268        ValueError: If the values of `aspect_ratio_range` is not positive.
269        ValueError: If the second value of `aspect_ratio_range` is less than or equal to the first one.
270        ValueError: If the elements of `area_range` list or tuple is not 2.
271        ValueError: If the values of `area_range` is out of range (0.0, 1.0].
272        ValueError: If the second value of `area_range` is less than or equal to the first one.
273        ValueError: If the value of `max_attempts` is not positive int.
274        ValueError: If `use_image_if_no_bounding_boxes` is False and no bounding boxes supplied.
275        RuntimeError: If the values of `image_size` is not positive.
276        RuntimeError: If the values of `bounding_boxes` is out of range [0.0, 1.0].
277        RuntimeError: If the `bounding_boxes` cannot make up bounding box.
278        RuntimeError: If the value of `min_object_covered` is out of range [0.0, 1.0].
279
280    Supported Platforms:
281        ``Ascend`` ``CPU``
282
283    Examples:
284        >>> image_size = Tensor([640, 480, 3], mindspore.int32)
285        >>> bounding_boxes = Tensor([[[0.38, 0.17, 0.95, 0.40]]], mindspore.float32)
286        >>> min_object_covered = Tensor([0.8], mindspore.float32)
287        >>> sample_distorted_bounding_box_v2 = \
288        ...   ops.SampleDistortedBoundingBoxV2(seed=1, seed2=1, aspect_ratio_range=(0.9, 1.1),
289        ...                                    area_range=(0.1,1.0), max_attempts=100,
290        ...                                    use_image_if_no_bounding_boxes=False)
291        >>> output = sample_distorted_bounding_box_v2(image_size, bounding_boxes, min_object_covered)
292        >>> begin, size, bboxes = output[0], output[1], output[2]
293        >>> print(begin)
294        [133   1   0]
295        >>> print(size)
296        [502 457  -1]
297        >>> print(bboxes)
298        [[[0.2078125  0.00208333 0.9921875  0.95416665]]]
299    """
300
301    @prim_attr_register
302    def __init__(self, seed=0, seed2=0, \
303                  aspect_ratio_range=(0.75, 1.33), \
304                  area_range=(0.05, 1.0), \
305                  max_attempts=100, \
306                  use_image_if_no_bounding_boxes=False):
307        validator.check_is_int(seed, "seed", self.name)
308        validator.check_is_int(seed2, "seed2", self.name)
309        validator.check_value_type("aspect_ratio_range", aspect_ratio_range, [list, tuple], self.name)
310        validator.check_value_type("area_range", area_range, [list, tuple], self.name)
311        validator.check_positive_int(max_attempts, "max_attempts", self.name)
312        validator.check_bool(use_image_if_no_bounding_boxes, "use_image_if_no_bounding_boxes", self.name)
313        for i, value in enumerate(aspect_ratio_range):
314            validator.check_value_type("aspect_ratio_range[%d]" % i, value, [float], self.name)
315        for i, value in enumerate(area_range):
316            validator.check_value_type("area_range[%d]" % i, value, [float], self.name)
317
318
319class CheckValid(Primitive):
320    """
321    Checks bounding box.
322
323    Checks whether the bounding boxes specified by `bboxes` is valid.
324    Returns True if the box is within borders specified by `img_metas`, False if not.
325
326    Inputs:
327        - **bboxes** (Tensor) - Bounding boxes tensor with shape :math:`(N, 4)`. :math:`N` indicates the number of
328          bounding boxes, the value "4" indicates "x0", "y0", "x1", and "y1". Data type must be float16 or float32.
329        - **img_metas** (Tensor) - Raw image size information with the format of :math:`(height, width, ratio)`,
330          specifying the valid boundary :math:`(height * ratio, width * ratio)`. Data type must be float16 or float32.
331
332    Outputs:
333        Tensor, with shape of :math:`(N,)` and dtype of bool, specifying whether the bounding boxes is in the image.
334        "True" indicates valid, while "False" indicates invalid.
335
336    Raises:
337        TypeError: If `bboxes` or `img_metas` is not a Tensor.
338        TypeError: If dtype of `bboxes` or `img_metas` is neither float16 nor float32.
339
340    Supported Platforms:
341        ``Ascend`` ``GPU`` ``CPU``
342
343    Examples:
344        >>> import mindspore
345        >>> import mindspore.nn as nn
346        >>> import numpy as np
347        >>> from mindspore import Tensor, ops
348        >>> class Net(nn.Cell):
349        ...     def __init__(self):
350        ...         super(Net, self).__init__()
351        ...         self.check_valid = ops.CheckValid()
352        ...     def construct(self, x, y):
353        ...         valid_result = self.check_valid(x, y)
354        ...         return valid_result
355        ...
356        >>> bboxes = Tensor(np.linspace(0, 6, 12).reshape(3, 4), mindspore.float32)
357        >>> img_metas = Tensor(np.array([2, 1, 3]), mindspore.float32)
358        >>> net = Net()
359        >>> output = net(bboxes, img_metas)
360        >>> print(output)
361        [ True False False]
362    """
363
364    @prim_attr_register
365    def __init__(self):
366        """Initialize CheckValid."""
367        self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
368
369
370class IOU(Primitive):
371    r"""
372    Calculates intersection over union for boxes.
373
374    Computes the intersection over union (IOU) or the intersection over foreground (IOF) based on the ground-truth and
375    predicted regions.
376
377    Refer to :func:`mindspore.ops.iou` for more details.
378
379    Args:
380        mode (string): The mode is used to specify the calculation method,
381                       now supporting ``'iou'`` (intersection over union) or ``'iof'``
382                       (intersection over foreground) mode. Default: ``'iou'`` .
383
384    Inputs:
385        - **anchor_boxes** (Tensor) - Anchor boxes, tensor of shape :math:`(N, 4)`.
386          "N" indicates the number of anchor boxes,
387          and the value "4" refers to "x0", "y0", "x1", and "y1". Data type must be float16 or float32.
388        - **gt_boxes** (Tensor) - Ground truth boxes, tensor of shape :math:`(M, 4)`. "M" indicates the number of ground
389          truth boxes, and the value "4" refers to "x0", "y0", "x1", and "y1". Data type must be float16 or float32.
390
391    Outputs:
392        Tensor, the 'iou' values, tensor of shape :math:`(M, N)`, with the same data type as `anchor_boxes`.
393
394    Supported Platforms:
395        ``Ascend`` ``GPU`` ``CPU``
396
397    Examples:
398        >>> import mindspore
399        >>> import numpy as np
400        >>> from mindspore import Tensor, ops
401        >>> iou = ops.IOU(mode='iou')
402        >>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
403        >>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
404        >>> output = iou(anchor_boxes, gt_boxes)
405        >>> print(output.shape)
406        (3, 3)
407    """
408
409    @prim_attr_register
410    def __init__(self, mode='iou'):
411        """Initialize IOU."""
412        if mode not in {'iou', 'iof'}:
413            raise KeyError(f"For '{self.name}', only 'iou' or 'iof' are supported, but got 'mode': {mode}.")
414        self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
415
416
417class Partial(Primitive):
418    """
419    Makes a partial function instance. Partial function can be used to derived specialized
420    functions from general functions by fixing the value of certain number of arguments.
421
422    Inputs:
423        - **args** (Union[FunctionType, Tensor]) - The function and bind arguments.
424
425    Outputs:
426        FunctionType, partial function bound with arguments.
427
428    Supported Platforms:
429        ``Ascend`` ``GPU`` ``CPU``
430
431    Examples:
432        >>> from mindspore import Tensor
433        >>> from mindspore import ops
434        >>> def show_input(x, y, z):
435        ...     return x, y, z
436        >>> partial = ops.Partial()
437        >>> partial_show_input = partial(show_input, Tensor(1))
438        >>> output1 = partial_show_input(Tensor(2), Tensor(3))
439        >>> print(output1)
440        (Tensor(shape=[], dtype=Int64, value= 1), Tensor(shape=[], dtype=Int64, value= 2), Tensor(shape=[], dtype=Int64,
441         value= 3))
442        >>> output2 = partial_show_input(Tensor(3), Tensor(4))
443        >>> print(output2)
444        (Tensor(shape=[], dtype=Int64, value= 1), Tensor(shape=[], dtype=Int64, value= 3), Tensor(shape=[], dtype=Int64,
445         value= 4))
446    """
447
448    # Side effect will propagated from the first argument to return value.
449    side_effect_propagate = 1
450
451    @prim_attr_register
452    def __init__(self):
453        """Initialize Partial."""
454        self.add_prim_attr('side_effect_propagate', 1)
455
456    def __call__(self, *args):
457        func = args[0].__call__
458        partial_func = functools.partial(func, *args[1:])
459        return partial_func
460
461
462class Depend(Primitive):
463    """
464    Depend is used for processing dependency operations.
465
466    In most scenarios, if operators have IO side effects or memory side effects,
467    they will be executed according to the user's semantics. In some scenarios,
468    if the two operators A and B have no order dependency, and A must be executed
469    before B, we recommend using Depend to specify their execution order. The
470    usage method is as follows::
471
472        a = A(x)                --->        a = A(x)
473        b = B(y)                --->        y = Depend(y, a)
474                                --->        b = B(y)
475
476    Inputs:
477        - **value** (Tensor) - the real value to return for depend operator.
478        - **expr** (Expression) - the expression to execute with no outputs.
479
480    Outputs:
481        Tensor, the value passed by last operator.
482
483    Supported Platforms:
484        ``Ascend`` ``GPU`` ``CPU``
485
486    Examples:
487        >>> import numpy as np
488        >>> import mindspore
489        >>> import mindspore.nn as nn
490        >>> from mindspore import ops
491        >>> from mindspore import Tensor
492        >>> class Net(nn.Cell):
493        ...     def __init__(self):
494        ...         super(Net, self).__init__()
495        ...         self.softmax = ops.Softmax()
496        ...         self.depend = ops.Depend()
497        ...
498        ...     def construct(self, x, y):
499        ...         mul = x * y
500        ...         y = self.depend(y, mul)
501        ...         ret = self.softmax(y)
502        ...         return ret
503        ...
504        >>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
505        >>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
506        >>> net = Net()
507        >>> output = net(x, y)
508        >>> print(output)
509        [[0.2 0.2 0.2 0.2 0.2]
510         [0.2 0.2 0.2 0.2 0.2]
511         [0.2 0.2 0.2 0.2 0.2]
512         [0.2 0.2 0.2 0.2 0.2]]
513    """
514
515    # Side effect will propagated from the first argument to return value.
516    side_effect_propagate = 1
517
518    @prim_attr_register
519    def __init__(self):
520        """Initialize Depend."""
521        self.add_prim_attr('side_effect_propagate', 1)
522
523    def __call__(self, value, expr):
524        return value
525
526
527class UpdateState(Primitive):
528    """
529    UpdateState is used for update side-effect state.
530
531    Inputs:
532        - **value** (State) - the state value to be updated.
533        - **expr** (Expression) - the expression to evaluate before state changes.
534
535    Outputs:
536        State, the updated state value.
537    """
538
539    @prim_attr_register
540    def __init__(self):
541        pass
542
543    def __call__(self, *args):
544        return args[0]
545
546
547class StopGradient(Primitive):
548    """
549    StopGradient is used for eliminating the effect of a value on the gradient,
550    such as truncating the gradient propagation from an output of a function.
551
552    Refer to :func:`mindspore.ops.stop_gradient` for more details.
553
554    Inputs:
555        - **value** (Any) - The value whose effect on the gradient to be eliminated.
556
557    Outputs:
558        The same as `value`.
559
560    Supported Platforms:
561        ``Ascend`` ``GPU`` ``CPU``
562
563    Examples:
564        >>> from mindspore import ops
565        >>> from mindspore import Tensor
566        >>> from mindspore import dtype as mstype
567        >>> def net(x, y):
568        ...     out1 = ops.MatMul()(x, y)
569        ...     out2 = ops.MatMul()(x, y)
570        ...     out2 = ops.StopGradient()(out2)
571        ...     return out1, out2
572        ...
573        >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
574        >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
575        >>> grad_fn = ops.grad(net)
576        >>> output = grad_fn(x, y)
577        >>> print(output)
578        [[1.4100001 1.6       6.5999994]
579         [1.4100001 1.6       6.5999994]]
580    """
581
582    @prim_attr_register
583    def __init__(self):
584        pass
585
586
587class ConfusionMatrix(PrimitiveWithInfer):
588    r"""
589    Calculates the confusion matrix from labels and predictions.
590
591    Args:
592        num_classes (int): The num of classes.
593        dtype (str): Data type of confusion matrix. Default: ``'int32'`` .
594
595    Inputs:
596        - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer.
597        - **predictions** (Tensor) - the labels from prediction, tensor of 1-D.
598          the shape same as `labels` and the dtype must be non-negative Integer.
599        - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`.
600
601    Outputs:
602        Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`).
603
604    Raises:
605        TypeError: If `num_classes` is not an int.
606        TypeError: If `dtype` is not a str.
607        TypeError: If `labels`, `predictions` or weight` is not a Tensor.
608
609    Examples:
610        >>> confusion_matrix = ops.ConfusionMatrix(4)
611        >>> labels = Tensor([0, 1, 1, 3], mindspore.int32)
612        >>> predictions = Tensor([1, 2, 1, 3], mindspore.int32)
613        >>> output = confusion_matrix(labels, predictions)
614        >>> print(output)
615        [[0 1 0 0]
616         [0 1 1 0]
617         [0 0 0 0]
618         [0 0 0 1]]
619    """
620
621    @prim_attr_register
622    def __init__(self, num_classes, dtype="int32"):
623        """Initialize ConfusionMatrix."""
624        validator.check_value_type("num_classes", num_classes, [int], self.name)
625        validator.check_value_type("dtype", dtype, [str], self.name)
626
627    def infer_shape(self, labels, predictions, weights=None):
628        validator.check('labels dimension', len(labels), '', 1, validator.EQ, self.name)
629        validator.check('labels shape', labels, 'predictions shape', predictions, validator.EQ, self.name)
630        if weights is not None:
631            validator.check('labels shape', labels, 'weights shape', weights, validator.EQ, self.name)
632        ret = (self.num_classes, self.num_classes)
633        return ret
634
635    def infer_dtype(self, labels, predictions, weights=None):
636        validator.check_subclass('labels', labels, mstype.tensor_type, self.name)
637        validator.check_subclass('predictions', predictions, mstype.tensor_type, self.name)
638        if weights is not None:
639            validator.check_subclass('weights', weights, mstype.tensor_type, self.name)
640        args = {"labels": labels, "predictions": predictions}
641        validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name)
642        return labels
643
644
645class Push(PrimitiveWithInfer):
646    """
647    Pushes the inputs of the corresponding optimizer to parameter server.
648
649    Args:
650        optim_type (string): The optimizer type. Default: ``'ApplyMomentum'`` .
651        only_shape_indices (list): The indices of input of which only shape
652                                   will be pushed to parameter server. Default: ``None`` .
653
654    Inputs:
655        - **optim_inputs** (tuple) - The inputs for this kind of optimizer.
656        - **optim_input_shapes** (tuple) - The shapes of the inputs.
657
658    Outputs:
659        Tensor, the key of the weight which needs to be updated.
660    """
661
662    @prim_attr_register
663    def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None):
664        """Initialize Push"""
665        self.add_prim_attr("primitive_target", "CPU")
666        self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key'])
667        self.add_prim_attr("side_effect_hidden", True)
668
669    def infer_shape(self, inputs, shapes):
670        return [1]
671
672    def infer_dtype(self, inputs, shapes):
673        return mstype.uint64
674
675
676class Pull(PrimitiveWithInfer):
677    """
678    Pulls weight from parameter server.
679
680    Inputs:
681        - **key** (Tensor) - The key of the weight.
682        - **weight** (Tensor) - The weight to be updated.
683
684    Outputs:
685        None.
686    """
687
688    @prim_attr_register
689    def __init__(self):
690        """Initialize Pull"""
691        self.add_prim_attr("primitive_target", "CPU")
692        self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output'])
693
694    def infer_shape(self, key_shape, weight_shape):
695        return [1]
696
697    def infer_dtype(self, key_dtype, weight_dtype):
698        return mstype.float32
699
700
701class PyInterpret(Primitive):
702    r"""
703    Interpret Python expression.
704    """
705
706    @prim_attr_register
707    def __init__(self):
708        super(PyInterpret, self).__init__(self.__class__.__name__)
709        self.add_prim_attr('side_effect_io', True)
710
711
712class PyExecute(PrimitiveWithInfer):
713    r"""
714    Execute Python expression.
715    """
716
717    @prim_attr_register
718    def __init__(self):
719        super(PyExecute, self).__init__(self.__class__.__name__)
720        self.add_prim_attr('side_effect_io', True)
721        self.add_prim_attr("primitive_target", "CPU")
722
723    def infer_shape(self, *args):
724        logger.error("The function output are empty tuple. Add a placeholder instead. "
725                     "Do not use it as it could be any uninitialized data.")
726        return ((1,),)
727
728    def infer_dtype(self, *args):
729        logger.error("The function output are empty tuple. Add a placeholder instead. "
730                     "Do not use it as it could be any uninitialized data.")
731        return (mstype.int32,)
732
733
734class PyFunc(PrimitiveWithInfer):
735    r"""
736    Execute Python function.
737
738    `PyFunc` encapsulates Python functions as an operator which could be compiled into computation graph.
739    Unlike normal operators, it cannot be exported to MindIR as it is executed in current Python context.
740    As only the weights of the network is stored in the checkpoint, network include `PyFunc` could save
741    checkpoint and load to the network again, but will lose any Python function state.
742
743    .. warning::
744        This is an experimental API that is subject to change or deletion.
745
746    Args:
747        fn (function): Python function which inputs and outputs should be Python built-in scalar or numpy ndarray.
748        in_types (list[:class:`mindspore.dtype`]): The type of the inputs.
749        in_shapes (list[tuple[int]]): The dimensionality of the inputs. An empty list represents a scalar, otherwise it
750                                      represent a numpy array.
751        out_types (list[:class:`mindspore.dtype`]): The type of the outputs.
752        out_shapes (list[tuple[int]]): The dimensionality of the outputs. An empty list represents a scalar, otherwise
753                                       it represent a numpy array.
754        stateful (bool): Whether the function is stateful or not.
755                         If True, the execution order is same with model definition.
756
757    Inputs:
758        - **input_x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
759          is made up of multiple tensors.
760
761    Outputs:
762        tuple[Tensor], execution results Python functions.
763
764    Raises:
765        TypeError: The Python function execution failed.
766        TypeError: The attributes(in_types/in_shapes/out_types/out_shapes) are inconsistent with Python function
767                   specifications.
768
769    Supported Platforms:
770        ``CPU``
771
772    Examples:
773        >>> def func(x1, x2):
774        ...     return x1 + x2
775        >>> x1 = Tensor(np.array([1, 2, 3]).astype(np.float32))
776        >>> x2 = Tensor(np.array([1, 2, 3]).astype(np.float32))
777        >>> op = P.PyFunc(func, [x1.dtype, x2.dtype], [x1.shape, x2.shape], [x1.dtype], [x1.shape])
778        >>> output = op((x1, x2))
779        >>> print(output[0].asnumpy())
780        [2. 4. 6.]
781    """
782
783    def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True):
784        super(PyFunc, self).__init__(self.__class__.__name__)
785        add_pyfunc(id(fn), fn)
786        self.add_prim_attr('fn_id', id(fn))
787        self.add_prim_attr('in_types', in_types)
788        self.add_prim_attr('in_shapes', in_shapes)
789        self.add_prim_attr('out_types', out_types)
790        self.add_prim_attr('out_shapes', out_shapes)
791        validator.check_value_type("in_types", in_types, [list, tuple], self.name)
792        validator.check_value_type("in_shapes", in_shapes, [list, tuple], self.name)
793        validator.check("in_types length", len(in_types), "in_shapes length", len(in_shapes), validator.EQ, self.name)
794        validator.check_value_type("out_types", out_types, [list, tuple], self.name)
795        validator.check_value_type("out_shapes", out_shapes, [list, tuple], self.name)
796        validator.check("out_types length", len(out_types), "out_shapes length",
797                        len(out_shapes), validator.EQ, self.name)
798        self.add_prim_attr("side_effect_io", stateful)
799        self.add_prim_attr("primitive_target", "CPU")
800        fake_output = False
801        single_scalar_output = False
802        if not out_types:
803            fake_output = True
804        elif not out_shapes:
805            single_scalar_output = True
806        self.add_prim_attr("fake_output", fake_output)
807        self.add_prim_attr("single_scalar_output", single_scalar_output)
808
809    def infer_shape(self, *args):
810        if self.out_shapes:
811            return tuple(self.out_shapes)
812
813        logger.warning("The function output are empty tuple. Add a placeholder instead. "
814                       "Do not use it as it could be any uninitialized data.")
815        return ((1,),)
816
817    def infer_dtype(self, *args):
818        if self.out_shapes:
819            dtype_list = tuple([typing.TensorType(dtype) for dtype in self.out_types])
820            return dtype_list
821
822        logger.warning("The function output are empty tuple. Add a placeholder instead. "
823                       "Do not use it as it could be any uninitialized data.")
824        return (typing.TensorType(mstype.int32),)
825
826
827class Reusing(Primitive):
828    r"""
829    Make the function graph to be labeled as no inline.
830
831    Refer to :func:`mindspore.ops.Reusing` for more details.
832
833    Inputs:
834        - **input_x** (function) - the function will be labeled as no inline.
835
836    Outputs:
837         function, the function that has been labeled as no inline.
838
839    Supported Platforms:
840        ``Ascend`` ``GPU`` ``CPU``
841
842    Examples:
843        >>> import mindspore
844        >>> from mindspore import Tensor, jit
845        >>> from mindspore.common import dtype as mstype
846        >>> from mindspore import ops
847        >>> def for_body_fun(i,val):
848                x = i *3
849                x = x * val * val
850                return x
851        >>> def fori_loop(lower, upper, body_fun, init_val):
852                body_fun = ops.reusing(body_fun)
853                val = init_val
854                for i in range(lower, upper):
855                    val = body_fun(i, val)
856                return val
857        >>> @jit
858        >>> def call_fori_loop(x):
859                x = fori_loop(1,10,for_body_fun,x)
860                return x
861        >>> x = Tensor([1], mstype.int32)
862        >>> x = call_fori_loop(x)
863        >>> print(x)
864    """
865
866    @prim_attr_register
867    def __init__(self):
868        """Initialize Reusing"""
869
870    def __call__(self, x):
871        return x
872