1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Other operators.""" 17import functools 18import mindspore.common._monad as monad 19from mindspore import log as logger 20from mindspore.common._decorator import deprecated 21from .. import signature as sig 22from ..._checkparam import Validator as validator, Rel 23from ...common import dtype as mstype 24from ..primitive import Primitive, PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register 25from .._register_for_op import PyFuncRegistry 26 27 28class Assign(Primitive): 29 """ 30 Assigns `Parameter` with a value. 31 32 Inputs of `variable` and `value` comply with the implicit type conversion rules to make the data types consistent. 33 If they have different data types, lower priority data type will be converted to 34 relatively highest priority data type. 35 RuntimeError exception will be thrown when the data type conversion of Parameter is required. 36 37 Inputs: 38 - **variable** (Parameter) - The `Parameter`. 39 :math:`(N,*)` where :math:`*` means ,any number of additional dimensions, its rank should less than 8. 40 - **value** (Tensor) - The value to be assigned, has the same shape with `variable`. 41 42 Outputs: 43 Tensor, has the same data type and shape as original `variable`. 44 45 Raises: 46 TypeError: If `variable` is not a Parameter. 47 TypeError: If `value` is not a Tensor. 48 49 Supported Platforms: 50 ``Ascend`` ``GPU`` ``CPU`` 51 52 Examples: 53 >>> value = Tensor([2.0], mindspore.float32) 54 >>> variable = mindspore.Parameter(Tensor([1.0], mindspore.float32), name="variable") 55 >>> assign = ops.Assign() 56 >>> output = assign(variable, value) 57 >>> print(output) 58 [2.] 59 """ 60 __mindspore_signature__ = ( 61 sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), 62 sig.make_sig('value', dtype=sig.sig_dtype.T), 63 sig.make_sig('u', default=monad.U, dtype=sig.sig_dtype.T1) 64 ) 65 66 @prim_attr_register 67 def __init__(self): 68 """Initialize Assign.""" 69 self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) 70 self.add_prim_attr('side_effect_mem', True) 71 72 73class InplaceAssign(PrimitiveWithInfer): 74 """ 75 Inplace assign `Parameter` with a value. 76 This primitive can only use in graph kernel. 77 78 InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead. 79 80 Inputs: 81 - **variable** (Parameter) - The `Parameter`. 82 - **value** (Tensor) - The value to be assigned. 83 - **depend** (Tensor) - The dependent tensor to keep this op connected in graph. 84 85 Outputs: 86 Tensor, has the same type as original `variable`. 87 88 Raises: 89 TypeError: If `value` or `depend` is not a Tensor. 90 91 Examples: 92 >>> class Net(nn.Cell): 93 ... def __init__(self): 94 ... super(Net, self).__init__() 95 ... self.inplace_assign = ops.InplaceAssign() 96 ... 97 ... def construct(self, x): 98 ... val = x - 1.0 99 ... ret = x + 2.0 100 ... return self.inplace_assign(x, val, ret) 101 ... 102 >>> x = Tensor([2.0], mindspore.float32) 103 >>> net = Net() 104 >>> output = net(x) 105 >>> print(output) 106 """ 107 @deprecated("1.3", "Assign", False) 108 @ prim_attr_register 109 def __init__(self): 110 """Initialize InplaceAssign.""" 111 self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) 112 113 def infer_shape(self, x, y, z): 114 return z 115 116 def infer_dtype(self, x, y, z): 117 return z 118 119 120class Load(PrimitiveWithCheck): 121 """ 122 Load `Parameter` to a value. 123 124 Inputs: 125 - **variable** (Parameter) - The `Parameter`. 126 127 Outputs: 128 Tensor - The loaded parameter tensor value. 129 """ 130 __mindspore_signature__ = ( 131 sig.make_sig('variable', sig.sig_rw.RW_READ, dtype=sig.sig_dtype.T), 132 sig.make_sig('u', dtype=sig.sig_dtype.T1) 133 ) 134 135 @prim_attr_register 136 def __init__(self): 137 """Initialize Load.""" 138 self.init_prim_io_names(inputs=['ref', 'u'], outputs=['output']) 139 140 def check_dtype(self, variable): 141 if variable != mstype.type_refkey: 142 validator.check_tensors_dtypes_same_and_valid({"variable": variable}, mstype.number_type, self.name) 143 144 145class BoundingBoxEncode(PrimitiveWithInfer): 146 """ 147 Encodes bounding boxes locations. 148 149 Args: 150 means (tuple): Means for encoding bounding boxes calculation. Default: (0.0, 0.0, 0.0, 0.0). 151 stds (tuple): The standard deviations of deltas calculation. Default: (1.0, 1.0, 1.0, 1.0). 152 153 Inputs: 154 - **anchor_box** (Tensor) - Anchor boxes. The shape of anchor_box must be (n, 4). 155 - **groundtruth_box** (Tensor) - Ground truth boxes. Which has the same shape with anchor_box. 156 157 Outputs: 158 Tensor, encoded bounding boxes. It has the same data type and shape as input `anchor_box`. 159 160 Raises: 161 TypeError: If `means` or `stds` is not a tuple. 162 TypeError: If `anchor_box` or `groundtruth_box` is not a Tensor. 163 164 Supported Platforms: 165 ``Ascend`` ``GPU`` ``CPU`` 166 167 Examples: 168 >>> anchor_box = Tensor([[2, 2, 2, 3], [2, 2, 2, 3]], mindspore.float32) 169 >>> groundtruth_box = Tensor([[1, 2, 1, 4], [1, 2, 1, 4]], mindspore.float32) 170 >>> boundingbox_encode = ops.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) 171 >>> output = boundingbox_encode(anchor_box, groundtruth_box) 172 >>> print(output) 173 [[ -1. 0.25 0. 0.40551758] 174 [ -1. 0.25 0. 0.40551758]] 175 """ 176 177 @prim_attr_register 178 def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): 179 """Initialize BoundingBoxEncode.""" 180 validator.check_value_type('means', means, tuple, self.name) 181 validator.check_value_type('stds', stds, tuple, self.name) 182 for i, value in enumerate(means): 183 validator.check_value_type("means[%d]" % i, value, [float], self.name) 184 for i, value in enumerate(stds): 185 validator.check_value_type("stds[%d]" % i, value, [float], self.name) 186 validator.check_equal_int(len(means), 4, "means len", self.name) 187 validator.check_equal_int(len(stds), 4, "stds len", self.name) 188 189 def infer_shape(self, anchor_box, groundtruth_box): 190 validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, 191 self.name) 192 validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) 193 validator.check("groundtruth_box rank", len(groundtruth_box), "", 2, Rel.EQ, self.name) 194 validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name) 195 validator.check_equal_int(groundtruth_box[1], 4, 'groundtruth_box shape[1]', self.name) 196 return anchor_box 197 198 def infer_dtype(self, anchor_box, groundtruth_box): 199 args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} 200 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 201 return anchor_box 202 203 204class BoundingBoxDecode(PrimitiveWithInfer): 205 """ 206 Decodes bounding boxes locations. 207 208 Args: 209 means (tuple): The means of deltas calculation. Default: (0.0, 0.0, 0.0, 0.0). 210 stds (tuple): The standard deviations of deltas calculation. Default: (1.0, 1.0, 1.0, 1.0). 211 max_shape (tuple): The max size limit for decoding box calculation. 212 wh_ratio_clip (float): The limit of width and height ratio for decoding box calculation. Default: 0.016. 213 214 Inputs: 215 - **anchor_box** (Tensor) - Anchor boxes. The shape of `anchor_box` must be (n, 4). 216 - **deltas** (Tensor) - Delta of boxes. Which has the same shape with `anchor_box`. 217 218 Outputs: 219 Tensor, decoded boxes. It has the same data type and shape as `anchor_box`. 220 221 Raises: 222 TypeError: If `means`, `stds` or `max_shape` is not a tuple. 223 TypeError: If `wh_ratio_clip` is not a float. 224 TypeError: If `anchor_box` or `deltas` is not a Tensor. 225 226 Supported Platforms: 227 ``Ascend`` ``GPU`` ``CPU`` 228 229 Examples: 230 >>> anchor_box = Tensor([[4, 1, 2, 1], [2, 2, 2, 3]], mindspore.float32) 231 >>> deltas = Tensor([[3, 1, 2, 2], [1, 2, 1, 4]], mindspore.float32) 232 >>> boundingbox_decode = ops.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), 233 ... max_shape=(768, 1280), wh_ratio_clip=0.016) 234 >>> output = boundingbox_decode(anchor_box, deltas) 235 >>> print(output) 236 [[ 4.1953125 0. 0. 5.1953125] 237 [ 2.140625 0. 3.859375 60.59375 ]] 238 239 """ 240 241 @prim_attr_register 242 def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): 243 """Initialize BoundingBoxDecode.""" 244 validator.check_value_type('means', means, tuple, self.name) 245 validator.check_value_type('stds', stds, tuple, self.name) 246 for i, value in enumerate(means): 247 validator.check_value_type("means[%d]" % i, value, [float], self.name) 248 for i, value in enumerate(stds): 249 validator.check_value_type("stds[%d]" % i, value, [float], self.name) 250 validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) 251 validator.check_equal_int(len(means), 4, "means len", self.name) 252 validator.check_equal_int(len(stds), 4, "stds len", self.name) 253 if max_shape is not None: 254 validator.check_value_type('max_shape', max_shape, [tuple], self.name) 255 validator.check_equal_int(len(max_shape), 2, "max_shape len", self.name) 256 257 def infer_shape(self, anchor_box, deltas): 258 validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) 259 validator.check("anchor_box rank", len(anchor_box), "", 2, Rel.EQ, self.name) 260 validator.check("deltas rank", len(deltas), "", 2, Rel.EQ, self.name) 261 validator.check_equal_int(anchor_box[1], 4, 'anchor_box shape[1]', self.name) 262 validator.check_equal_int(deltas[1], 4, 'deltas shape[1]', self.name) 263 return anchor_box 264 265 def infer_dtype(self, anchor_box, deltas): 266 args = {"anchor_box": anchor_box, "deltas": deltas} 267 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 268 return anchor_box 269 270 271class CheckValid(PrimitiveWithInfer): 272 """ 273 Checks bounding box. 274 275 Checks whether the bounding box cross data and data border are valid. 276 277 .. warning:: 278 specifying the valid boundary (heights x ratio, weights x ratio). 279 280 Inputs: 281 - **bboxes** (Tensor) - Bounding boxes tensor with shape (N, 4). Data type must be float16 or float32. 282 - **img_metas** (Tensor) - Raw image size information with the format of (height, width, ratio). 283 Data type must be float16 or float32. 284 285 Outputs: 286 Tensor, with shape of (N,) and dtype of bool. 287 288 Raises: 289 TypeError: If `bboxes` or `img_metas` is not a Tensor. 290 TypeError: If dtype of `bboxes` or `img_metas` is neither float16 nor float32. 291 292 Supported Platforms: 293 ``Ascend`` ``GPU`` ``CPU`` 294 295 Examples: 296 >>> import mindspore 297 >>> import mindspore.nn as nn 298 >>> import numpy as np 299 >>> from mindspore import Tensor, ops 300 >>> class Net(nn.Cell): 301 ... def __init__(self): 302 ... super(Net, self).__init__() 303 ... self.check_valid = ops.CheckValid() 304 ... def construct(self, x, y): 305 ... valid_result = self.check_valid(x, y) 306 ... return valid_result 307 ... 308 >>> bboxes = Tensor(np.linspace(0, 6, 12).reshape(3, 4), mindspore.float32) 309 >>> img_metas = Tensor(np.array([2, 1, 3]), mindspore.float32) 310 >>> net = Net() 311 >>> output = net(bboxes, img_metas) 312 >>> print(output) 313 [ True False False] 314 """ 315 316 @prim_attr_register 317 def __init__(self): 318 """Initialize CheckValid.""" 319 self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) 320 321 def infer_shape(self, bboxes_shape, metas_shape): 322 validator.check("bboxes rank", len(bboxes_shape), "", 2, Rel.EQ, self.name) 323 validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ, self.name) 324 validator.check("img_metas rank", len(metas_shape), "", 1, Rel.EQ, self.name) 325 validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ, self.name) 326 return bboxes_shape[:-1] 327 328 def infer_dtype(self, bboxes_type, metas_type): 329 valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] 330 validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name) 331 validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name) 332 return mstype.bool_ 333 334 335class IOU(PrimitiveWithInfer): 336 r""" 337 Calculates intersection over union for boxes. 338 339 Computes the intersection over union (IOU) or the intersection over foreground (IOF) based on the ground-truth and 340 predicted regions. 341 342 .. math:: 343 \text{IOU} = \frac{\text{Area of Overlap}}{\text{Area of Union}} 344 345 \text{IOF} = \frac{\text{Area of Overlap}}{\text{Area of Ground Truth}} 346 347 .. warning:: 348 In Ascend, only computation of float16 data is supported. To avoid overflow, the input length 349 and width are scaled by 0.2 internally. 350 351 Args: 352 mode (string): The mode is used to specify the calculation method, 353 now supporting 'iou' (intersection over union) or 'iof' 354 (intersection over foreground) mode. Default: 'iou'. 355 356 Inputs: 357 - **anchor_boxes** (Tensor) - Anchor boxes, tensor of shape (N, 4). "N" indicates the number of anchor boxes, 358 and the value "4" refers to "x0", "y0", "x1", and "y1". Data type must be float16 or float32. 359 - **gt_boxes** (Tensor) - Ground truth boxes, tensor of shape (M, 4). "M" indicates the number of ground 360 truth boxes, and the value "4" refers to "x0", "y0", "x1", and "y1". Data type must be float16 or float32. 361 362 Outputs: 363 Tensor, the 'iou' values, tensor of shape (M, N), with the same data type as `anchor_boxes`. 364 365 Raises: 366 KeyError: When `mode` is not 'iou' or 'iof'. 367 368 Supported Platforms: 369 ``Ascend`` ``GPU`` ``CPU`` 370 371 Examples: 372 >>> iou = ops.IOU() 373 >>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16) 374 >>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float16) 375 >>> output = iou(anchor_boxes, gt_boxes) 376 >>> print(output.shape) 377 (3, 3) 378 """ 379 380 @prim_attr_register 381 def __init__(self, mode='iou'): 382 """Initialize IOU.""" 383 if mode not in {'iou', 'iof'}: 384 raise KeyError(f"For '{self.name}', only 'iou' or 'iof' are supported, but got 'mode': {mode}.") 385 self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) 386 387 def infer_shape(self, anchor_boxes, gt_boxes): 388 validator.check_equal_int(gt_boxes[1], 4, 'gt_boxes shape[1]', self.name) 389 validator.check_equal_int(anchor_boxes[1], 4, 'anchor_boxes shape[1]', self.name) 390 validator.check_equal_int(len(anchor_boxes), 2, 'anchor_boxes rank', self.name) 391 validator.check_equal_int(len(gt_boxes), 2, 'gt_boxes rank', self.name) 392 iou = [gt_boxes[0], anchor_boxes[0]] 393 return iou 394 395 def infer_dtype(self, anchor_boxes, gt_boxes): 396 valid_type = [mstype.float32, mstype.float16] 397 validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name) 398 validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name) 399 return anchor_boxes 400 401 402class Partial(Primitive): 403 """ 404 Makes a partial function instance, used for pynative mode. 405 406 Inputs: 407 - **args** (Union[FunctionType, Tensor]) - The function and bind arguments. 408 409 Outputs: 410 FunctionType, partial function binded with arguments. 411 """ 412 413 # Side effect will propagated from the first argument to return value. 414 side_effect_propagate = 1 415 416 @prim_attr_register 417 def __init__(self): 418 """Initialize Partial.""" 419 self.add_prim_attr('side_effect_propagate', 1) 420 421 def __call__(self, *args): 422 func = args[0].__call__ 423 partial_func = functools.partial(func, *args[1:]) 424 return partial_func 425 426 427class Depend(Primitive): 428 """ 429 Depend is used for processing dependency operations. 430 431 In most scenarios, if operators have IO side effects or memory side effects, 432 they will be executed according to the user's semantics. In some scenarios, 433 if the two operators A and B have no order dependency, and A must be executed 434 before B, we recommend using Depend to specify their execution order. The 435 usage method is as follows:: 436 437 a = A(x) ---> a = A(x) 438 b = B(y) ---> y = Depend(y, a) 439 ---> b = B(y) 440 441 Inputs: 442 - **value** (Tensor) - the real value to return for depend operator. 443 - **expr** (Expression) - the expression to execute with no outputs. 444 445 Outputs: 446 Tensor, the value passed by last operator. 447 448 Supported Platforms: 449 ``Ascend`` ``GPU`` ``CPU`` 450 451 Examples: 452 >>> import numpy as np 453 >>> import mindspore 454 >>> import mindspore.nn as nn 455 >>> import mindspore.ops as ops 456 >>> from mindspore import Tensor 457 >>> class Net(nn.Cell): 458 ... def __init__(self): 459 ... super(Net, self).__init__() 460 ... self.softmax = ops.Softmax() 461 ... self.depend = ops.Depend() 462 ... 463 ... def construct(self, x, y): 464 ... mul = x * y 465 ... y = self.depend(y, mul) 466 ... ret = self.softmax(y) 467 ... return ret 468 ... 469 >>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32) 470 >>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32) 471 >>> net = Net() 472 >>> output = net(x, y) 473 >>> print(output) 474 [[0.2 0.2 0.2 0.2 0.2] 475 [0.2 0.2 0.2 0.2 0.2] 476 [0.2 0.2 0.2 0.2 0.2] 477 [0.2 0.2 0.2 0.2 0.2]] 478 """ 479 480 # Side effect will propagated from the first argument to return value. 481 side_effect_propagate = 1 482 483 @prim_attr_register 484 def __init__(self): 485 """Initialize Depend.""" 486 self.add_prim_attr('side_effect_propagate', 1) 487 488 def __call__(self, value, expr): 489 return value 490 491 492class UpdateState(Primitive): 493 """ 494 UpdateState is used for update side-effect state. 495 496 Inputs: 497 - **value** (State) - the state value to be updated. 498 - **expr** (Expression) - the expression to evaluate before state changes. 499 500 Outputs: 501 State, the updated state value. 502 """ 503 504 @prim_attr_register 505 def __init__(self): 506 pass 507 508 def __call__(self, state, expr): 509 return state 510 511 512class CheckBprop(PrimitiveWithInfer): 513 """ 514 Checks whether the data type and the shape of corresponding elements from tuples x and y are the same. 515 516 Inputs: 517 - **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked. 518 - **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against. 519 520 Outputs: 521 (tuple[Tensor]), the `input_x`, 522 if data type and shape of corresponding elements from `input_x` and `input_y` are the same. 523 524 Raises: 525 TypeError: If `input_x` or `input_y` is not a Tensor. 526 527 Examples: 528 >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) 529 >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) 530 >>> out = ops.CheckBprop()(input_x, input_y) 531 """ 532 533 @prim_attr_register 534 def __init__(self, prim_to_check=""): 535 """Initialize CheckBprop""" 536 self.prim_to_check = prim_to_check 537 538 def infer_shape(self, xshapes, yshapes): 539 tips = f'Bprop of {self.prim_to_check}' 540 validator.check_value_type('grads', xshapes, (tuple,), tips) 541 validator.check_value_type('params', yshapes, (tuple,), tips) 542 if len(xshapes) < len(yshapes): 543 raise ValueError(f"For '{tips}', the size of 'input_x.shape' should not be less than {len(yshapes)}, " 544 f"but got {len(xshapes)}.") 545 checking_range = len(yshapes) 546 for i in range(checking_range): 547 xshape = xshapes[i] 548 yshape = yshapes[i] 549 if not xshape or not yshape: 550 continue 551 if xshape != yshape: 552 raise ValueError(f"For '{tips}', the shape of 'input_x' in {i}th index should be {yshape}," 553 f" but got 'input_x[i]': {xshape}.") 554 return xshapes 555 556 def infer_dtype(self, xdtypes, ydtypes): 557 tips = f'Bprop of {self.prim_to_check}' 558 validator.check_value_type('grads', xdtypes, (tuple,), tips) 559 validator.check_value_type('params', ydtypes, (tuple,), tips) 560 if len(xdtypes) < len(ydtypes): 561 raise ValueError(f"For '{tips}', the size of 'input_x.dtype' should not be less than {len(ydtypes)}," 562 f" but got {len(xdtypes)}.") 563 checking_range = len(ydtypes) 564 for i in range(checking_range): 565 xdtype = xdtypes[i] 566 ydtype = ydtypes[i] 567 if isinstance(xdtype, mstype.anything_type) or isinstance(ydtype, mstype.anything_type): 568 continue 569 if isinstance(ydtype, mstype.function_type): 570 if not isinstance(xdtype, mstype.env_type_type): 571 raise TypeError(f"For '{tips}', the dtype of 'input_x' in {i}th index should be " 572 f"{mstype.env_type_type}, but got {xdtype}.") 573 continue 574 if xdtype != ydtype: 575 raise TypeError(f"For '{tips}', the dtype of 'input_x' in {i}th index should be {ydtype}," 576 f" but got {xdtype}.") 577 return xdtypes 578 579 580class ConfusionMatrix(PrimitiveWithInfer): 581 r""" 582 Calculates the confusion matrix from labels and predictions. 583 584 Args: 585 num_classes (int): The num of classes. 586 dtype (str): Data type of confusion matrix. Default: 'int32'. 587 588 Inputs: 589 - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer. 590 - **predictions** (Tensor) - the labels from prediction, tensor of 1-D. 591 the shape same as `labels` and the dtype must be non-negative Integer. 592 - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`. 593 594 Outputs: 595 Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`). 596 597 Raises: 598 TypeError: If `num_classes` is not an int. 599 TypeError: If `dtype` is not a str. 600 TypeError: If `labels`, `predictions` or weight` is not a Tensor. 601 602 Examples: 603 >>> confusion_matrix = ops.ConfusionMatrix(4) 604 >>> labels = Tensor([0, 1, 1, 3], mindspore.int32) 605 >>> predictions = Tensor([1, 2, 1, 3], mindspore.int32) 606 >>> output = confusion_matrix(labels, predictions) 607 >>> print(output) 608 [[0 1 0 0] 609 [0 1 1 0] 610 [0 0 0 0] 611 [0 0 0 1]] 612 """ 613 614 @prim_attr_register 615 def __init__(self, num_classes, dtype="int32"): 616 """Initialize ConfusionMatrix.""" 617 validator.check_value_type("num_classes", num_classes, [int], self.name) 618 validator.check_value_type("dtype", dtype, [str], self.name) 619 620 def infer_shape(self, labels, predictions, weights=None): 621 validator.check('labels dimension', len(labels), '', 1, Rel.EQ, self.name) 622 validator.check('labels shape', labels, 'predictions shape', predictions, Rel.EQ, self.name) 623 if weights is not None: 624 validator.check('labels shape', labels, 'weights shape', weights, Rel.EQ, self.name) 625 ret = (self.num_classes, self.num_classes) 626 return ret 627 628 def infer_dtype(self, labels, predictions, weights=None): 629 validator.check_subclass('labels', labels, mstype.tensor, self.name) 630 validator.check_subclass('predictions', predictions, mstype.tensor, self.name) 631 if weights is not None: 632 validator.check_subclass('weights', weights, mstype.tensor, self.name) 633 args = {"labels": labels, "predictions": predictions} 634 validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name) 635 return labels 636 637 638class PopulationCount(PrimitiveWithInfer): 639 r""" 640 Calculates population count. 641 642 Inputs: 643 - **input** (Tensor) - The data type must be int16 or uint16. 644 645 Outputs: 646 Tensor, with the same shape as the input. 647 648 Raises: 649 TypeError: If `input` is not a Tensor. 650 651 Supported Platforms: 652 ``Ascend`` 653 654 Examples: 655 >>> population_count = ops.PopulationCount() 656 >>> x_input = Tensor([0, 1, 3], mindspore.int16) 657 >>> output = population_count(x_input) 658 >>> print(output) 659 [0 1 2] 660 """ 661 662 @prim_attr_register 663 def __init__(self): 664 pass 665 666 def infer_shape(self, x_shape): 667 return x_shape 668 669 def infer_dtype(self, x_dtype): 670 validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) 671 return mstype.tensor_type(mstype.uint8) 672 673 674class Push(PrimitiveWithInfer): 675 """ 676 Pushes the inputs of the corresponding optimizer to parameter server. 677 678 Args: 679 optim_type (string): The optimizer type. Default: 'ApplyMomentum'. 680 only_shape_indices (list): The indices of input of which only shape 681 will be pushed to parameter server. Default: None. 682 683 Inputs: 684 - **optim_inputs** (tuple) - The inputs for this kind of optimizer. 685 - **optim_input_shapes** (tuple) - The shapes of the inputs. 686 687 Outputs: 688 Tensor, the key of the weight which needs to be updated. 689 """ 690 691 @prim_attr_register 692 def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None): 693 """Initialize Push""" 694 self.add_prim_attr("primitive_target", "CPU") 695 self.add_prim_attr("_side_effect", True) 696 self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key']) 697 698 def infer_shape(self, inputs, shapes): 699 return [1] 700 701 def infer_dtype(self, inputs, shapes): 702 return mstype.uint64 703 704 705class Pull(PrimitiveWithInfer): 706 """ 707 Pulls weight from parameter server. 708 709 Inputs: 710 - **key** (Tensor) - The key of the weight. 711 - **weight** (Tensor) - The weight to be updated. 712 713 Outputs: 714 None. 715 """ 716 717 @prim_attr_register 718 def __init__(self): 719 """Initialize Pull""" 720 self.add_prim_attr("primitive_target", "CPU") 721 self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output']) 722 723 def infer_shape(self, key_shape, weight_shape): 724 return [1] 725 726 def infer_dtype(self, key_dtype, weight_dtype): 727 return mstype.float32 728 729 730class PullWeight(PrimitiveWithInfer): 731 """ 732 Pull weight by its names from server. 733 734 Inputs: 735 - **weight** (Tensor) - The weight to be pulled. 736 - **name** (String) - The full name of the weight. 737 - **index** (Int) - The index of the weight. 738 739 Outputs: 740 None. 741 """ 742 743 @prim_attr_register 744 def __init__(self): 745 """Initialize PullWeight""" 746 self.add_prim_attr("primitive_target", "CPU") 747 self.init_prim_io_names(inputs=['weight', "name", "index"], outputs=['output']) 748 749 def infer_shape(self, weight, name, index): 750 return [1] 751 752 def infer_dtype(self, weight, name, index): 753 return mstype.float32 754 755 756class PushWeight(PrimitiveWithInfer): 757 """ 758 Upload weight by its names to server. 759 760 Inputs: 761 - **weight** (Tensor) - The weight to be uploaded. 762 - **name** (String) - The full name of the weight. 763 - **index** (Int) - The index of the weight. 764 765 Outputs: 766 None. 767 """ 768 769 @prim_attr_register 770 def __init__(self): 771 """Initialize PushWeight""" 772 self.add_prim_attr("primitive_target", "CPU") 773 self.init_prim_io_names(inputs=["weight", "name", "index"], outputs=["output"]) 774 775 def infer_shape(self, weight, name, index): 776 return [1] 777 778 def infer_dtype(self, weight, ps_key, index): 779 return mstype.float32 780 781 782class PushMetrics(PrimitiveWithInfer): 783 """ 784 Push metrics like loss and accuracy for federated learning worker. 785 786 Inputs: 787 - **loss** (Tensor) - The loss. 788 - **accuracy** (Tensor) - The accuracy. 789 790 Outputs: 791 None. 792 """ 793 794 @prim_attr_register 795 def __init__(self): 796 """Initialize PushMetrics""" 797 self.add_prim_attr("primitive_target", "CPU") 798 self.add_prim_attr("side_effect_mem", True) 799 self.init_prim_io_names(inputs=["loss", "accuracy"], outputs=["result"]) 800 801 def infer_shape(self, loss, accuracy): 802 return [1] 803 804 def infer_dtype(self, loss, accuracy): 805 return mstype.float32 806 807 808class StartFLJob(PrimitiveWithInfer): 809 """ 810 StartFLJob for federated learning worker. 811 """ 812 @prim_attr_register 813 def __init__(self, data_size): 814 self.add_prim_attr("primitive_target", "CPU") 815 self.add_prim_attr("data_size", data_size) 816 self.init_prim_io_names(inputs=[], outputs=["result"]) 817 818 def infer_shape(self): 819 return [1] 820 821 def infer_dtype(self): 822 return mstype.float32 823 824 825class UpdateModel(PrimitiveWithInfer): 826 """ 827 UpdateModel for federated learning worker. 828 """ 829 @prim_attr_register 830 def __init__(self): 831 self.add_prim_attr("primitive_target", "CPU") 832 self.add_prim_attr('side_effect_mem', True) 833 self.init_prim_io_names(inputs=["weights"], outputs=["result"]) 834 835 def infer_shape(self, weights): 836 return [1] 837 838 def infer_dtype(self, weights): 839 return mstype.float32 840 841 842class GetModel(PrimitiveWithInfer): 843 """ 844 GetModel for federated learning worker. 845 """ 846 @prim_attr_register 847 def __init__(self): 848 self.add_prim_attr("primitive_target", "CPU") 849 self.add_prim_attr('side_effect_mem', True) 850 self.init_prim_io_names(inputs=["weights"], outputs=["result"]) 851 852 def infer_shape(self, weights): 853 return [1] 854 855 def infer_dtype(self, weights): 856 return mstype.float32 857 858 859class identity(Primitive): 860 """ 861 Makes a identify primitive, used for pynative mode. 862 863 Inputs: 864 - **x** (Any) - identity input value. 865 866 Outputs: 867 The same as input. 868 """ 869 870 # Side effect will propagated from the first argument to return value. 871 side_effect_propagate = 1 872 873 @prim_attr_register 874 def __init__(self): 875 """Initialize identity.""" 876 self.add_prim_attr('side_effect_propagate', 1) 877 878 def __call__(self, x): 879 return x 880 881pyfunc_register = PyFuncRegistry() 882 883 884def get_pyfunc(fn_id): 885 return pyfunc_register.get(fn_id) 886 887 888class PyFunc(PrimitiveWithInfer): 889 r""" 890 Execute Python function. 891 892 `PyFunc` encapsulates Python functions as an operator which could be compiled into computation graph. 893 Unlike normal operators, it cannot be exported to MindIR as it is executed in current Python context. 894 As only the weights of the network is stored in the checkpoint, network include `PyFunc` could save 895 checkpoint and load to the network again, but will lose any Python function state. 896 897 .. warning:: 898 This is an experimental prototype that is subject to change and/or deletion. 899 900 Args: 901 fn (function): Python function which inputs and outputs should be Python built-in scalar or numpy ndarray. 902 in_types (list[:class:`mindspore.dtype`]): The type of the inputs. 903 in_shapes (list[tuple[int]]): The dimensionality of the inputs. An empty list represents a scalar, otherwise it 904 represent a numpy array. 905 out_types (list[:class:`mindspore.dtype`]): The type of the outputs. 906 out_shapes (list[tuple[int]]): The dimensionality of the outputs. An empty list represents a scalar, otherwise 907 it represent a numpy array. 908 stateful (bool): Whether the function is stateful or not. 909 If True, the execution order is same with model definition. 910 911 Inputs: 912 - **input_x** (Union(tuple[Tensor], list[Tensor])) - The input tuple or list 913 is made up of multiple tensors. 914 915 Outputs: 916 tuple[Tensor], execution results Python functions. 917 918 Raises: 919 TypeError: The Python function execution failed. 920 TypeError: The attributes(in_types/in_shapes/out_types/out_shapes) are inconsistent with Python function 921 specifications. 922 923 Supported Platforms: 924 ``CPU`` 925 926 Examples: 927 >>> def func(x1, x2): 928 >>> return x1 + x2 929 >>> x1 = Tensor(np.array([1, 2, 3]).astype(np.float32)) 930 >>> x2 = Tensor(np.array([1, 2, 3]).astype(np.float32)) 931 >>> op = P.PyFunc(func, [x1.dtype, x2.dtype], [x1.shape, x2.shape], [x1.dtype], [x1.dtype]) 932 >>> output = op((x1, x2)) 933 >>> print(output[0].asnumpy()) 934 [2. 4. 6.] 935 """ 936 937 def __init__(self, fn, in_types, in_shapes, out_types, out_shapes, stateful=True): 938 super(PyFunc, self).__init__(self.__class__.__name__) 939 pyfunc_register.register(id(fn), fn) 940 self.add_prim_attr('fn_id', id(fn)) 941 self.add_prim_attr('in_types', in_types) 942 self.add_prim_attr('in_shapes', in_shapes) 943 self.add_prim_attr('out_types', out_types) 944 self.add_prim_attr('out_shapes', out_shapes) 945 validator.check_value_type("in_types", in_types, [list, tuple], self.name) 946 validator.check_value_type("in_shapes", in_shapes, [list, tuple], self.name) 947 validator.check("in_types length", len(in_types), "in_shapes length", len(in_shapes), Rel.EQ, self.name) 948 validator.check_value_type("out_types", out_types, [list, tuple], self.name) 949 validator.check_value_type("out_shapes", out_shapes, [list, tuple], self.name) 950 validator.check("out_types length", len(out_types), "out_shapes length", len(out_shapes), Rel.EQ, self.name) 951 self.add_prim_attr("side_effect_io", stateful) 952 self.add_prim_attr("primitive_target", "CPU") 953 954 def infer_shape(self, *args): 955 if self.out_shapes: 956 return tuple(self.out_shapes) 957 958 logger.warning("The function output are empty tuple. Add a placeholder instead. " 959 "Do not use it as it could be any uninitialized data.") 960 return ((1,),) 961 962 def infer_dtype(self, *args): 963 if self.out_shapes: 964 return tuple(self.out_types) 965 966 logger.warning("The function output are empty tuple. Add a placeholder instead. " 967 "Do not use it as it could be any uninitialized data.") 968 return (mstype.int32,) 969