1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Inner operators.""" 17from types import FunctionType, MethodType 18from collections.abc import Iterable 19import os 20import numpy as np 21 22from mindspore.common import Tensor 23from mindspore.common._stub_tensor import StubTensor 24from mindspore.ops import composite as C 25from mindspore.ops.operations.array_ops import Cast 26from mindspore.ops.operations._scalar_ops import bit_or, bit_and 27from mindspore.ops import signature as sig 28from mindspore.ops.operations.math_ops import _infer_shape_reduce 29from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \ 30 _run_op, _check_contains_variable 31from mindspore._c_expression import Tensor as Tensor_ 32from mindspore._c_expression import typing 33from mindspore import _checkparam as validator 34from mindspore.common import dtype as mstype 35from mindspore.common.parameter import Parameter 36from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size 37from mindspore.common.api import _pynative_executor 38from mindspore.common._register_for_adapter import ms_adapter_registry 39from mindspore import ops 40from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \ 41 PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4 42 43# Bit operation 44bit_and = bit_and() 45bit_or = bit_or() 46bit_xor = Primitive("bit_xor") 47bit_left_shift = Primitive("bit_left_shift") 48bit_right_shift = Primitive("bit_right_shift") 49# String operation 50string_lt = Primitive("string_lt") 51string_gt = Primitive("string_gt") 52string_le = Primitive("string_le") 53string_ge = Primitive("string_ge") 54string_not = Primitive("string_not") 55string_in = Primitive("string_in") 56string_mul = Primitive("string_mul") 57string_getitem = Primitive("string_getitem") 58 59 60class Generator(Primitive): 61 r""" 62 Manage the state of random number generation. 63 64 Inputs: 65 - **cmd** (int) : operation to be executed. 66 - **inputs** (tuple[tensor]) : inputs for the operation. 67 68 Outputs: 69 - **seed** (Tensor): Seed for the random number generation algorithm. 70 - **offset** (Tensor): Offset of the random number sequence. 71 - **state** (Tensor): State tensor, can be used to restore current state. 72 """ 73 74 @prim_attr_register 75 def __init__(self): 76 self.add_prim_attr("side_effect_mem", True) 77 78 def __call__(self, cmd, inputs): 79 if cmd == 0: # step cmd 80 return inputs[0], inputs[1] 81 return super().__call__(cmd, inputs) 82 83 84class Quant(PrimitiveWithInfer): 85 r""" 86 Returns the quantized value of input_x. 87 88 If `sqrt_mode` is False: 89 90 .. math:: 91 y = round(scale * x + offset) 92 93 If `sqrt_mode` is True: 94 95 .. math:: 96 y = round(scale * x * scale + offset) 97 98 Note: 99 This operation only support Atlas 200/300/500 inference product. 100 101 Args: 102 scale (float) : Specifies the scaling ratio. 103 offset (float): Specifies the offset. 104 sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``. 105 round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"]. 106 Default: "Round". 107 108 Inputs: 109 - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32. 110 111 Outputs: 112 - Tensor: The quantized output tensor of type mindspore.int8. 113 114 Examples: 115 >>> input_x = Tensor([100.0, 150.0], mstype.float32) 116 >>> quant = ops.Quant(80.0, 0.0, False, "Round") 117 >>> y = quant(input_x) 118 """ 119 120 @prim_attr_register 121 def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"): 122 self.scale = validator.check_value_type("scale", scale, [float], self.name) 123 self.offset = validator.check_value_type("offset", offset, [float], self.name) 124 self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) 125 self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"], 126 "round_mode", self.name) 127 self.add_prim_attr("dst_type", mstype.int8) 128 129 def infer_shape(self, x_shape): 130 return x_shape 131 132 def infer_dtype(self, x_type): 133 validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name) 134 validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name) 135 return self.get_attr_dict()['dst_type'] 136 137 138class Lamb(PrimitiveWithInfer): 139 r""" 140 LAMB optimizer algorithm. 141 142 The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes 143 <https://arxiv.org/abs/1904.00962>`_. 144 145 Inputs: 146 - **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means, 147 any number of additional dimensions. The data type can be float16 or float32. 148 - **m** (Tensor) - The 1st moment vector in the updating formula, 149 the shape and data type value should be the same as `var`. 150 - **v** (Tensor) - the 2nd moment vector in the updating formula, 151 the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`. 152 - **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`, 153 the data type value should be the same as `var`. 154 - **beta1** (float) - The exponential decay rate for the 1st moment estimations, 155 the data type value should be the same as `var`. The paper suggested value is :math:`0.9` 156 - **beta2** (float) - The exponential decay rate for the 2nd moment estimations, 157 the data type value should be the same as `var`. The paper suggested value is :math:`0.999` 158 - **epsilon** (float) - Term added to the denominator to improve numerical stability. 159 - **decay** (float) - The weight decay value, must be a scalar tensor with float data type. 160 Default: 0.0. 161 - **global_step** (Tensor) - Tensor to record current global step. 162 - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`. 163 164 Outputs: 165 Tensor, the updated parameters. 166 167 - **var** (Tensor) - The same shape and data type as `var`. 168 169 Supported Platforms: 170 ``Ascend````GPU`` 171 """ 172 173 @prim_attr_register 174 def __init__(self): 175 """Initialize Lamb.""" 176 self.add_prim_attr('side_effect_mem', True) 177 178 def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, 179 epsilon_shape, decay_shape, global_step_shape, gradient_shape): 180 validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name) 181 validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name) 182 validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, validator.EQ, self.name) 183 return var_shape 184 185 def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, 186 epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype): 187 args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype} 188 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 189 190 args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, 191 "epsilon": epsilon_dtype} 192 validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) 193 return var_dtype 194 195 196class Dequant(PrimitiveWithInfer): 197 r""" 198 Returns the dequantized value of input_x. 199 This operation will do ReLU to the dequantized value if `relu_flag` is True. 200 201 If `sqrt_mode` is False: 202 203 .. math:: 204 y = x * deq\_scale 205 206 If `sqrt_mode` is True: 207 208 .. math:: 209 y = x * deq\_scale * deq\_scale 210 211 Note: 212 This operation only support Atlas 200/300/500 inference product. 213 214 Args: 215 sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``. 216 relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``. 217 218 Inputs: 219 - **input_x** (Tensor) : Input tensor. Must be mindspore.int32. 220 - **deq_scale** (Tensor) : Specifies the scaling ratio. 221 Data type must be mindspore.float16 or mindspore.uint64 222 223 Outputs: 224 - Tensor: The quantized output tensor of type mindspore.float16. 225 226 Examples: 227 >>> input_x = Tensor([100.0, 150.0], mstype.float32) 228 >>> dequant = ops.Dequant(False, False) 229 >>> y = dequant(input_x) 230 """ 231 232 @prim_attr_register 233 def __init__(self, sqrt_mode=False, relu_flag=False, dtype=mstype.float16): 234 self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) 235 self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) 236 self.dtype = dtype 237 238 def infer_shape(self, x_shape, deq_scale_shape): 239 return x_shape 240 241 def infer_dtype(self, x_type, deq_scale_type): 242 validator.check_subclass("x", x_type, mstype.tensor_type, self.name) 243 validator.check_type_name("x", x_type, [mstype.int32], self.name) 244 validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name) 245 return mstype.float16 246 247 248class AntiQuant(Primitive): 249 r""" 250 Returns the antiquantized value of input_x. 251 252 If `sqrt_mode` is False: 253 254 .. math:: 255 y = scale * (x + offset) 256 257 If `sqrt_mode` is True: 258 259 .. math:: 260 y = scale * scale * (x + offset) 261 262 Note: 263 This operation only support Atlas 200/300/500 inference product. 264 265 Args: 266 scale (float) : Specifies the scaling ratio. 267 offset (float): Specifies the offset. 268 sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``. 269 270 Inputs: 271 - **input_x** (Tensor) : Input tensor. Must be mindspore.int8. 272 273 Outputs: 274 - Tensor: The antiquantized output tensor of type mindspore.float32. 275 276 Examples: 277 >>> from mindspore.ops.operations._inner_ops import AntiQuant 278 >>> input_x = Tensor([50.0, 20.0], mstype.int8) 279 >>> antiquant = AntiQuant(2.0, 1.0, False) 280 >>> y = antiquant(input_x) 281 >>> print(y) 282 [102. 42.] 283 """ 284 285 @prim_attr_register 286 def __init__(self, sqrt_mode=False, dtype=mstype.float16): 287 super().__init__("AntiQuant") 288 self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) 289 self.dtype = dtype 290 291 self.init_prim_io_names(inputs=['x', 'scale', 'offset'], 292 outputs=['y']) 293 294 295class MatrixDiag(PrimitiveWithInfer): 296 """ 297 Returns a batched diagonal tensor with a given batched diagonal values. 298 299 Inputs: 300 - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data 301 types: float32, float16, int32, int8, and uint8. 302 - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must be greater than or equal to 2 and 303 it's last dimension must be equal to the second to last dimension. 304 305 Outputs: 306 Tensor, has the same type and shape as input `assist`. 307 308 Examples: 309 >>> x = Tensor(np.array([1, -1]), mstype.float32) 310 >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) 311 >>> matrix_diag = ops.MatrixDiag() 312 >>> result = matrix_diag(x, assist) 313 >>> print(result) 314 [[[-12. 11.] 315 [-10. 9.]] 316 [[ -8. 7.] 317 [ -6. 5.]] 318 [[ -4. 3.] 319 [ -2. 1.]]] 320 """ 321 322 @prim_attr_register 323 def __init__(self): 324 """Initialize MatrixDiag""" 325 326 def infer_dtype(self, x_dtype, assist_dtype): 327 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 328 args = {"x": x_dtype, "assist": assist_dtype} 329 validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) 330 return x_dtype 331 332 def infer_shape(self, x_shape, assist_shape): 333 validator.check_int(len(assist_shape), 2, validator.GE, "assist rank", self.name) 334 validator.check('rank of x', len(x_shape) + 1, 335 'rank of assist', len(assist_shape), validator.LE, self.name) 336 validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', 337 assist_shape[-1], validator.EQ, self.name) 338 339 r_end_dim = -len(x_shape) 340 r_idx = -1 341 while r_idx >= r_end_dim: 342 if x_shape[r_idx] != 1: 343 validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % 344 assist_shape[r_idx - 1], assist_shape[r_idx - 1], validator.EQ, self.name) 345 r_idx = r_idx - 1 346 347 return assist_shape 348 349 350class MatrixDiagPart(PrimitiveWithInfer): 351 r""" 352 Returns the batched diagonal part of a batched tensor. 353 354 Inputs: 355 - **x** (Tensor) - The batched tensor. It can be one of the following data types: 356 float32, float16, int32, int8, uint8. 357 - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. 358 359 Outputs: 360 Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])]. 361 362 Examples: 363 >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) 364 >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) 365 >>> matrix_diag_part = ops.MatrixDiagPart() 366 >>> result = matrix_diag_part(x, assist) 367 >>> print(result) 368 [[12., -9.], [8., -5.], [4., -1.]] 369 """ 370 371 @prim_attr_register 372 def __init__(self): 373 """Initialize MatrixDiagPart""" 374 375 def infer_dtype(self, x_dtype, assist_dtype): 376 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 377 args = {"x": x_dtype, "assist": assist_dtype} 378 validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) 379 return x_dtype 380 381 def infer_shape(self, x_shape, assist_shape): 382 validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name) 383 validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name) 384 385 if assist_shape[-2] < assist_shape[-1]: 386 out_shape = assist_shape[:-1] 387 else: 388 out_shape = assist_shape[:-2] + assist_shape[-1:] 389 return out_shape 390 391 392class MatrixSetDiag(PrimitiveWithInfer): 393 r""" 394 Modifies the batched diagonal part of a batched tensor. 395 396 Inputs: 397 - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types: 398 float32, float16, int32, int8, uint8. 399 - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1. 400 - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. 401 402 Outputs: 403 Tensor, data type same as input `x`. The shape same as `x`. 404 405 Examples: 406 >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) 407 >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) 408 >>> matrix_set_diag = ops.MatrixSetDiag() 409 >>> result = matrix_set_diag(x, diagonal) 410 >>> print(result) 411 [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]] 412 413 """ 414 415 @prim_attr_register 416 def __init__(self): 417 """Initialize MatrixSetDiag""" 418 419 def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): 420 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 421 args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} 422 validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) 423 return x_dtype 424 425 def infer_shape(self, x_shape, diagonal_shape, assist_shape): 426 validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name) 427 validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name) 428 429 if x_shape[-2] < x_shape[-1]: 430 validator.check("diagonal shape", diagonal_shape, "x shape excluding the last dimension", 431 x_shape[:-1], validator.EQ, self.name) 432 else: 433 validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension", 434 x_shape[:-2] + x_shape[-1:], validator.EQ, self.name) 435 436 return assist_shape 437 438 439class ConfusionMulGrad(PrimitiveWithInfer): 440 """ 441 `output0` is the dot product result of input0 and input1. 442 443 `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it. 444 445 Args: 446 axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. 447 Default:(), reduce all dimensions. Only constant value is allowed. 448 keep_dims (bool): 449 450 - If true, keep these reduced dimensions and the length as 1. 451 - If false, don't keep these dimensions. Default:False. 452 453 Inputs: 454 - **input_0** (Tensor) - The input Tensor. 455 - **input_1** (Tensor) - The input Tensor. 456 - **input_2** (Tensor) - The input Tensor. 457 458 Outputs: 459 - **output_0** (Tensor) - The same shape as `input0`. 460 - **output_1** (Tensor) 461 462 - If axis is (), and keep_dims is false, the output is a 0-D array representing 463 the sum of all elements in the input array. 464 - If axis is int, set as 2, and keep_dims is false, 465 the shape of output is :math:`(x_1,x_3,...,x_R)`. 466 - If axis is tuple(int), set as (2,3), and keep_dims is false, 467 the shape of output is :math:`(x_1,x_4,...x_R)`. 468 469 Examples: 470 >>> confusion_mul_grad = ops.ConfusionMulGrad() 471 >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) 472 >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32) 473 >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32) 474 >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2) 475 output_0: 476 [[ 3. 1. 0.] 477 [-6. 2. -2.]] 478 output_1: 479 -3.0 480 """ 481 482 @prim_attr_register 483 def __init__(self, axis=(), keep_dims=False): 484 self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"]) 485 self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) 486 self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) 487 488 def infer_shape(self, input0_shape, input1_shape, input2_shape): 489 outshape0 = input0_shape 490 outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) 491 return outshape0, outshape1 492 493 def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): 494 validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name) 495 validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name) 496 validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name) 497 return input0_dtype, input1_dtype 498 499 500class ConvertToDynamic(PrimitiveWithCheck): 501 """ 502 This op is used for dynamic rank testing. Its inferred shape will be unknown 503 during compile time, so that its output will appear to be dynamically ranked. 504 The input will not be altered in any way. Put this operator before the operator 505 being tested for dynamic rank support. 506 507 Args: 508 is_dynamic_rank (bool): If true, convert to dynamic rank. 509 If false, convert to dynamic shape. Default: ``False``. 510 511 Inputs: 512 - **input** (Tensor) - The tensor used for testing. 513 514 Outputs: 515 - **output** (Tensor) - Same shape, type and value as `input`. 516 517 Supported Platforms: 518 ``CPU`` 519 520 Examples: 521 >>> import mindspore as ms 522 >>> import mindspore.nn as nn 523 >>> from mindspore.ops.operations import _inner_ops as inner 524 >>> from mindspore.ops import operations as P 525 >>> class TestDynamicNet(nn.Cell): 526 >>> def __init__(self): 527 >>> super(TestDynamicNet, self).__init__() 528 >>> self.convert_to_dynamic = inner.ConvertToDynamic() 529 >>> # suppose we are testing Reshape op 530 >>> self.reshape = P.Reshape() 531 >>> 532 >>> def construct(self, input, new_shape): 533 >>> dynamic_input = self.convert_to_dynamic(input) 534 >>> reshaped_input = self.reshape(dynamic_input, new_shape) 535 >>> 536 >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU") 537 >>> input = Tensor(np.array([0, 1, 2, 3]) 538 >>> new_shape = (2, 2) 539 >>> net = TestDynamicNet() 540 >>> output = net(input, new_shape) 541 >>> print(output) 542 [[0, 1], [2, 3] 543 """ 544 545 @prim_attr_register 546 def __init__(self, is_dynamic_rank=False): 547 validator.check_value_type('is_dynamic_rank', is_dynamic_rank, [bool], self.name) 548 self.init_prim_io_names(inputs=["input"], outputs=["output"]) 549 550 def check_shape(self, input_shape): 551 validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name) 552 553 def check_dtype(self, input_dtype): 554 validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name) 555 556 557class GpuConvertToDynamicShape(PrimitiveWithCheck): 558 """ 559 This op is used for dynamic shape testing. Its inferred shape will be unknown 560 during compile time, so that its output will appear to be dynamically shaped. 561 The input will not be altered in any way. Put this operator before the operator 562 being tested for dynamic shape support. 563 564 Inputs: 565 - **input** (Tensor) - The tensor used for testing. 566 567 Outputs: 568 - **output** (Tensor) - Same shape, type and value as `input`. 569 570 Examples: 571 >>> # make a model, since dynamic shape operators must be in GRAPH_MODE 572 >>> import mindspore as ms 573 >>> import mindspore.nn as nn 574 >>> from mindspore.ops.operations import _inner_ops as inner 575 >>> from mindspore.ops import operations as P 576 >>> class TestDynamicShapeReshapeNet(nn.Cell): 577 >>> def __init__(self): 578 >>> super(TestDynamicShapeReshapeNet, self).__init__() 579 >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() 580 >>> # suppose we are testing Reshape op 581 >>> self.reshape = P.Reshape() 582 >>> 583 >>> def construct(self, input, new_shape): 584 >>> dynamic_shape_input = self.convert_to_dynamic_shape(input) 585 >>> reshaped_input = self.reshape(input, new_shape) 586 >>> 587 >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU") 588 >>> input = Tensor(np.array([0, 1, 2, 3]) 589 >>> new_shape = (2, 2) 590 >>> net = TestDynamicShapeReshapeNet() 591 >>> output = net(input, new_shape) 592 >>> print(output) 593 [[0, 1], [2, 3] 594 """ 595 596 @prim_attr_register 597 def __init__(self): 598 self.init_prim_io_names(inputs=["input"], outputs=["output"]) 599 600 def check_shape(self, input_shape): 601 validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name) 602 603 def check_dtype(self, input_dtype): 604 validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name) 605 606 607class ErrorOnDynamicShapeInput(PrimitiveWithInfer): 608 """ 609 This op is used for dynamic shape testing. The only purpose of this operator is 610 that it will throw a value error if the input is dynamically shaped. 611 612 Inputs: 613 - **input** (Tensor) - The tensor used for testing. 614 615 Outputs: 616 - **output** (Tensor) - Same shape, type and value as `input`. 617 618 Examples: 619 >>> # make a model, since dynamic shape operators must be in GRAPH_MODE 620 >>> import mindspore as ms 621 >>> import mindspore.nn as nn 622 >>> from mindspore.ops.operations import _inner_ops as inner 623 >>> from mindspore.ops import operations as P 624 >>> class AssertDynamicShapeNet(nn.Cell): 625 >>> def __init__(self): 626 >>> super(AssertDynamicShapeNet, self).__init__() 627 >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() 628 >>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() 629 >>> 630 >>> def construct(self, input, new_shape): 631 >>> dynamic_shape_input = self.convert_to_dynamic_shape(input) 632 >>> self.error_on_dynamic_shape_input(dynamic_shape_input) 633 >>> 634 >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU") 635 >>> input = Tensor(np.array([0]) 636 >>> net = TestDynamicShapeReshapeNet() 637 >>> output = net(input, new_shape) 638 ValueError: Input is dynamically shaped. 639 """ 640 641 @prim_attr_register 642 def __init__(self): 643 self.init_prim_io_names(inputs=["input"], outputs=["output"]) 644 645 def infer_shape(self, input_shape): 646 shape = list(input_shape) 647 648 for dim in shape: 649 if dim == -1: 650 raise ValueError("Input is dynamically shaped.") 651 652 return input_shape 653 654 def infer_type(self, input_dtype): 655 """Infer the dtype of input for ErrorOnDynamicShapeInput.""" 656 validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name) 657 return input_dtype 658 659 def infer_value(self, input_tensor): 660 return input_tensor 661 662 663class SequenceMask(PrimitiveWithCheck): 664 """ 665 Returns a mask tensor representing the first N positions of each cell. 666 667 If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape 668 [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) 669 670 Inputs: 671 - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be 672 less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`. 673 Must be type int32 or int64. 674 675 - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same 676 type as elements in `lengths`. 677 678 Outputs: 679 One mask tensor of shape lengths.shape + (maxlen,). 680 681 Supported Platforms: 682 ``GPU`` ``CPU`` 683 684 Examples: 685 >>> from mindspore import ops 686 >>> import numpy as np 687 >>> x = Tensor(np.array([[1, 3], [2, 0]])) 688 >>> sequence_mask = ops.SequenceMask() 689 >>> output = sequence_mask(x, 3) 690 >>> print(output) 691 [[[True False False] 692 [True True True]] 693 [[True True False] 694 [False False False]]] 695 """ 696 697 @prim_attr_register 698 def __init__(self): 699 self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"]) 700 701 def check_shape(self, lengths_shape, maxlen_shape): 702 validator.check("lengths_shape", len(lengths_shape), "", 0, validator.GT, self.name) 703 validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name) 704 705 def check_dtype(self, lengths_dtype, maxlen_dtype): 706 validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name) 707 validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) 708 709 710class SyncBatchNorm(Primitive): 711 r""" 712 Sync Batch Normalization for input data and updated parameters. 713 714 Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is 715 widely used in convolutional neural networks. This operation applies Batch Normalization over input 716 to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating 717 Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. 718 It rescales and recenters the features using a mini-batch of data and the learned parameters which 719 can be described in the following formula, 720 721 .. math:: 722 y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta 723 724 where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. 725 726 Args: 727 epsilon (float): A small value added for numerical stability. Default: 1e-5. 728 momentum (float): The hyper parameter to compute moving average for running_mean and running_var 729 (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). 730 Momentum value must be [0, 1]. Default: 0.1. 731 group (str): The communication group to work on. Default: "sync_bn_group0". 732 device_num (int): The number of devices in each group. Default: 2. 733 734 Inputs: 735 - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. 736 - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. 737 - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. 738 - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. 739 - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`. 740 741 Outputs: 742 Tuple of 5 Tensor, the normalized inputs and the updated parameters. 743 744 - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`. 745 - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. 746 - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. 747 - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. 748 - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. 749 750 Supported Platforms: 751 ``Ascend`` 752 753 Examples: 754 >>> # This example should be run with multiple processes. 755 >>> # Please refer to nn.SyncBatchNorm for direct use. 756 >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) 757 >>> scale = Tensor(np.ones([2]), mindspore.float32) 758 >>> bias = Tensor(np.ones([2]), mindspore.float32) 759 >>> mean = Tensor(np.ones([2]), mindspore.float32) 760 >>> variance = Tensor(np.ones([2]), mindspore.float32) 761 >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm() 762 >>> output = sync_batch_norm(input_x, scale, bias, mean, variance) 763 >>> print(output) 764 (Tensor(shape=[2, 2], dtype=Float32, value= 765 [[ 1.00000000e+00, 1.00000000e+00], 766 [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value= 767 [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= 768 [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= 769 [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= 770 [ 1.00000000e+00, 1.00000000e+00])) 771 """ 772 773 @prim_attr_register 774 def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2): 775 validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name) 776 validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name) 777 validator.check_isinstance("group", group, str) 778 validator.check_int(device_num, 2, validator.GE, "device_num", self.name) 779 self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], 780 outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) 781 self.add_prim_attr('side_effect_mem', True) 782 self.add_prim_attr('format', 'NCHW') 783 784 785class Centralization(PrimitiveWithInfer): 786 """ 787 Computes centralization. y = x - mean(x, axis). 788 789 Note: 790 The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. 791 792 Inputs: 793 - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. 794 - **axis** (Union[int, Tuple(int), List(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. 795 Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). 796 797 Outputs: 798 Tensor, has the same shape and dtype as the `input_x`. 799 800 Raises: 801 TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. 802 TypeError: If `axis` has non-Int elements. 803 804 Supported Platforms: 805 ``Ascend`` 806 807 Examples: 808 >>> mindspore.set_seed(1) 809 >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 810 >>> centralization = ops.Centralization() 811 >>> output = centralization(input_x, -1) 812 >>> print(output) 813 [[ 1.1180509 -1.1180508] 814 [ 0.2723984 -0.2723984]] 815 """ 816 817 __mindspore_signature__ = ( 818 sig.make_sig('input_x'), 819 sig.make_sig('axis', default=()) 820 ) 821 822 @prim_attr_register 823 def __init__(self): 824 """Initialize Centralization""" 825 self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) 826 827 def __infer__(self, input_x, axis): 828 x_shape = list(input_x['shape']) 829 x_dtype = input_x['dtype'] 830 axis_v = axis['value'] 831 rank = len(x_shape) 832 833 args = {'input_x': input_x['dtype']} 834 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 835 836 if axis_v is None: 837 raise ValueError(f"For {self.name}, axis must be const.") 838 validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) 839 840 if isinstance(axis_v, int): 841 validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, 'axis', self.name) 842 elif axis: 843 for index, one_axis in enumerate(axis_v): 844 validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) 845 846 out = {'shape': x_shape, 847 'dtype': x_dtype, 848 'value': None} 849 return out 850 851 852class StackInit(PrimitiveWithInfer): 853 """ 854 Create a stack that produces tensors in first-in last-out order. 855 856 After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped 857 at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`. 858 859 Args: 860 index (int): The index of the stack. Default: 1. 861 862 Supported Platforms: 863 ``Ascend`` 864 865 Examples: 866 >>> x = Tensor(np.array([[1, 3], [2, 0]])) 867 >>> index = 0 868 >>> stack = ops.StackInit(index) 869 >>> push = ops.StackPush(index) 870 >>> pop = ops.StackPop(index, x.shape, x.dtype) 871 >>> destroy = ops.StackDestroy(index) 872 >>> stack() 873 >>> push(x) 874 >>> y = pop() 875 >>> destroy() 876 >>> print(y) 877 [[1 3] 878 [2 0]] 879 """ 880 881 @prim_attr_register 882 def __init__(self, index=1): 883 """StackInit""" 884 validator.check_value_type("index", index, [int], self.name) 885 886 887class StackPush(PrimitiveWithInfer): 888 """ 889 Push a tensor onto the stack. 890 891 Before `StackPush`, the stack should be created using `StackInit`. 892 Please refer to the usage in source code of `StackInit`. 893 894 Args: 895 index (int): The index of the stack. Default: 1. 896 897 Inputs: 898 - **input** (Tensor) - A tensor to be pushed onto the stack. 899 900 Supported Platforms: 901 ``Ascend`` 902 903 Examples: 904 Please refer to the usage of `StackInit`. 905 """ 906 907 @prim_attr_register 908 def __init__(self, index=1): 909 """StackPush""" 910 validator.check_value_type("index", index, [int], self.name) 911 self.init_prim_io_names(inputs=['input'], outputs=[]) 912 913 914class StackPop(PrimitiveWithInfer): 915 """ 916 Pop the tensor at the top of the stack. 917 918 Before `StackPop`, the stack should be created using `StackInit`. 919 Please refer to the usage in source code of `StackInit`. 920 921 Args: 922 index (int): The index of the stack. Default: 1. 923 shape (tuple): The shape of the tensor at the top of the stack. Default: (1,). 924 dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32. 925 926 Outputs: 927 - **output** (Tensor) - The tensor at the top of the stack. 928 929 Supported Platforms: 930 ``Ascend`` 931 932 Examples: 933 Please refer to the usage of `StackInit`. 934 """ 935 936 @prim_attr_register 937 def __init__(self, index=1, shape=(1,), dtype=mstype.float32): 938 """StackPop""" 939 validator.check_value_type("index", index, [int], self.name) 940 941 validator.check_value_type('shape type', shape, [list, tuple], self.name) 942 validator.check_int(len(np.array(shape).shape), 1, validator.EQ, "dim of shape", self.name) 943 for elem in shape: 944 validator.check_int(elem, 1, validator.GE, 'shape element', self.name) 945 validator.check_value_type('type of shape element', elem, [int], self.name) 946 947 validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name) 948 self.shape = shape 949 self.dtype = dtype 950 951 self.init_prim_io_names(inputs=[], outputs=['output']) 952 953 def __infer__(self): 954 return {'shape': (list(self.shape)), 955 'dtype': (self.dtype), 956 'value': None} 957 958 959class StackDestroy(PrimitiveWithInfer): 960 """ 961 Destroy the stack. 962 963 Before `StackDestroy`, the stack should be created using `StackInit`. 964 Please refer to the usage in source code of `StackInit`. 965 966 Args: 967 index (int): The index of the stack. Default: 1. 968 969 Supported Platforms: 970 ``Ascend`` 971 972 Examples: 973 Please refer to the usage of `StackInit`. 974 """ 975 976 @prim_attr_register 977 def __init__(self, index=1): 978 """StackDestroy""" 979 validator.check_value_type("index", index, [int], self.name) 980 981 982class DynamicStitch(PrimitiveWithCheck): 983 r""" 984 Interleave the values from the data tensors into a single tensor. 985 986 Inputs: 987 - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. 988 - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. 989 990 Outputs: 991 Tensor. A stacked Tensor with the same type as `data`. 992 993 Raises: 994 TypeError: If the data types of elements in `data` or `indices` are not the same. 995 ValueError: If the length of `data` or `indices` is not greater than 1. 996 997 Supported Platforms: 998 ``Ascend`` 999 1000 Examples: 1001 >>> x1 = Tensor([6], mstype.int32) 1002 >>> x2 = Tensor(np.array([4, 1]), mstype.int32) 1003 >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32) 1004 >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32) 1005 >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32) 1006 >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32) 1007 >>> stitch = ops.DynamicStitch() 1008 >>> output = stitch([x1, x2, x3], [y1, y2, y3]) 1009 >>> print(output) 1010 [[ 1 2] 1011 [11 12] 1012 [21 22] 1013 [31 32] 1014 [41 42] 1015 [51 52] 1016 [61 62]] 1017 """ 1018 1019 @prim_attr_register 1020 def __init__(self): 1021 """Initialize DynamicStitch""" 1022 1023 def check_shape(self, indices_shape, data_shape): 1024 validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name) 1025 validator.check_int(len(indices_shape), 1, validator.GE, "len of indices_shape", self.name) 1026 indices_dim0 = len(indices_shape[0]) 1027 indices_num = len(indices_shape) 1028 1029 validator.check_value_type("shape of data", data_shape, [tuple, list], self.name) 1030 validator.check_int(len(data_shape), 1, validator.GE, "len of data_shape", self.name) 1031 data_dim0 = len(data_shape[0]) 1032 data_num = len(indices_shape) 1033 1034 validator.check("size of indices", indices_num, 'size of data', data_num, validator.EQ, self.name) 1035 1036 # shape of `data` must start with shape of `indices` 1037 for i in range(0, indices_num): 1038 indices_dim = len(indices_shape[i]) 1039 data_dim = len(data_shape[i]) 1040 validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, validator.LE, self.name) 1041 if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]: 1042 raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}") 1043 1044 # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape. 1045 base_extra = data_dim0 - indices_dim0 1046 for i in range(0, data_num): 1047 indices_dim = len(indices_shape[i]) 1048 data_dim = len(data_shape[i]) 1049 extra = data_dim - indices_dim 1050 validator.check(f"extra dim of data[{i}]", extra, 1051 f"extra dim of data[0]", base_extra, validator.EQ, self.name) 1052 validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:], 1053 f"data[{i}].shape[{len(indices_shape[i])}:]", 1054 data_shape[i][indices_dim:], validator.EQ, self.name) 1055 1056 out_shape = [-1] + data_shape[0][indices_dim0:] 1057 return out_shape 1058 1059 def check_dtype(self, indices_type, data_type): 1060 validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name) 1061 validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name) 1062 indices_num = len(indices_type) 1063 for i in range(0, indices_num): 1064 validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name) 1065 validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i], 1066 mstype.number_type + (mstype.bool_,), self.name) 1067 validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", 1068 data_type[0], validator.EQ, self.name) 1069 return data_type[0] 1070 1071 1072class DynamicBroadcastGradientArgs(Primitive): 1073 """ 1074 Broadcast the two input shapes, return the dimensions that each need to be broadcast. 1075 1076 Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal 1077 or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input 1078 shape's value in that dimension. 1079 1080 Inputs: 1081 - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64, 1082 uint32, uint64. 1083 - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`. 1084 1085 Outputs: 1086 Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask 1087 tensor. 1088 1089 - **r0** (Tensor) - The output shape is 1-D with the same type as s0. 1090 - **r1** (Tensor) - The output shape is 1-D with the same type as s0. 1091 1092 Raises: 1093 ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid 1094 location. 1095 1096 Supported Platforms: 1097 ``Ascend`` 1098 1099 Examples: 1100 >>> shape0 = (4, 2, 1) 1101 >>> shape1 = (2, 7) 1102 >>> from mindspore.ops.operations import _inner_ops 1103 >>> args = _inner_ops.DynamicBroadcastGradientArgs() 1104 >>> r0, r1 = args(Tensor(shape0), Tensor(shape1)) 1105 >>> print(r0, r1) 1106 [2], [0] 1107 """ 1108 1109 @prim_attr_register 1110 def __init__(self): 1111 """Init BroadcastGradientArgs""" 1112 1113 1114class DSDMatmul(PrimitiveWithInfer): 1115 """ 1116 The definition of the CusSquare primitive. 1117 """ 1118 1119 @prim_attr_register 1120 def __init__(self): 1121 self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y']) 1122 1123 def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape): 1124 batch_size = input_w1_shape[0] 1125 head = input_w1_shape[1] 1126 v_embedding = input_v_shape[1] * 16 // head 1127 seq_len = input_v_shape[0] * 16 // batch_size 1128 return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16) 1129 1130 def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3): 1131 return data_dtype1 1132 1133 1134class MatmulDDS(PrimitiveWithInfer): 1135 """MatmulDDS definition""" 1136 1137 @prim_attr_register 1138 def __init__(self, bs, heads): 1139 """init MatmulDDS""" 1140 self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'], 1141 outputs=['local_prob', 'global_prob']) 1142 1143 self.heads = heads 1144 1145 def infer_shape(self, q, k, local_mask, global_mask): 1146 seq_len = local_mask[0] * local_mask[-1] 1147 bs = q[1] * q[2] // seq_len 1148 global_size = seq_len // 4 1149 size_per_head = q[0] * q[-1] // self.heads 1150 heads = q[0] * q[-1] // size_per_head 1151 block_size = local_mask[1] * local_mask[2] // bs 1152 block_num = seq_len // block_size 1153 l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) 1154 g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) 1155 1156 return l_size, g_size 1157 1158 def infer_dtype(self, q, k, local_mask, global_mask): 1159 return q, q 1160 1161 1162class DSDGrad(PrimitiveWithInfer): 1163 """ 1164 The definition of the CusSquare primitive. 1165 """ 1166 1167 @prim_attr_register 1168 def __init__(self): 1169 self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'], 1170 outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm']) 1171 1172 def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape): 1173 return input_w1_shape, input_w2_shape, input_v_shape 1174 1175 def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5): 1176 return data_dtype1, data_dtype1, data_dtype1 1177 1178 1179class MatmulDDSGrad(PrimitiveWithInfer): 1180 """MatmulDDS definition""" 1181 1182 @prim_attr_register 1183 def __init__(self): 1184 """init MatmulDDS""" 1185 self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'], 1186 outputs=['dq', 'dk']) 1187 1188 def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad): 1189 k_size = (q[1], q[0], q[3], q[2]) 1190 1191 return q, k_size 1192 1193 def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad): 1194 return q, k 1195 1196 1197class NonZeroWithValue(Primitive): 1198 """ 1199 Returns the value of elements that are non-zero (in row-major order - by dimension). 1200 1201 Inputs: 1202 - **x** (Tensor), input array of rank >= 2. 1203 1204 Outputs: 1205 elements that are non-zero. 1206 1207 Supported Platforms: 1208 ``Ascend`` 1209 1210 Examples: 1211 >>> op = NonZeroWithValue() 1212 >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32) 1213 >>> value, index, count = op(data) 1214 >>> print(value) 1215 [1.0, 1.0] 1216 """ 1217 1218 @prim_attr_register 1219 def __init__(self, transpose=False): 1220 """Initialize NonZeroWithValue""" 1221 validator.check_value_type("transpose", transpose, [bool], self.name) 1222 self.init_prim_io_names(inputs=['x'], outputs=['value', 'index', 'count']) 1223 1224 1225class NonZeroWithValueShape(Primitive): 1226 """ 1227 Returns the value and index of elements that are non-zero (in row-major order - by dimension). 1228 1229 Inputs: 1230 - **x** (Tensor), input array of rank >= 2. 1231 1232 Outputs: 1233 elements that are non-zero. 1234 1235 Supported Platforms: 1236 ``Ascend`` 1237 1238 Examples: 1239 >>> non_zero = NonZeroWithValue() 1240 >>> op = NonZeroWithValueShape() 1241 >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32) 1242 >>> value, index, count = non_zero(data) 1243 >>> out_value, out_index = op(value, index, count) 1244 >>> print(out_index) 1245 [[0, 1], [0, 2]] 1246 """ 1247 1248 @prim_attr_register 1249 def __init__(self): 1250 """Initialize NonZeroWithValueShape""" 1251 self.init_prim_io_names(inputs=['value', 'index', 'count'], outputs=['out_value', 'out_index']) 1252 1253 1254class DecodeImage(PrimitiveWithInfer): 1255 """ 1256 Returns image data that parse from string Tensor. 1257 1258 Inputs: 1259 - **x** (Tensor), a Tensor of type string. 0-D. The jPEG, GIF, PNG, BMP-encoded image. 1260 1261 Outputs: 1262 A Tensor of type uint8, uint16, float. 1263 1264 Supported Platforms: 1265 ``Ascend`` 1266 1267 Examples: 1268 """ 1269 1270 @prim_attr_register 1271 def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3", 1272 _op_max_size=[8000000]): 1273 self.init_prim_io_names(inputs=["contents"], outputs=["image"]) 1274 self.res_type = dtype 1275 1276 def infer_shape(self, x): 1277 return (-1, -1, 3) 1278 1279 def infer_dtype(self, x): 1280 return self.res_type 1281 1282 1283class SliceGetItem(Primitive): 1284 """ 1285 using SliceGetItem to get slice's attribute of 'start' 'stop' 'step' 1286 """ 1287 1288 @prim_attr_register 1289 def __init__(self): 1290 """Initialize ScatterElements""" 1291 self.init_prim_io_names(inputs=['slice', 'attr'], outputs=['slice_item']) 1292 1293 def __call__(self, slice_value, value): 1294 if not isinstance(slice_value, slice): 1295 raise TypeError( 1296 "Primitive[SliceGetItem] only support to get a slice type element but got {}".format(slice_value)) 1297 if value == "start": 1298 if hasattr(slice_value.start, "ndim") and slice_value.start.ndim == 1: 1299 return slice_value.start.item() 1300 return slice_value.start 1301 if value == "stop": 1302 if hasattr(slice_value.stop, "ndim") and slice_value.stop.ndim == 1: 1303 return slice_value.stop.item() 1304 return slice_value.stop 1305 if value == "step": 1306 if hasattr(slice_value.step, "ndim") and slice_value.step.ndim == 1: 1307 return slice_value.step.item() 1308 return slice_value.step 1309 raise AttributeError("\'slice\' object has no attribute {}".format(value)) 1310 1311 1312class DynamicBroadcastTo(Primitive): 1313 """ 1314 Broadcasts input tensor to a given shape. 1315 1316 Inputs: 1317 - **input_x** (Tensor) - The input tensor. The data type should be one of the following types: 1318 float16, float32, int32, int8, uint8. 1319 The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions. 1320 - **shape** (Tensor): The target shape to broadcast. 1321 1322 Outputs: 1323 Tensor, with the given `shape` and the same data type as `input_x`. 1324 1325 Raises: 1326 ValueError: if the target and input shapes are incompatible. 1327 1328 Supported Platforms: 1329 ``Ascend`` ``GPU`` ``CPU`` 1330 """ 1331 1332 @prim_attr_register 1333 def __init__(self): 1334 """Initialize DynamicBroadcastTo""" 1335 self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y']) 1336 1337 1338class DynamicResizeNearestNeighbor(Primitive): 1339 r""" 1340 Resizes the input tensor by using the nearest neighbor algorithm. 1341 1342 Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest 1343 neighbor algorithm selects the value of the nearest point and does not consider the 1344 values of neighboring points at all, yielding a piecewise-constant interpolant. 1345 1346 Note: 1347 The operator supports dynamic shape. 1348 1349 Args: 1350 align_corners (bool): Whether the centers of the 4 corner pixels of the input 1351 and output tensors are aligned. Default: ``False``. 1352 1353 Inputs: 1354 - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. 1355 - **size** (Union[tuple, list]): The target size. The dimension of size must be 2. 1356 1357 Outputs: 1358 Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`. 1359 The data type is the same as the `input_x`. 1360 """ 1361 1362 @prim_attr_register 1363 def __init__(self, align_corners=False): 1364 """Initialize ResizeNearestNeighbor""" 1365 validator.check_value_type("align_corners", align_corners, [bool], self.name) 1366 self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) 1367 1368 1369class PsROIPooling(PrimitiveWithInfer): 1370 r""" 1371 Position Sensitive ROI-Pooling 1372 Inputs: 1373 - feature(Tensor) 1374 - rois(Tensor) 1375 1376 - **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`. 1377 - **rois** (Tensor) - The shape is :math:`(rois\_n, 5)`. With data type of float16 or float32. 1378 `rois_n` represents the number of RoI. The size of the second dimension must be `5` and the `5` colunms 1379 are :math:`(image\_index, top\_left\_x, top\_left\_y, bottom\_right\_x, bottom\_right\_y)`. 1380 `image_index` represents the index of image. `top_left_x` and `top_left_y` represent the `x, y` 1381 coordinates of the top left corner of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y` 1382 represent the `x, y` coordinates of the bottom right corner of corresponding RoI, respectively. 1383 1384 Outputs: 1385 - out shape(rois_num, out_channel, pool_height, pool_width), the result after pooling. 1386 - channel_map shape(rois_num, out_channel, pool_height, pool_width), use for back forward to compute grad 1387 Supported Platforms: 1388 ``GPU`` 1389 1390 Examples: 1391 >>> import mindspore 1392 >>> import numpy as np 1393 >>> from mindspore import Tensor 1394 >>> from mindspore.ops.operations import _inner_ops as inner 1395 >>> features = np.random.randn(4, 21 * 7 * 7, 80, 48) 1396 >>> features = Tensor.from_numpy(features).astype(mindspore.float32) 1397 >>> rois = Tensor.from_numpy( 1398 >>> np.array([ 1399 >>> [0.0000, 150.3563, 200.1320, 579.3563, 602.3452], 1400 >>> [1.0000, 657.1263, 302.8564, 762.4214, 567.9854], 1401 >>> [2.0000, 321.3122, 232.2410, 679.0281, 587.6346], 1402 >>> [3.0000, 664.1630, 387.4919, 778.7322, 562.7321], 1403 >>> ])).astype(mindspore.float32) 1404 >>> psRoIPooling = inner.PsROIPooling(pooled_height=7, pooled_width=7, num_rois=4, 1405 >>> spatial_scale=1.0/16, out_dim=21, 1406 >>> group_size=7) 1407 >>> out, channel_map = psRoIPooling(features, rois) 1408 >>> print(out.shape) 1409 [4, 21, 7, 7] 1410 >>> print(channel_map.shape) 1411 [4, 21, 7, 7] 1412 """ 1413 1414 @prim_attr_register 1415 def __init__(self, pooled_height, pooled_width, num_rois, spatial_scale, out_dim, group_size): 1416 """Initialize PsROIPooling""" 1417 validator.check_value_type("pooled_height", pooled_height, [int], self.name) 1418 validator.check_value_type("pooled_width", pooled_width, [int], self.name) 1419 validator.check_value_type("num_rois", pooled_width, [int], self.name) 1420 validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) 1421 validator.check_value_type("out_dim", out_dim, [int], self.name) 1422 validator.check_value_type("group_size", group_size, [int], self.name) 1423 self.pooled_height = pooled_height 1424 self.pooled_width = pooled_width 1425 self.num_rois = num_rois 1426 self.spatial_scale = spatial_scale 1427 self.out_dim = out_dim 1428 self.group_size = group_size 1429 1430 def infer_shape(self, inputs_shape, rois_shape): 1431 output_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width] 1432 output_map_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width] 1433 return output_shape, output_map_shape 1434 1435 def infer_dtype(self, inputs_type, rois_type): 1436 map_type = mstype.TensorType(mstype.int32) 1437 return inputs_type, map_type 1438 1439 1440class ParallelResizeBilinear(PrimitiveWithInfer): 1441 """ParallelResizeBilinear ops""" 1442 1443 @prim_attr_register 1444 def __init__(self, ori_image_size, split_size, src_start_w, dst_start_w, align_corners): 1445 """Initialize ParallelResizeBilinear.""" 1446 validator.check_value_type("ori_image_size", ori_image_size, [list, tuple], self.name) 1447 validator.check_value_type("split_size", split_size, [list, tuple], self.name) 1448 validator.check_int(len(split_size), 2, validator.EQ, "len of split_size", self.name) 1449 validator.check_value_type("src_start_w", src_start_w, [int], self.name) 1450 validator.check_value_type("dst_start_w", dst_start_w, [int], self.name) 1451 validator.check_value_type("align_corners", align_corners, [bool], self.name) 1452 self.ori_image_size = list(ori_image_size) 1453 self.split_size = list(split_size) 1454 self.src_start_w = src_start_w 1455 self.dst_start_w = dst_start_w 1456 self.align_corners = align_corners 1457 self.half_pixel_centers = False 1458 self.add_prim_attr('ori_image_size', self.ori_image_size) 1459 self.add_prim_attr('split_size', self.split_size) 1460 self.add_prim_attr('src_start_w', self.src_start_w) 1461 self.add_prim_attr('dst_start_w', self.dst_start_w) 1462 self.add_prim_attr('align_corners', self.align_corners) 1463 self.add_prim_attr('half_pixel_centers', self.half_pixel_centers) 1464 1465 def __infer__(self, x, size): 1466 size_val = size['value'] 1467 x_shape = x['shape'] 1468 x_dtype = x['dtype'] 1469 validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name) 1470 if size_val is None: 1471 raise ValueError("size must be const input") 1472 output_shape = [x_shape[0], x_shape[1], self.split_size[0], self.split_size[1]] 1473 1474 return {'shape': output_shape, 1475 'dtype': x_dtype, 1476 'value': None} 1477 1478 1479class PartitionedCall(PrimitiveWithInfer): 1480 """ 1481 Pass the input tensors to the subgraph and return the output tensors. 1482 1483 Inputs: 1484 - **inputs** (Tuple), the input tensors, which will be passed to subgraph. 1485 1486 Outputs: 1487 - outputs(Tuple), the output tensor returned by subgraph. 1488 1489 Supported Platforms: 1490 ``Ascend`` 1491 1492 Examples: 1493 """ 1494 1495 @prim_attr_register 1496 def __init__(self, graph, executor_type=""): 1497 super(PartitionedCall, self).__init__(self.__class__.__name__) 1498 self.add_prim_attr("executor_type", executor_type) 1499 self.graph = graph 1500 1501 def infer_shape(self, *inputs): 1502 return NotImplementedError 1503 1504 def infer_dtype(self, *inputs): 1505 return NotImplementedError 1506 1507 1508class CellBackwardHook(PrimitiveWithInfer): 1509 r""" 1510 This operator is used to hook input gradient and output gradient of Cell object. 1511 1512 Note: 1513 This operator is only used in backward hook function of Cell object in pynative mode. 1514 1515 Args: 1516 cell_id (str): Used to identify which cell obj the hook function registered on. For example, 'nn.Add()' is a 1517 cell object. 1518 1519 Inputs: 1520 - **input** - The variable to hook. 1521 1522 Outputs: 1523 - **output** - Returns `input` directly. `CellBackwardHook` does not affect the forward result. 1524 1525 Supported Platforms: 1526 ``Ascend`` ``GPU`` ``CPU`` 1527 1528 Examples: 1529 >>> import mindspore as ms 1530 >>> from mindspore import Tensor 1531 >>> from mindspore.ops import GradOperation 1532 >>> from mindspore.ops.operations import _inner_ops as inner 1533 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 1534 >>> def hook_fn(grad): 1535 ... print(grad) 1536 ... 1537 >>> hook = inner.CellBackwardHook() 1538 >>> hook_fn_key = hook.register_backward_hook(hook_fn) 1539 >>> def hook_test(x, y): 1540 ... z = x * y 1541 ... z = hook(z) 1542 ... z = z * y 1543 ... return z 1544 ... 1545 >>> grad_all = GradOperation(get_all=True) 1546 >>> def backward(x, y): 1547 ... return grad_all(hook_test)(x, y) 1548 ... 1549 >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32)) 1550 (Tensor(shape=[], dtype=Float32, value= 2),) 1551 >>> print(output) 1552 (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4)) 1553 >>> hook.remove_backward_hook(hook_fn_key) 1554 >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32)) 1555 >>> print(output) 1556 (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4)) 1557 """ 1558 1559 def __init__(self, cell_id=""): 1560 """Initialize CellBackwardHook""" 1561 super(CellBackwardHook, self).__init__(self.__class__.__name__) 1562 self.cell_id = cell_id 1563 self.add_prim_attr("cell_id", cell_id) 1564 self.init_attrs["cell_id"] = cell_id 1565 1566 def __call__(self, args): 1567 if not isinstance(args, tuple): 1568 args = (args,) 1569 return _run_op(self, self.name, args) 1570 1571 def infer_shape(self, *inputs_shape): 1572 if len(inputs_shape) == 1: 1573 return inputs_shape[0] 1574 return inputs_shape 1575 1576 def infer_dtype(self, *inputs_type): 1577 if len(inputs_type) == 1: 1578 return inputs_type[0] 1579 return inputs_type 1580 1581 def register_backward_hook(self, hook_fn): 1582 r""" 1583 This function is used to register backward hook function. Note that this function is only supported in pynative 1584 mode. 1585 1586 Note: 1587 The 'hook_fn' must be defined as the following code. 1588 `cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell. 1589 `grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by 1590 returning a new output gradient. 1591 The 'hook_fn' should have the following signature: 1592 hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none. 1593 The 'hook_fn' is executed in the python environment. 1594 1595 Args: 1596 hook_fn (Function): Python function. Backward hook function. 1597 1598 Returns: 1599 - **key** (int) - The key of 'hook_fn'. 1600 1601 Raises: 1602 TypeError: If the `hook_fn` is not a function of python. 1603 """ 1604 if not isinstance(hook_fn, (FunctionType, MethodType)): 1605 raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python " 1606 f"function, but got {type(hook_fn)}.") 1607 key = self.add_backward_hook_fn(hook_fn) 1608 return key 1609 1610 def remove_backward_hook(self, key): 1611 r""" 1612 This function is used to remove backward hook function. Note that this operation is only supported in pynative 1613 mode. 1614 1615 Note: 1616 The 'key' is the object returned by 'register_backward_hook' function of the same CellBackwardHook 1617 operator. 1618 1619 Args: 1620 key (int): The key corresponding to the 'hook_fn'. 1621 1622 Returns: 1623 None. 1624 """ 1625 self.remove_backward_hook_fn(key) 1626 1627 1628class Format(PrimitiveWithInfer): 1629 r""" 1630 This operator is used to format a string. 1631 1632 Note: 1633 Current not supported to using by customer. 1634 Only support convert str.format() in user code and it will be converted to be Format 1635 operation by ME-Compiler automatically. 1636 1637 1638 Inputs: 1639 - **input** - 1640 string : the string to be formatted. 1641 args : the format args. 1642 1643 Outputs: 1644 - **output** - Returns formatted string. 1645 1646 Supported Platforms: 1647 ``Ascend`` ``GPU`` ``CPU`` 1648 """ 1649 1650 @prim_attr_register 1651 def __init__(self): 1652 self.init_prim_io_names(inputs=['string', 'args'], outputs=['string']) 1653 1654 def __infer__(self, str_, *var): 1655 def check_variable(str_, var): 1656 if _check_contains_variable(str_['dtype'], str_['value']): 1657 return True 1658 1659 for item in var: 1660 if _check_contains_variable(item['dtype'], item['value']): 1661 return True 1662 return False 1663 1664 if check_variable(str_, var): 1665 return {'dtype': mstype.string, 'shape': [], 'value': None} 1666 1667 str_value = str_['value'] 1668 kwargs = dict() 1669 var_value = list() 1670 1671 for item in var: 1672 if isinstance(item["dtype"], typing.Keyword): 1673 kwargs.update(item["value"]) 1674 var_value.append(item["value"]) 1675 1676 value = str_value.format(*var_value, **kwargs) 1677 return {'dtype': mstype.string, 'shape': [], 'value': value} 1678 1679 1680class FlattenConcat(Primitive): 1681 """ 1682 Flatten input tensors and concatenate them into several chunk tensors grouped by data types. 1683 1684 Args: 1685 fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0. 1686 1687 Inputs: 1688 - **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated. 1689 1690 Outputs: 1691 tuple[Tensor], result chunk tensors. 1692 1693 Supported Platforms: 1694 ``Ascend`` ``GPU`` ``CPU`` 1695 1696 Examples: 1697 >>> from mindspore.ops.operations import _inner_ops as inner 1698 >>> t1 = Tensor(np.array([1]).astype(np.float32)) 1699 >>> t2 = Tensor(np.array([2]).astype(np.float32)) 1700 >>> t3 = Tensor(np.array([3]).astype(np.float64)) 1701 >>> t4 = Tensor(np.array([4]).astype(np.float32)) 1702 >>> t5 = Tensor(np.array([5]).astype(np.float64)) 1703 >>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5]) 1704 >>> print(chunks[0].asnumpy()) 1705 >>> print(chunks[1].asnumpy()) 1706 [1. 2. 4.] 1707 [3. 5.] 1708 """ 1709 1710 @prim_attr_register 1711 def __init__(self, fusion_size=0): 1712 """Initialize FlattenConcat""" 1713 validator.check_non_negative_int(fusion_size, 'fusion_size', self.name) 1714 self.fusion_size = fusion_size 1715 self.add_prim_attr('fusion_size', fusion_size) 1716 1717 1718class KMeansCentroids(PrimitiveWithInfer): 1719 """ 1720 Calculate the segment_sum, segment_count, kmean_total_sum that are clustering results 1721 1722 Args: 1723 use_actual_distance (bool): A bool value to decide whether do complete calculation of distance. 1724 1725 Inputs: 1726 - **x** (Tensor(float32)) - Input data used for clustering 1727 - **y** (Tensor(float32)) - Initial centroids of clutering 1728 - **sum_square_y** (Tensor(float32)) - The result of preprocessing such as square, reduce and transpose of y 1729 - **sum_square_x** (Tensor(float32)) - The result of preprocessing such as square and reduce of x 1730 1731 Outputs: 1732 - **segment_sum** (Tensor(float32)) - Clustering result w.r.t. each centroid 1733 - **segment_count** (Tensor(float32)) - Clustering count w.r.t. each centroid 1734 - **kmean_total_sum** (Tensor(float32)) - The sum of the distances from all vectors to ther nearest centroid 1735 1736 Supported Platforms: 1737 ''Ascend'' 1738 1739 Examples: 1740 >>> import numpy as np 1741 >>> import mindspore as ms 1742 >>> import mindspore.common.dtype as mstype 1743 >>> import mindspore.nn as nn 1744 >>> from mindspore import Tensor 1745 >>> from mindspore.ops import operations as P 1746 >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 1747 1748 >>> class Net(nn.Cell): 1749 >>> def __init__(self): 1750 >>> super(Net, self).__init__() 1751 >>> self.reduce_sum = P.ReduceSUm(keep_dims=True) 1752 >>> self.square = P.Square() 1753 >>> self.transpose = P.Transpose() 1754 >>> self.k_means_centroids = P.KMeansCentroids(True) 1755 1756 >>> def construct(self, x, y): 1757 >>> p1 = self.reduce_sum(self.square(x), -1) 1758 >>> p2 = self.transpose(self.reduce_sum(self.square(y), -1), (1, 0)) 1759 >>> return self.k_means_centroids(x, y, p2, p1) 1760 1761 >>> def test_net(): 1762 >>> data_type = np.float32 1763 >>> x = Tensor(np.random.uniform(-10, 10, (65536, 128)).astype(data_type)) 1764 >>> y = P.Ones()((1048576, 128), mstype.float32) 1765 >>> net = Net() 1766 >>> local_sum, local_count, local_avg_distance = net(x, y) 1767 """ 1768 1769 @prim_attr_register 1770 def __init__(self, use_actual_distance): 1771 validator.check_value_type('use_actual_distance', use_actual_distance, [bool], self.name) 1772 self.init_prim_io_names(inputs=['x', 'y', 'sum_square_y', 'sum_square_x'], 1773 outputs=['segment_sum', 'segment_count', 'kmean_total_sum']) 1774 1775 def infer_shape(self, x_shape, y_shape, sum_square_y_shape, sum_square_x_shape): 1776 """infer shape of primitive""" 1777 expected_shape_size = 2 1778 validator.check_int(len(x_shape), expected_shape_size, validator.EQ, "dims of x", self.name) 1779 validator.check_int(len(y_shape), expected_shape_size, validator.EQ, "dims of y", self.name) 1780 validator.check_int(len(sum_square_y_shape), expected_shape_size, validator.EQ, 1781 "dims of sum_square_y", self.name) 1782 validator.check_int(len(sum_square_x_shape), expected_shape_size, validator.EQ, 1783 "dims of sum_square_x", self.name) 1784 1785 validator.check_int(x_shape[1], y_shape[1], validator.EQ, 1786 "the second dim of x and the second dim of y", self.name) 1787 validator.check_int(y_shape[0], sum_square_y_shape[1], validator.EQ, 1788 "the first dim of y and the second dim of sum_square_y", self.name) 1789 validator.check_int(x_shape[0], sum_square_x_shape[0], validator.EQ, 1790 "the first dim of x and the first dim of sum_square_x", self.name) 1791 validator.check_int(sum_square_y_shape[0], sum_square_x_shape[1], validator.EQ, 1792 "the first dim of sum_square_y and the first dim of sum_square_x", 1793 self.name) 1794 validator.check_int(sum_square_y_shape[0], 1, validator.EQ, 1795 "the first dim of sum_square_y", self.name) 1796 1797 k = y_shape[0] 1798 em_size = x_shape[1] 1799 return (k, em_size), (k, 1), (1) 1800 1801 1802class ClipByNorm(PrimitiveWithInfer): 1803 r""" 1804 Clips tensor values to a maximum :math:`L_2`-norm. 1805 1806 Note: 1807 The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input 1808 tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as: 1809 1810 .. math:: 1811 \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)}, 1812 1813 where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. 1814 1815 Args: 1816 axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension. 1817 Default: ``None``, all dimensions to calculate. 1818 1819 Inputs: 1820 - **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32. 1821 - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. 1822 Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32. 1823 1824 Outputs: 1825 Tensor, clipped Tensor with the same shape as the `x`, whose type is float32. 1826 1827 Raises: 1828 TypeError: If `axis` is not one of None, int, tuple(int) and list(int). 1829 TypeError: If dtype of `x` is neither float16 nor float32. 1830 TypeError: If dtype of `clip_norm` is neither float16 nor float32. 1831 1832 Supported Platforms: 1833 ``Ascend`` ``GPU`` ``CPU`` 1834 1835 Examples: 1836 >>> import numpy as np 1837 >>> import mindspore 1838 >>> from mindspore import Tensor 1839 >>> from mindspore.ops.operations import _inner_ops as inner 1840 >>> clip_by_norm = inner.ClipByNorm() 1841 >>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32) 1842 >>> clip_norm = Tensor(np.array([100]).astype(np.float32)) 1843 >>> output = clip_by_norm(x, clip_norm) 1844 >>> print(output.shape) 1845 (4, 16) 1846 """ 1847 1848 @prim_attr_register 1849 def __init__(self, axis=None): 1850 """Initialize ClipByNorm""" 1851 self.axis = () if axis is None else axis 1852 validator.check_value_type('axis', self.axis, [int, tuple, list], self.name) 1853 axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,) 1854 for i, value in enumerate(axis_check): 1855 validator.check_value_type('axis[%d]' % i, value, [int], self.name) 1856 self.init_attrs['axis'] = self.axis 1857 self.add_prim_attr('axis', self.axis) 1858 self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output']) 1859 1860 def infer_shape(self, x_shape, clip_norm_shape): 1861 """Infer shape for ClipByNorm""" 1862 x_dim = len(x_shape) 1863 axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,) 1864 for _, value in enumerate(axis): 1865 validator.check_int_range(value, -x_dim, x_dim, validator.INC_LEFT, 'axis', self.name) 1866 return x_shape 1867 1868 def infer_dtype(self, x_type, clip_norm_type): 1869 """Infer data type for ClipByNorm""" 1870 validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name) 1871 validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type, 1872 [mstype.float16, mstype.float32], self.name) 1873 return mstype.float32 1874 1875 1876class TopTypeof(Primitive): 1877 """ 1878 Internal primitive method, to speed up mindspore.ops.typeof. 1879 1880 Returns the top type of the input data. 1881 1882 In Pynative mode, returns the top type in cache. 1883 1884 Supported Platforms: 1885 ``Ascend`` ``GPU`` ``CPU`` 1886 """ 1887 1888 @prim_attr_register 1889 def __init__(self): 1890 self.prim = Primitive('TopTypeof') 1891 self.typeof_cache = { 1892 'slice': mstype.Slice(), 1893 'list': mstype.List(), 1894 'tuple': mstype.Tuple(), 1895 'Tensor': mstype.tensor_type, 1896 'NoneType': mstype.NoneType(), 1897 'int': mstype.Int(), 1898 'bool': mstype.Bool(), 1899 'ellipsis': mstype.Ellipsis_(), 1900 'dict': mstype.Dict() 1901 } 1902 1903 def __call__(self, x): 1904 index_type = type(x).__name__ 1905 if 'Tensor' in index_type: 1906 index_type = 'Tensor' 1907 if index_type in self.typeof_cache: 1908 return self.typeof_cache.get(index_type) 1909 return _pynative_executor.constant_folding(self.prim, x) 1910 1911 1912class MixedPrecisionCast(Primitive): 1913 r""" 1914 Internal primitive method, to achieve mindspore.functional.mixed_precision_cast. 1915 1916 Note: 1917 This internal primitive method used to do mixed precision conversion. 1918 Only the input object with float dtype will be cast. 1919 1920 Inputs: 1921 - **dtype** (Union[Float16, Float32]) - The data type of the output object. 1922 - **input** (Union[Tensor, Tuple, Dictionary, KeywordArg]) - The object to be cast. 1923 1924 Outputs: 1925 Object, its dtype is the same as `dtype` and shape is the same as 'input'. 1926 1927 Supported Platforms: 1928 ``Ascend`` ``GPU`` ``CPU`` 1929 1930 Examples: 1931 >>> import numpy as np 1932 >>> from mindspore import Tensor 1933 >>> from mindspore import dtype as mstype 1934 >>> from mindspore.ops.operations import _inner_ops as inner 1935 >>> x = Tensor(np.ones([2, 3], dtype=np.float32)) 1936 >>> out = inner.MixedPrecisionCast(mstype.float16, x) 1937 >>> print(out.dtype) 1938 Float16 1939 """ 1940 1941 @prim_attr_register 1942 def __init__(self): 1943 """Initialize MixedPrecisionCast""" 1944 self.init_prim_io_names(inputs=['dst_dtype', 'input_x'], outputs=['output']) 1945 self.cast = Cast() 1946 self.hyper_map = C.HyperMap() 1947 1948 def __call__(self, dst_dtype, x): 1949 def cast_inner(data): 1950 if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32, 1951 mstype.float64, mstype.bfloat16): 1952 return self.cast(data, dst_dtype) 1953 return data 1954 1955 return self.hyper_map(cast_inner, x) 1956 1957 1958class CheckBprop(PrimitiveWithInfer): 1959 """ 1960 Checks whether the data type and the shape of corresponding elements from tuples x and y are the same. 1961 1962 Args: 1963 prim_to_check (str): The name of the primitive being checked. Default: ''. 1964 1965 Inputs: 1966 - **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked. 1967 - **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against. 1968 1969 Outputs: 1970 Tuple[Tensor], the `input_x`, 1971 if data type and shape of corresponding elements from `input_x` and `input_y` are the same. 1972 1973 Raises: 1974 TypeError: If `input_x` or `input_y` is not a Tensor. 1975 1976 Supported Platforms: 1977 ``Ascend`` ``GPU`` ``CPU`` 1978 1979 Examples: 1980 >>> class Net(nn.Cell): 1981 ... def __init__(self): 1982 ... super(Net, self).__init__() 1983 ... self.op = ops.CheckBprop() 1984 ... def construct(self, x, y): 1985 ... return self.op(x, y) 1986 ... 1987 >>> net = Net() 1988 >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) 1989 >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),) 1990 >>> output = net(input_x, input_y) 1991 >>> print(output) 1992 (Tensor(shape=[2, 2], dtype=Float32, value= 1993 [[ 2.00000000e+00, 2.00000000e+00], 1994 [ 2.00000000e+00, 2.00000000e+00]]),) 1995 """ 1996 1997 @prim_attr_register 1998 def __init__(self, prim_to_check=""): 1999 """Initialize CheckBprop""" 2000 self.prim_to_check = prim_to_check 2001 2002 def infer_shape(self, xshapes, yshapes): 2003 """infer shape""" 2004 tips = f"user defined method 'bprop'" 2005 validator.check_value_type('grads', xshapes, (tuple,), tips) 2006 validator.check_value_type('params', yshapes, (tuple,), tips) 2007 if not len(xshapes) == len(yshapes): 2008 raise ValueError(f"For {tips} the number of return values(gradients) must be equal to " 2009 f"the number of input arguments except 'out' and 'dout', " 2010 f"which is:{len(yshapes)} but got {len(xshapes)}.") 2011 2012 def shape_equal(shape1, shape2): 2013 if len(shape1) != len(shape2): 2014 return False 2015 for shape_axis1, shape_axis2 in zip(shape1, shape2): 2016 if shape_axis1 == -1 or shape_axis2 == -1: 2017 continue 2018 if shape_axis1 != shape_axis2: 2019 return False 2020 return True 2021 2022 for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)): 2023 if not xshape or not yshape: 2024 continue 2025 2026 if not shape_equal(xshape, yshape): 2027 raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) " 2028 f"should have the same shape as the {i}th argument, " 2029 f"which is:{yshape}, but got: {xshape}.") 2030 return xshapes 2031 2032 def infer_dtype(self, xdtypes, ydtypes): 2033 """infer dtype""" 2034 tips = f"user defined method 'bprop'" 2035 validator.check_value_type('grads', xdtypes, (tuple,), tips) 2036 validator.check_value_type('params', ydtypes, (tuple,), tips) 2037 if not len(xdtypes) == len(ydtypes): 2038 raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to " 2039 f"the number of input arguments except 'out' and 'dout', " 2040 f"which is:{len(ydtypes)} but got {len(xdtypes)}.") 2041 checking_range = len(ydtypes) 2042 for i in range(checking_range): 2043 xdtype = xdtypes[i] 2044 ydtype = ydtypes[i] 2045 if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType): 2046 continue 2047 if isinstance(ydtype, mstype.FunctionType): 2048 if not isinstance(xdtype, mstype.EnvType): 2049 raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type " 2050 f"should be {mstype.EnvType}, but got {xdtype}.") 2051 if xdtype != ydtype: 2052 raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) " 2053 f"should have the same dtype as the {i}th argument, " 2054 f"which is:{ydtype}, but got: {xdtype}.") 2055 return xdtypes 2056 2057 2058check_bprop = CheckBprop() 2059 2060 2061class SameTypeShape(PrimitiveWithInfer): 2062 """ 2063 Checks whether the data type and shape of two tensors are the same. 2064 2065 Refer to :func:`mindspore.ops.same_type_shape` for more detail. 2066 2067 Supported Platforms: 2068 ``Ascend`` ``GPU`` ``CPU`` 2069 2070 Examples: 2071 >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32) 2072 >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32) 2073 >>> output = ops.SameTypeShape()(input_x, input_y) 2074 >>> print(output) 2075 [[2. 2.] 2076 [2. 2.]] 2077 """ 2078 2079 @prim_attr_register 2080 def __init__(self): 2081 """Initialize Same""" 2082 2083 def __call__(self, x, y): 2084 """run in PyNative mode""" 2085 validator.check_value_type('x', x, Tensor, self.name) 2086 validator.check_value_type('y', y, Tensor, self.name) 2087 validator.check('x dtype', x.dtype, 'y dtype', y.dtype, validator.EQ, self.name, TypeError) 2088 validator.check('x shape', x.shape, 'y shape', y.shape, validator.EQ, self.name) 2089 return x 2090 2091 def __infer__(self, x, y): 2092 validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name) 2093 validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name) 2094 validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError) 2095 validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name) 2096 return x 2097 2098 2099same_type_shape_ = SameTypeShape() 2100 2101 2102def _is_subclass_(type_, dtype): 2103 if not isinstance(type_, typing.Type): 2104 return False 2105 return typing.is_subclass(type_, dtype) 2106 2107 2108class IsSubClass(PrimitiveWithInfer): 2109 """ 2110 Checks whether this type is a sub-class of another type. 2111 2112 Inputs: 2113 - **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed. 2114 - **type_** (mindspore.dtype) - The target type. Only constant value is allowed. 2115 2116 Outputs: 2117 bool, the check result. 2118 2119 Raises: 2120 TypeError: If `sub_type` or `type_` is not a Type. 2121 2122 Supported Platforms: 2123 ``Ascend`` ``GPU`` ``CPU`` 2124 2125 Examples: 2126 >>> output = ops.IsSubClass()(mindspore.int32, mindspore.intc) 2127 >>> print(output) 2128 True 2129 """ 2130 2131 @prim_attr_register 2132 def __init__(self): 2133 pass 2134 2135 def __infer__(self, sub_type, type_): 2136 sub_type_t = sub_type['value'] 2137 type_v = type_['value'] 2138 2139 validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name) 2140 validator.check_value_type("type_", type_v, [mstype.Type], self.name) 2141 2142 value = _is_subclass_(sub_type_t, type_v) 2143 2144 out = {'shape': (), 2145 'dtype': mstype.type_type, 2146 'value': value} 2147 return out 2148 2149 2150issubclass_ = IsSubClass() 2151 2152 2153class IsInstance(PrimitiveWithInfer): 2154 """ 2155 Checks whether an object is an instance of a target type. 2156 2157 Inputs: 2158 - **inst** (Any Object) - The instance to be checked. Only constant value is allowed. 2159 - **type_** (mindspore.dtype) - The target type. Only constant value is allowed. 2160 2161 Outputs: 2162 bool, the check result. 2163 2164 Raises: 2165 TypeError: If `type_` is not a Type. 2166 2167 Supported Platforms: 2168 ``Ascend`` ``GPU`` ``CPU`` 2169 2170 Examples: 2171 >>> inst = 1 2172 >>> output = ops.IsInstance()(inst, mindspore.int32) 2173 >>> print(output) 2174 False 2175 """ 2176 2177 @prim_attr_register 2178 def __init__(self): 2179 pass 2180 2181 def __infer__(self, inst, type_): 2182 sub_type_t = inst['dtype'] 2183 type_v = type_['value'] 2184 2185 validator.check_value_type("type_", type_v, [mstype.Type], self.name) 2186 2187 if type_v == mstype.list_: 2188 value = isinstance(sub_type_t, list) 2189 elif type_v == mstype.tuple_: 2190 value = isinstance(sub_type_t, tuple) 2191 else: 2192 value = _is_subclass_(sub_type_t, type_v) 2193 2194 out = {'shape': (), 2195 'dtype': mstype.type_type, 2196 'value': value} 2197 return out 2198 2199 2200class ConvertToAdapterTensor(Primitive): 2201 """ 2202 Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type, 2203 where MSAdapter's Tensor is a subclass of MindSpore's Tensor. 2204 2205 Inputs: 2206 - **x** (Tensor) - The input tensor. 2207 2208 Outputs: 2209 A tensor, whose type is MSAdapter's Tensor. 2210 2211 Supported Platforms: 2212 ``Ascend`` ``GPU`` ``CPU`` 2213 2214 Examples: 2215 >>> x = Tensor([1, 2 ,3]) 2216 >>> x = ops.ConvertToAdapterTensor()(x) 2217 >>> print(x) 2218 [1 2 3] 2219 """ 2220 2221 @prim_attr_register 2222 def __init__(self): 2223 """Initialize""" 2224 2225 def __call__(self, x): 2226 """Run in PyNative mode""" 2227 return ms_adapter_registry.tensor(x, cast_tensor=True) 2228 2229 2230convert_to_adapter_tensor = ConvertToAdapterTensor() 2231 2232 2233class ConvertToMsTensor(Primitive): 2234 """ 2235 Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type, 2236 where MSAdapter's Tensor is a subclass of MindSpore's Tensor. 2237 2238 Inputs: 2239 - **x** (Tensor) - The input tensor. 2240 2241 Outputs: 2242 A tensor, whose type is MindSpore's Tensor. 2243 2244 Supported Platforms: 2245 ``Ascend`` ``GPU`` ``CPU`` 2246 2247 Examples: 2248 >>> x = Tensor([1, 2 ,3]) 2249 >>> x = ops.ConvertToMsTensor()(x) 2250 >>> print(x) 2251 [1 2 3] 2252 """ 2253 2254 @prim_attr_register 2255 def __init__(self): 2256 """Initialize""" 2257 2258 def __call__(self, x): 2259 """Run in PyNative mode""" 2260 if isinstance(x, StubTensor): 2261 return StubTensor(stub=x.stub, tensor=x.tensor) 2262 return ops.auto_generate.deepcopy(x) 2263 2264 2265convert_to_ms_tensor = ConvertToMsTensor() 2266 2267 2268class GetGrad(Primitive): 2269 """ 2270 Use the position id or Parameter object to get the gradient from the output 2271 which returned by the :func:`mindspore.ops.grad`. 2272 """ 2273 2274 @prim_attr_register 2275 def __init__(self): 2276 """Initialize ScatterElements""" 2277 self.init_prim_io_names( 2278 inputs=['gradients', 'x'], outputs=['gradient']) 2279 2280 def __call__(self, gradients, x): 2281 if not isinstance(x, int) and not isinstance(x, Parameter): 2282 raise TypeError( 2283 f"For `get_grad`, the `x` should be an integer or a Parameter, but got {x}") 2284 hash_id = x 2285 if isinstance(x, Parameter): 2286 hash_id = x.name 2287 output = None 2288 2289 def _get_grad(grads, identifier): 2290 if isinstance(grads, tuple): 2291 if len(grads) != 2 or identifier != grads[0]: 2292 for gradient in grads: 2293 _get_grad(gradient, identifier) 2294 else: 2295 nonlocal output 2296 output = grads[1] 2297 return 2298 2299 _get_grad(gradients, hash_id) 2300 if output is None: 2301 raise RuntimeError( 2302 f"Can not find the gradient for position or Parameter {x}") 2303 return output 2304 2305 2306class IsParameter(PrimitiveWithInfer): 2307 """ 2308 Check if input is `Parameter` 2309 """ 2310 2311 @prim_attr_register 2312 def __init__(self): 2313 """Initialize IsParameter""" 2314 2315 def __call__(self, x): 2316 return isinstance(x, Parameter) 2317 2318 def __infer__(self, x): 2319 return {'shape': [], 2320 'dtype': mstype.bool_, 2321 'value': isinstance(x['dtype'], mstype.RefType)} 2322 2323 2324class TileSize(Primitive): 2325 r""" 2326 Tile size for matmul 2327 """ 2328 2329 @prim_attr_register 2330 def __init__(self): 2331 """Initialize TileSize""" 2332 self.init_prim_io_names(inputs=['shape', 'out_shape', 'ndim'], outputs=['output']) 2333 2334 def __call__(self, shape, out_shape, ndim): 2335 size = [1] * ndim 2336 for idx, (i, j) in enumerate(zip(shape, out_shape)): 2337 if i != j: 2338 size[idx] = j 2339 return tuple(size) 2340 2341 2342class GetitemTensorIndexInfo(Primitive): 2343 r""" 2344 Get getitem tensor index info 2345 """ 2346 2347 @prim_attr_register 2348 def __init__(self, is_ascend): 2349 """Initialize GetitemTensorIndexInfo""" 2350 self.init_prim_io_names(inputs=['data', 'index'], 2351 outputs=["new_index", "tensor_update_types", "tensor_update_args"]) 2352 validator.check_value_type('is_ascend', is_ascend, [bool], self.name) 2353 self.is_ascend = is_ascend 2354 2355 def __call__(self, data, index): 2356 return Tensor_.getitem_index_info(data, index, self.is_ascend) 2357 2358 2359class SetitemTensorIndexInfo(Primitive): 2360 r""" 2361 Get setitem tensor index info 2362 """ 2363 2364 @prim_attr_register 2365 def __init__(self, is_ascend): 2366 """Initialize GetitemTensorIndexInfo""" 2367 self.init_prim_io_names( 2368 inputs=['data', 'index', 'value'], outputs=['new_index', 2369 'v_transfer_types', 2370 'v_transfer_args', 2371 'tensor_update_types', 2372 'tensor_update_args']) 2373 validator.check_value_type('is_ascend', is_ascend, [bool], self.name) 2374 self.is_ascend = is_ascend 2375 2376 def __call__(self, data, index, value): 2377 return Tensor_.setitem_index_info(data, index, value, self.is_ascend) 2378 2379 2380class IsConstant(Primitive): 2381 r""" 2382 Check if the input is constant 2383 """ 2384 2385 @prim_attr_register 2386 def __init__(self): 2387 """Initialize IsConstant""" 2388 2389 def __call__(self, x): 2390 return True 2391 2392 2393class SelectView(Primitive): 2394 r""" 2395 Select tensor of view 2396 """ 2397 2398 @prim_attr_register 2399 def __init__(self): 2400 self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output']) 2401 2402 2403class CopyWithSlice(Primitive): 2404 r""" 2405 Copy data to discontinuous tensor 2406 """ 2407 2408 @prim_attr_register 2409 def __init__(self): 2410 self.add_prim_attr('side_effect_mem', True) 2411 self.init_prim_io_names(inputs=['x', 'y'], outputs=['x']) 2412 2413 2414class FFN(Primitive): 2415 r""" 2416 The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul. 2417 2418 Args: 2419 activation (string): The activation type, set to 'fastgelu' or 'gelu'. 2420 Only support 'fastgelu' for now. Default: "fastgelu". 2421 inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance. 2422 Only support 1 for now. Default: 0. 2423 2424 Inputs: 2425 - **x** (Tensor) - The input tensor with data type of int8, float16. 2426 Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. 2427 - **weight1** (Tensor) - The weight1 tensor with data type of float16. 2428 Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`. 2429 - **weight2** (Tensor) - The weight2 tensor with data type of float16. 2430 Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`. 2431 - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64. 2432 Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)` 2433 indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens, 2434 the 2th expert do noting and so on. 2435 - **bias1** (Tensor) - The bias1 tensor with data type of float16. 2436 Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`. 2437 - **bias2** (Tensor) - The bias2 tensor with data type of float16. 2438 Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`. 2439 - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now. 2440 - **offset** (Tensor) - The offset tensor with data type of float16. Not enable now. 2441 - **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now. 2442 - **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now. 2443 2444 Outputs: 2445 Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16. 2446 2447 Supported Platforms: 2448 ``Ascend`` 2449 2450 Examples: 2451 >>> from mindspore.ops.operations import _inner_ops 2452 >>> b = 4 2453 >>> s = 128 2454 >>> h = 1024 2455 >>> h_f = 4 * h 2456 >>> e = 16 2457 >>> x = Tensor(np.random.randn(s, h).astype(np.float16)) 2458 >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16)) 2459 >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16)) 2460 >>> expert_tokens = Tensor(np.full(e, 8)) 2461 >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16)) 2462 >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16)) 2463 >>> ffn = _inner_ops.FFN("fastgelu", 1) 2464 >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2) 2465 >>> print(output) 2466 """ 2467 2468 @prim_attr_register 2469 def __init__(self, activation, inner_precise): 2470 """Initialize FFN.""" 2471 self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1", 2472 "bias2", "scale", "offset", "deq_scale1", "deq_scale2", 2473 "antiquant_scale1", "antiquant_scale2", 2474 "antiquant_offset1", "antiquant_offset2"], 2475 outputs=["y"]) 2476 cls_name = self.name 2477 validator.check_value_type("activation", activation, [str], cls_name) 2478 validator.check_value_type("inner_precise", inner_precise, [int], cls_name) 2479 2480 2481class _MirrorSilentCheck(PrimitiveWithInfer): 2482 """ 2483 The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator. 2484 Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection. 2485 2486 Inputs: 2487 - **input** (Tensor) : The tensor used for detection. 2488 Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16. 2489 - **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection. 2490 Please only generated by method generate_params() of ASDBase. 2491 - **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection. 2492 Please only generated by method generate_params() of ASDBase. 2493 - **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection. 2494 Please only generated by method generate_params() of ASDBase. 2495 - **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection. 2496 Please only generated by method generate_params() of ASDBase. 2497 After each invocation of _MirrorSilentCheck, increment the value of cnt by one. 2498 2499 Outputs: 2500 - **output** (Tensor) - Same shape, type and value as `input`. 2501 """ 2502 @prim_attr_register 2503 def __init__(self, min_steps=8): 2504 upper_thresh, sigma_thresh = self.get_thresh() 2505 self.min_steps = min_steps 2506 self.thresh_l1 = upper_thresh[0] 2507 self.coeff_l1 = sigma_thresh[0] 2508 self.thresh_l2 = upper_thresh[1] 2509 self.coeff_l2 = sigma_thresh[1] 2510 self.add_prim_attr('side_effect_mem', True) 2511 2512 def parse_thresh(self, env_var_name, default_value, min_value): 2513 env_var = os.environ.get(env_var_name, default=default_value) 2514 thresh = [value.strip() for value in env_var.split(",")] 2515 if len(thresh) != 2 or not all(value.isdigit() for value in thresh): 2516 thresh = default_value.split(",") 2517 thresh = [float(max(int(value), min_value)) for value in thresh] 2518 if thresh[0] <= thresh[1]: 2519 thresh = [float(value) for value in default_value.split(",")] 2520 2521 return thresh 2522 2523 def get_thresh(self): 2524 upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3) 2525 sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3) 2526 return upper_thresh, sigma_thresh 2527 2528 def infer_shape(self, x_shape, pre_shape, min_shape, max_shape, n_step, loss_scale_shape): 2529 return x_shape 2530 2531 def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype): 2532 return x_dtype 2533 2534 2535class _VirtualConverterEnd(PrimitiveWithInfer): 2536 """ 2537 Auto parallel virtual operator. 2538 """ 2539 2540 @prim_attr_register 2541 def __init__(self, input_nums): 2542 """Initialize _VirtualConverterEnd.""" 2543 self.input_nums = input_nums 2544 2545 def infer_shape(self, *args): 2546 return (args[0][0] * self.input_nums,) + tuple(args[0][1:]) 2547 2548 def infer_dtype(self, *args): 2549 return args[0] 2550 2551 2552class _VirtualConverterBegin(PrimitiveWithInfer): 2553 """ 2554 Auto parallel virtual operator. 2555 """ 2556 2557 @prim_attr_register 2558 def __init__(self, output_nums): 2559 """Initialize _VirtualConverterBegin.""" 2560 self.output_nums = output_nums 2561 2562 def infer_shape(self, arg): 2563 new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:]) 2564 return (new_arg,) * self.output_nums 2565 2566 def infer_dtype(self, arg): 2567 return (arg,) * self.output_nums 2568