1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Operators for nn.""" 16from __future__ import absolute_import 17from __future__ import division 18 19import numbers 20import math 21import numpy as np 22from mindspore.ops import signature as sig 23from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer 24from mindspore.ops._primitive_cache import _get_cache_prim 25from mindspore.ops.auto_generate import gen_arg_handler as handler 26from mindspore.common import Tensor, CSRTensor, COOTensor 27from mindspore.common._stub_tensor import _convert_stub 28from mindspore._c_expression import typing 29from mindspore._c_expression import Tensor as Tensor_ 30from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones 31from mindspore.common import dtype as mstype 32from mindspore.common._utils import is_shape_unknown 33from mindspore import _checkparam as validator 34from mindspore.ops.operations.manually_defined._inner import ScalarCast 35from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum 36from mindspore.common.initializer import Zero 37from mindspore.common.parameter import Parameter 38from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore 39 40 41dtype_to_type_id = DtypeToEnum() 42 43 44dtype_to_type_id = DtypeToEnum() 45 46 47class ScalarDiv(Primitive): 48 r""" 49 Computes the quotient of dividing the first input scalar by the second input scalar element-wise. 50 51 .. math:: 52 53 out_{i} = \frac{x_i}{y_i} 54 55 .. note:: 56 The inputs can be constant/variable value. Usage is the same as '/' in Python. 57 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 58 59 Inputs: 60 - **x** (Scalar) - A constant or variable scalar. 61 - **y** (Scalar) - A constant or variable scalar. 62 63 Outputs: 64 Scalar, the type of scalar is float. 65 66 Raises: 67 TypeError: If `x` and `y` are not scalar. 68 ValueError: If `y` is 0. 69 70 Supported Platforms: 71 ``Ascend`` ``GPU`` ``CPU`` 72 """ 73 @prim_attr_register 74 def __init__(self): 75 """Initialize ScalarDiv""" 76 77 def __call__(self, x, y): 78 if y == 0: 79 raise ValueError('The divisor could not be zero. But the divisor is zero now.') 80 return x / y 81 82 83class ScalarFloorDiv(Primitive): 84 r""" 85 Computes the quotient of dividing the first input scalar by the second input scalar element-wise. 86 87 .. math:: 88 89 out_{i} = \frac{x_i}{y_i} 90 91 .. note:: 92 The inputs can be constant/variable value. Usage is the same as '//' in Python. 93 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 94 95 Inputs: 96 - **x** (Scalar) - A constant or variable scalar. 97 - **y** (Scalar) - A constant or variable scalar. 98 99 Outputs: 100 Scalar, the type of scalar is float. 101 102 Raises: 103 TypeError: If `x` and `y` are not scalar. 104 ValueError: If `y` is 0. 105 106 Supported Platforms: 107 ``Ascend`` ``GPU`` ``CPU`` 108 """ 109 @prim_attr_register 110 def __init__(self): 111 """Initialize ScalarFloorDiv""" 112 self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) 113 114 def __call__(self, x, y): 115 if y == 0: 116 raise ValueError('The divisor could not be zero. But the divisor is zero now.') 117 return x // y 118 119 120class ScalarAdd(Primitive): 121 r""" 122 Adds two input scalar. 123 124 .. note:: 125 The inputs can be constant/variable value. Usage is the same as '+' in Python. 126 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 127 128 Inputs: 129 - **x** (Scalar) - A constant or variable scalar. 130 - **y** (Scalar) - A constant or variable scalar. 131 132 Outputs: 133 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 134 135 Raises: 136 TypeError: If `x` and `y` are not scalar. 137 138 Supported Platforms: 139 ``Ascend`` ``GPU`` ``CPU`` 140 """ 141 @prim_attr_register 142 def __init__(self): 143 """Initialize ScalarAdd""" 144 145 def __call__(self, x, y): 146 return x + y 147 148 149class ScalarPow(Primitive): 150 r""" 151 Pow two input scalar. 152 153 .. note:: 154 The inputs can be constant/variable value. Usage is the same as '+' in Python. 155 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 156 157 Inputs: 158 - **x** (Scalar) - A constant or variable scalar. 159 - **y** (Scalar) - A constant or variable scalar. 160 161 Outputs: 162 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 163 164 Raises: 165 TypeError: If `x` and `y` are not scalar. 166 167 Supported Platforms: 168 ``Ascend`` ``GPU`` ``CPU`` 169 """ 170 @prim_attr_register 171 def __init__(self): 172 """Initialize ScalarPow""" 173 174 def __call__(self, x, y): 175 return pow(x, y) 176 177 178class ScalarLog(Primitive): 179 r""" 180 Log input scalar. 181 182 .. note:: 183 The inputs can be constant/variable value. Usage is the same as '+' in Python. 184 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 185 186 Inputs: 187 - **x** (Scalar) - A constant or variable scalar. 188 189 Outputs: 190 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 191 192 Raises: 193 TypeError: If `x` and `y` are not scalar. 194 195 Supported Platforms: 196 ``Ascend`` ``GPU`` ``CPU`` 197 """ 198 @prim_attr_register 199 def __init__(self): 200 """Initialize ScalarAdd""" 201 202 def __call__(self, x): 203 return math.log(x) 204 205 206class ScalarUadd(Primitive): 207 r""" 208 UAdds input scalar. 209 210 .. note:: 211 The inputs can be constant/variable value. Usage is the same as '+' in Python. 212 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 213 214 Inputs: 215 - **x** (Scalar) - A constant or variable scalar. 216 217 Outputs: 218 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 219 220 Raises: 221 TypeError: If `x` and `y` are not scalar. 222 223 Supported Platforms: 224 ``Ascend`` ``GPU`` ``CPU`` 225 """ 226 @prim_attr_register 227 def __init__(self): 228 """Initialize ScalarAdd""" 229 230 def __call__(self, x): 231 return x 232 233 234class ScalarUsub(Primitive): 235 r""" 236 usub input scalar. 237 238 .. note:: 239 The inputs can be constant/variable value. Usage is the same as '+' in Python. 240 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 241 242 Inputs: 243 - **x** (Scalar) - A constant or variable scalar. 244 - **y** (Scalar) - A constant or variable scalar. 245 246 Outputs: 247 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 248 249 Raises: 250 TypeError: If `x` and `y` are not scalar. 251 252 Supported Platforms: 253 ``Ascend`` ``GPU`` ``CPU`` 254 """ 255 @prim_attr_register 256 def __init__(self): 257 """Initialize ScalarUsub""" 258 259 def __call__(self, x): 260 return -x 261 262 263class ScalarSub(Primitive): 264 r""" 265 Subtracts the second input Scalar from the first input Scalar. 266 267 .. note:: 268 The inputs can be constant/variable value. Usage is the same as '-' in Python. 269 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 270 271 Inputs: 272 - **x** (Scalar) - A constant or variable scalar. 273 - **y** (Scalar) - A constant or variable scalar. 274 275 Outputs: 276 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 277 278 Raises: 279 TypeError: If `x` and `y` are not scalar. 280 281 Supported Platforms: 282 ``Ascend`` ``GPU`` ``CPU`` 283 """ 284 @prim_attr_register 285 def __init__(self): 286 """Initialize ScalarSub""" 287 288 def __call__(self, x, y): 289 return x - y 290 291 292class ScalarMul(Primitive): 293 r""" 294 Muls two input scalar. 295 296 .. note:: 297 The inputs can be constant/variable value. Usage is the same as '+' in Python. 298 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 299 300 Inputs: 301 - **x** (Scalar) - A constant or variable scalar. 302 - **y** (Scalar) - A constant or variable scalar. 303 304 Outputs: 305 Scalar, and the data type is the one with higher precision or higher digits among the two inputs. 306 307 Raises: 308 TypeError: If `x` and `y` are not scalar. 309 310 Supported Platforms: 311 ``Ascend`` ``GPU`` ``CPU`` 312 """ 313 @prim_attr_register 314 def __init__(self): 315 """Initialize ScalarMul""" 316 317 def __call__(self, x, y): 318 return x * y 319 320 321class ScalarEq(Primitive): 322 r""" 323 Computes the equivalence between two Scalars. 324 325 .. note:: 326 The inputs can be constant/variable value. Usage is the same as '==' in Python. 327 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 328 329 Inputs: 330 - **x** (Scalar) - A constant or variable scalar. 331 - **y** (Scalar) - A constant or variable scalar. 332 333 Outputs: 334 Scalar, the type of scalar is bool. 335 336 Raises: 337 TypeError: If `x` and `y` are not scalar. 338 339 Supported Platforms: 340 ``Ascend`` ``GPU`` ``CPU`` 341 """ 342 @prim_attr_register 343 def __init__(self): 344 """Initialize ScalarEq""" 345 346 def __call__(self, x, y): 347 return x == y 348 349 350class ScalarGt(Primitive): 351 r""" 352 Compare the value of the input scalars :math:`x,y`, and the output result is a bool value. 353 354 .. note:: 355 The inputs can be constant/variable value. Usage is the same as '>' in Python. 356 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 357 358 Inputs: 359 - **x** (Scalar) - A constant or variable scalar. 360 - **y** (Scalar) - A constant or variable scalar. 361 362 Outputs: 363 Scalar, the type of scalar is bool. 364 365 Raises: 366 TypeError: If `x` and `y` are not scalar. 367 368 Supported Platforms: 369 ``Ascend`` ``GPU`` ``CPU`` 370 """ 371 @prim_attr_register 372 def __init__(self): 373 """Initialize scalar_gt""" 374 375 def __call__(self, x, y): 376 return x > y 377 378 379class ScalarLt(Primitive): 380 r""" 381 Computes the boolean value of :math:`x < y`. 382 383 .. note:: 384 The inputs can be constant/variable value. Usage is the same as '<' in Python. 385 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 386 387 Inputs: 388 - **x** (Scalar) - A constant or variable scalar. 389 - **y** (Scalar) - A constant or variable scalar. 390 391 Outputs: 392 Scalar, the type of scalar is bool. 393 394 Raises: 395 TypeError: If `x` and `y` are not scalar. 396 397 Supported Platforms: 398 ``Ascend`` ``GPU`` ``CPU`` 399 """ 400 @prim_attr_register 401 def __init__(self): 402 """Initialize scalar_lt""" 403 404 def __call__(self, x, y): 405 return x < y 406 407 408class ScalarGe(Primitive): 409 r""" 410 Compare the value of the input scalars :math:`x,y`, and the output result is a bool value. 411 412 .. note:: 413 The inputs can be constant/variable value. Usage is the same as '>=' in Python. 414 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 415 416 Inputs: 417 - **x** (Scalar) - A constant or variable scalar. 418 - **y** (Scalar) - A constant or variable scalar. 419 420 Outputs: 421 Scalar, the type of scalar is bool. 422 423 Raises: 424 TypeError: If `x` and `y` are not scalar. 425 426 Supported Platforms: 427 ``Ascend`` ``GPU`` ``CPU`` 428 """ 429 @prim_attr_register 430 def __init__(self): 431 """Initialize scalar_ge""" 432 433 def __call__(self, x, y): 434 return x >= y 435 436 437class ScalarLe(Primitive): 438 r""" 439 Compare the value of the input scalars :math:`x,y`, and the output result is a bool value. 440 441 .. note:: 442 The inputs can be constant/variable value. Usage is the same as '<=' in Python. 443 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 444 445 Inputs: 446 - **x** (Scalar) - A constant or variable scalar. 447 - **y** (Scalar) - A constant or variable scalar. 448 449 Outputs: 450 Scalar, the type of scalar is bool. 451 452 Raises: 453 TypeError: If `x` and `y` are not scalar. 454 455 Supported Platforms: 456 ``Ascend`` ``GPU`` ``CPU`` 457 """ 458 @prim_attr_register 459 def __init__(self): 460 """Initialize scalar_le""" 461 462 def __call__(self, x, y): 463 return x <= y 464 465 466class ScalarMod(Primitive): 467 r""" 468 Computes the remainder of dividing the first input scalar by the second input scalar element-wise. 469 470 .. math:: 471 472 out_{i} = x_{i} \text{ % } y_{i} 473 474 .. note:: 475 The inputs can be constant/variable value. Usage is the same as '%' in Python. 476 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 477 478 Inputs: 479 - **x** (Scalar) - A constant or variable scalar. 480 - **y** (Scalar) - A constant or variable scalar. 481 482 Outputs: 483 Scalar, the type is the one with higher precision or higher digits among the two inputs. 484 485 Raises: 486 TypeError: If `x` and `y` are not scalar. 487 488 Supported Platforms: 489 ``Ascend`` ``GPU`` ``CPU`` 490 """ 491 @prim_attr_register 492 def __init__(self): 493 """Initialize ScalarMod""" 494 495 def __call__(self, x, y): 496 if y == 0: 497 raise ValueError('Cannot perform modulo operation on zero.') 498 return x % y 499 500 501class ScalarBool(Primitive): 502 r""" 503 Computes the input scalar true or false. 504 505 .. note:: 506 The inputs can be constant/variable value. 507 This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous. 508 509 Inputs: 510 - **x** (Scalar) - A constant or variable scalar. 511 512 Outputs: 513 Scalar, the type is bool. 514 515 Raises: 516 TypeError: If `x` are not scalar. 517 518 Supported Platforms: 519 ``Ascend`` ``GPU`` ``CPU`` 520 """ 521 @prim_attr_register 522 def __init__(self): 523 """Initialize ScalarBool""" 524 525 def __call__(self, x): 526 return bool(x) 527 528 529scalar_div = ScalarDiv() 530scalar_mod = ScalarMod() 531scalar_add = ScalarAdd() 532scalar_mul = ScalarMul() 533scalar_sub = ScalarSub() 534scalar_gt = ScalarGt() 535scalar_ge = ScalarGe() 536scalar_le = ScalarLe() 537scalar_lt = ScalarLt() 538scalar_eq = ScalarEq() 539scalar_bool = ScalarBool() 540scalar_floordiv = ScalarFloorDiv() 541scalar_log = ScalarLog() 542scalar_pow = ScalarPow() 543scalar_uadd = ScalarUadd() 544scalar_usub = ScalarUsub() 545 546 547class BatchNorm(Primitive): 548 r""" 549 Batch Normalization for input data and updated parameters. 550 551 Batch Normalization is widely used in convolutional neural networks. This operation 552 applies Batch Normalization over inputs to avoid internal covariate shift as described 553 in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal 554 Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the 555 features using a mini-batch of data and the learned parameters can be described 556 in the following formula, 557 558 .. math:: 559 560 y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta 561 562 where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon, 563 :math:`mean` is the mean of :math:`x`, 564 :math:`variance` is the variance of :math:`x`. 565 566 .. warning:: 567 - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available, 568 then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance". 569 - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. 570 571 Args: 572 is_training (bool): If `is_training` is ``True`` , `mean` and `variance` are computed during training. 573 If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` . 574 epsilon (float): A small value added for numerical stability. Default: ``1e-5``, value must be (0, 1] . 575 momentum (float): The hyper parameter to compute moving average for running_mean and running_var 576 (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). 577 Momentum value must be [0, 1]. Default: ``0.1`` . 578 data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``, and the ``'NHWC'`` format 579 is only supported in GPU target. Default: ``"NCHW"`` . 580 581 Inputs: 582 If `is_training` is ``False`` , inputs are Tensors. 583 584 - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. 585 - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. 586 - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. 587 - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. 588 - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. 589 590 If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters. 591 592 - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. 593 - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type. 594 - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. 595 - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. 596 - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. 597 598 Outputs: 599 Tuple of 5 Tensors, the normalized inputs and the updated parameters. 600 601 - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`. 602 - **batch_mean** (Tensor) - The mean calculated per-dimension over the mini-batches, 603 shape is :math:`(C,)`. 604 - **batch_variance** (Tensor) - The variance calculated per-dimension over the mini-batches, 605 shape is :math:`(C,)`. 606 - **reserve_space_1** (Tensor) - The mean that needs to be reused when calculating gradients, 607 one-dimensional Tensor. The shape is :math:`(C,)`. 608 - **reserve_space_2** (Tensor) - The variance that needs to be reused when calculating gradients, 609 one-dimensional Tensor. The shape is :math:`(C,)`. 610 611 Raises: 612 TypeError: If `is_training` is not a bool. 613 TypeError: If dtype of `epsilon` or `momentum` is not float. 614 TypeError: If `data_format` is not a str. 615 TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor. 616 TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32. 617 618 Supported Platforms: 619 ``Ascend`` ``GPU`` ``CPU`` 620 621 Examples: 622 >>> import mindspore 623 >>> import numpy as np 624 >>> from mindspore import Tensor, ops 625 >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) 626 >>> scale = Tensor(np.ones([2]), mindspore.float32) 627 >>> bias = Tensor(np.ones([2]), mindspore.float32) 628 >>> mean = Tensor(np.ones([2]), mindspore.float32) 629 >>> variance = Tensor(np.ones([2]), mindspore.float32) 630 >>> batch_norm = ops.BatchNorm() 631 >>> output = batch_norm(input_x, scale, bias, mean, variance) 632 >>> print(output[0]) 633 [[1. 1.] 634 [1. 1.]] 635 """ 636 __mindspore_signature__ = (sig.make_sig('input_x', dtype=sig.sig_dtype.T1), 637 sig.make_sig('scale', 638 sig.sig_rw.RW_WRITE, 639 dtype=sig.sig_dtype.T2), 640 sig.make_sig('bias', 641 sig.sig_rw.RW_WRITE, 642 dtype=sig.sig_dtype.T2), 643 sig.make_sig('mean', 644 sig.sig_rw.RW_WRITE, 645 dtype=sig.sig_dtype.T3), 646 sig.make_sig('variance', 647 sig.sig_rw.RW_WRITE, 648 dtype=sig.sig_dtype.T3)) 649 650 @prim_arg_register 651 def __init__(self, 652 is_training=False, 653 epsilon=1e-5, 654 momentum=0.1, 655 data_format="NCHW"): 656 """Initialize BatchNorm.""" 657 if is_training is False: 658 self.set_signatures(tuple()) 659 else: 660 self.add_prim_attr('side_effect_mem', True) 661 self.is_training = is_training 662 self.epsilon = epsilon 663 self.momentum = momentum 664 self.data_format = handler.str_to_enum("BatchNorm", "data_format", data_format) 665 666 def __call__(self, *args): 667 return super().__call__(*args, self.is_training, self.epsilon, 668 self.momentum, self.data_format) 669 670 671def batch_norm_(input_x, 672 scale, 673 bias, 674 mean, 675 variance, 676 is_training=False, 677 epsilon=1e-5, 678 momentum=0.1, 679 data_format="NCHW"): 680 r""" 681 Batch Normalization for input data and updated parameters. 682 683 Batch Normalization is widely used in convolutional neural networks. This operation 684 applies Batch Normalization over inputs to avoid internal covariate shift as described 685 in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal 686 Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the 687 features using a mini-batch of data and the learned parameters can be described 688 in the following formula, 689 690 .. math:: 691 692 y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta 693 694 where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon, 695 :math:`mean` is the mean of :math:`x`, 696 :math:`variance` is the variance of :math:`x`. 697 698 .. warning:: 699 - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available, 700 then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance". 701 - For Atlas 200/300/500 inference product, 702 the result accuracy fails to reach 1‰ due to the square root instruction. 703 704 Note: 705 - If `training` is `False`, `weight`, `bias`, `running_mean` and `running_var` are tensors. 706 - If `training` is `True`, `weight`, `bias`, `running_mean` and `running_var` are Parameters. 707 708 Args: 709 input_x (tensor): tensor of shape :math:`(N, C)`, with float16 or float32 data type. 710 scale (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`. 711 bias (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`. 712 mean (Union[tensor, Parameter]): The shape :math:`(C,)`, with float16 or float32 data type. 713 variance (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`. 714 is_training (bool, optional): If `training` is `True`, `mean` and `variance` are computed during training. 715 If `training` is `False`, they're loaded from checkpoint during inference. Default: False. 716 epsilon (float): A small value added for numerical stability. 717 Default: ``1e-5``, value must be (0, 1] . 718 momentum (float): The hyper parameter to compute moving average for running_mean and running_var 719 (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). 720 Momentum value must be [0, 1]. 721 Default: ``0.1`` . 722 data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``, 723 and the ``'NHWC'`` format is only supported in GPU target. 724 Default: ``"NCHW"`` . 725 726 Returns: 727 output_x (Tensor): The same type and shape as the input_x. The shape is :math:`(N, C)`. 728 batch_mean (Tensor): Tensor of shape :math:`(C,)`. 729 batch_variance (Tensor): Tensor of shape :math:`(C,)`. 730 reserve_space_1 (Tensor): Tensor of shape :math:`(C,)`. 731 reserve_space_2 (Tensor): Tensor of shape :math:`(C,)`. 732 733 Raises: 734 TypeError: If `is_training` is not a bool. 735 TypeError: If dtype of `epsilon` or `momentum` is not float. 736 TypeError: If `data_format` is not a str. 737 TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor. 738 TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32. 739 740 Supported Platforms: 741 ``Ascend`` ``GPU`` ``CPU`` 742 743 Examples: 744 >>> import mindspore 745 >>> import numpy as np 746 >>> from mindspore import Tensor, ops 747 >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) 748 >>> scale = Tensor(np.ones([2]), mindspore.float32) 749 >>> bias = Tensor(np.ones([2]), mindspore.float32) 750 >>> mean = Tensor(np.ones([2]), mindspore.float32) 751 >>> variance = Tensor(np.ones([2]), mindspore.float32) 752 >>> output = ops.batch_norm_(input_x, scale, bias, mean, variance, is_training, epsilon, momentum, data_format) 753 >>> print(output[0]) 754 [[1. 1.] 755 [1. 1.]] 756 """ 757 batch_norm_op = _get_cache_prim(BatchNorm)(is_training, epsilon, momentum, 758 data_format) 759 return batch_norm_op(input_x, scale, bias, mean, variance) 760 761 762class Rank(Primitive): 763 """ 764 Returns the rank of a tensor. 765 766 Refer to :func:`mindspore.ops.rank` for more details. 767 768 Supported Platforms: 769 ``Ascend`` ``GPU`` ``CPU`` 770 771 Examples: 772 >>> import mindspore 773 >>> import numpy as np 774 >>> from mindspore import Tensor, ops 775 >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32) 776 >>> rank = ops.Rank() 777 >>> output = rank(input_tensor) 778 >>> print(output) 779 2 780 >>> print(type(output)) 781 <class 'int'> 782 """ 783 784 @prim_attr_register 785 def __init__(self): 786 """Initialize Rank""" 787 788 def __call__(self, x): 789 if not isinstance(x, (Tensor, Tensor_)): 790 raise TypeError("the input x must be Tensor!") 791 return len(x.shape) 792 793 794def rank(input_x): 795 """ 796 Returns the rank of a tensor. 797 798 Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor 799 is the number of indices required to uniquely select each element of the tensor. 800 801 Args: 802 input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number. 803 804 Returns: 805 Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int. 806 807 Raises: 808 TypeError: If `input_x` is not a Tensor. 809 810 Supported Platforms: 811 ``Ascend`` ``GPU`` ``CPU`` 812 813 Examples: 814 >>> import mindspore 815 >>> import numpy as np 816 >>> from mindspore import Tensor, ops 817 >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32) 818 >>> output = ops.rank(input_tensor) 819 >>> print(output) 820 2 821 >>> print(type(output)) 822 <class 'int'> 823 824 """ 825 rank_op = _get_cache_prim(Rank)() 826 return rank_op(input_x) 827 828 829class Shape(Primitive): 830 """ 831 Returns the shape of the input tensor. 832 833 Refer to :func:`mindspore.ops.shape` for more details. 834 835 Inputs: 836 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 837 838 Outputs: 839 tuple[int], the output tuple is constructed by multiple integers, 840 :math:`(x_1, x_2, ..., x_R)`. 841 842 Supported Platforms: 843 ``Ascend`` ``GPU`` ``CPU`` 844 845 Examples: 846 >>> import mindspore 847 >>> import numpy as np 848 >>> from mindspore import Tensor, ops 849 >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32) 850 >>> shape = ops.Shape() 851 >>> output = shape(input_x) 852 >>> print(output) 853 (3, 2, 1) 854 """ 855 856 @prim_attr_register 857 def __init__(self): 858 """Initialize Shape""" 859 860 def __call__(self, x): 861 if isinstance(x, (Tensor, COOTensor, CSRTensor, Tensor_)): 862 return x.shape 863 raise TypeError(f"For primitive[{self.name}], the input argument must be Tensor, but got {type(x)}.") 864 865 866def shape_(input_x): 867 """ 868 Returns the shape of the input tensor. 869 870 Args: 871 input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 872 873 Returns: 874 tuple[int], the output tuple is constructed by multiple integers, 875 :math:`(x_1, x_2, ..., x_R)`. 876 877 Raises: 878 TypeError: If `input_x` is not a Tensor. 879 880 Supported Platforms: 881 ``Ascend`` ``GPU`` ``CPU`` 882 883 Examples: 884 >>> import mindspore 885 >>> import numpy as np 886 >>> from mindspore import Tensor, ops 887 >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32) 888 >>> output = ops.shape(input_x) 889 >>> print(output) 890 (3, 2, 1) 891 """ 892 shape_op = _get_cache_prim(Shape)() 893 return shape_op(input_x) 894 895 896class ScalarToTensor(PrimitiveWithInfer): 897 """ 898 Converts a scalar to a `Tensor`, and converts the data type to the specified type. 899 900 Refer to :func:`mindspore.ops.scalar_to_tensor` for more details. 901 902 Inputs: 903 - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed. 904 - **dtype** (mindspore.dtype) - The target data type. Default: ``mindspore.float32`` . Only 905 constant value is allowed. 906 907 Outputs: 908 Tensor. 0-D Tensor and the content is the input. 909 910 Supported Platforms: 911 ``Ascend`` ``GPU`` ``CPU`` 912 913 Examples: 914 >>> import mindspore 915 >>> from mindspore import ops 916 >>> op = ops.ScalarToTensor() 917 >>> data = 1 918 >>> output = op(data, mindspore.float32) 919 >>> print(output) 920 1.0 921 """ 922 923 @prim_attr_register 924 def __init__(self): 925 self.init_prim_io_names(inputs=['input_scalar', 'dtype'], outputs=['output_data']) 926 927 def __call__(self, x, dtype=mstype.float32): 928 validator.check_value_type("x", x, [bool, int, float], self.name) 929 validator.check_subclass("dtype", dtype, mstype.number, self.name) 930 data_type = mstype.dtype_to_nptype(dtype) 931 return Tensor(np.array(x, data_type), dtype=dtype) 932 933 934class Tile(Primitive): 935 r""" 936 Replicates an input tensor with given multiple times. 937 938 Refer to :func:`mindspore.ops.tile` for more details. 939 940 Inputs: 941 - **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as 942 :math:`(x_1, x_2, ..., x_S)` . 943 - **dims** (tuple[int]) - The parameter that specifies the number of replications, 944 the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`. 945 Only constant value is allowed. 946 947 Outputs: 948 Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`, 949 the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`. 950 951 - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and 952 the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`. 953 - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent. 954 Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`, 955 then the shape of their corresponding positions can be multiplied, and the shape of Outputs is 956 :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`. 957 - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the 958 `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions 959 can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`. 960 961 Raises: 962 TypeError: If `dims` is not a tuple or its elements are not all int. 963 ValueError: If the elements of `dims` are not all greater than or equal to 0. 964 965 Supported Platforms: 966 ``Ascend`` ``GPU`` ``CPU`` 967 968 Examples: 969 >>> import mindspore 970 >>> import numpy as np 971 >>> from mindspore import Tensor, ops 972 >>> tile = ops.Tile() 973 >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32) 974 >>> dims = (2, 3) 975 >>> output = tile(input, dims) 976 >>> print(output) 977 [[1. 2. 1. 2. 1. 2.] 978 [3. 4. 3. 4. 3. 4.] 979 [1. 2. 1. 2. 1. 2.] 980 [3. 4. 3. 4. 3. 4.]] 981 >>> dims = (2, 3, 2) 982 >>> output = tile(input, dims) 983 >>> print(output) 984 [[[1. 2. 1. 2.] 985 [3. 4. 3. 4.] 986 [1. 2. 1. 2.] 987 [3. 4. 3. 4.] 988 [1. 2. 1. 2.] 989 [3. 4. 3. 4.]] 990 [[1. 2. 1. 2.] 991 [3. 4. 3. 4.] 992 [1. 2. 1. 2.] 993 [3. 4. 3. 4.] 994 [1. 2. 1. 2.] 995 [3. 4. 3. 4.]]] 996 """ 997 998 @prim_attr_register 999 def __init__(self): 1000 """Initialize.""" 1001 1002 def __call__(self, input, dims): 1003 return _convert_stub(pyboost_tile(self, [input, dims])) 1004 1005 # pylint: disable=missing-docstring 1006 def check_elim(self, *args): 1007 base_tensor, dims = args 1008 if not isinstance(base_tensor, Tensor): 1009 raise TypeError(f"For '{self.name}', the type of 'input' must be Tensor, " 1010 f"but got {type(base_tensor).__name__}.") 1011 if not isinstance(dims, tuple): 1012 raise TypeError(f"For '{self.name}', the type of 'dims' must be tuple, " 1013 f"but got {type(dims).__name__}.") 1014 1015 if all(v == 1 for v in dims) and len(base_tensor.shape) >= len(dims): 1016 from mindspore.ops.auto_generate.gen_ops_def import Identity 1017 ret = Identity()(base_tensor) 1018 return (True, ret) 1019 return (False, None) 1020 1021 1022def tile(input, dims): 1023 r""" 1024 Creates a new tensor by replicating `input` `dims` times. The i'th dimension of 1025 output tensor has `input.shape[i] * dims[i]` elements, and the values of `input` 1026 are replicated `dims[i]` times along the i'th dimension. 1027 1028 Args: 1029 input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as 1030 :math:`(x_1, x_2, ..., x_S)` . 1031 1032 dims (tuple[int]): The parameter that specifies the number of replications, 1033 the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`. 1034 Only constant value is allowed. 1035 1036 Returns: 1037 Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`, 1038 the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`. 1039 1040 - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and 1041 the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`. 1042 - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent. 1043 Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`, 1044 then the shape of their corresponding positions can be multiplied, and the shape of Outputs is 1045 :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`. 1046 - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the 1047 `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions 1048 can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`. 1049 1050 Raises: 1051 TypeError: If `dims` is not a tuple or its elements are not all int. 1052 ValueError: If the elements of `dims` are not all greater than or equal to 0. 1053 1054 Supported Platforms: 1055 ``Ascend`` ``GPU`` ``CPU`` 1056 1057 Examples: 1058 >>> import mindspore 1059 >>> import numpy as np 1060 >>> from mindspore import Tensor, ops 1061 >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32) 1062 >>> dims = (2, 3) 1063 >>> output = ops.tile(input, dims) 1064 >>> print(output) 1065 [[1. 2. 1. 2. 1. 2.] 1066 [3. 4. 3. 4. 3. 4.] 1067 [1. 2. 1. 2. 1. 2.] 1068 [3. 4. 3. 4. 3. 4.]] 1069 >>> dims = (2, 3, 2) 1070 >>> output = ops.tile(input, dims) 1071 >>> print(output) 1072 [[[1. 2. 1. 2.] 1073 [3. 4. 3. 4.] 1074 [1. 2. 1. 2.] 1075 [3. 4. 3. 4.] 1076 [1. 2. 1. 2.] 1077 [3. 4. 3. 4.]] 1078 [[1. 2. 1. 2.] 1079 [3. 4. 3. 4.] 1080 [1. 2. 1. 2.] 1081 [3. 4. 3. 4.] 1082 [1. 2. 1. 2.] 1083 [3. 4. 3. 4.]]] 1084 """ 1085 tile_op = _get_cache_prim(Tile)() 1086 return tile_op(input, dims) 1087 1088 1089def scalar_cast(input_x, input_y): 1090 r""" 1091 The interface is deprecated from version 2.3 and will be removed in a future version, 1092 please use `int(x)` or `float(x)` instead. 1093 1094 Casts the input scalar to another type. 1095 1096 Args: 1097 input_x (scalar): The input scalar. 1098 input_y (mindspore.dtype): The type to be cast. Only constant value is allowed. 1099 The value should only be mindspore.int64, mindspore.float64, or mindspore.bool\_. 1100 1101 Returns: 1102 Scalar, the type is the same as the python type corresponding to `input_y`. 1103 1104 Raises: 1105 ValueError: if input_y's value is invalid. 1106 1107 Supported Platforms: 1108 Deprecated 1109 1110 Examples: 1111 >>> import mindspore 1112 >>> from mindspore import ops 1113 >>> output = ops.scalar_cast(255.0, mindspore.int64) 1114 >>> print(output) 1115 255 1116 """ 1117 scalar_cast_op = _get_cache_prim(ScalarCast)() 1118 return scalar_cast_op(input_x, input_y) 1119 1120 1121class Cast(Primitive): 1122 """ 1123 Returns a tensor with the new specified data type. 1124 1125 Note: 1126 When converting complex numbers to boolean type, the imaginary part of the complex number is not 1127 taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False. 1128 1129 Inputs: 1130 - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 1131 The tensor to be cast. 1132 - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed. 1133 1134 Outputs: 1135 Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`. 1136 1137 Raises: 1138 TypeError: If `input_x` is neither Tensor nor Number. 1139 TypeError: If `type` is not a Number. 1140 1141 Supported Platforms: 1142 ``Ascend`` ``GPU`` ``CPU`` 1143 1144 Examples: 1145 >>> import mindspore 1146 >>> import numpy as np 1147 >>> from mindspore import Tensor, ops 1148 >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 1149 >>> input_x = Tensor(input_np) 1150 >>> type_dst = mindspore.int32 1151 >>> cast = ops.Cast() 1152 >>> output = cast(input_x, type_dst) 1153 >>> print(output.dtype) 1154 Int32 1155 >>> print(output.shape) 1156 (2, 3, 4, 5) 1157 """ 1158 1159 @prim_attr_register 1160 def __init__(self): 1161 """Initialize Cast""" 1162 self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) 1163 1164 def check_elim(self, x, dtype): 1165 if isinstance(x, (Tensor, numbers.Number, Parameter)): 1166 if isinstance(x, Parameter): 1167 data = x.data 1168 if data.dtype == dtype: 1169 return (True, x) 1170 if isinstance(x, Tensor) and x.dtype == dtype: 1171 x = Tensor(x) 1172 x.set_cast_dtype() 1173 return (True, x) 1174 if isinstance(x, numbers.Number): 1175 return (True, Tensor(x, dtype=dtype)) 1176 return (False, None) 1177 1178 def __call__(self, input_x, dtype): 1179 should_elim, output = self.check_elim(input_x, dtype) 1180 if should_elim: 1181 return output 1182 return _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)])) 1183 1184 1185def to_sequence(val): 1186 """ 1187 to_sequence 1188 """ 1189 if isinstance(val, (tuple, list)): 1190 return val 1191 return (val,) 1192 1193 1194class EmbeddingTableExport(Primitive): 1195 """ 1196 EmbeddingTableExport 1197 """ 1198 1199 @prim_attr_register 1200 def __init__(self, embedding_dim, value_total_len, export_mode="all", 1201 only_var_flag=False, file_type="bin", table_name=(), 1202 filter_export_flag=False, steps_to_live_list=()): 1203 """Initialize EmbeddingTableExport""" 1204 self.add_prim_attr("_process_node_engine_id", "PS") 1205 1206 1207class EmbeddingTableImport(Primitive): 1208 """ 1209 EmbeddingTableImport 1210 """ 1211 1212 @prim_attr_register 1213 def __init__(self, embedding_dim, value_total_len, 1214 only_var_flag=False, file_type="bin", table_name=()): 1215 """Initialize EmbeddingTableImport""" 1216 self.add_prim_attr("_process_node_engine_id", "PS") 1217 1218 1219class EmbeddingComputeVarImport(Primitive): 1220 """ 1221 EmbeddingComputeVarImport 1222 """ 1223 1224 @prim_attr_register 1225 def __init__(self, table_name=()): 1226 """Initialize EmbeddingComputeVarImport""" 1227 self.add_prim_attr("_process_node_engine_id", "PS") 1228 1229 1230class EmbeddingComputeVarExport(Primitive): 1231 """ 1232 EmbeddingComputeVarExport 1233 """ 1234 1235 @prim_attr_register 1236 def __init__(self, table_name=()): 1237 """Initialize EmbeddingComputeVarExport""" 1238 self.add_prim_attr("_process_node_engine_id", "PS") 1239 1240 1241class InitEmbeddingHashmap(Primitive): 1242 """ 1243 InitEmbeddingHashmap 1244 """ 1245 @prim_attr_register 1246 def __init__(self, value_total_len, embedding_dim, _table_id, 1247 bucket_size=0, dtype=mstype.float32, initializer_mode="", 1248 constant_valu=0., min=-2., max=2., mu=0., sigma=1., seed=0, 1249 seed2=0, filter_mode="no_filter", optimizer_mode="", 1250 optimizer_params=()): 1251 self.add_prim_attr("_process_node_engine_id", "PS") 1252 1253 1254def init_embedding_hashmap(table_id, value_total_len, embedding_dim, _table_id, 1255 bucket_size=0, dtype=mstype.float32, initializer_mode='', 1256 constant_value=0.0, min=-2.0, max=2.0, mu=0.0, sigma=1.0, 1257 seed=0, seed2=0, filter_mode='no_filter', 1258 optimizer_mode='', optimizer_params=()): 1259 """ 1260 init_embedding_hashmap 1261 """ 1262 op = _get_cache_prim(InitEmbeddingHashmap)(value_total_len, embedding_dim, _table_id, 1263 bucket_size, dtype, initializer_mode, 1264 constant_value, min, max, mu, sigma, seed, 1265 seed2, filter_mode, optimizer_mode, optimizer_params) 1266 return op(table_id) 1267 1268 1269class InitPartitionMap(Primitive): 1270 """ 1271 InitPartitionMap 1272 """ 1273 @prim_attr_register 1274 def __init__(self, _embedding_dim, _max_key_num, 1275 _ps_num=1, partition_num=65537): 1276 self.add_prim_attr("_process_node_engine_id", "PS") 1277 1278 1279def init_partition_map(ps_num, ps_ids, _embedding_dim, _max_key_num, 1280 _ps_num=1, partition_num=65537): 1281 """ 1282 init_partition_map 1283 """ 1284 op = _get_cache_prim(InitPartitionMap)(_embedding_dim, _max_key_num, _ps_num, partition_num) 1285 return op(ps_num, ps_ids) 1286 1287 1288class EmbeddingApplyAdam(Primitive): 1289 """ 1290 EmbeddingApplyAdam 1291 """ 1292 @prim_attr_register 1293 def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,), 1294 padding_key=(0,), padding_key_mask=(1,), 1295 completion_key=(0,), completion_key_mask=(1,)): 1296 self.add_prim_attr("_process_node_engine_id", "PS") 1297 1298 1299class EmbeddingApplyAdamW(Primitive): 1300 """ 1301 EmbeddingApplyAdam 1302 """ 1303 @prim_attr_register 1304 def __init__(self, embedding_dim, _max_key_num, amsgrad=(0,), 1305 maximize=(0,), mask_zero=(0,), padding_key=(0,), 1306 padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,)): 1307 self.add_prim_attr("_process_node_engine_id", "PS") 1308 1309 1310class EmbeddingApplyAdaGrad(Primitive): 1311 """ 1312 EmbeddingApplyAdaGrad 1313 """ 1314 @prim_attr_register 1315 def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,), 1316 padding_key=(0,), padding_key_mask=(1,), 1317 completion_key=(0,), completion_key_mask=(1,)): 1318 self.add_prim_attr("_process_node_engine_id", "PS") 1319 1320 1321class EmbeddingApplyFtrl(Primitive): 1322 """ 1323 EmbeddingApplyFtrl 1324 """ 1325 @prim_attr_register 1326 def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,), 1327 padding_key=(0,), padding_key_mask=(1,), 1328 completion_key=(0,), completion_key_mask=(1,)): 1329 self.add_prim_attr("_process_node_engine_id", "PS") 1330 1331 1332class EmbeddingTableFind(Primitive): 1333 """ 1334 EmbeddingTableFind 1335 """ 1336 @prim_attr_register 1337 def __init__(self, embedding_dim, _embedding_dim, _max_key_num, 1338 _table_id, default_value=(-1.), _use_counter_filter=0): 1339 self.add_prim_attr("_process_node_engine_id", "PS") 1340 self.add_prim_attr("_execute_times", 2) 1341 1342 1343def embedding_table_find(table_id, keys, embedding_dim, _max_key_num, 1344 _table_id, default_value=(-1.0,), _use_counter_filter=0): 1345 r""" 1346 embedding_table_find 1347 """ 1348 _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id] 1349 op = _get_cache_prim(EmbeddingTableFind)(to_sequence(embedding_dim), _embedding_dim, 1350 _max_key_num, _table_id, 1351 to_sequence(default_value), 1352 _use_counter_filter) 1353 return op(table_id, keys) 1354 1355 1356class EmbeddingTableFindAndInit(Primitive): 1357 """ 1358 EmbeddingTableFindAndInit 1359 """ 1360 @prim_attr_register 1361 def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id, 1362 _max_key_num, initializer_mode=("random_uniform",), 1363 constant_value=(0.,), min=(-2.,), max=(2.,), mu=(0.,), 1364 sigma=(1.,), seed=(0,), seed2=(0,), 1365 filter_mode=("no_filter",), filter_freq=(0,), 1366 default_key_or_value=(0,), default_key=(0,), 1367 default_value=(0.,), completion_key=(0,), 1368 completion_key_mask=(1,), optimizer_mode=(), 1369 optimizer_params=(), _use_counter_filter=0, 1370 backward_mode="adam", 1371 backward_int_params=((0,), (0,), (0,), (1,)), 1372 backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)): 1373 self.add_prim_attr("_process_node_engine_id", "PS") 1374 self.add_prim_attr("_execute_times", 2) 1375 1376 1377def embedding_table_find_and_init(table_id, keys, max_grad_norm, parameter, embedding_dim, 1378 value_total_len, _table_id, _max_key_num, 1379 initializer_mode=('random_uniform',), constant_value=(0.,), 1380 min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,), 1381 seed2=(0,), filter_mode=("no_filter",), 1382 filter_freq=(0,), default_key_or_value=(0,), 1383 default_key=(0,), default_value=(0.,), 1384 completion_key=(0,), completion_key_mask=(1,), 1385 optimizer_mode=(), optimizer_params=(), _use_counter_filter=0, 1386 backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)), 1387 backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)): 1388 """ 1389 embedding_table_find_and_init 1390 1391 backward_int_params (Union[tuple[tuple[int]], list[list[int]]]): 1392 - when the backward_mode is 'adam', 'ftrl' or 'adagrad', 1393 it means [[global_step], mask_zero, padding_key, padding_key_mask] 1394 - when the backward_mode is 'adamw', it means: 1395 [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask] 1396 backward_float_params (Union[tuple[float], list[float]]): 1397 - when the backward_mode is 'adam', it means: 1398 [beta1_power, beta2_power, lr, beta1, beta2, epsilon] 1399 - when the backward_mode is 'ftrl', it means: 1400 [lr, lr_power, lambda1, lambda2] 1401 - when the backward_mode is 'adamw', it means: 1402 [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon] 1403 - when the backward_mode is 'adagrad', it means [lr,] 1404 """ 1405 _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id] 1406 op = _get_cache_prim(EmbeddingTableFindAndInit)(to_sequence(embedding_dim), to_sequence(value_total_len), 1407 _embedding_dim, _table_id, _max_key_num, 1408 to_sequence(initializer_mode), 1409 to_sequence(constant_value), to_sequence(min), 1410 to_sequence(max), to_sequence(mu), 1411 to_sequence(sigma), to_sequence(seed), 1412 to_sequence(seed2), to_sequence(filter_mode), 1413 to_sequence(filter_freq), to_sequence(default_key_or_value), 1414 to_sequence(default_key), to_sequence(default_value), 1415 to_sequence(completion_key), to_sequence(completion_key_mask), 1416 to_sequence(optimizer_mode), to_sequence(optimizer_params), 1417 _use_counter_filter, 1418 backward_mode, backward_int_params, backward_float_params) 1419 return op(table_id, keys, max_grad_norm, parameter) 1420 1421 1422class FakeRemoteLookupUniqued(Primitive): 1423 1424 """ 1425 FakeRemoteLookupUniqued 1426 """ 1427 @prim_attr_register 1428 def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id, 1429 _max_key_num, initializer_mode=('random_uniform',), constant_value=(0.,), 1430 min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,), seed2=(0,), 1431 filter_mode=("no_filter",), filter_freq=(0,), 1432 default_key_or_value=(0,), default_key=(0,), default_value=(0.,), 1433 completion_key=(0,), completion_key_mask=(1,), 1434 optimizer_mode=(), optimizer_params=(), _use_counter_filter=0, 1435 backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)), 1436 backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)): 1437 self.add_prim_attr("_process_node_engine_id", "PS") 1438 self.add_prim_attr("_execute_times", 2) 1439 1440 1441def fake_remote_lookup_uniqued(table_id, keys, actual_keys_num, unique_indices, 1442 key_count, max_grad_norm, parameter, 1443 embedding_dim, value_total_len, _table_id, _max_key_num, 1444 initializer_mode=('random_uniform',), constant_value=(0.,), 1445 min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,), 1446 seed2=(0,), filter_mode=("no_filter",), 1447 filter_freq=(0,), default_key_or_value=(0,), 1448 default_key=(0,), default_value=(0.,), 1449 completion_key=(0,), completion_key_mask=(1,), 1450 optimizer_mode=(), optimizer_params=(), _use_counter_filter=0, 1451 backward_mode='adam', backward_int_params=((0,), (0,), (0,), (1,)), 1452 backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)): 1453 """ 1454 fake_remote_lookup_uniqued 1455 1456 backward_mode (str): determine the optimizer used by backpropagation, 1457 valid values are ["adam", "adamw", "adagrad", "ftrl"] 1458 backward_int_params (Union[tuple[tuple[int]], list[list[int]]]): 1459 - when the backward_mode is 'adam', 'ftrl' or 'adagrad', 1460 it means [[global_step], mask_zero, padding_key, padding_key_mask] 1461 - when the backward_mode is 'adamw', it means: 1462 [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask] 1463 backward_float_params (Union[tuple[float], list[float]]): 1464 - when the backward_mode is 'adam', it means: 1465 [beta1_power, beta2_power, lr, beta1, beta2, epsilon] 1466 - when the backward_mode is 'ftrl', it means: 1467 [lr, lr_power, lambda1, lambda2] 1468 - when the backward_mode is 'adamw', it means: 1469 [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon] 1470 - when the backward_mode is 'adagrad', it means [lr,] 1471 """ 1472 _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id] 1473 op = _get_cache_prim(FakeRemoteLookupUniqued)(to_sequence(embedding_dim), to_sequence(value_total_len), 1474 _embedding_dim, _table_id, _max_key_num, 1475 to_sequence(initializer_mode), to_sequence(constant_value), 1476 to_sequence(min), to_sequence(max), to_sequence(mu), 1477 to_sequence(sigma), to_sequence(seed), to_sequence(seed2), 1478 to_sequence(filter_mode), to_sequence(filter_freq), 1479 to_sequence(default_key_or_value), to_sequence(default_key), 1480 to_sequence(default_value), to_sequence(completion_key), 1481 to_sequence(completion_key_mask), to_sequence(optimizer_mode), 1482 to_sequence(optimizer_params), _use_counter_filter, 1483 backward_mode, backward_int_params, backward_float_params) 1484 return op(table_id, keys, actual_keys_num, unique_indices, key_count, max_grad_norm, parameter) 1485 1486 1487# Following is Python Infer Value. 1488# A valid infer value function should be: 1489# 1490# 1. named as infer_value_for_OpName 1491# 2. All inputs should pass without default value. 1492# 3. If not const input is given, return None. (for now) 1493 1494 1495def infer_value_for_Tile(input, dims): 1496 """Infer value for Tile op.""" 1497 if input is None or dims is None or None in dims: 1498 return None 1499 return Tensor(np.tile(input.asnumpy(), dims)) 1500 1501 1502def infer_value_for_Concat(tensors, axis): 1503 """Infer value for Concat op.""" 1504 if not tensors or None in tensors or axis is None: 1505 return None 1506 1507 tensor_to_concat = [x.asnumpy() if x.dtype != mstype.bfloat16 else x.float().asnumpy() for x in tensors] 1508 return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype) 1509 1510 1511def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode): 1512 """Infer value for ReduceSum op.""" 1513 value = None 1514 if input_x is not None and axis is not None: 1515 value = input_x.asnumpy() 1516 if isinstance(axis, int): 1517 pass 1518 elif axis: 1519 axis = tuple(set(axis)) 1520 elif axis in ((), []) and skip_mode: 1521 return input_x 1522 else: 1523 axis = tuple(range(len(value.shape))) 1524 value = np.sum(value, axis, keepdims=keep_dims) 1525 value = np.array(value) 1526 value = Tensor(value) 1527 return value 1528 1529 1530def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name): 1531 """Infer value for Common Reduce op.""" 1532 value = None 1533 if input_x is not None and axis is not None: 1534 prim_map = { 1535 'ReduceMax': np.max, 1536 'ReduceMin': np.min, 1537 'ReduceProd': np.prod, 1538 'ReduceMean': np.mean, 1539 'ReduceAll': np.all, 1540 'ReduceAny': np.any, 1541 } 1542 np_reduce_func = prim_map.get(prim_name, None) 1543 1544 if np_reduce_func is not None: 1545 value = input_x.asnumpy() 1546 if isinstance(axis, int): 1547 pass 1548 elif axis: 1549 axis = tuple(set(axis)) 1550 else: 1551 axis = tuple(range(len(value.shape))) 1552 value = np_reduce_func(value, axis, keepdims=keep_dims) 1553 value = np.array(value) 1554 value = Tensor(value) 1555 return value 1556 1557 1558def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name): 1559 """Infer value for Common ReduceExtand op.""" 1560 value = None 1561 if input_x is not None: 1562 prim_map = { 1563 'MeanExt': np.mean, 1564 'SumExt': np.sum, 1565 'ProdExt': np.prod, 1566 } 1567 np_reduce_extand_func = prim_map.get(prim_name, None) 1568 1569 if np_reduce_extand_func is not None: 1570 value = input_x.asnumpy() 1571 if isinstance(axis, int): 1572 pass 1573 elif axis: 1574 axis = tuple(set(axis)) 1575 else: 1576 axis = tuple(range(len(value.shape))) 1577 if dtype is not None: 1578 np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype)) 1579 value = np_reduce_extand_func(value, axis, dtype=np_dtype, keepdims=keep_dims) 1580 else: 1581 value = np_reduce_extand_func(value, axis, keepdims=keep_dims) 1582 1583 value = np.array(value) 1584 value = Tensor(value) 1585 return value 1586 1587 1588def _infer_value_for_max_min(input_x, prim_name): 1589 """Infer value for Max/Min op.""" 1590 value = None 1591 if input_x is not None: 1592 prim_map = { 1593 'Max': np.max, 1594 'Min': np.min, 1595 } 1596 np_reduce_func = prim_map.get(prim_name, None) 1597 1598 if np_reduce_func is not None: 1599 value = input_x.asnumpy() 1600 value = np_reduce_func(value, None, keepdims=False) 1601 value = np.array(value) 1602 value = Tensor(value) 1603 return value 1604 1605 1606def infer_value_for_Cast(x, dst_type_enum=None): 1607 """Infer value for Cast op.""" 1608 if x is None or dst_type_enum is None: 1609 return None 1610 dst_type = typing.type_id_to_type(dst_type_enum) 1611 src_type = mstype.get_py_obj_dtype(x) 1612 validator.check_subclass("input_x", src_type, [mstype.tensor_type, mstype.number], "Cast") 1613 validator.check_subclass("type", dst_type, mstype.number, "Cast") 1614 1615 if isinstance(src_type, type(mstype.tensor_type)): 1616 src_type = src_type.element_type() 1617 if isinstance(dst_type, type(mstype.tensor_type)): 1618 dst_type = dst_type.element_type() 1619 1620 value = None 1621 np_dst_type = mstype.dtype_to_nptype(dst_type) 1622 if isinstance(x, (int, float)): 1623 value = Tensor(np.array(x).astype(np_dst_type), dtype=dst_type) 1624 else: 1625 value = Tensor_(x.asnumpy().astype(np_dst_type), dtype=dst_type) 1626 return value 1627 1628 1629def infer_value_for_ReduceMax(input_x, axis, keep_dims): 1630 """Infer value for ReduceMax op.""" 1631 return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax') 1632 1633 1634def infer_value_for_Max(input_x): 1635 """Infer value for Max op.""" 1636 return _infer_value_for_max_min(input_x, 'Max') 1637 1638 1639def infer_value_for_ReduceMin(input_x, axis, keep_dims): 1640 """Infer value for ReduceMin op.""" 1641 return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMin') 1642 1643 1644def infer_value_for_Min(input_x): 1645 """Infer value for Max op.""" 1646 return _infer_value_for_max_min(input_x, 'Min') 1647 1648 1649def infer_value_for_ReduceProd(input_x, axis, keep_dims): 1650 """Infer value for ReduceProd op.""" 1651 return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceProd') 1652 1653 1654def infer_value_for_ReduceMean(input_x, axis, keep_dims): 1655 """Infer value for ReduceMean op.""" 1656 return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMean') 1657 1658 1659def infer_value_for_ReduceAll(input_x, axis, keep_dims): 1660 """Infer value for ReduceAll op.""" 1661 return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAll') 1662 1663 1664def infer_value_for_ReduceAny(input_x, axis, keep_dims): 1665 """Infer value for ReduceAny op.""" 1666 return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAny') 1667 1668 1669def infer_value_for_MeanExt(input_x, axis, keep_dims, dtype): 1670 """Infer value for MeanExt op.""" 1671 return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'MeanExt') 1672 1673 1674def infer_value_for_SumExt(input_x, axis, keep_dims, dtype): 1675 """Infer value for SumExt op.""" 1676 return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'SumExt') 1677 1678 1679def infer_value_for_ProdExt(input_x, axis, keep_dims, dtype): 1680 """Infer value for ProdExt op.""" 1681 return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'ProdExt') 1682 1683 1684def infer_value_for_Diag(input_x): 1685 """Infer value for Diag op.""" 1686 if input_x is None: 1687 return None 1688 # do constant-folding only when x rank is 1 1689 if len(input_x.shape) != 1: 1690 return None 1691 ret = np.diag(input_x.asnumpy()) 1692 return Tensor(ret) 1693 1694 1695def infer_value_for_BroadcastTo(x, shape): 1696 """Infer value for BroadcastTo op.""" 1697 def none_in_tuple_or_list(x): 1698 return isinstance(x, (tuple, list)) and None in x 1699 if shape is None or none_in_tuple_or_list(shape) or x is None: 1700 return None 1701 1702 if isinstance(shape, (Tensor, Tensor_)): 1703 validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype), 1704 [mstype.int32, mstype.int64], "BroadcastTo") 1705 shape = shape.asnumpy().tolist() 1706 else: 1707 validator.check_value_type("shape", shape, [tuple], "BroadcastTo") 1708 shape = list(shape) 1709 1710 np_data = np.broadcast_to(x.asnumpy(), shape) 1711 if 0 in shape: 1712 init_func = Zero() 1713 init_func.__enable_zero_dim__ = True 1714 out = Tensor(shape=shape, dtype=x.dtype, init=init_func) 1715 return out 1716 return Tensor(np_data) 1717 1718 1719def infer_value_for_Reshape(x, shape): 1720 """Infer value for Reshape op.""" 1721 def none_in_tuple_or_list(x): 1722 return isinstance(x, (tuple, list)) and None in x 1723 # for shape is not constant 1724 if shape is None or none_in_tuple_or_list(shape) or x is None: 1725 return None 1726 1727 if isinstance(shape, (Tensor, Tensor_)): 1728 validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype), 1729 [mstype.int32, mstype.int64], "Reshape") 1730 shape = shape.asnumpy().tolist() 1731 else: 1732 validator.check_value_type("shape", shape, [tuple], "Reshape") 1733 shape = list(shape) 1734 1735 neg_index = -1 1736 dim_prod = 1 1737 for i, shp_i in enumerate(shape): 1738 validator.check_value_type("shape[%d]" % i, shp_i, [int], "Reshape") 1739 if shp_i == -1: 1740 if neg_index != -1: 1741 raise ValueError(f"For 'Reshape', there can be at most one '-1' in 'input_shape', " 1742 f"but got {shape}.") 1743 neg_index = i 1744 else: 1745 dim_prod *= shp_i 1746 out = None 1747 if not is_shape_unknown(x.shape): 1748 x_shp = x.shape 1749 if dim_prod < 0: 1750 raise ValueError(f"For 'Reshape', the shape of 'input_x' is {x_shp}, " 1751 f"the value of 'input_shape' is {shape}. " 1752 f"The product of 'input_shape' should > 0, but got {dim_prod}.") 1753 arr_prod = np.prod(x_shp) 1754 if neg_index != -1: 1755 shape[neg_index] = int(arr_prod // dim_prod) 1756 dim_prod *= shape[neg_index] 1757 if dim_prod != arr_prod: 1758 raise ValueError(f"For 'Reshape', the product of the 'input_x' shape " 1759 f"should be equal to product of 'input_shape', but got product of the" 1760 f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.") 1761 if 0 in shape: 1762 init_func = Zero() 1763 init_func.__enable_zero_dim__ = True 1764 out = Tensor(shape=shape, dtype=x.dtype, init=init_func) 1765 else: 1766 out = Tensor(x.asnumpy().reshape(shape)) 1767 return out 1768 1769 1770class Ones(Primitive): 1771 r""" 1772 Creates a tensor filled with value ones. 1773 1774 Refer to :func:`mindspore.ops.ones` for more details. 1775 1776 .. warning:: 1777 For argument `size`, Tensor type input will be deprecated in the future version. 1778 1779 Inputs: 1780 - **shape** (Union[tuple[int], List[int], int, Tensor]) - The specified shape of output tensor. 1781 - **type** (:class:`mindspore.dtype`) - The specified type of output tensor. 1782 1783 Outputs: 1784 Tensor, whose dtype and size are defined by input. 1785 1786 Raises: 1787 TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int. 1788 1789 Supported Platforms: 1790 ``Ascend`` ``GPU`` ``CPU`` 1791 1792 Examples: 1793 >>> import mindspore 1794 >>> from mindspore import ops 1795 >>> ones = ops.Ones() 1796 >>> output = ones((2, 2), mindspore.float32) 1797 >>> print(output) 1798 [[1. 1.] 1799 [1. 1.]] 1800 >>> output = ones((3, 3), mindspore.float32) 1801 >>> print(output) 1802 [[1. 1. 1.] 1803 [1. 1. 1.] 1804 [1. 1. 1.]] 1805 """ 1806 1807 __mindspore_signature__ = ( 1808 sig.make_sig('size'), 1809 sig.make_sig('type', default=None), 1810 ) 1811 1812 @prim_arg_register 1813 def __init__(self): 1814 pass 1815 1816 def __call__(self, size, type=None): 1817 return _convert_stub(pyboost_ones(self, [size, type if type is None \ 1818 else handler.dtype_to_type_id('Ones', 'type', type)])) 1819 1820 1821class Zeros(Primitive): 1822 r""" 1823 Zeros will be deprecated in the future. Please use class `mindspore.ops.zeros` instead. 1824 1825 Creates a tensor filled with value zeros. 1826 1827 Creates a tensor with shape described by the first argument and 1828 fills it with value zeros in type of the second argument. 1829 1830 .. warning:: 1831 For argument `size`, Tensor type input will be deprecated in the future version. 1832 1833 Inputs: 1834 - **shape** (tuple[int], List[int], int, Tensor) - The specified shape of output tensor. 1835 - **type** (mindspore.dtype) - The specified type of output tensor. 1836 1837 Outputs: 1838 Tensor, whose dtype and size are defined by input. 1839 1840 Raises: 1841 TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int. 1842 1843 Supported Platforms: 1844 ``Ascend`` ``GPU`` ``CPU`` 1845 1846 Examples: 1847 >>> import mindspore 1848 >>> from mindspore import ops 1849 >>> zeros = ops.Zeros() 1850 >>> output = zeros((2, 2), mindspore.float32) 1851 >>> print(output) 1852 [[0. 0.] 1853 [0. 0.]] 1854 1855 """ 1856 1857 __mindspore_signature__ = ( 1858 sig.make_sig('size'), 1859 sig.make_sig('type', default=None), 1860 ) 1861 1862 @prim_arg_register 1863 def __init__(self): 1864 pass 1865 1866 def __call__(self, size, type=None): 1867 return _convert_stub(pyboost_zeros(self, [size, type if type is None else \ 1868 handler.dtype_to_type_id('Zeros', 'type', type)])) 1869 1870 1871def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None, 1872 attn_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, keep_prob=1.0, 1873 scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0, 1874 input_layout='BSH', sparse_mode=0): 1875 r""" 1876 The interface is not open to the public, just for internal use, 1877 1878 .. math:: 1879 \begin{array}{ll} \\ 1880 y = Dropout(Softmax(Mask(scale_value \mul (real_shift + query * key), attn_mask), -1), keep\_prob) \\ 1881 \mul value \\ 1882 \end{array} 1883 1884 B -- Batch size. Value range 1 to 2k. 1885 S1 -- Sequence length of query. Value range 1 to 512k. 1886 S2 -- Sequence length of key and value. Value range 1 to 512k. 1887 N1 -- Num heads of query. Value range 1 to 256. 1888 N2 -- Num heads of key and value, and N2 must be a factor of N1. 1889 D -- Head size. The value ranges is a multiple of 16, with the max value of 512. 1890 H1 -- Hidden size of query, which equals to N1 * D. 1891 H2 -- Hidden size of key and value, which equals to N2 * D. 1892 1893 .. warning:: 1894 This is an experimental API that is subject to change or deletion. Only support on Atlas training series. 1895 1896 Args: 1897 query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`, 1898 `(B, N1, S1, D)`, `(S1, B, H1)`, `(B, S1, N1, D)` or `(T1, N1, D)`. 1899 key (Tensor[float16, bfloat16]): The key tensor. Input tensor of shape :math:`(B, S2, H2)`, 1900 `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. 1901 value (Tensor[float16, bfloat16]): The value tensor. Input tensor of shape :math:`(B, S2, H2)`, 1902 `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. The key and value have the same shape. 1903 head_num (int): The head num of query, equal to N1. 1904 real_shift (Union[Tensor[float16, bfloat16], None]): Also known as pse. The position embedding code. If S 1905 is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of 1906 the lower triangle for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`, 1907 `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`. 1908 1909 - ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle. 1910 In this scenario, real_shift is `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`. 1911 - Non-ALiBi scenario: real_shift is `(B, N1, S1, S2)`, `(1, N1, S1, S2)`. 1912 1913 The shape of `real_shift` should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)` when input_layout is 1914 `TND`. 1915 drop_mask (Union[Tensor[uint8], None]): The dropout mask tensor. Input tensor of shape :math: 1916 `(B, N1, S1, S2 // 8) or None`. S2 is a multiple of 8 when not None. 1917 padding_mask (None): Reserved parameter. Not implemented yet. 1918 attn_mask (Union[Tensor[uint8], Tensor[bool], None]): The attention mask tensor. For each element, 0 1919 indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`, 1920 `(B, 1, S1, S2)`, `(S1, S2)` or `(2048, 2048)`. In compression scenario, sparse_mode is 2, 3, or 4, 1921 attn_mask must be `(2048, 2048)`. When sparse_mode is 5, attn_mask must be `(B, N1, S1, S2)`, 1922 `(B, 1, S1, S2)`. When sparse_mode is 0 and 1, attn_mask should be `(B, N1, S1, S2)`, `(B, 1, S1, S2)`, 1923 `(S1, S2)`. 1924 prefix (Union[List[int64], Tuple[int64] None]): N value of each Batch in the prefix sparse calculation 1925 scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when sparse_mode is 5. 1926 If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2. 1927 actual_seq_qlen (Union[List[int64], Tuple[int64], None]): Size of query corresponding to each batch, array 1928 with increasing values and the last value equal to T1. 1929 actual_seq_kvlen (Union[List[int64], Tuple[int64], None]): Size of key and value corresponding to each batch, 1930 array with increasing values and the last value equal to T2. 1931 keep_prob (float): The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0. when keep_prob 1932 is 1.0, drop_mask should be none. 1933 scale_value (float): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0. 1934 pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward. 1935 When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647. 1936 next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward. 1937 When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647. 1938 The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the 1939 valid area on the attn_mask matrix. It must ensure that the band is not empty. 1940 The following values are not allowed: 1941 1942 - pre_tokens < 0 and next_tokens < 0. 1943 - (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2). 1944 - (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1). 1945 1946 inner_precise (int): The parameter is reserved and not implemented yet. Default: 0. 1947 input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH", "BNSD", 1948 "SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH". 1949 When input_layout is "TND", the following restrictions must be met. 1950 There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each 1951 value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6], 1952 list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is 1953 sum(list_seq_k) = 22. 1954 max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k). 1955 qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication. 1956 1957 - The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024. 1958 - When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none. 1959 Otherwise, they are none. 1960 - The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must 1961 be non-decreasing. 1962 - If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and 1963 list_seq_k is greater than 1024. Real_shift should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)`, and 1964 S2 is equal to max_seqlen_k. 1965 - Attn mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask 1966 should be `(2048, 2048)`. 1967 - The shape of drop_mask is (qk_pointer * N1 // 8,). 1968 - Prefix is none. 1969 - Next_tokens is 0, and pre_tokens is not less than max_seqlen_q. 1970 - When sparse_mode is 3, S1 of each batch should be less than or equal to S2. 1971 - 0 should not exist in list_seq_k. 1972 1973 sparse_mode (int): Indicates sparse mode. Default 0. 1974 1975 - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed, 1976 and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full 1977 attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and 1978 nextTokens needs to be calculated. 1979 - 1: Represents allMask, that is, passing in the complete attn_mask matrix. 1980 - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left 1981 vertex, and the optimized attn_mask matrix (2048*2048) is required. 1982 - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower 1983 right vertex, and the optimized attn_mask matrix (2048*2048) is required. 1984 - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the 1985 optimized attn_mask matrix (2048*2048) is required. 1986 - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and 1987 width N is added to the left side. The value of N is obtained by the new input prefix, and the N value 1988 of each Batch axis is different, not implemented yet. 1989 - 6: Represents the global scenario, not implemented yet. 1990 - 7: Represents the dilated scenario, not implemented yet. 1991 - 8: Represents the block_local scenario, not implemented yet. 1992 1993 Returns: 1994 attention_out (Tensor[float16, bfloat16]), The output of attention, its shape, and data type are the same 1995 as the query. 1996 1997 Supported Platforms: 1998 ``Ascend`` 1999 2000 Examples: 2001 >>> import mindspore 2002 >>> import mindspore.common.dtype as mstype 2003 >>> import numpy as np 2004 >>> from mindspore import ops, Tensor 2005 >>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16) 2006 >>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16) 2007 >>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16) 2008 >>> head_num = 4 2009 >>> output = ops.flash_attention_score(query, key, value, head_num) 2010 >>> print(output.shape) 2011 (2, 4, 64) 2012 """ 2013 rank_op = _get_cache_prim(FlashAttentionScore)(head_num, keep_prob, scalar_value, pre_tokens, next_tokens, 2014 inner_precise, input_layout, sparse_mode) 2015 return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen, 2016 actual_seq_kvlen)[3] 2017