• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Other operators."""
17import functools
18import mindspore.common._monad as monad
19from mindspore import log as logger
20from mindspore.common._decorator import deprecated
21from .. import signature as sig
22from ..._checkparam import Validator as validator, Rel
23from ...common import dtype as mstype
24from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
25from .._register_for_op import PyFuncRegistry
26
27
28class Assign(Primitive):
29    """
30    Assigns `Parameter` with a value.
31
32    Inputs of `variable` and `value` comply with the implicit type conversion rules to make the data types consistent.
33    If they have different data types, lower priority data type will be converted to
34    relatively highest priority data type.
35    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
36
37    Inputs:
38        - **variable** (Parameter) - The `Parameter`.
39          :math:`(N,*)` where :math:`*` means ,any number of additional dimensions, its rank should less than 8.
40        - **value** (Tensor) - The value to be assigned, has the same shape with `variable`.
41
42    Outputs:
43        Tensor, has the same data type and shape as original `variable`.
44
45    Raises:
46        TypeError: If `variable` is not a Parameter.
47        TypeError: If `value` is not a Tensor.
48
49    Supported Platforms:
50        ``Ascend`` ``GPU`` ``CPU``
51
52    Examples:
53        >>> value = Tensor([2.0], mindspore.float32)
54        >>> variable = mindspore.Parameter(Tensor([1.0], mindspore.float32), name="variable")
55        >>> assign = ops.Assign()
56        >>> output = assign(variable, value)
57        >>> print(output)
58        [2.]
59    """
60    __mindspore_signature__ = (
61        sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
62        sig.make_sig('value', dtype=sig.sig_dtype.T),
63        sig.make_sig('u', default=monad.U, dtype=sig.sig_dtype.T1)
64    )
65
66    @prim_attr_register
67    def __init__(self):
68        """Initialize Assign."""
69        self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
70        self.add_prim_attr('side_effect_mem', True)
71
72
73class InplaceAssign(PrimitiveWithInfer):
74    """
75    Inplace assign `Parameter` with a value.
76    This primitive can only use in graph kernel.
77
78    InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead.
79
80    Inputs:
81        - **variable** (Parameter) - The `Parameter`.
82        - **value** (Tensor) - The value to be assigned.
83        - **depend** (Tensor) - The dependent tensor to keep this op connected in graph.
84
85    Outputs:
86        Tensor, has the same type as original `variable`.
87
88    Raises:
89        TypeError: If `value` or `depend` is not a Tensor.
90
91    Examples:
92        >>> class Net(nn.Cell):
93        ...     def __init__(self):
94        ...         super(Net, self).__init__()
95        ...         self.inplace_assign = ops.InplaceAssign()
96        ...
97        ...     def construct(self, x):
98        ...         val = x - 1.0
99        ...         ret = x + 2.0
100        ...         return self.inplace_assign(x, val, ret)
101        ...
102        >>> x = Tensor([2.0], mindspore.float32)
103        >>> net = Net()
104        >>> output = net(x)
105        >>> print(output)
106    """
107    @deprecated("1.3", "Assign", False)
108    @ prim_attr_register
109    def __init__(self):
110        """Initialize InplaceAssign."""
111        self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
112
113    def infer_shape(self, x, y, z):
114        return z
115
116    def infer_dtype(self, x, y, z):
117        return z
118
119
120class Load(PrimitiveWithCheck):
121    """
122    Load `Parameter` to a value.
123
124    Inputs:
125        - **variable** (Parameter) - The `Parameter`.
126
127    Outputs:
128        Tensor - The loaded parameter tensor value.
129    """
130    __mindspore_signature__ = (
131        sig.make_sig('variable', sig.sig_rw.RW_READ, dtype=sig.sig_dtype.T),
132        sig.make_sig('u', dtype=sig.sig_dtype.T1)
133    )
134
135    @prim_attr_register
136    def __init__(self):
137        """Initialize Load."""
138        self.init_prim_io_names(inputs=['ref', 'u'], outputs=['output'])
139
140    def check_dtype(self, variable):
141        if variable != mstype.type_refkey:
142            validator.check_tensors_dtypes_same_and_valid({"variable": variable}, mstype.number_type, self.name)
143
144
145class BoundingBoxEncode(PrimitiveWithInfer):
146    """
147    Encodes bounding boxes locations.
148
149    Args:
150        means (tuple): Means for encoding bounding boxes calculation. Default: (0.0, 0.0, 0.0, 0.0).
151        stds (tuple): The standard deviations of deltas calculation. Default: (1.0, 1.0, 1.0, 1.0).
152
153    Inputs:
154        - **anchor_box** (Tensor) - Anchor boxes. The shape of anchor_box must be (n, 4).
155        - **groundtruth_box** (Tensor) - Ground truth boxes. Which has the same shape with anchor_box.
156
157    Outputs:
158        Tensor, encoded bounding boxes. It has the same data type and shape as input `anchor_box`.
159
160    Raises:
161        TypeError: If `means` or `stds` is not a tuple.
162        TypeError: If `anchor_box` or `groundtruth_box` is not a Tensor.
163
164    Supported Platforms:
165        ``Ascend`` ``GPU`` ``CPU``
166
167    Examples:
168        >>> anchor_box = Tensor([[2, 2, 2, 3], [2, 2, 2, 3]], mindspore.float32)
169        >>> groundtruth_box = Tensor([[1, 2, 1, 4], [1, 2, 1, 4]], mindspore.float32)
170        >>> boundingbox_encode = ops.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
171        >>> output = boundingbox_encode(anchor_box, groundtruth_box)
172        >>> print(output)
173        [[ -1.  0.25  0.  0.40551758]
174         [ -1.  0.25  0.  0.40551758]]
175    """
176
177    @prim_attr_register
178    def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
179        """Initialize BoundingBoxEncode."""
180        validator.check_value_type('means', means, tuple, self.name)
181        validator.check_value_type('stds', stds, tuple, self.name)
182        for i, value in enumerate(means):
183            validator.check_value_type("means[%d]" % i, value, [float], self.name)
184        for i, value in enumerate(stds):
185            validator.check_value_type("stds[%d]" % i, value, [float], self.name)
186        validator.check_equal_int(len(means), 4, "means len", self.name)
187        validator.check_equal_int(len(stds), 4, "stds len", self.name)
188
189    def infer_shape(self, anchor_box, groundtruth_box):
190        validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ,
191                        self.name)
192        validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
193        validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name)
194        validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
195        validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name)
196        return anchor_box
197
198    def infer_dtype(self, anchor_box, groundtruth_box):
199        args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box}
200        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
201        return anchor_box
202
203
204class BoundingBoxDecode(PrimitiveWithInfer):
205    """
206    Decodes bounding boxes locations.
207
208    Args:
209        means (tuple): The means of deltas calculation. Default: (0.0, 0.0, 0.0, 0.0).
210        stds (tuple): The standard deviations of deltas calculation. Default: (1.0, 1.0, 1.0, 1.0).
211        max_shape (tuple): The max size limit for decoding box calculation.
212        wh_ratio_clip (float): The limit of width and height ratio for decoding box calculation. Default: 0.016.
213
214    Inputs:
215        - **anchor_box** (Tensor) - Anchor boxes. The shape of `anchor_box` must be (n, 4).
216        - **deltas** (Tensor) - Delta of boxes. Which has the same shape with `anchor_box`.
217
218    Outputs:
219        Tensor, decoded boxes. It has the same data type and shape as `anchor_box`.
220
221    Raises:
222        TypeError: If `means`, `stds` or `max_shape` is not a tuple.
223        TypeError: If `wh_ratio_clip` is not a float.
224        TypeError: If `anchor_box` or `deltas` is not a Tensor.
225
226    Supported Platforms:
227        ``Ascend`` ``GPU`` ``CPU``
228
229    Examples:
230        >>> anchor_box = Tensor([[4, 1, 2, 1], [2, 2, 2, 3]], mindspore.float32)
231        >>> deltas = Tensor([[3, 1, 2, 2], [1, 2, 1, 4]], mindspore.float32)
232        >>> boundingbox_decode = ops.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0),
233        ...                                          max_shape=(768, 1280), wh_ratio_clip=0.016)
234        >>> output = boundingbox_decode(anchor_box, deltas)
235        >>> print(output)
236        [[ 4.1953125  0.         0.         5.1953125]
237         [ 2.140625   0.         3.859375  60.59375  ]]
238
239    """
240
241    @prim_attr_register
242    def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
243        """Initialize BoundingBoxDecode."""
244        validator.check_value_type('means', means, tuple, self.name)
245        validator.check_value_type('stds', stds, tuple, self.name)
246        for i, value in enumerate(means):
247            validator.check_value_type("means[%d]" % i, value, [float], self.name)
248        for i, value in enumerate(stds):
249            validator.check_value_type("stds[%d]" % i, value, [float], self.name)
250        validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
251        validator.check_equal_int(len(means), 4, "means len", self.name)
252        validator.check_equal_int(len(stds), 4, "stds len", self.name)
253        if max_shape is not None:
254            validator.check_value_type('max_shape', max_shape, [tuple], self.name)
255            validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name)
256
257    def infer_shape(self, anchor_box, deltas):
258        validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name)
259        validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name)
260        validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name)
261        validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name)
262        validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name)
263        return anchor_box
264
265    def infer_dtype(self, anchor_box, deltas):
266        args = {"anchor_box": anchor_box, "deltas": deltas}
267        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
268        return anchor_box
269
270
271class CheckValid(PrimitiveWithInfer):
272    """
273    Checks bounding box.
274
275    Checks whether the bounding box cross data and data border are valid.
276
277    .. warning::
278        specifying the valid boundary (heights x ratio, weights x ratio).
279
280    Inputs:
281        - **bboxes** (Tensor) - Bounding boxes tensor with shape (N, 4). Data type must be float16 or float32.
282        - **img_metas** (Tensor) - Raw image size information with the format of (height, width, ratio).
283          Data type must be float16 or float32.
284
285    Outputs:
286        Tensor, with shape of (N,) and dtype of bool.
287
288    Raises:
289        TypeError: If `bboxes` or `img_metas` is not a Tensor.
290        TypeError: If dtype of `bboxes` or `img_metas` is neither float16 nor float32.
291
292    Supported Platforms:
293        ``Ascend`` ``GPU`` ``CPU``
294
295    Examples:
296        >>> import mindspore
297        >>> import mindspore.nn as nn
298        >>> import numpy as np
299        >>> from mindspore import Tensor, ops
300        >>> class Net(nn.Cell):
301        ...     def __init__(self):
302        ...         super(Net, self).__init__()
303        ...         self.check_valid = ops.CheckValid()
304        ...     def construct(self, x, y):
305        ...         valid_result = self.check_valid(x, y)
306        ...         return valid_result
307        ...
308        >>> bboxes = Tensor(np.linspace(0, 6, 12).reshape(3, 4), mindspore.float32)
309        >>> img_metas = Tensor(np.array([2, 1, 3]), mindspore.float32)
310        >>> net = Net()
311        >>> output = net(bboxes, img_metas)
312        >>> print(output)
313        [ True False False]
314    """
315
316    @prim_attr_register
317    def __init__(self):
318        """Initialize CheckValid."""
319        self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output'])
320
321    def infer_shape(self, bboxes_shape, metas_shape):
322        validator.check("bboxes rank", len(bboxes_shape), "", 2, Rel.EQ, self.name)
323        validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ, self.name)
324        validator.check("img_metas rank", len(metas_shape), "", 1, Rel.EQ, self.name)
325        validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ, self.name)
326        return bboxes_shape[:-1]
327
328    def infer_dtype(self, bboxes_type, metas_type):
329        valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8]
330        validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name)
331        validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name)
332        return mstype.bool_
333
334
335class IOU(PrimitiveWithInfer):
336    r"""
337    Calculates intersection over union for boxes.
338
339    Computes the intersection over union (IOU) or the intersection over foreground (IOF) based on the ground-truth and
340    predicted regions.
341
342    .. math::
343        \text{IOU} = \frac{\text{Area of Overlap}}{\text{Area of Union}}
344
345        \text{IOF} = \frac{\text{Area of Overlap}}{\text{Area of Ground Truth}}
346
347    .. warning::
348        In Ascend, only computation of float16 data is supported. To avoid overflow, the input length
349        and width are scaled by 0.2 internally.
350
351    Args:
352        mode (string): The mode is used to specify the calculation method,
353                       now supporting 'iou' (intersection over union) or 'iof'
354                       (intersection over foreground) mode. Default: 'iou'.
355
356    Inputs:
357        - **anchor_boxes** (Tensor) - Anchor boxes, tensor of shape (N, 4). "N" indicates the number of anchor boxes,
358          and the value "4" refers to "x0", "y0", "x1", and "y1". Data type must be float16 or float32.
359        - **gt_boxes** (Tensor) - Ground truth boxes, tensor of shape (M, 4). "M" indicates the number of ground
360          truth boxes, and the value "4" refers to "x0", "y0", "x1", and "y1". Data type must be float16 or float32.
361
362    Outputs:
363        Tensor, the 'iou' values, tensor of shape (M, N), with the same data type as `anchor_boxes`.
364
365    Raises:
366        KeyError: When `mode` is not 'iou' or 'iof'.
367
368    Supported Platforms:
369        ``Ascend`` ``GPU`` ``CPU``
370
371    Examples:
372        >>> iou = ops.IOU()
373        >>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
374        >>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16)
375        >>> output = iou(anchor_boxes, gt_boxes)
376        >>> print(output.shape)
377        (3, 3)
378    """
379
380    @prim_attr_register
381    def __init__(self, mode='iou'):
382        """Initialize IOU."""
383        if mode not in {'iou', 'iof'}:
384            raise KeyError(f"For '{self.name}', only 'iou' or 'iof' are supported, but got 'mode': {mode}.")
385        self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap'])
386
387    def infer_shape(self, anchor_boxes, gt_boxes):
388        validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name)
389        validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name)
390        validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name)
391        validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name)
392        iou = [gt_boxes[0], anchor_boxes[0]]
393        return iou
394
395    def infer_dtype(self, anchor_boxes, gt_boxes):
396        valid_type = [mstype.float32, mstype.float16]
397        validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name)
398        validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name)
399        return anchor_boxes
400
401
402class Partial(Primitive):
403    """
404    Makes a partial function instance, used for pynative mode.
405
406    Inputs:
407        - **args** (Union[FunctionType, Tensor]) - The function and bind arguments.
408
409    Outputs:
410        FunctionType, partial function binded with arguments.
411    """
412
413    # Side effect will propagated from the first argument to return value.
414    side_effect_propagate = 1
415
416    @prim_attr_register
417    def __init__(self):
418        """Initialize Partial."""
419        self.add_prim_attr('side_effect_propagate', 1)
420
421    def __call__(self, *args):
422        func = args[0].__call__
423        partial_func = functools.partial(func, *args[1:])
424        return partial_func
425
426
427class Depend(Primitive):
428    """
429    Depend is used for processing dependency operations.
430
431    In most scenarios, if operators have IO side effects or memory side effects,
432    they will be executed according to the user's semantics. In some scenarios,
433    if the two operators A and B have no order dependency, and A must be executed
434    before B, we recommend using Depend to specify their execution order. The
435    usage method is as follows::
436
437        a = A(x)                --->        a = A(x)
438        b = B(y)                --->        y = Depend(y, a)
439                                --->        b = B(y)
440
441    Inputs:
442        - **value** (Tensor) - the real value to return for depend operator.
443        - **expr** (Expression) - the expression to execute with no outputs.
444
445    Outputs:
446        Tensor, the value passed by last operator.
447
448    Supported Platforms:
449        ``Ascend`` ``GPU`` ``CPU``
450
451    Examples:
452        >>> import numpy as np
453        >>> import mindspore
454        >>> import mindspore.nn as nn
455        >>> import mindspore.ops as ops
456        >>> from mindspore import Tensor
457        >>> class Net(nn.Cell):
458        ...     def __init__(self):
459        ...         super(Net, self).__init__()
460        ...         self.softmax = ops.Softmax()
461        ...         self.depend = ops.Depend()
462        ...
463        ...     def construct(self, x, y):
464        ...         mul = x * y
465        ...         y = self.depend(y, mul)
466        ...         ret = self.softmax(y)
467        ...         return ret
468        ...
469        >>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
470        >>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
471        >>> net = Net()
472        >>> output = net(x, y)
473        >>> print(output)
474        [[0.2 0.2 0.2 0.2 0.2]
475         [0.2 0.2 0.2 0.2 0.2]
476         [0.2 0.2 0.2 0.2 0.2]
477         [0.2 0.2 0.2 0.2 0.2]]
478    """
479
480    # Side effect will propagated from the first argument to return value.
481    side_effect_propagate = 1
482
483    @prim_attr_register
484    def __init__(self):
485        """Initialize Depend."""
486        self.add_prim_attr('side_effect_propagate', 1)
487
488    def __call__(self, value, expr):
489        return value
490
491
492class UpdateState(Primitive):
493    """
494    UpdateState is used for update side-effect state.
495
496    Inputs:
497        - **value** (State) - the state value to be updated.
498        - **expr** (Expression) - the expression to evaluate before state changes.
499
500    Outputs:
501        State, the updated state value.
502    """
503
504    @prim_attr_register
505    def __init__(self):
506        pass
507
508    def __call__(self, state, expr):
509        return state
510
511
512class CheckBprop(PrimitiveWithInfer):
513    """
514    Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
515
516    Inputs:
517        - **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
518        - **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
519
520    Outputs:
521        (tuple[Tensor]), the `input_x`,
522        if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
523
524    Raises:
525        TypeError: If `input_x` or `input_y` is not a Tensor.
526
527    Examples:
528        >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
529        >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
530        >>> out = ops.CheckBprop()(input_x, input_y)
531    """
532
533    @prim_attr_register
534    def __init__(self, prim_to_check=""):
535        """Initialize CheckBprop"""
536        self.prim_to_check = prim_to_check
537
538    def infer_shape(self, xshapes, yshapes):
539        tips = f'Bprop of {self.prim_to_check}'
540        validator.check_value_type('grads', xshapes, (tuple,), tips)
541        validator.check_value_type('params', yshapes, (tuple,), tips)
542        if len(xshapes) < len(yshapes):
543            raise ValueError(f"For '{tips}', the size of 'input_x.shape' should not be less than {len(yshapes)}, "
544                             f"but got {len(xshapes)}.")
545        checking_range = len(yshapes)
546        for i in range(checking_range):
547            xshape = xshapes[i]
548            yshape = yshapes[i]
549            if not xshape or not yshape:
550                continue
551            if xshape != yshape:
552                raise ValueError(f"For '{tips}', the shape of 'input_x' in {i}th index should be {yshape},"
553                                 f" but got 'input_x[i]': {xshape}.")
554        return xshapes
555
556    def infer_dtype(self, xdtypes, ydtypes):
557        tips = f'Bprop of {self.prim_to_check}'
558        validator.check_value_type('grads', xdtypes, (tuple,), tips)
559        validator.check_value_type('params', ydtypes, (tuple,), tips)
560        if len(xdtypes) < len(ydtypes):
561            raise ValueError(f"For '{tips}', the size of 'input_x.dtype' should not be less than {len(ydtypes)},"
562                             f" but got {len(xdtypes)}.")
563        checking_range = len(ydtypes)
564        for i in range(checking_range):
565            xdtype = xdtypes[i]
566            ydtype = ydtypes[i]
567            if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type):
568                continue
569            if isinstance(ydtype, mstype.function_type):
570                if not isinstance(xdtype, mstype.env_type_type):
571                    raise TypeError(f"For '{tips}', the dtype of 'input_x' in {i}th index should be "
572                                    f"{mstype.env_type_type}, but got {xdtype}.")
573                continue
574            if xdtype != ydtype:
575                raise TypeError(f"For '{tips}', the dtype of 'input_x' in {i}th index should be {ydtype},"
576                                f" but got {xdtype}.")
577        return xdtypes
578
579
580class ConfusionMatrix(PrimitiveWithInfer):
581    r"""
582    Calculates the confusion matrix from labels and predictions.
583
584    Args:
585        num_classes (int): The num of classes.
586        dtype (str): Data type of confusion matrix. Default: 'int32'.
587
588    Inputs:
589        - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer.
590        - **predictions** (Tensor) - the labels from prediction, tensor of 1-D.
591          the shape same as `labels` and the dtype must be non-negative Integer.
592        - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`.
593
594    Outputs:
595        Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`).
596
597    Raises:
598        TypeError: If `num_classes` is not an int.
599        TypeError: If `dtype` is not a str.
600        TypeError: If `labels`, `predictions` or weight` is not a Tensor.
601
602    Examples:
603        >>> confusion_matrix = ops.ConfusionMatrix(4)
604        >>> labels = Tensor([0, 1, 1, 3], mindspore.int32)
605        >>> predictions = Tensor([1, 2, 1, 3], mindspore.int32)
606        >>> output = confusion_matrix(labels, predictions)
607        >>> print(output)
608        [[0 1 0 0]
609         [0 1 1 0]
610         [0 0 0 0]
611         [0 0 0 1]]
612    """
613
614    @prim_attr_register
615    def __init__(self, num_classes, dtype="int32"):
616        """Initialize ConfusionMatrix."""
617        validator.check_value_type("num_classes", num_classes, [int], self.name)
618        validator.check_value_type("dtype", dtype, [str], self.name)
619
620    def infer_shape(self, labels, predictions, weights=None):
621        validator.check('labels dimension', len(labels), '', 1, Rel.EQ, self.name)
622        validator.check('labels shape', labels, 'predictions shape', predictions, Rel.EQ, self.name)
623        if weights is not None:
624            validator.check('labels shape', labels, 'weights shape', weights, Rel.EQ, self.name)
625        ret = (self.num_classes, self.num_classes)
626        return ret
627
628    def infer_dtype(self, labels, predictions, weights=None):
629        validator.check_subclass('labels', labels, mstype.tensor, self.name)
630        validator.check_subclass('predictions', predictions, mstype.tensor, self.name)
631        if weights is not None:
632            validator.check_subclass('weights', weights, mstype.tensor, self.name)
633        args = {"labels": labels, "predictions": predictions}
634        validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name)
635        return labels
636
637
638class PopulationCount(PrimitiveWithInfer):
639    r"""
640    Calculates population count.
641
642    Inputs:
643        - **input** (Tensor) -  The data type must be int16 or uint16.
644
645    Outputs:
646        Tensor, with the same shape as the input.
647
648    Raises:
649        TypeError: If `input` is not a Tensor.
650
651    Supported Platforms:
652        ``Ascend``
653
654    Examples:
655        >>> population_count = ops.PopulationCount()
656        >>> x_input = Tensor([0, 1, 3], mindspore.int16)
657        >>> output = population_count(x_input)
658        >>> print(output)
659        [0 1 2]
660    """
661
662    @prim_attr_register
663    def __init__(self):
664        pass
665
666    def infer_shape(self, x_shape):
667        return x_shape
668
669    def infer_dtype(self, x_dtype):
670        validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name)
671        return mstype.tensor_type(mstype.uint8)
672
673
674class Push(PrimitiveWithInfer):
675    """
676    Pushes the inputs of the corresponding optimizer to parameter server.
677
678    Args:
679        optim_type (string): The optimizer type. Default: 'ApplyMomentum'.
680        only_shape_indices (list): The indices of input of which only shape
681                                   will be pushed to parameter server. Default: None.
682
683    Inputs:
684        - **optim_inputs** (tuple) - The inputs for this kind of optimizer.
685        - **optim_input_shapes** (tuple) - The shapes of the inputs.
686
687    Outputs:
688        Tensor, the key of the weight which needs to be updated.
689    """
690
691    @prim_attr_register
692    def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None):
693        """Initialize Push"""
694        self.add_prim_attr("primitive_target", "CPU")
695        self.add_prim_attr("_side_effect", True)
696        self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key'])
697
698    def infer_shape(self, inputs, shapes):
699        return [1]
700
701    def infer_dtype(self, inputs, shapes):
702        return mstype.uint64
703
704
705class Pull(PrimitiveWithInfer):
706    """
707    Pulls weight from parameter server.
708
709    Inputs:
710        - **key** (Tensor) - The key of the weight.
711        - **weight** (Tensor) - The weight to be updated.
712
713    Outputs:
714        None.
715    """
716
717    @prim_attr_register
718    def __init__(self):
719        """Initialize Pull"""
720        self.add_prim_attr("primitive_target", "CPU")
721        self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output'])
722
723    def infer_shape(self, key_shape, weight_shape):
724        return [1]
725
726    def infer_dtype(self, key_dtype, weight_dtype):
727        return mstype.float32
728
729
730class PullWeight(PrimitiveWithInfer):
731    """
732    Pull weight by its names from server.
733
734    Inputs:
735        - **weight** (Tensor) - The weight to be pulled.
736        - **name** (String) - The full name of the weight.
737        - **index** (Int) - The index of the weight.
738
739    Outputs:
740        None.
741    """
742
743    @prim_attr_register
744    def __init__(self):
745        """Initialize PullWeight"""
746        self.add_prim_attr("primitive_target", "CPU")
747        self.init_prim_io_names(inputs=['weight', "name", "index"], outputs=['output'])
748
749    def infer_shape(self, weight, name, index):
750        return [1]
751
752    def infer_dtype(self, weight, name, index):
753        return mstype.float32
754
755
756class PushWeight(PrimitiveWithInfer):
757    """
758    Upload weight by its names to server.
759
760    Inputs:
761        - **weight** (Tensor) - The weight to be uploaded.
762        - **name** (String) - The full name of the weight.
763        - **index** (Int) - The index of the weight.
764
765    Outputs:
766        None.
767    """
768
769    @prim_attr_register
770    def __init__(self):
771        """Initialize PushWeight"""
772        self.add_prim_attr("primitive_target", "CPU")
773        self.init_prim_io_names(inputs=["weight", "name", "index"], outputs=["output"])
774
775    def infer_shape(self, weight, name, index):
776        return [1]
777
778    def infer_dtype(self, weight, ps_key, index):
779        return mstype.float32
780
781
782class PushMetrics(PrimitiveWithInfer):
783    """
784    Push metrics like loss and accuracy for federated learning worker.
785
786    Inputs:
787        - **loss** (Tensor) - The loss.
788        - **accuracy** (Tensor) - The accuracy.
789
790    Outputs:
791        None.
792    """
793
794    @prim_attr_register
795    def __init__(self):
796        """Initialize PushMetrics"""
797        self.add_prim_attr("primitive_target", "CPU")
798        self.add_prim_attr("side_effect_mem", True)
799        self.init_prim_io_names(inputs=["loss", "accuracy"], outputs=["result"])
800
801    def infer_shape(self, loss, accuracy):
802        return [1]
803
804    def infer_dtype(self, loss, accuracy):
805        return mstype.float32
806
807
808class StartFLJob(PrimitiveWithInfer):
809    """
810    StartFLJob for federated learning worker.
811    """
812    @prim_attr_register
813    def __init__(self, data_size):
814        self.add_prim_attr("primitive_target", "CPU")
815        self.add_prim_attr("data_size", data_size)
816        self.init_prim_io_names(inputs=[], outputs=["result"])
817
818    def infer_shape(self):
819        return [1]
820
821    def infer_dtype(self):
822        return mstype.float32
823
824
825class UpdateModel(PrimitiveWithInfer):
826    """
827    UpdateModel for federated learning worker.
828    """
829    @prim_attr_register
830    def __init__(self):
831        self.add_prim_attr("primitive_target", "CPU")
832        self.add_prim_attr('side_effect_mem', True)
833        self.init_prim_io_names(inputs=["weights"], outputs=["result"])
834
835    def infer_shape(self, weights):
836        return [1]
837
838    def infer_dtype(self, weights):
839        return mstype.float32
840
841
842class GetModel(PrimitiveWithInfer):
843    """
844    GetModel for federated learning worker.
845    """
846    @prim_attr_register
847    def __init__(self):
848        self.add_prim_attr("primitive_target", "CPU")
849        self.add_prim_attr('side_effect_mem', True)
850        self.init_prim_io_names(inputs=["weights"], outputs=["result"])
851
852    def infer_shape(self, weights):
853        return [1]
854
855    def infer_dtype(self, weights):
856        return mstype.float32
857
858
859class identity(Primitive):
860    """
861    Makes a identify primitive, used for pynative mode.
862
863    Inputs:
864        - **x** (Any) - identity input value.
865
866    Outputs:
867        The same as input.
868    """
869
870    # Side effect will propagated from the first argument to return value.
871    side_effect_propagate = 1
872
873    @prim_attr_register
874    def __init__(self):
875        """Initialize identity."""
876        self.add_prim_attr('side_effect_propagate', 1)
877
878    def __call__(self, x):
879        return x
880
881pyfunc_register = PyFuncRegistry()
882
883
884def get_pyfunc(fn_id):
885    return pyfunc_register.get(fn_id)
886
887
888class PyFunc(PrimitiveWithInfer):
889    r"""
890    Execute Python function.
891
892    `PyFunc` encapsulates Python functions as an operator which could be compiled into computation graph.
893    Unlike normal operators, it cannot be exported to MindIR as it is executed in current Python context.
894    As only the weights of the network is stored in the checkpoint, network include `PyFunc` could save
895    checkpoint and load to the network again, but will lose any Python function state.
896
897    .. warning::
898        This is an experimental prototype that is subject to change and/or deletion.
899
900    Args:
901        fn (function): Python function which inputs and outputs should be Python built-in scalar or numpy ndarray.
902        in_types (list[:class:`mindspore.dtype`]): The type of the inputs.
903        in_shapes (list[tuple[int]]): The dimensionality of the inputs. An empty list represents a scalar, otherwise it
904                                      represent a numpy array.
905        out_types (list[:class:`mindspore.dtype`]): The type of the outputs.
906        out_shapes (list[tuple[int]]): The dimensionality of the outputs. An empty list represents a scalar, otherwise
907                                       it represent a numpy array.
908        stateful (bool): Whether the function is stateful or not.
909                         If True, the execution order is same with model definition.
910
911    Inputs:
912        - **input_x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list
913          is made up of multiple tensors.
914
915    Outputs:
916        tuple[Tensor], execution results Python functions.
917
918    Raises:
919        TypeError: The Python function execution failed.
920        TypeError: The attributes(in_types/in_shapes/out_types/out_shapes) are inconsistent with Python function
921                   specifications.
922
923    Supported Platforms:
924        ``CPU``
925
926    Examples:
927        >>> def func(x1, x2):
928        >>>     return x1 + x2
929        >>> x1 = Tensor(np.array([1, 2, 3]).astype(np.float32))
930        >>> x2 = Tensor(np.array([1, 2, 3]).astype(np.float32))
931        >>> op = P.PyFunc(func, [x1.dtype, x2.dtype], [x1.shape, x2.shape], [x1.dtype], [x1.dtype])
932        >>> output = op((x1, x2))
933        >>> print(output[0].asnumpy())
934        [2. 4. 6.]
935    """
936
937    def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True):
938        super(PyFunc, self).__init__(self.__class__.__name__)
939        pyfunc_register.register(id(fn), fn)
940        self.add_prim_attr('fn_id', id(fn))
941        self.add_prim_attr('in_types', in_types)
942        self.add_prim_attr('in_shapes', in_shapes)
943        self.add_prim_attr('out_types', out_types)
944        self.add_prim_attr('out_shapes', out_shapes)
945        validator.check_value_type("in_types", in_types, [list, tuple], self.name)
946        validator.check_value_type("in_shapes", in_shapes, [list, tuple], self.name)
947        validator.check("in_types length", len(in_types), "in_shapes length", len(in_shapes), Rel.EQ, self.name)
948        validator.check_value_type("out_types", out_types, [list, tuple], self.name)
949        validator.check_value_type("out_shapes", out_shapes, [list, tuple], self.name)
950        validator.check("out_types length", len(out_types), "out_shapes length", len(out_shapes), Rel.EQ, self.name)
951        self.add_prim_attr("side_effect_io", stateful)
952        self.add_prim_attr("primitive_target", "CPU")
953
954    def infer_shape(self, *args):
955        if self.out_shapes:
956            return tuple(self.out_shapes)
957
958        logger.warning("The function output are empty tuple. Add a placeholder instead. "
959                       "Do not use it as it could be any uninitialized data.")
960        return ((1,),)
961
962    def infer_dtype(self, *args):
963        if self.out_shapes:
964            return tuple(self.out_types)
965
966        logger.warning("The function output are empty tuple. Add a placeholder instead. "
967                       "Do not use it as it could be any uninitialized data.")
968        return (mstype.int32,)
969