1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Inner operators.""" 17 18import numpy as np 19from mindspore.common import Tensor 20from ..._checkparam import Rel 21from ..._checkparam import Validator as validator 22from ... import context 23from ...common import dtype as mstype 24from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive 25from ..operations.math_ops import _infer_shape_reduce 26from ...communication.management import GlobalComm 27from .. import signature as sig 28 29 30class ExtractImagePatches(PrimitiveWithInfer): 31 """ 32 Extracts patches from images. 33 The input tensor must be a 4-D tensor and the data format is NHWC. 34 35 Args: 36 ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers, 37 and the format is [1, 1, ksize_row, ksize_col]. 38 strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches, 39 must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col]. 40 rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension 41 pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col]. 42 padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", 43 not case sensitive. Default: "valid". 44 45 - same: Means that the patch can take the part beyond the original image, and this part is filled with 0. 46 47 - valid: Means that the taken patch area must be completely covered in the original image. 48 49 Inputs: 50 - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and 51 data type is number. 52 53 Outputs: 54 Tensor, a 4-D tensor whose data type is same as 'input_x', 55 and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is the same as the in_batch. 56 """ 57 58 @prim_attr_register 59 def __init__(self, ksizes, strides, rates, padding="valid"): 60 """init""" 61 62 def _check_tuple_or_list(arg_name, arg_val, prim_name): 63 validator.check_value_type(f"{arg_name}s", arg_val, [tuple, list], self.name) 64 if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1: 65 raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " 66 f"{arg_name}_col, 1], but got {arg_val}.") 67 if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1: 68 raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be " 69 f"an positive integer number, but got {arg_name}_row is {arg_val[2]}, " 70 f"{arg_name}_col is {arg_val[3]}") 71 72 _check_tuple_or_list("ksize", ksizes, self.name) 73 _check_tuple_or_list("stride", strides, self.name) 74 _check_tuple_or_list("rate", rates, self.name) 75 self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) 76 self.add_prim_attr("padding", self.padding) 77 self.is_ge = context.get_context("enable_ge") 78 79 def infer_shape(self, input_x): 80 """infer shape""" 81 if len(input_x) != 4: 82 raise ValueError("The `input_x` should be a 4-D tensor, " 83 f"but got a {len(input_x)}-D tensor whose shape is {input_x}") 84 85 in_batch, in_depth, in_row, in_col = input_x 86 _, _, ksize_row, ksize_col = self.ksizes 87 _, _, stride_row, stride_col = self.strides 88 _, _, rate_row, rate_col = self.rates 89 90 out_batch = in_batch 91 out_depth = ksize_row * ksize_col * in_depth 92 93 if self.padding == "VALID": 94 out_row = \ 95 (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1 96 out_col = \ 97 (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1 98 else: 99 out_row = (in_row - 1) // stride_row + 1 100 out_col = (in_col - 1) // stride_col + 1 101 102 out_shape = [out_batch, out_depth, out_row, out_col] 103 # avoiding empty outputs 104 validator.check("out_batch", out_batch, "", 0, Rel.GT, self.name) 105 validator.check("out_depth", out_depth, "", 0, Rel.GT, self.name) 106 validator.check("out_row", out_row, "", 0, Rel.GT, self.name) 107 validator.check("out_col", out_col, "", 0, Rel.GT, self.name) 108 return out_shape 109 110 def infer_dtype(self, input_x): 111 """infer dtype""" 112 validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name) 113 return input_x 114 115 116class Range(PrimitiveWithInfer): 117 r""" 118 Creates a sequence of numbers. 119 Set `input_x` as :math:`x_i` for each element, `output` as follows: 120 121 .. math:: 122 \text{output}(x_i) = x_i * \text{delta} + \text{start} 123 124 Args: 125 start (float): If `limit` is `None`, the value acts as limit in the range and first entry 126 defaults to `0`. Otherwise, it acts as first entry in the range. 127 limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` 128 while set the first entry of the range to `0`. It can not be equal to `start`. Default: None. 129 delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. 130 131 Inputs: 132 - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. 133 134 Outputs: 135 Tensor, has the same shape and dtype as `input_x`. 136 137 Examples: 138 >>> range_op = ops.Range(1.0, 8.0, 2.0) 139 >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) 140 >>> output = range_op(x) 141 >>> print(output) 142 [3, 5, 7, 5] 143 """ 144 145 @prim_attr_register 146 def __init__(self, start, limit=None, delta=1.0): 147 self.init_prim_io_names(inputs=['x'], outputs=['y']) 148 self.delta = validator.check_value_type("delta", delta, [float], self.name) 149 validator.check_value_type("start", start, [float], self.name) 150 if limit is None: 151 self.start = 0.0 152 self.limit = start 153 self.add_prim_attr("start", self.start) 154 self.add_prim_attr("limit", self.limit) 155 else: 156 validator.check_value_type("limit", limit, [float], self.name) 157 validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) 158 if self.delta == 0.0: 159 raise ValueError("The input of `delta` can not be equal to zero.") 160 if self.delta > 0.0 and self.start > self.limit: 161 raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " 162 f"but got start:{self.start}, limit:{self.limit}") 163 if self.delta < 0.0 and self.start < self.limit: 164 raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " 165 f"but got start:{self.start}, limit:{self.limit}") 166 167 def infer_shape(self, x_shape): 168 return x_shape 169 170 def infer_dtype(self, x_dtype): 171 validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) 172 return x_dtype 173 174 def infer_value(self, x_value): 175 return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype) 176 177 178class Quant(PrimitiveWithInfer): 179 r""" 180 Returns the quantized value of input_x. 181 182 If `sqrt_mode` is False: 183 184 .. math:: 185 y = round(scale * x + offset) 186 187 If `sqrt_mode` is True: 188 189 .. math:: 190 y = round(scale * x * scale + offset) 191 192 Note: 193 This operation only support Ascend 310 inference environment. 194 195 Args: 196 scale (float) : Specifies the scaling ratio. 197 offset (float): Specifies the offset. 198 sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False. 199 round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"]. 200 Default: "Round". 201 202 Inputs: 203 - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32. 204 205 Outputs: 206 - Tensor: The quantized output tensor of type mindspore.int8. 207 208 Examples: 209 >>> input_x = Tensor([100.0, 150.0], mstype.float32) 210 >>> quant = ops.Quant(80.0, 0.0, False, "Round") 211 >>> y = quant(input_x) 212 """ 213 214 @prim_attr_register 215 def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"): 216 self.scale = validator.check_value_type("scale", scale, [float], self.name) 217 self.offset = validator.check_value_type("offset", offset, [float], self.name) 218 self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) 219 self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"], 220 "round_mode", self.name) 221 222 def infer_shape(self, x_shape): 223 return x_shape 224 225 def infer_dtype(self, x_type): 226 validator.check_subclass("input_x", x_type, mstype.tensor, self.name) 227 validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name) 228 return mstype.int8 229 230 231class Dequant(PrimitiveWithInfer): 232 r""" 233 Returns the dequantized value of input_x. 234 This operation will do ReLU to the dequantized value if `relu_flag` is True. 235 236 If `sqrt_mode` is False: 237 238 .. math:: 239 y = x * deq\_scale 240 241 If `sqrt_mode` is True: 242 243 .. math:: 244 y = x * deq\_scale * deq\_scale 245 246 Note: 247 This operation only support Ascend 310 inference environment. 248 249 Args: 250 sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False. 251 relu_flag (bool): Specifies whether to perform ReLU. Default: False. 252 253 Inputs: 254 - **input_x** (Tensor) : Input tensor. Must be mindspore.int32. 255 - **deq_scale** (Tensor) : Specifies the scaling ratio. 256 Data type must be mindspore.float16 or mindspore.uint64 257 258 Outputs: 259 - Tensor: The quantized output tensor of type mindspore.float16. 260 261 Examples: 262 >>> input_x = Tensor([100.0, 150.0], mstype.float32) 263 >>> dequant = ops.Dequant(False, False) 264 >>> y = dequant(input_x) 265 """ 266 267 @prim_attr_register 268 def __init__(self, sqrt_mode=False, relu_flag=False): 269 self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) 270 self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) 271 self.add_prim_attr("dtype", mstype.float16) 272 273 def infer_shape(self, x_shape, deq_scale_shape): 274 return x_shape 275 276 def infer_dtype(self, x_type, deq_scale_type): 277 validator.check_subclass("x", x_type, mstype.tensor, self.name) 278 validator.check_type_name("x", x_type, [mstype.int32], self.name) 279 validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name) 280 return mstype.float16 281 282 283class MatrixDiag(PrimitiveWithInfer): 284 """ 285 Returns a batched diagonal tensor with a given batched diagonal values. 286 287 Inputs: 288 - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data 289 types: float32, float16, int32, int8, and uint8. 290 - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and 291 it's last dimension must equal to the second to last dimension. 292 293 Outputs: 294 Tensor, has the same type and shape as input `assist`. 295 296 Examples: 297 >>> x = Tensor(np.array([1, -1]), mstype.float32) 298 >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) 299 >>> matrix_diag = ops.MatrixDiag() 300 >>> result = matrix_diag(x, assist) 301 >>> print(result) 302 [[[-12. 11.] 303 [-10. 9.]] 304 [[ -8. 7.] 305 [ -6. 5.]] 306 [[ -4. 3.] 307 [ -2. 1.]]] 308 """ 309 310 @prim_attr_register 311 def __init__(self): 312 """Initialize MatrixDiag""" 313 314 def infer_dtype(self, x_dtype, assist_dtype): 315 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 316 args = {"x": x_dtype, "assist": assist_dtype} 317 validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) 318 return x_dtype 319 320 def infer_shape(self, x_shape, assist_shape): 321 validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) 322 validator.check('rank of x', len(x_shape) + 1, 323 'rank of assist', len(assist_shape), Rel.LE, self.name) 324 validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', 325 assist_shape[-1], Rel.EQ, self.name) 326 327 r_end_dim = -len(x_shape) 328 r_idx = -1 329 while r_idx >= r_end_dim: 330 if x_shape[r_idx] != 1: 331 validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % 332 assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name) 333 r_idx = r_idx - 1 334 335 return assist_shape 336 337 338class MatrixDiagPart(PrimitiveWithInfer): 339 r""" 340 Returns the batched diagonal part of a batched tensor. 341 342 Inputs: 343 - **x** (Tensor) - The batched tensor. It can be one of the following data types: 344 float32, float16, int32, int8, uint8. 345 - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. 346 347 Outputs: 348 Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])]. 349 350 Examples: 351 >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) 352 >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) 353 >>> matrix_diag_part = ops.MatrixDiagPart() 354 >>> result = matrix_diag_part(x, assist) 355 >>> print(result) 356 [[12., -9.], [8., -5.], [4., -1.]] 357 """ 358 359 @prim_attr_register 360 def __init__(self): 361 """Initialize MatrixDiagPart""" 362 363 def infer_dtype(self, x_dtype, assist_dtype): 364 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 365 args = {"x": x_dtype, "assist": assist_dtype} 366 validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) 367 return x_dtype 368 369 def infer_shape(self, x_shape, assist_shape): 370 validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) 371 validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) 372 373 if assist_shape[-2] < assist_shape[-1]: 374 out_shape = assist_shape[:-1] 375 else: 376 out_shape = assist_shape[:-2] + assist_shape[-1:] 377 return out_shape 378 379 380class Send(PrimitiveWithInfer): 381 """ 382 Send tensors from src_rank to the specified dest_rank. 383 384 Note: 385 Send and Recveive must be used in combination and have same sr_tag. 386 Send must be used between servers. 387 388 Args: 389 sr_tag (int): A required integer identifying the send/recv message tag. The message will 390 will be received by the Receive op with the same "sr_tag". 391 dest_rank (int): A required integer identifying the destination rank. 392 group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". 393 394 Inputs: 395 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 396 397 Examples: 398 >>> import mindspore.ops as ops 399 >>> import mindspore.nn as nn 400 >>> from mindspore.communication import init 401 >>> from mindspore import Tensor 402 >>> import numpy as np 403 >>> 404 >>> init() 405 >>> class Net(nn.Cell): 406 >>> def __init__(self): 407 >>> super(Net, self).__init__() 408 >>> self.depend = ops.Depend() 409 >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group") 410 >>> 411 >>> def construct(self, x): 412 >>> out = self.depend(x, self.send(x)) 413 >>> return out 414 >>> 415 >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) 416 >>> net = Net() 417 >>> output = net(input_) 418 """ 419 420 @prim_attr_register 421 def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP): 422 self.rank = dest_rank 423 self.sr_tag = sr_tag 424 self.group = group 425 426 def infer_shape(self, x_shape): 427 self.add_prim_attr("shape", x_shape) 428 return x_shape 429 430 def infer_dtype(self, x_dtype): 431 return x_dtype 432 433 434class Receive(PrimitiveWithInfer): 435 """ 436 receive tensors from src_rank. 437 438 Note: 439 Send and Receive must be used in combination and have same sr_tag. 440 Receive must be used between servers. 441 442 Args: 443 sr_tag (int): A required integer identifying the send/recv message tag. The message will 444 will be send by the Send op with the same "sr_tag". 445 src_rank (int): A required integer identifying the source rank. 446 shape (list[int]): A required list identifying the shape of the tensor to be received. 447 dtype (Type): A required Type identifying the type of the tensor to be received. The supported types: 448 int8, int16, int32, float16, float32. 449 group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". 450 451 Inputs: 452 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 453 454 Examples: 455 >>> import mindspore.ops as ops 456 >>> import mindspore.nn as nn 457 >>> from mindspore.communication import init 458 >>> from mindspore import Tensor 459 >>> import numpy as np 460 >>> 461 >>> init() 462 >>> class Net(nn.Cell): 463 >>> def __init__(self): 464 >>> super(Net, self).__init__() 465 >>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, 466 >>> group="hccl_world_group") 467 >>> 468 >>> def construct(self): 469 >>> out = self.recv() 470 >>> return out 471 >>> 472 >>> net = Net() 473 >>> output = net() 474 """ 475 476 @prim_attr_register 477 def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP, 478 group_back=GlobalComm.WORLD_COMM_GROUP): 479 self.rank = src_rank 480 self.tag = sr_tag 481 self.shape = shape 482 self.dtype = dtype 483 self.group = group 484 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 485 args = {"dtype": dtype} 486 validator.check_scalar_or_tensor_types_same(args, valid_type, self.name) 487 488 def infer_shape(self, x_shape=None): 489 return self.shape 490 491 def infer_dtype(self, x_dtype=None): 492 return self.dtype 493 494 495class MatrixSetDiag(PrimitiveWithInfer): 496 r""" 497 Modifies the batched diagonal part of a batched tensor. 498 499 Inputs: 500 - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types: 501 float32, float16, int32, int8, uint8. 502 - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1. 503 - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`. 504 505 Outputs: 506 Tensor, data type same as input `x`. The shape same as `x`. 507 508 Examples: 509 >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) 510 >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) 511 >>> matrix_set_diag = ops.MatrixSetDiag() 512 >>> result = matrix_set_diag(x, diagonal) 513 >>> print(result) 514 [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]] 515 516 """ 517 518 @prim_attr_register 519 def __init__(self): 520 """Initialize MatrixSetDiag""" 521 522 def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): 523 valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] 524 args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} 525 validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) 526 return x_dtype 527 528 def infer_shape(self, x_shape, diagonal_shape, assist_shape): 529 validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) 530 validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) 531 532 if x_shape[-2] < x_shape[-1]: 533 validator.check("diagnoal shape", diagonal_shape, "x shape excluding the last dimension", 534 x_shape[:-1], Rel.EQ, self.name) 535 else: 536 validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension", 537 x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name) 538 539 return assist_shape 540 541 542class ConfusionMulGrad(PrimitiveWithInfer): 543 """ 544 `output0` is the dot product result of input0 and input1. 545 546 `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it. 547 548 Args: 549 axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. 550 Default:(), reduce all dimensions. Only constant value is allowed. 551 keep_dims (bool): 552 553 - If true, keep these reduced dimensions and the length as 1. 554 - If false, don't keep these dimensions. Default:False. 555 556 Inputs: 557 - **input_0** (Tensor) - The input Tensor. 558 - **input_1** (Tensor) - The input Tensor. 559 - **input_2** (Tensor) - The input Tensor. 560 561 Outputs: 562 - **output_0** (Tensor) - The same shape as `input0`. 563 - **output_1** (Tensor) 564 565 - If axis is (), and keep_dims is false, the output is a 0-D array representing 566 the sum of all elements in the input array. 567 - If axis is int, set as 2, and keep_dims is false, 568 the shape of output is :math:`(x_1,x_3,...,x_R)`. 569 - If axis is tuple(int), set as (2,3), and keep_dims is false, 570 the shape of output is :math:`(x_1,x_4,...x_R)`. 571 572 Examples: 573 >>> confusion_mul_grad = ops.ConfusionMulGrad() 574 >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) 575 >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32) 576 >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32) 577 >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2) 578 output_0: 579 [[ 3. 1. 0.] 580 [-6. 2. -2.]] 581 output_1: 582 -3.0 583 """ 584 585 @prim_attr_register 586 def __init__(self, axis=(), keep_dims=False): 587 self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"]) 588 self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) 589 self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) 590 591 def infer_shape(self, input0_shape, input1_shape, input2_shape): 592 outshape0 = input0_shape 593 outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) 594 return outshape0, outshape1 595 596 def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): 597 validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) 598 validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) 599 validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) 600 return input0_dtype, input1_dtype 601 602 603class GpuConvertToDynamicShape(PrimitiveWithCheck): 604 """ 605 This op is used for dynamic shape testing. Its inferred shape will be unknown 606 during compile time, so that its output will appear to be dynamically shaped. 607 The input will not be altered in any way. Put this operator before the operator 608 being tested for dynamic shape support. 609 610 Inputs: 611 - **input** (Tensor) - The tensor used for testing. 612 613 Outputs: 614 - **output** (Tensor) - Same shape, type and value as `input`. 615 616 Examples: 617 >>> # make a model, since dynamic shape operators must be in GRAPH_MODE 618 >>> class TestDynamicShapeReshapeNet(nn.Cell): 619 >>> def __init__(self): 620 >>> super(TestDynamicShapeReshapeNet, self).__init__() 621 >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() 622 >>> # suppose we are testing Reshape op 623 >>> self.reshape = P.Reshape() 624 >>> 625 >>> def construct(self, input, new_shape): 626 >>> dynamic_shape_input = self.convert_to_dynamic_shape(input) 627 >>> reshaped_input = self.reshape(input, new_shape) 628 >>> 629 >>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 630 >>> input = Tensor(np.array([0, 1, 2, 3]) 631 >>> new_shape = (2, 2) 632 >>> net = TestDynamicShapeReshapeNet() 633 >>> output = net(input, new_shape) 634 >>> print(output) 635 [[0, 1], [2, 3] 636 """ 637 638 @prim_attr_register 639 def __init__(self): 640 self.init_prim_io_names(inputs=["input"], outputs=["output"]) 641 642 def check_shape(self, input_shape): 643 validator.check("input_shape rank", len(input_shape), "", 0, Rel.GT, self.name) 644 645 def check_dtype(self, input_dtype): 646 validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name) 647 648 649class ErrorOnDynamicShapeInput(PrimitiveWithInfer): 650 """ 651 This op is used for dynamic shape testing. The only purpose of this operator is 652 that it will throw a value error if the input is dynamically shaped. 653 654 Inputs: 655 - **input** (Tensor) - The tensor used for testing. 656 657 Outputs: 658 - **output** (Tensor) - Same shape, type and value as `input`. 659 660 Examples: 661 >>> # make a model, since dynamic shape operators must be in GRAPH_MODE 662 >>> class AssertDynamicShapeNet(nn.Cell): 663 >>> def __init__(self): 664 >>> super(AssertDynamicShapeNet, self).__init__() 665 >>> self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() 666 >>> self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput() 667 >>> 668 >>> def construct(self, input, new_shape): 669 >>> dynamic_shape_input = self.convert_to_dynamic_shape(input) 670 >>> self.error_on_dynamic_shape_input(dynamic_shape_input) 671 >>> 672 >>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU") 673 >>> input = Tensor(np.array([0]) 674 >>> net = TestDynamicShapeReshapeNet() 675 >>> output = net(input, new_shape) 676 ValueError: Input is dynamically shaped. 677 """ 678 679 @prim_attr_register 680 def __init__(self): 681 self.init_prim_io_names(inputs=["input"], outputs=["output"]) 682 683 def infer_shape(self, input_shape): 684 shape = list(input_shape) 685 686 for dim in shape: 687 if dim == -1: 688 raise ValueError("Input is dynamically shaped.") 689 690 return input_shape 691 692 def infer_type(self, input_dtype): 693 """Infer the dtype of input for ErrorOnDynamicShapeInput.""" 694 validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name) 695 return input_dtype 696 697 def infer_value(self, input_tensor): 698 return input_tensor 699 700 701class SequenceMask(PrimitiveWithCheck): 702 """ 703 Returns a mask tensor representing the first N positions of each cell. 704 705 If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape 706 [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) 707 708 Inputs: 709 - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be 710 less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`. 711 Must be type int32 or int64. 712 713 - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same 714 type as elements in `lengths`. 715 716 Outputs: 717 One mask tensor of shape lengths.shape + (maxlen,). 718 719 Supported Platforms: 720 ``GPU`` 721 722 Examples: 723 >>> x = Tensor(np.array([[1, 3], [2, 0]])) 724 >>> sequence_mask = ops.SequenceMask() 725 >>> output = sequence_mask(x, 3) 726 >>> print(output) 727 [[[True False False] 728 [True True True]] 729 [[True True False] 730 [False False False]]] 731 """ 732 733 @prim_attr_register 734 def __init__(self): 735 self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"]) 736 737 def check_shape(self, lengths_shape, maxlen_shape): 738 validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name) 739 validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name) 740 741 def check_dtype(self, lengths_dtype, maxlen_dtype): 742 validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) 743 validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) 744 745 746class SyncBatchNorm(PrimitiveWithInfer): 747 r""" 748 Sync Batch Normalization for input data and updated parameters. 749 750 Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is 751 widely used in convolutional neural networks. This operation applies Batch Normalization over input 752 to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating 753 Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. 754 It rescales and recenters the features using a mini-batch of data and the learned parameters which 755 can be described in the following formula, 756 757 .. math:: 758 y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta 759 760 where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. 761 762 Args: 763 epsilon (float): A small value added for numerical stability. Default: 1e-5. 764 momentum (float): The hyper parameter to compute moving average for running_mean and running_var 765 (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). 766 Momentum value must be [0, 1]. Default: 0.1. 767 group (str): The communication group to work on. Default: "sync_bn_group0". 768 device_num (int): The number of devices in each group. Default: 2. 769 770 Inputs: 771 - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. 772 - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. 773 - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. 774 - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. 775 - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`. 776 777 Outputs: 778 Tuple of 5 Tensor, the normalized inputs and the updated parameters. 779 780 - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`. 781 - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. 782 - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. 783 - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. 784 - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. 785 786 Supported Platforms: 787 ``Ascend`` 788 789 Examples: 790 >>> # This example should be run with multiple processes. 791 >>> # Please refer to nn.SyncBatchNorm for direct use. 792 >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) 793 >>> scale = Tensor(np.ones([2]), mindspore.float32) 794 >>> bias = Tensor(np.ones([2]), mindspore.float32) 795 >>> mean = Tensor(np.ones([2]), mindspore.float32) 796 >>> variance = Tensor(np.ones([2]), mindspore.float32) 797 >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm() 798 >>> output = sync_batch_norm(input_x, scale, bias, mean, variance) 799 >>> print(output) 800 (Tensor(shape=[2, 2], dtype=Float32, value= 801 [[ 1.00000000e+00, 1.00000000e+00], 802 [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value= 803 [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= 804 [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= 805 [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value= 806 [ 1.00000000e+00, 1.00000000e+00])) 807 """ 808 809 @prim_attr_register 810 def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2): 811 validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) 812 validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) 813 validator.check_isinstance("group", group, str) 814 validator.check_int(device_num, 2, Rel.GE, "device_num", self.name) 815 self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], 816 outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) 817 818 def infer_shape(self, input_x, scale, bias, mean, variance): 819 validator.check_equal_int(len(scale), 1, "scale rank", self.name) 820 validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) 821 validator.check("scale shape[0]", scale[0], "input_x channel", input_x[1], Rel.EQ, self.name) 822 validator.check_equal_int(len(mean), 1, "mean rank", self.name) 823 validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) 824 validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) 825 return (input_x, scale, scale, scale, scale) 826 827 def infer_dtype(self, input_x, scale, bias, mean, variance): 828 validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) 829 args = {"scale": scale, "bias": bias} 830 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 831 args_moving = {"mean": mean, "variance": variance} 832 validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) 833 return (input_x, scale, bias, input_x, input_x) 834 835 836class Centralization(PrimitiveWithInfer): 837 """ 838 Computes centralization. y = x - mean(x, axis). 839 840 Note: 841 The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. 842 843 Inputs: 844 - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. 845 - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. 846 Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). 847 848 Outputs: 849 Tensor, has the same shape and dtype as the `input_x`. 850 851 Raises: 852 TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. 853 TypeError: If `axis` has non-Int elements. 854 855 Supported Platforms: 856 ``Ascend`` 857 858 Examples: 859 >>> mindspore.set_seed(1) 860 >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) 861 >>> centralization = ops.Centralization() 862 >>> output = centralization(input_x, -1) 863 >>> print(output) 864 [[ 1.1180509 -1.1180508] 865 [ 0.2723984 -0.2723984]] 866 """ 867 868 __mindspore_signature__ = ( 869 sig.make_sig('input_x'), 870 sig.make_sig('axis', default=()) 871 ) 872 873 @prim_attr_register 874 def __init__(self): 875 """Initialize Centralization""" 876 self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) 877 878 def __infer__(self, input_x, axis): 879 x_shape = list(input_x['shape']) 880 x_dtype = input_x['dtype'] 881 axis_v = axis['value'] 882 rank = len(x_shape) 883 884 args = {'input_x': input_x['dtype']} 885 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 886 887 if axis_v is None: 888 raise ValueError(f"For {self.name}, axis must be const.") 889 validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) 890 891 if isinstance(axis_v, int): 892 validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) 893 elif axis: 894 for index, one_axis in enumerate(axis_v): 895 validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) 896 897 out = {'shape': x_shape, 898 'dtype': x_dtype, 899 'value': None} 900 return out 901 902 903class StackInit(PrimitiveWithInfer): 904 """ 905 Create a stack that produces tensors in first-in last-out order. 906 907 After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped 908 at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`. 909 910 Args: 911 index (int): The index of the stack. Default: 1. 912 913 Supported Platforms: 914 ``Ascend`` 915 916 Examples: 917 >>> x = Tensor(np.array([[1, 3], [2, 0]])) 918 >>> index = 0 919 >>> stack = ops.StackInit(index) 920 >>> push = ops.StackPush(index) 921 >>> pop = ops.StackPop(index, x.shape, x.dtype) 922 >>> destroy = ops.StackDestroy(index) 923 >>> stack() 924 >>> push(x) 925 >>> y = pop() 926 >>> destroy() 927 >>> print(y) 928 [[1 3] 929 [2 0]] 930 """ 931 932 @prim_attr_register 933 def __init__(self, index=1): 934 """StackInit""" 935 validator.check_value_type("index", index, [int], self.name) 936 937 938class StackPush(PrimitiveWithInfer): 939 """ 940 Push a tensor onto the stack. 941 942 Before `StackPush`, the stack should be created using `StackInit`. 943 Please refer to the usage in source code of `StackInit`. 944 945 Args: 946 index (int): The index of the stack. Default: 1. 947 948 Inputs: 949 - **input** (Tensor) - A tensor to be pushed onto the stack. 950 951 Supported Platforms: 952 ``Ascend`` 953 954 Examples: 955 Please refer to the usage of `StackInit`. 956 """ 957 958 @prim_attr_register 959 def __init__(self, index=1): 960 """StackPush""" 961 validator.check_value_type("index", index, [int], self.name) 962 self.init_prim_io_names(inputs=['input'], outputs=[]) 963 964 965class StackPop(PrimitiveWithInfer): 966 """ 967 Pop the tensor at the top of the stack. 968 969 Before `StackPop`, the stack should be created using `StackInit`. 970 Please refer to the usage in source code of `StackInit`. 971 972 Args: 973 index (int): The index of the stack. Default: 1. 974 shape (tuple): The shape of the tensor at the top of the stack. Default: (1,). 975 dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32. 976 977 Outputs: 978 - **output** (Tensor) - The tensor at the top of the stack. 979 980 Supported Platforms: 981 ``Ascend`` 982 983 Examples: 984 Please refer to the usage of `StackInit`. 985 """ 986 987 @prim_attr_register 988 def __init__(self, index=1, shape=(1,), dtype=mstype.float32): 989 """StackPop""" 990 validator.check_value_type("index", index, [int], self.name) 991 992 validator.check_value_type('shape type', shape, [list, tuple], self.name) 993 validator.check_int(len(np.array(shape).shape), 1, Rel.EQ, "dim of shape", self.name) 994 for elem in shape: 995 validator.check_int(elem, 1, Rel.GE, 'shape element', self.name) 996 validator.check_value_type('type of shape element', elem, [int], self.name) 997 998 validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name) 999 self.shape = shape 1000 self.dtype = dtype 1001 1002 self.init_prim_io_names(inputs=[], outputs=['output']) 1003 1004 def __infer__(self): 1005 return {'shape': (list(self.shape)), 1006 'dtype': (self.dtype), 1007 'value': None} 1008 1009 1010class StackDestroy(PrimitiveWithInfer): 1011 """ 1012 Destroy the stack. 1013 1014 Before `StackDestroy`, the stack should be created using `StackInit`. 1015 Please refer to the usage in source code of `StackInit`. 1016 1017 Args: 1018 index (int): The index of the stack. Default: 1. 1019 1020 Supported Platforms: 1021 ``Ascend`` 1022 1023 Examples: 1024 Please refer to the usage of `StackInit`. 1025 """ 1026 1027 @prim_attr_register 1028 def __init__(self, index=1): 1029 """StackDestroy""" 1030 validator.check_value_type("index", index, [int], self.name) 1031 1032 1033class DynamicStitch(PrimitiveWithCheck): 1034 r""" 1035 Interleave the values from the data tensors into a single tensor. 1036 1037 Inputs: 1038 - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. 1039 - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. 1040 1041 Outputs: 1042 Tensor. A stacked Tensor with the same type as `data`. 1043 1044 Raises: 1045 TypeError: If the data types of elements in `data` or `indices` are not the same. 1046 ValueError: If the length of `data` or `indices` is not greater than 1. 1047 1048 Supported Platforms: 1049 ``Ascend`` 1050 1051 Examples: 1052 >>> x1 = Tensor([6], mstype.int32) 1053 >>> x2 = Tensor(np.array([4, 1]), mstype.int32) 1054 >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32) 1055 >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32) 1056 >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32) 1057 >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32) 1058 >>> stitch = ops.DynamicStitch() 1059 >>> output = stitch([x1, x2, x3], [y1, y2, y3]) 1060 >>> print(output) 1061 [[ 1 2] 1062 [11 12] 1063 [21 22] 1064 [31 32] 1065 [41 42] 1066 [51 52] 1067 [61 62]] 1068 """ 1069 1070 @prim_attr_register 1071 def __init__(self): 1072 """Initialize DynamicStitch""" 1073 1074 def check_shape(self, indices_shape, data_shape): 1075 validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name) 1076 validator.check_int(len(indices_shape), 1, Rel.GE, "len of indices_shape", self.name) 1077 indices_dim0 = len(indices_shape[0]) 1078 indices_num = len(indices_shape) 1079 1080 validator.check_value_type("shape of data", data_shape, [tuple, list], self.name) 1081 validator.check_int(len(data_shape), 1, Rel.GE, "len of data_shape", self.name) 1082 data_dim0 = len(data_shape[0]) 1083 data_num = len(indices_shape) 1084 1085 validator.check("size of indices", indices_num, 'size of data', data_num, Rel.EQ, self.name) 1086 1087 # shape of `data` must start with shape of `indices` 1088 for i in range(0, indices_num): 1089 indices_dim = len(indices_shape[i]) 1090 data_dim = len(data_shape[i]) 1091 validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LE, self.name) 1092 if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]: 1093 raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}") 1094 1095 # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape. 1096 base_extra = data_dim0 - indices_dim0 1097 for i in range(0, data_num): 1098 indices_dim = len(indices_shape[i]) 1099 data_dim = len(data_shape[i]) 1100 extra = data_dim - indices_dim 1101 validator.check(f"extra dim of data[{i}]", extra, 1102 f"extra dim of data[0]", base_extra, Rel.EQ, self.name) 1103 validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:], 1104 f"data[{i}].shape[{len(indices_shape[i])}:]", 1105 data_shape[i][indices_dim:], Rel.EQ, self.name) 1106 1107 out_shape = [-1] + data_shape[0][indices_dim0:] 1108 return out_shape 1109 1110 def check_dtype(self, indices_type, data_type): 1111 validator.check_subclass("indices[0]", indices_type[0], mstype.tensor, self.name) 1112 validator.check_subclass("data[0]", data_type[0], mstype.tensor, self.name) 1113 indices_num = len(indices_type) 1114 for i in range(0, indices_num): 1115 validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name) 1116 validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i], 1117 mstype.number_type + (mstype.bool_,), self.name) 1118 validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", data_type[0], Rel.EQ, self.name) 1119 return data_type[0] 1120 1121 1122class DynamicBroadcastGradientArgs(Primitive): 1123 """ 1124 Broadcast the two input shapes, return the dimensions that each need to be broadcast. 1125 1126 Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal 1127 or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input 1128 shape's value in that dimension. 1129 1130 Inputs: 1131 - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64, 1132 uint32, uint64. 1133 - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`. 1134 1135 Outputs: 1136 Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask 1137 tensor. 1138 1139 - **r0** (Tensor) - The output shape is 1-D with the same type as s0. 1140 - **r1** (Tensor) - The output shape is 1-D with the same type as s0. 1141 1142 Raises: 1143 ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid 1144 location. 1145 1146 Supported Platforms: 1147 ``Ascend`` 1148 1149 Examples: 1150 >>> shape0 = (4, 2, 1) 1151 >>> shape1 = (2, 7) 1152 >>> from mindspore.ops.operations import _inner_ops 1153 >>> args = _inner_ops.DynamicBroadcastGradientArgs() 1154 >>> r0, r1 = args(Tensor(shape0), Tensor(shape1)) 1155 >>> print(r0, r1) 1156 [2], [0] 1157 """ 1158 1159 @prim_attr_register 1160 def __init__(self): 1161 """Init BroadcastGradientArgs""" 1162 1163 1164class TensorCopySlices(Primitive): 1165 """ 1166 Copy continues memory. 1167 1168 Inputs: 1169 - **x** (Tensor) - The target Tensor. 1170 - **value** (Tensor) - The tensor to update x. 1171 - **begin** (tuple[int]) - A tuple which represents the location where to start. Only 1172 constant value is allowed. 1173 - **end** (tuple[int]) - A tuple or which represents the maximum location where to end. 1174 Only constant value is allowed. 1175 - **strides** (tuple[int]) - A tuple which represents the stride is continuously added 1176 before reaching the maximum location. Only constant value is allowed. 1177 1178 Outputs: 1179 - **y** (Tensor), has the same shape and data type of x. 1180 1181 Examples: 1182 >>> import numpy as np 1183 >>> from mindspore.ops.operations import _inner_ops 1184 >>> copy_slices = _inner_ops.TensorCopySlices() 1185 >>> out = copy_slices(Tensor(np.zeros((5, 5))), Tensor(np.ones((2, 5))), (3, 0), (5, 5), (1, 1)) 1186 >>> print(out) 1187 [[1., 1., 1., 1., 1.], 1188 [1., 1., 1., 1., 1.], 1189 [1., 1., 1., 1., 1.], 1190 [0., 0., 0., 0., 0.], 1191 [0., 0., 0., 0., 0.]] 1192 1193 Supported Platforms: 1194 ``Ascend`` ``GPU`` ``CPU`` 1195 """ 1196 1197 @prim_attr_register 1198 def __init__(self): 1199 """Initialize TensorScatterUpdate""" 1200 self.init_prim_io_names(inputs=['x', 'value', 'begin', 'end', 'strides'], outputs=['y']) 1201 1202 1203class Roll(Primitive): 1204 """ 1205 Rolls the elements of a tensor along an axis. 1206 1207 The elements are shifted positively (towards larger indices) by the offset of `shift` along the dimension of `axis`. 1208 Negative `shift` values will shift elements in the opposite direction. Elements that roll passed the last position 1209 will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified. 1210 1211 Note: 1212 This inner operation is valid only if the axis is equal to 0. If the shift and the axis are tuples or lists, 1213 this inner operation is valid only for the first pair of elements. 1214 1215 Args: 1216 shift (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted 1217 positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements 1218 in the opposite direction. 1219 axis (Union[list(int), tuple(int), int]): Specifies the dimension indexes of shape to be rolled. The value is 1220 forced to be zero in this operation. 1221 1222 Inputs: 1223 - **input_x** (Tensor) - Input tensor. 1224 1225 Outputs: 1226 Tensor, has the same shape and type as `input_x`. 1227 1228 Raises: 1229 TypeError: If `shift` is not an int, a tuple or a list. 1230 TypeError: If `axis` is not an int, a tuple or a list. 1231 TypeError: If element of `shift` is not an int. 1232 TypeError: If element of `axis` is not an int. 1233 ValueError: If axis is not equal to 0. 1234 ValueError: If shape of `shift` is not equal to 1. 1235 ValueError: If shape of `axis` is not equal to 1. 1236 1237 Supported Platforms: 1238 ``Ascend`` 1239 1240 Examples: 1241 >>> from mindspore.ops.operations import _inner_ops as inner 1242 >>> input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32)) 1243 >>> op = inner.Roll(shift=2, axis=0) 1244 >>> output = op(input_x) 1245 >>> print(output) 1246 [3. 4. 0. 1. 2.] 1247 >>> input_x = Tensor(np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).astype(np.float32)) 1248 >>> op = inner.Roll(shift=-1, axis=0) 1249 >>> output = op(input_x) 1250 >>> print(output) 1251 [[5. 6. 7. 8. 9.] 1252 [0. 1. 2. 3. 4.]] 1253 """ 1254 1255 @prim_attr_register 1256 def __init__(self, shift, axis): 1257 """Initialize Roll""" 1258 validator.check_value_type("shift", shift, [int, tuple, list], self.name) 1259 validator.check_value_type("axis", axis, [int, tuple, list], self.name) 1260 if isinstance(shift, (tuple, list)) and isinstance(axis, (tuple, list)): 1261 validator.check_equal_int(len(shift), 1, "shift size", self.name) 1262 validator.check_equal_int(len(axis), 1, "shift size", self.name) 1263 validator.check_equal_int(axis[0], 0, "axis", self.name) 1264 elif isinstance(shift, int) and isinstance(axis, int): 1265 validator.check_equal_int(axis, 0, "axis", self.name) 1266 self.init_prim_io_names(inputs=['input_x'], outputs=['output']) 1267 1268 1269class DSDMatmul(PrimitiveWithInfer): 1270 """ 1271 The definition of the CusSquare primitive. 1272 """ 1273 1274 @prim_attr_register 1275 def __init__(self): 1276 self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y']) 1277 1278 def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape): 1279 batch_size = input_w1_shape[0] 1280 head = input_w1_shape[1] 1281 v_embedding = input_v_shape[1] * 16 // head 1282 seq_len = input_v_shape[0] * 16 // batch_size 1283 return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16) 1284 1285 def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3): 1286 return data_dtype1 1287 1288 1289class MatmulDDS(PrimitiveWithInfer): 1290 """MatmulDDS definition""" 1291 1292 @prim_attr_register 1293 def __init__(self, bs, heads): 1294 """init MatmulDDS""" 1295 self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'], 1296 outputs=['local_prob', 'global_prob']) 1297 1298 self.heads = heads 1299 1300 def infer_shape(self, q, k, local_mask, global_mask): 1301 seq_len = local_mask[0] * local_mask[-1] 1302 bs = q[1] * q[2] // seq_len 1303 global_size = seq_len // 4 1304 size_per_head = q[0] * q[-1] // self.heads 1305 heads = q[0] * q[-1] // size_per_head 1306 block_size = local_mask[1] * local_mask[2] // bs 1307 block_num = seq_len // block_size 1308 l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16) 1309 g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16) 1310 1311 return l_size, g_size 1312 1313 def infer_dtype(self, q, k, local_mask, global_mask): 1314 return q, q 1315 1316 1317class DSDGrad(PrimitiveWithInfer): 1318 """ 1319 The definition of the CusSquare primitive. 1320 """ 1321 @prim_attr_register 1322 def __init__(self): 1323 self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'], 1324 outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm']) 1325 1326 def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape): 1327 return input_w1_shape, input_w2_shape, input_v_shape 1328 1329 def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5): 1330 return data_dtype1, data_dtype1, data_dtype1 1331 1332 1333class MatmulDDSGrad(PrimitiveWithInfer): 1334 """MatmulDDS definition""" 1335 1336 @prim_attr_register 1337 def __init__(self): 1338 """init MatmulDDS""" 1339 self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'], 1340 outputs=['dq', 'dk']) 1341 1342 def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad): 1343 k_size = (q[1], q[0], q[3], q[2]) 1344 1345 return q, k_size 1346 1347 def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad): 1348 return q, k 1349