1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Operators for gradients.""" 17import math 18from functools import partial 19from mindspore._checkparam import _check_3d_int_or_tuple 20from .. import signature as sig 21from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register 22from ..._checkparam import Validator as validator, Rel 23from .._utils import get_concat_offset 24from ...common import dtype as mstype 25from ... import context 26 27 28class AbsGrad(PrimitiveWithInfer): 29 """Computes gradients for abs operation.""" 30 31 @prim_attr_register 32 def __init__(self): 33 """Initialize AbsGrad""" 34 35 def infer_shape(self, y, dy): 36 return y 37 38 def infer_dtype(self, y, dy): 39 return y 40 41 42class ACosGrad(PrimitiveWithInfer): 43 """ 44 Computes ACosGrad of input element-wise. 45 46 Returns: 47 Tensor, has the same type as input. 48 """ 49 50 @prim_attr_register 51 def __init__(self): 52 """Initialize ACosGrad""" 53 54 def infer_shape(self, x, dout): 55 validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) 56 return x 57 58 def infer_dtype(self, x, dout): 59 args = {"x": x, "dout": dout} 60 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 61 return x 62 63 64class AcoshGrad(PrimitiveWithInfer): 65 """Performs grad of Acosh operation.""" 66 67 @prim_attr_register 68 def __init__(self): 69 """Initialize AcoshGrad""" 70 71 def infer_shape(self, x, dout): 72 validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) 73 return x 74 75 def infer_dtype(self, x, dout): 76 args = {"x": x, "dout": dout} 77 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 78 return x 79 80 81class AsinGrad(PrimitiveWithInfer): 82 """ 83 Computes AsinGrad of input element-wise. 84 85 Returns: 86 Tensor, has the same type as input. 87 """ 88 89 @prim_attr_register 90 def __init__(self): 91 """Initialize AsinGrad""" 92 93 def infer_shape(self, x, dout): 94 validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) 95 return x 96 97 def infer_dtype(self, x, dout): 98 args = {"x": x, "dout": dout} 99 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 100 return x 101 102 103class AsinhGrad(PrimitiveWithInfer): 104 """Performs grad of Asinh operation.""" 105 106 @prim_attr_register 107 def __init__(self): 108 """Initialize AsinhGrad""" 109 110 def infer_shape(self, x, dout): 111 validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) 112 return x 113 114 def infer_dtype(self, x, dout): 115 args = {"x": x, "dout": dout} 116 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 117 return x 118 119 120class ReciprocalGrad(PrimitiveWithInfer): 121 """Performs grad of Reciprocal operation.""" 122 123 @prim_attr_register 124 def __init__(self): 125 """Initialize ReciprocalGrad""" 126 127 def infer_shape(self, x_shape, dout_shape): 128 validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) 129 return x_shape 130 131 def infer_dtype(self, x_dtype, dout_dtype): 132 args = {"x": x_dtype, "dout": dout_dtype} 133 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 134 return x_dtype 135 136 137class RsqrtGrad(PrimitiveWithInfer): 138 """Performs grad of Rsqrt operation.""" 139 140 @prim_attr_register 141 def __init__(self): 142 """Initialize RsqrtGrad""" 143 144 def infer_shape(self, x_shape, dout_shape): 145 validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) 146 return x_shape 147 148 def infer_dtype(self, x_dtype, dout_dtype): 149 args = {"x": x_dtype, "dout": dout_dtype} 150 validator.check_tensors_dtypes_same_and_valid(args, 151 [mstype.float16, mstype.float32, mstype.int32, mstype.int8], 152 self.name) 153 return x_dtype 154 155 156class SoftmaxGrad(ReciprocalGrad): 157 """Performs grad of Softmax operation.""" 158 159 160class SqrtGrad(PrimitiveWithInfer): 161 """Performs grad of Sqrt operation.""" 162 163 @prim_attr_register 164 def __init__(self): 165 """Initialize SqrtGrad""" 166 167 def infer_shape(self, x_shape, dout_shape): 168 validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name) 169 return x_shape 170 171 def infer_dtype(self, x_dtype, dout_dtype): 172 args = {"x": x_dtype, "dout": dout_dtype} 173 valid_types = [mstype.float16, mstype.float32, mstype.float64] 174 validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) 175 return x_dtype 176 177 178class BatchNormGrad(PrimitiveWithInfer): 179 """Performs grad of BatchNorm operation.""" 180 181 @prim_attr_register 182 def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'): 183 self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) 184 self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) 185 self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) 186 187 def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve): 188 validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) 189 return (x_shape, scale_shape, scale_shape) 190 191 def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape, reserve): 192 return (x_type, scale_type, scale_type) 193 194 195class SyncBatchNormGrad(PrimitiveWithInfer): 196 """Performs grad of SyncBatchNorm operation.""" 197 198 @prim_attr_register 199 def __init__(self, epsilon=1e-5, group="group0", device_num=2): 200 validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) 201 if not isinstance(group, str): 202 raise TypeError("The group attr of SyncBatchNormGrad should be str.") 203 validator.check_int(device_num, 2, Rel.GE, "device_num", self.name) 204 205 def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape): 206 validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) 207 return (x_shape, scale_shape, scale_shape) 208 209 def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape): 210 return (x_type, scale_type, scale_type) 211 212 213class BiasAddGrad(Primitive): 214 """Computes gradients of BiasAdd.""" 215 216 @prim_attr_register 217 def __init__(self, data_format="NCHW"): 218 self.init_prim_io_names(inputs=['dout'], outputs=['output']) 219 self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) 220 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 221 raise ValueError("NHWC format only support in GPU target.") 222 if self.format == "NCDHW": 223 self.format = "NCHW" 224 self.add_prim_attr('data_format', self.format) 225 226 227class KLDivLossGrad(PrimitiveWithInfer): 228 """Computes gradients for `KLDivLoss` operation.""" 229 230 @prim_attr_register 231 def __init__(self, reduction='mean'): 232 self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name) 233 234 def infer_shape(self, x_shape, y_shape, doutput_shape): 235 validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) 236 return x_shape, y_shape 237 238 def infer_dtype(self, x_type, y_type, doutput_type): 239 args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} 240 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 241 return x_type, y_type 242 243 244class BinaryCrossEntropyGrad(PrimitiveWithInfer): 245 """Computes gradients for `BinaryCrossEntropy` operation.""" 246 247 @prim_attr_register 248 def __init__(self, reduction='mean'): 249 self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name) 250 251 def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape): 252 validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) 253 if weight_shape: 254 validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) 255 return x_shape 256 257 def infer_dtype(self, x_type, y_type, doutput_type, weight_type): 258 args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} 259 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 260 if weight_type: 261 validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) 262 return x_type 263 264 265class ConcatOffset(PrimitiveWithInfer): 266 """primitive for computing Concat's gradient.""" 267 268 @prim_attr_register 269 def __init__(self, N=2, axis=0): 270 """Initialize ConcatOffset""" 271 272 def __infer__(self, input_x): 273 axis = self.axis 274 x_shp = input_x['shape'] 275 x_type = input_x['dtype'] 276 offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name) 277 self.add_prim_attr('T', x_type[0].element_type()) 278 offset_values = [] 279 for i in range(len(x_shp)): 280 values = [] 281 for j in range(len(x_shp[0])): 282 value = 0 283 if j == axis: 284 value = offset[i] 285 values.append(value) 286 offset_values.append(tuple(values)) 287 out = {'shape': None, 288 'dtype': None, 289 'value': tuple(offset_values)} 290 return out 291 292 293class Conv3DBackpropFilter(PrimitiveWithInfer): 294 """ 295 Computes the gradients of convolution 3D with respect to the filter. 296 297 Args: 298 out_channel (int): The dimension of the output. 299 kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution. 300 mode (int): Modes for different convolutions. Not currently used. 301 pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". 302 pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of 303 head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four 304 integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], 305 pad[3], pad[4] and pad[5] correspondingly. 306 stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. 307 dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. 308 group (int): Splits input into groups. Default: 1. 309 data_format (str): The optional value for data format. Currently only support 'NCDHW'. 310 311 Inputs: 312 - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`. 313 Currently dout data type only support float16 and float32. 314 - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default 315 data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16 316 and float32. 317 - **w_size** (tuple(int)) - A tuple describes the shape of the weight which conforms to the format 318 :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. 319 320 Outputs: 321 Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight. 322 323 Supported Platforms: 324 ``Ascend`` 325 326 Examples: 327 >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16) 328 >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16) 329 >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16) 330 >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2)) 331 >>> output = conv3d_backprop_input(x, dout, F.shape(w)) 332 >>> print(output.shape) 333 (32, 32, 4, 6, 2) 334 """ 335 336 @prim_attr_register 337 def __init__(self, 338 out_channel, 339 kernel_size, 340 mode=1, 341 pad_mode="valid", 342 pad=0, 343 stride=(1, 1, 1, 1, 1), 344 dilation=(1, 1, 1, 1, 1), 345 group=1, 346 data_format="NCDHW"): 347 """Initialize Convolution""" 348 self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y']) 349 self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) 350 self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) 351 self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True) 352 self.add_prim_attr('strides', self.stride) 353 self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True) 354 self.add_prim_attr('dilations', self.dilation) 355 validator.check_value_type('pad', pad, (int, tuple), self.name) 356 if isinstance(pad, int): 357 pad = (pad,) * 6 358 validator.check_equal_int(len(pad), 6, 'pad size', self.name) 359 self.add_prim_attr('pad', self.pad) 360 self.pad_list = pad 361 self.add_prim_attr('pad_list', self.pad_list) 362 363 self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) 364 if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0): 365 raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.") 366 if self.pad_mode == 'pad': 367 for item in pad: 368 validator.check_non_negative_int(item, 'pad item', self.name) 369 self.add_prim_attr('pad_mode', self.pad_mode) 370 371 self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) 372 self.add_prim_attr('mode', self.mode) 373 self.group = validator.check_positive_int(group, 'group', self.name) 374 self.add_prim_attr('groups', self.group) 375 self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) 376 self.add_prim_attr('data_format', self.format) 377 378 def __infer__(self, x, doutput, w_size): 379 w_size_v = w_size['value'] 380 validator.check_value_type('w_size', w_size_v, [tuple], self.name) 381 for i, dim_len in enumerate(w_size_v): 382 validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) 383 args = {"x": x['dtype'], "doutput": doutput['dtype']} 384 valid_dtypes = [mstype.float16, mstype.float32] 385 validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) 386 387 validator.check("filter's batch", w_size_v[0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name) 388 validator.check("filter's channel", w_size_v[1], "input_size's channel", x['shape'][1], Rel.EQ, self.name) 389 validator.check("input_size's batch", x['shape'][0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name) 390 391 # infer shape 392 x_shape = x['shape'] 393 dout_shape = doutput['shape'] 394 kernel_d = self.kernel_size[0] 395 kernel_h = self.kernel_size[1] 396 kernel_w = self.kernel_size[2] 397 stride_d = self.stride[2] 398 stride_h = self.stride[3] 399 stride_w = self.stride[4] 400 dilation_d = self.dilation[2] 401 dilation_h = self.dilation[3] 402 dilation_w = self.dilation[4] 403 # The pad_mode is valid by default. If pad_mode is not valid or same, then pad. 404 if self.pad_mode == "valid": 405 self.pad_list = (0, 0, 0, 0, 0, 0) 406 if self.pad_mode == "same": 407 pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_shape[2]) 408 pad_head = math.floor(pad_needed_d / 2) 409 pad_tail = pad_needed_d - pad_head 410 411 pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_shape[3]) 412 pad_top = math.floor(pad_needed_h / 2) 413 pad_bottom = pad_needed_h - pad_top 414 415 pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_shape[4]) 416 pad_left = math.floor(pad_needed_w / 2) 417 pad_right = pad_needed_w - pad_left 418 self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right) 419 420 self.add_prim_attr('pad_list', self.pad_list) 421 out = { 422 'value': None, 423 'shape': w_size_v, 424 'dtype': mstype.float32, 425 } 426 return out 427 428 429class Conv2DBackpropFilter(Primitive): 430 """ 431 Computes the gradients of convolution with respect to the filter. 432 433 Args: 434 out_channel (int): The dimensionality of the output space. 435 kernel_size (Union[int, tuple[int]]): The size of the convolution window. 436 pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". 437 pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of 438 top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the 439 padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly. 440 pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0). 441 mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution , 442 2 deconvolution, 3 depthwise convolution. Default: 1. 443 stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1). 444 dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1). 445 group (int): Splits input into groups. Default: 1. 446 data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ 447 default is 'NCHW'. 448 449 Returns: 450 Tensor, the gradients of convolution. 451 """ 452 453 @prim_attr_register 454 def __init__(self, 455 out_channel, 456 kernel_size, 457 pad_mode="valid", 458 pad=0, 459 pad_list=(0, 0, 0, 0), 460 mode=1, 461 stride=(1, 1), 462 dilation=(1, 1, 1, 1), 463 group=1, 464 data_format="NCHW"): 465 """Initialize Convolution""" 466 self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output']) 467 self.out_channel = out_channel 468 self.kernel_size = kernel_size 469 self.mode = mode 470 pad_mode = pad_mode.upper() 471 self.add_prim_attr('pad_mode', pad_mode) 472 if isinstance(pad, int): 473 pad = (pad,) * 4 474 else: 475 validator.check_equal_int(len(pad), 4, 'pad size', self.name) 476 self.add_prim_attr("pad", pad) 477 if isinstance(stride, tuple) and len(stride) == 4: 478 self.stride = (stride[2], stride[3]) 479 self.add_prim_attr('stride', self.stride) 480 self.dilation = dilation 481 self.group = group 482 self.add_prim_attr('groups', group) 483 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) 484 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 485 raise ValueError("NHWC format only support in GPU target.") 486 self.add_prim_attr('data_format', self.format) 487 488 489class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): 490 """ 491 Returns the gradient of filter for DepthwiseConv2dNative. 492 493 Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier. 494 495 Refer to class DepthwiseConv2dNative for more details. 496 497 Args: 498 channel_multiplier (int): The multiplier for the original output conv. 499 kernel_size (int or tuple): The size of the conv kernel. 500 mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution, 501 2 deconvolution,3 depthwise convolution. Default: 3. 502 pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid". 503 pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of 504 top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the 505 padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly. 506 pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0). 507 stride (int): The stride to be applied to the convolution filter. Default: 1. 508 dilation (int): Specifies the space to use between kernel elements. Default: 1. 509 group (int): Splits input into groups. Default: 1. 510 511 Returns: 512 Tensor, the value is the gradient of filter for DepthwiseConv2dNative. 513 """ 514 515 @prim_attr_register 516 def __init__(self, 517 channel_multiplier, 518 kernel_size, 519 pad_mode="valid", 520 pad=0, 521 pad_list=(0, 0, 0, 0), 522 mode=3, 523 stride=1, 524 dilation=1, 525 group=1): 526 """Initialize Convolution""" 527 self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output']) 528 self.channel_multiplier = channel_multiplier 529 self.kernel_size = kernel_size 530 self.mode = mode 531 self.pad_mode = pad_mode 532 if isinstance(pad, int): 533 pad = (pad,) * 4 534 else: 535 validator.check_equal_int(len(pad), 4, 'pad size', self.name) 536 self.add_prim_attr("pad", pad) 537 self.pad_list = pad_list 538 self.stride = stride 539 self.dilation = dilation 540 self.group = group 541 self.add_prim_attr('data_format', "NCHW") 542 543 def __call__(self, x, w_size, dout): 544 raise NotImplementedError 545 546 def __infer__(self, x, w_size, dout): 547 w_size_v = w_size['value'] 548 args = {'x': x['dtype'], 'dout': dout['dtype']} 549 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 550 out = { 551 'value': None, 552 'shape': w_size_v, 553 'dtype': dout['dtype'], 554 } 555 return out 556 557 558class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): 559 """ 560 Returns the gradient of input for DepthwiseConv2dNative. 561 562 Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier. 563 564 Args: 565 channel_multiplier (int): The multiplier for the original output conv. 566 kernel_size (int or tuple): The size of the conv kernel. 567 mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution , 568 2 deconvolution,3 depthwise convolution. Default: 3. 569 pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid". 570 pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of 571 top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the 572 padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly. 573 pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0). 574 stride (int): The stride to be applied to the convolution filter. Default: 1. 575 dilation (int): Specifies the space to use between kernel elements. Default: 1. 576 group (int): Splits input into groups. Default: 1. 577 578 Returns: 579 Tensor, the value is the gradient of input for DepthwiseConv2dNative. 580 """ 581 582 @prim_attr_register 583 def __init__(self, 584 channel_multiplier, 585 kernel_size, 586 pad_mode="valid", 587 pad=0, 588 pad_list=(0, 0, 0, 0), 589 mode=3, 590 stride=1, 591 dilation=1, 592 group=1): 593 """Initialize Convolution""" 594 self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output']) 595 self.channel_multiplier = channel_multiplier 596 self.kernel_size = kernel_size 597 self.mode = mode 598 self.pad_mode = pad_mode 599 if isinstance(pad, int): 600 pad = (pad,) * 4 601 else: 602 validator.check_equal_int(len(pad), 4, 'pad size', self.name) 603 self.add_prim_attr("pad", pad) 604 self.pad_list = pad_list 605 self.stride = stride 606 self.dilation = dilation 607 self.group = group 608 self.add_prim_attr('data_format', "NCHW") 609 610 def __call__(self, x_size, w, dout): 611 raise NotImplementedError 612 613 def __infer__(self, x_size, w, dout): 614 args = {'w': w['dtype'], 'dout': dout['dtype']} 615 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 616 x_size_v = x_size['value'] 617 out = { 618 'value': None, 619 'shape': x_size_v, 620 'dtype': dout['dtype'], 621 } 622 return out 623 624 625class DropoutGrad(Primitive): 626 """ 627 The gradient of Dropout. During training, randomly zeroes some of the elements 628 of the input tensor with probability. 629 630 Args: 631 keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9, 632 means dropping out 10% of input units. Default: 0.5. 633 634 Inputs: 635 - **shape** (tuple[int]) - The shape of target mask. 636 637 Outputs: 638 Tensor, the value of generated mask for input shape. 639 640 Examples: 641 >>> dropout_grad = ops.DropoutGrad(keep_prob=0.5) 642 >>> in = Tensor((20, 16, 50, 50)) 643 >>> out = dropout_grad(in) 644 """ 645 646 @prim_attr_register 647 def __init__(self, keep_prob=0.5): 648 self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) 649 650 651class FlattenGrad(PrimitiveWithInfer): 652 """Performs gradients of Flatten.""" 653 654 @prim_attr_register 655 def __init__(self): 656 self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output']) 657 658 def __infer__(self, *args): 659 out = { 660 'value': None, 661 'shape': args[1]['value'], 662 'dtype': args[0]['dtype'], 663 } 664 return out 665 666 667class InstanceNormGrad(PrimitiveWithInfer): 668 """Gradients of InstanceNorm operation.""" 669 670 @prim_attr_register 671 def __init__(self, epsilon=0.0, momentum=0.1): 672 self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'], 673 outputs=['dx', 'bn_gamma', 'bn_beta']) 674 675 def infer_shape(self, y_backprop_shape, x_shape, gamma_shape, save_mean_shape, save_variance_shape): 676 return (x_shape, gamma_shape, gamma_shape) 677 678 def infer_dtype(self, y_backprop_type, x_type, gamma_type, save_mean_type, save_variance_type): 679 return (x_type, gamma_type, gamma_type) 680 681 682class UniqueGrad(Primitive): 683 """Gradients of Unique operation.""" 684 685 @prim_attr_register 686 def __init__(self): 687 self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx']) 688 689 def __call__(self, dy, x, scale, save_mean, save_inv_variance): 690 raise NotImplementedError 691 692 693class BNTrainingReduceGrad(PrimitiveWithInfer): 694 """Gradients of FusedBatchNorm operation.""" 695 696 @prim_attr_register 697 def __init__(self, epsilon=0.0001): 698 _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance'] 699 self.init_prim_io_names(inputs=_inputs, outputs=['y']) 700 701 def infer_shape(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance): 702 return grads 703 704 def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance): 705 return grads 706 707 708class BNTrainingUpdateGrad(PrimitiveWithInfer): 709 """Gradients of FusedBatchNorm operation.""" 710 711 @prim_attr_register 712 def __init__(self, epsilon=0.0001): 713 self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'], 714 outputs=['diff_scale', 'diff_offset']) 715 716 def infer_shape(self, grads, x, batch_mean, batch_variance): 717 return (batch_mean, batch_variance) 718 719 def infer_dtype(self, grads, x, batch_mean, batch_variance): 720 return (batch_mean, batch_variance) 721 722 723class GeLUGrad(PrimitiveWithInfer): 724 """Gradients of GeLU operation.""" 725 726 @prim_attr_register 727 def __init__(self): 728 """Initialize GeLUGrad""" 729 730 def infer_shape(self, y_backprop_shape, x_shape, y_shape): 731 return x_shape 732 733 def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): 734 tuple(map(partial(validator.check_tensor_dtype_valid, 735 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 736 ("y_backprop", "x", "y"), 737 (y_backprop_dtype, x_dtype, y_dtype))) 738 return x_dtype 739 740 741class FastGeLUGrad(PrimitiveWithInfer): 742 """Gradients of FastGeLU operation.""" 743 744 @prim_attr_register 745 def __init__(self): 746 """init FastGeLUGrad""" 747 748 def infer_shape(self, y_backprop_shape, x_shape): 749 return x_shape 750 751 def infer_dtype(self, y_backprop_dtype, x_dtype): 752 tuple(map(partial(validator.check_tensor_dtype_valid, 753 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 754 ("y_backprop", "x"), 755 (y_backprop_dtype, x_dtype))) 756 return x_dtype 757 758 759class _PoolGrad(PrimitiveWithInfer): 760 """Gradients of the max/avg pool operation.""" 761 762 @prim_attr_register 763 def __init__(self, kernel_size, strides, pad_mode="VALID", data_format="NCHW"): 764 self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) 765 766 validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) 767 validator.check_value_type('strides', strides, [int, tuple], self.name) 768 self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) 769 self.add_prim_attr("pad_mode", self.pad_mode) 770 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) 771 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 772 raise ValueError("NHWC format only support in GPU target.") 773 self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") 774 if not self.is_maxpoolgradwithargmax: 775 self.add_prim_attr('data_format', self.format) 776 777 def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax): 778 validator.check_value_type(arg_name, arg_val, (int, tuple), self.name) 779 error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number " 780 f"or a tuple of two or four positive int numbers, but got {arg_val}") 781 if isinstance(arg_val, int): 782 ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val) 783 elif len(arg_val) == 2: 784 ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1]) 785 elif len(arg_val) == 4: 786 ret = arg_val 787 else: 788 raise error_msg 789 # whether all elements of tuple are positive integers 790 for item in ret: 791 if not isinstance(item, int) or item <= 0: 792 raise error_msg 793 return ret 794 795 kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size, self.is_maxpoolgradwithargmax) 796 self.kernel_size = kernel_size if self.format == "NCHW" else [kernel_size[0], kernel_size[2], 797 kernel_size[3], kernel_size[1]] 798 self.add_prim_attr("kernel_size", self.kernel_size) 799 800 strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax) 801 self.strides = strides if self.format == "NCHW" else [strides[0], strides[2], strides[3], strides[1]] 802 self.add_prim_attr("strides", self.strides) 803 804 805class AvgPoolGradVm(_PoolGrad): 806 """Gradients of the avg pool operation for vm.""" 807 808 @prim_attr_register 809 def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"): 810 super(AvgPoolGradVm, self).__init__(kernel_size, strides, pad_mode) 811 self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output']) 812 813 def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix): 814 out = { 815 'value': None, 816 'shape': tuple(origin_input['value']), 817 'dtype': dout['dtype'], 818 } 819 820 return out 821 822 823class AvgPoolGrad(_PoolGrad): 824 """Gradients of the avg pool operation.""" 825 826 @prim_attr_register 827 def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"): 828 super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format) 829 830 def infer_shape(self, x1_shape, x2_shape, grad_shape): 831 return x1_shape 832 833 def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): 834 return x1_dtype 835 836 837class AdaptiveAvgPool2DGrad(PrimitiveWithInfer): 838 """Gradients of the adaptive avg pool 2D operation.""" 839 840 @prim_attr_register 841 def __init__(self): 842 """Initialize AdaptiveAvgPool2DGrad""" 843 844 def infer_shape(self, x1_shape, grad_shape): 845 return x1_shape 846 847 def infer_dtype(self, x1_dtype, grad_dtype): 848 return x1_dtype 849 850 851class AvgPool3DGrad(Primitive): 852 """Gradients of the avg pool3d operation.""" 853 854 @prim_attr_register 855 def __init__(self, kernel_size=1, strides=1, pads=0, ceil_mode=False, 856 count_include_pad=True, divisor_override=0, data_format="NCDHW"): 857 self.init_prim_io_names(inputs=['origin_input_shape', 'grads'], outputs=['output']) 858 self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name) 859 self.add_prim_attr('kernel_size', self.kernel_size) 860 self.strides = _check_3d_int_or_tuple('strides', strides, self.name) 861 self.add_prim_attr('strides', self.strides) 862 validator.check_value_type('pads', pads, (int, tuple), self.name) 863 if isinstance(pads, int): 864 pads = (pads,) * 6 865 validator.check_equal_int(len(pads), 6, 'pad size', self.name) 866 for item in pads: 867 validator.check_non_negative_int(item, 'pad item', self.name) 868 self.add_prim_attr('pad_list', pads) 869 self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name) 870 self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name) 871 self.divisor_override = validator.check_value_type('divisor_override', divisor_override, int, self.name) 872 self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) 873 874 875class MaxPoolGrad(_PoolGrad): 876 """Performs gradients of the max pool operation.""" 877 878 @prim_attr_register 879 def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"): 880 super(MaxPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format) 881 882 def infer_shape(self, x1_shape, x2_shape, grad_shape): 883 return x1_shape 884 885 def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): 886 return x1_dtype 887 888 889class MaxPoolGradGrad(_PoolGrad): 890 r""" 891 Performs gradients of the MaxPoolGrad operation. 892 893 Args: 894 kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value, 895 is an int number that represents height and width are both kernel_size, or a tuple 896 of two int numbers that represent height and width respectively. Default: 1. 897 strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents 898 the height and width of movement are both strides, or a tuple of two int numbers that 899 represent height and width of movement respectively. Default: 1. 900 pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive. 901 Default: "valid". 902 903 - same: Adopts the way of completion. The height and width of the output will be the same as 904 the input. The total number of padding will be calculated in horizontal and vertical 905 directions and evenly distributed to top and bottom, left and right if possible. 906 Otherwise, the last extra padding will be done from the bottom and the right side. 907 908 - valid: Adopts the way of discarding. The possible largest height and width of output 909 will be returned without padding. Extra pixels will be discarded. 910 911 Inputs: 912 - **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be float16. 913 - **origin_output** (Tensor) - Data type same as `origin_input`. 914 - **grad** (Tensor) - Data type same as `origin_input`. 915 916 Outputs: 917 Tensor, with data type same as `origin_input`. 918 919 """ 920 921 @prim_attr_register 922 def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"): 923 super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode) 924 925 def infer_shape(self, x1_shape, x2_shape, grad_shape): 926 return x2_shape 927 928 def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): 929 args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} 930 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) 931 return x2_dtype 932 933 934def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode): 935 """ 936 helper for get max pool3d grad pads by pad_mode 937 """ 938 def get_pad(origin_shape, ksize, stride): 939 tail = origin_shape % stride 940 pad = (ksize - tail) if tail > 0 else (ksize - stride) 941 pad = max(pad, 0) 942 pad1 = int(pad / 2) 943 pad2 = int(pad / 2) + pad % 2 944 return pad1, pad2 945 946 _, _, d, h, w = input_shape 947 _, _, kd, kh, kw = kernel_size 948 _, _, strd, strh, strw = strides 949 950 pads = (0, 0, 0, 0, 0, 0) 951 if pad_mode == 'SAME': 952 pads_d = get_pad(d, kd, strd) 953 pads_h = get_pad(h, kh, strh) 954 pads_w = get_pad(w, kw, strw) 955 pads = pads_d + pads_h + pads_w 956 return pads 957 958 959class MaxPool3DGrad(PrimitiveWithInfer): 960 """Gradients of the max pool3d operation.""" 961 962 @prim_attr_register 963 def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), 964 pad_mode='VALID', pad_list=0, data_format="NCDHW"): 965 validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) 966 validator.check_value_type('strides', strides, [int, tuple], self.name) 967 validator.check_value_type('pad_mode', pad_mode, [str], self.name) 968 self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) 969 if pad_mode.upper() == 'PAD': 970 pad_mode = 'CALCULATED' 971 self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'CALCULATED'], 'pad_mode', self.name) 972 self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, 973 allow_five=True, ret_five=True) 974 self.add_prim_attr("kernel_size", self.kernel_size) 975 self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True) 976 self.add_prim_attr("strides", self.strides) 977 validator.check_value_type('pad_list', pad_list, (int, tuple), self.name) 978 self.pad_list = pad_list 979 if isinstance(self.pad_list, int): 980 self.pad_list = (self.pad_list,) * 6 981 if len(self.pad_list) == 3: 982 self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[3]) 983 if len(self.pad_list) != 3 and len(self.pad_list) != 6: 984 raise ValueError(f"For `maxpool3d` attr 'pad_list' should be an positive int number or a tuple of " 985 f"three or six positive int numbers, but got `{len(self.pad_list)}` numbers.") 986 if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0): 987 raise ValueError(f"For '{self.name}', when pad_list is not 0, pad_mode should be set as 'pad'.") 988 if self.pad_mode == 'CALCULATED': 989 for item in self.pad_list: 990 validator.check_non_negative_int(item, 'pad_list item', self.name) 991 self.add_prim_attr("pad_list", self.pad_list) 992 993 def infer_shape(self, x_shape, y_shape, grad_shape): 994 validator.check_equal_int(len(x_shape), 5, "x rank", self.name) 995 return x_shape 996 997 def infer_dtype(self, x_dtype, y_dtype, grad_dtype): 998 args = {'x_dtype': x_dtype, 'y_dtype': y_dtype, 'grad_dtype': grad_dtype} 999 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 1000 return x_dtype 1001 1002 1003class MaxPool3DGradGrad(PrimitiveWithInfer): 1004 """Gradients of the max pool3d grad operation.""" 1005 1006 @prim_attr_register 1007 def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"): 1008 validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) 1009 validator.check_value_type('strides', strides, [int, tuple], self.name) 1010 validator.check_value_type('pad_mode', pad_mode, [str], self.name) 1011 self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) 1012 self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) 1013 self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, 1014 allow_five=True, ret_five=True) 1015 self.add_prim_attr("kernel_size", self.kernel_size) 1016 self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True) 1017 self.add_prim_attr("strides", self.strides) 1018 1019 def infer_shape(self, x_shape, y_shape, grad_shape): 1020 validator.check_equal_int(len(x_shape), 5, "x rank", self.name) 1021 validator.check('x_shape', x_shape, 'grad_shape', grad_shape, prim_name=self.name) 1022 pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode) 1023 for pad in pad_list: 1024 validator.check_non_negative_int(pad, 'element of pad_list', self.name) 1025 self.add_prim_attr("pad_list", pad_list) 1026 return y_shape 1027 1028 def infer_dtype(self, x_dtype, y_dtype, grad_dtype): 1029 args = {'x_dtype': x_dtype, 'y_dtype': y_dtype} 1030 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 1031 validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name) 1032 return x_dtype 1033 1034 1035class MaximumGrad(Primitive): 1036 """Grad for maximum.""" 1037 1038 @prim_attr_register 1039 def __init__(self, grad_x=True, grad_y=True): 1040 """Initialize MaximumGrad""" 1041 1042 def __call__(self, x, y, dout): 1043 raise NotImplementedError 1044 1045 1046class MaxPoolGradWithArgmax(_PoolGrad): 1047 """Computes the gradients of MaxPoolWithArgmax.""" 1048 1049 @prim_attr_register 1050 def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"): 1051 self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output']) 1052 super(MaxPoolGradWithArgmax, self).__init__(kernel_size, strides, pad_mode) 1053 1054 def infer_shape(self, x_shape, grad_shape, argmax_shape): 1055 if not grad_shape: 1056 raise TypeError("The dout of MaxPoolGradWithArgmax should be a Tensor.") 1057 return x_shape 1058 1059 def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): 1060 return grad_dtype 1061 1062 1063class MaxPoolGradGradWithArgmax(_PoolGrad): 1064 r""" 1065 Computes the gradients of MaxPoolGradWithArgmax. 1066 1067 Args: 1068 kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value, 1069 is an int number that represents height and width are both kernel_size, or a tuple 1070 of two int numbers that represent height and width respectively. Default: 1. 1071 strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents 1072 the height and width of movement are both strides, or a tuple of two int numbers that 1073 represent height and width of movement respectively. Default: 1. 1074 pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive. 1075 Default: "valid". 1076 1077 - same: Adopts the way of completion. The height and width of the output will be the same as 1078 the input. The total number of padding will be calculated in horizontal and vertical 1079 directions and evenly distributed to top and bottom, left and right if possible. 1080 Otherwise, the last extra padding will be done from the bottom and the right side. 1081 1082 - valid: Adopts the way of discarding. The possible largest height and width of output 1083 will be returned without padding. Extra pixels will be discarded. 1084 1085 Inputs: 1086 - **x** (Tensor) - Tensor with data format "NCHW", data type must be float16. 1087 - **grad** (Tensor) - Data type same as `x`. 1088 - **argmax** (Tensor) - Data type must be uint16 or int64. 1089 1090 Outputs: 1091 Tensor, with data type same as `x`. 1092 1093 """ 1094 1095 @prim_attr_register 1096 def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"): 1097 self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output']) 1098 super(MaxPoolGradGradWithArgmax, self).__init__(kernel_size, strides, pad_mode) 1099 1100 def infer_shape(self, x_shape, grad_shape, argmax_shape): 1101 if not grad_shape: 1102 raise TypeError("The dout of MaxPoolGradGradWithArgmax should be a Tensor.") 1103 return x_shape 1104 1105 def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): 1106 args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} 1107 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) 1108 return grad_dtype 1109 1110 1111class MinimumGrad(Primitive): 1112 """Grad for minimum.""" 1113 1114 @prim_attr_register 1115 def __init__(self, grad_x=True, grad_y=True): 1116 """Initialize MinimumGrad""" 1117 1118 def __call__(self, x, y, dout): 1119 raise NotImplementedError 1120 1121 1122class L2NormalizeGrad(PrimitiveWithInfer): 1123 r""" 1124 Gradients of L2 normalize. 1125 1126 Args: 1127 axis (Union[list(int), tuple(int), int]): The begin axis for the input to apply L2 normalize. Default: 0. 1128 epsilon (float): A small value added for numerical stability. Default: 1e-4. 1129 1130 Inputs: 1131 - **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize. 1132 - **out** (Tensor) - Must be the output of forward operator L2Normalize. 1133 - **dout** (Tensor) - The backprop of the next layer. 1134 1135 Outputs: 1136 Tensor, gradients of L2Normalize `input_x`. 1137 """ 1138 1139 @prim_attr_register 1140 def __init__(self, axis=0, epsilon=1e-4): 1141 axis = [axis] if isinstance(axis, int) else axis 1142 validator.check_value_type('axis', axis, [list, tuple], self.name) 1143 validator.check_value_type('epsilon', epsilon, [int, float], self.name) 1144 self.add_prim_attr('axis', axis) 1145 self.init_attrs['axis'] = axis 1146 if len(axis) != 1: 1147 raise TypeError("The length of axis must be 1, later will support multiple axis!") 1148 1149 def infer_shape(self, input_x, out, dout): 1150 validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name) 1151 validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name) 1152 return input_x 1153 1154 def infer_dtype(self, input_x, out, dout): 1155 args = {'input_x': input_x, 'out': out, 'dout': dout} 1156 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 1157 return input_x 1158 1159 1160class LayerNormGrad(Primitive): 1161 """ 1162 Applies the layer Normalization to the input array. 1163 1164 This operator will calculate the input gradients of layernorm. 1165 1166 Args: 1167 begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1. 1168 begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1. 1169 1170 Returns: 1171 tuple[int], tuple of 3 values (the gradients of layernorm input, gamma, beta). 1172 """ 1173 1174 @prim_attr_register 1175 def __init__(self, begin_norm_axis=1, begin_params_axis=1): 1176 """init""" 1177 self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) 1178 self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) 1179 1180 def __call__(self, x, dy, variance, mean, gamma): 1181 raise NotImplementedError 1182 1183 1184class LayerNormGradGrad(PrimitiveWithInfer): 1185 """ 1186 Gets the gradient of LayerNormGrad operation. 1187 1188 Args: 1189 begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1. 1190 begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1. 1191 1192 Returns: 1193 tuple[int], tuple of 3 values (the gradients of layernormgrad input, dy, gamma). 1194 """ 1195 1196 @prim_attr_register 1197 def __init__(self, begin_norm_axis=1, begin_params_axis=1): 1198 """init""" 1199 self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) 1200 self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) 1201 1202 def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): 1203 raise NotImplementedError 1204 1205 def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): 1206 return x, dy, gamma 1207 1208 def infer_dtype(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db): 1209 return x, dy, gamma 1210 1211 1212class LogSoftmaxGrad(PrimitiveWithInfer): 1213 """Computes gradient for the Log Softmax activation.""" 1214 1215 @prim_attr_register 1216 def __init__(self, axis=-1): 1217 """Initialize LogSoftmaxGrad""" 1218 validator.check_value_type("axis", axis, [int], self.name) 1219 1220 def infer_shape(self, dout, logits): 1221 rank = len(logits) 1222 validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name) 1223 return logits 1224 1225 def infer_dtype(self, dout, logits): 1226 validator.check_subclass("logits", logits, mstype.tensor, self.name) 1227 return logits 1228 1229 1230class LSTMGradData(PrimitiveWithInfer): 1231 """Computes the data gradients of LSTM.""" 1232 1233 @prim_attr_register 1234 def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 1235 self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) 1236 self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) 1237 self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) 1238 self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) 1239 self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) 1240 self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 1241 self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) 1242 1243 if bidirectional: 1244 self.num_directions = 2 1245 else: 1246 self.num_directions = 1 1247 1248 def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, 1249 hx_shape, cx_shape, reserve_shape, state_shape): 1250 # dhy and dcy should be same shape 1251 validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) 1252 validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) 1253 validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) 1254 validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) 1255 validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) 1256 1257 validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) 1258 validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name) 1259 1260 validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) 1261 validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) 1262 validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) 1263 1264 dx_shape = (y_shape[0], y_shape[1], self.input_size) 1265 dhx_shape = dhy_shape 1266 dcx_shape = dcy_shape 1267 1268 return (dx_shape, dhx_shape, dcx_shape) 1269 1270 def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, 1271 hx_dtype, cx_dtype, reserve_dtype, state_dtype): 1272 args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} 1273 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) 1274 return (dy_dtype, dy_dtype, dy_dtype) 1275 1276 1277class LSTMGradWeight(PrimitiveWithInfer): 1278 """Computes the weight gradients of LSTM.""" 1279 1280 @prim_attr_register 1281 def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 1282 self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) 1283 self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) 1284 self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) 1285 self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) 1286 self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) 1287 self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 1288 self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) 1289 1290 if bidirectional: 1291 self.num_directions = 2 1292 else: 1293 self.num_directions = 1 1294 1295 def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape): 1296 weight_size = 0 1297 gate_size = 4 * self.hidden_size 1298 for layer in range(self.num_layers): 1299 for _ in range(self.num_directions): 1300 input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions 1301 weight_size += gate_size * input_layer_size 1302 weight_size += gate_size * self.hidden_size 1303 if self.has_bias: 1304 weight_size += 2 * gate_size 1305 1306 return (weight_size, 1, 1) 1307 1308 def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype): 1309 return hx_dtype 1310 1311 1312class LSTMGrad(PrimitiveWithInfer): 1313 """Computes the data and weight gradients of LSTM.""" 1314 1315 @prim_attr_register 1316 def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 1317 self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) 1318 self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) 1319 self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) 1320 self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) 1321 self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) 1322 self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 1323 self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) 1324 1325 if bidirectional: 1326 self.num_directions = 2 1327 else: 1328 self.num_directions = 1 1329 1330 def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape, 1331 dcy_shape, reserve_shape): 1332 # dhy and dcy should be same shape 1333 validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) 1334 validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) 1335 validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) 1336 validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) 1337 validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) 1338 1339 validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) 1340 validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name) 1341 1342 validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) 1343 validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) 1344 validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) 1345 1346 dx_shape = (y_shape[0], y_shape[1], self.input_size) 1347 dhx_shape = dhy_shape 1348 dcx_shape = dcy_shape 1349 weight_size = 0 1350 gate_size = 4 * self.hidden_size 1351 for layer in range(self.num_layers): 1352 for _ in range(self.num_directions): 1353 input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions 1354 weight_size += gate_size * input_layer_size 1355 weight_size += gate_size * self.hidden_size 1356 if self.has_bias: 1357 weight_size += gate_size 1358 1359 return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1)) 1360 1361 def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype, 1362 dcy_dtype, reserve_dtype): 1363 return (dy_dtype, dy_dtype, dy_dtype, hx_dtype) 1364 1365 1366class DynamicRNNGrad(PrimitiveWithInfer): 1367 """Computes the input gradients of DynamicRNN.""" 1368 1369 @prim_attr_register 1370 def __init__(self, 1371 cell_type='LSTM', 1372 direction='UNIDIRECTIONAL', 1373 cell_depth=1, 1374 use_peephole=False, 1375 keep_prob=1.0, 1376 cell_clip=-1.0, 1377 num_proj=0, 1378 time_major=True, 1379 forget_bias=0.0): 1380 self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) 1381 1382 def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, 1383 c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): 1384 validator.check_equal_int(len(x_shape), 3, "x_shape", self.name) 1385 num_step, batch_size, input_size = x_shape 1386 hidden_size = w_shape[-1] // 4 1387 if w_shape[-1] % 4 != 0: 1388 raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.") 1389 validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size", 1390 input_size + hidden_size, Rel.EQ, self.name) 1391 valid_shape = [num_step, batch_size, hidden_size] 1392 validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name) 1393 validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1394 validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1395 validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1396 validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1397 validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1398 validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1399 validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1400 validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1401 validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1402 validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name) 1403 validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name) 1404 1405 return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape 1406 1407 def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype, 1408 c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype): 1409 return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype 1410 1411 1412class DynamicGRUV2Grad(PrimitiveWithInfer): 1413 r""" 1414 Computes the input gradients of DynamicGRUV2. 1415 1416 Args: 1417 direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'. 1418 Only 'UNIDIRECTIONAL' is currently supported. 1419 cell_depth (int): An integer identifying the cell depth in the op. Default: 1. 1420 keep_prob (float): A float identifying the keep prob in the op. Default: 1.0. 1421 cell_clip (float): A float identifying the cell clip in the op. Default: -1.0. 1422 num_proj (int): An integer identifying the num proj in the op. Default: 0. 1423 time_major (bool): A bool identifying the time major in the op. Default: True. 1424 gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh. 1425 'zrh' is another option. 1426 reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True. 1427 1428 Inputs: 1429 - **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`. 1430 The data type must be float16 or float32. 1431 - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`. 1432 The data type must be float16 or float32. 1433 - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden_size, 3 x hidden_size)`. 1434 The data type must be float16 or float32. 1435 - **y** (Tensor) - A Tensor of shape :math: 1436 if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`, 1437 if num_proj == 0 `(num_step, batch_size, hidden_size)`. 1438 The data type must be float16 or float32. 1439 - **init_h** (Tensor) - Hidden state of initial time. 1440 Tensor of shape :math:`(batch_size, hidden_size)`. 1441 The data type must be float16 or float32. 1442 - **h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. 1443 The data type must be float16 or float32. 1444 - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`. 1445 - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`. 1446 - **update** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. 1447 The data type must be float16 or float32. 1448 - **reset** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. 1449 The data type must be float16 or float32. 1450 - **new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. 1451 The data type must be float16 or float32. 1452 - **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. 1453 The data type must be float16 or float32. 1454 - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`. 1455 Only `None` is currently supported. 1456 - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32. 1457 1458 Outputs: 1459 - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`. 1460 Has the same type with input `x`. 1461 - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`. 1462 Has the same type with input `x`. 1463 - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`. 1464 Has the same type with input `x`. 1465 - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`. 1466 Has the same type with input `x`. 1467 - **dx** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`. 1468 Has the same type with input `x`. 1469 - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch_size, hidden_size)`. 1470 Has the same type with input `x`. 1471 """ 1472 1473 @prim_attr_register 1474 def __init__(self, 1475 direction='UNIDIRECTIONAL', 1476 cell_depth=1, 1477 keep_prob=1.0, 1478 cell_clip=-1.0, 1479 num_proj=0, 1480 time_major=True, 1481 gate_order="rzh", 1482 reset_after=True): 1483 self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name) 1484 self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) 1485 self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name) 1486 self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name) 1487 self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) 1488 self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) 1489 self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) 1490 self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) 1491 1492 def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, 1493 dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): 1494 validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) 1495 validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name) 1496 validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name) 1497 validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name) 1498 num_step, batch_size, input_size = x_shape 1499 hidden_size = whidden_shape[0] 1500 validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size", 1501 3 * hidden_size, Rel.EQ, self.name) 1502 validator.check("weight_input_shape", winput_shape, "excepted shape", 1503 [input_size, 3 * hidden_size], Rel.EQ, self.name) 1504 if self.num_proj > 0: 1505 valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)] 1506 else: 1507 valid_y_shape = [num_step, batch_size, hidden_size] 1508 validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name) 1509 1510 validator.check("init_h_shape", init_h_shape, "excepted shape", 1511 [batch_size, hidden_size], Rel.EQ, self.name) 1512 valid_shape = [num_step, batch_size, hidden_size] 1513 validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1514 validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1515 validator.check("dh_shape", dh_shape, "excepted shape", 1516 [batch_size, hidden_size], Rel.EQ, self.name) 1517 validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1518 validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1519 validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1520 validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name) 1521 if seq_shape is not None: 1522 validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name) 1523 1524 dx_shape = (num_step, batch_size, input_size) 1525 dh_shape = (batch_size, hidden_size) 1526 dwinput_shape = (input_size, 3 * hidden_size) 1527 dwhidden_shape = (hidden_size, 3 * hidden_size) 1528 db_shape = (3 * hidden_size,) 1529 return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape 1530 1531 def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype, 1532 dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype): 1533 valid_types = (mstype.float16, mstype.float32) 1534 args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype, 1535 "dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype, 1536 "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} 1537 validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name) 1538 validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name) 1539 validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name) 1540 validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name) 1541 validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) 1542 if seq_dtype is not None: 1543 validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name) 1544 if mask_dtype is not None: 1545 validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name) 1546 return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype 1547 1548 1549class PReLUGrad(PrimitiveWithInfer): 1550 r""" 1551 Gradients of PReLU operation. 1552 1553 Note: 1554 1-dimensional input_x is not supported. 1555 1556 Inputs: 1557 - **y_backprop** (Tensor) - Representing the backprop of the next layer. 1558 - **input_x** (Tensor) - Must be the input `input_x` of forward operator PRelu. 1559 - **weight** (Tensor) - Float Tensor, w > 0, must be the input `weight` of forward operator PRelu. 1560 1561 Outputs: 1562 Tensor, with the same type as `input_x`. 1563 """ 1564 1565 @prim_attr_register 1566 def __init__(self): 1567 pass 1568 1569 def infer_shape(self, y_backprop_shape, a_shape, w_shape): 1570 return y_backprop_shape, w_shape 1571 1572 def infer_dtype(self, y_backprop_dtype, a_dtype, w_dtype): 1573 tuple(map(partial(validator.check_tensor_dtype_valid, 1574 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 1575 ('y_backprop', "input_x", "weight"), 1576 (y_backprop_dtype, a_dtype, w_dtype))) 1577 return y_backprop_dtype, w_dtype 1578 1579 1580class ReluGrad(Primitive): 1581 """Performs grad of Relu operation.""" 1582 1583 @prim_attr_register 1584 def __init__(self): 1585 """Initialize ReluGrad""" 1586 self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output']) 1587 1588 def __call__(self, y_backprop, x): 1589 raise NotImplementedError 1590 1591 1592class ReLU6Grad(PrimitiveWithInfer): 1593 """Performs grad of ReLU6 operation.""" 1594 1595 @prim_attr_register 1596 def __init__(self): 1597 self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output']) 1598 1599 def __call__(self, y_grad, x): 1600 raise NotImplementedError 1601 1602 def infer_shape(self, y_grad_shape, x_shape): 1603 return x_shape 1604 1605 def infer_dtype(self, y_grad_dtype, x_dtype): 1606 valid_dtypes = (mstype.float16, mstype.float32) 1607 validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) 1608 validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) 1609 return x_dtype 1610 1611 1612class ReluGradV2(Primitive): 1613 """Performs grad of ReLUV2 operation.""" 1614 1615 @prim_attr_register 1616 def __init__(self): 1617 self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output']) 1618 1619 def __call__(self, gradients, mask): 1620 raise NotImplementedError 1621 1622 1623class EluGrad(PrimitiveWithInfer): 1624 """Performs grad of Elu operation.""" 1625 1626 @prim_attr_register 1627 def __init__(self): 1628 """Initialize EluGrad""" 1629 1630 def infer_shape(self, y_grad_shape, x_shape): 1631 return x_shape 1632 1633 def infer_dtype(self, y_grad_dtype, x_dtype): 1634 args = {'y_grad': y_grad_dtype, 'x': x_dtype} 1635 validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) 1636 return x_dtype 1637 1638 1639class GatherDGrad(PrimitiveWithInfer): 1640 """Performs grad of GatherD operation.""" 1641 1642 @prim_attr_register 1643 def __init__(self, dim=0, shape=None): 1644 """Initialize GatherDGrad""" 1645 validator.check_is_int(dim, int) 1646 self.add_prim_attr("dim", dim) 1647 self.dim = dim 1648 self.out_shape = shape 1649 self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output']) 1650 1651 def infer_shape(self, index_shape, grad_shape): 1652 return self.out_shape 1653 1654 def infer_dtype(self, index_dtype, grad_dtype): 1655 return grad_dtype 1656 1657 1658class ResizeBilinearGrad(PrimitiveWithInfer): 1659 """Performs grad of ResizeBilinear operation.""" 1660 1661 @prim_attr_register 1662 def __init__(self, align_corners=False): 1663 """init""" 1664 1665 def infer_shape(self, dout_shape, orig_shape): 1666 return orig_shape 1667 1668 def infer_dtype(self, dout_dtype, orig_type): 1669 return orig_type 1670 1671 1672class ResizeNearestNeighborGrad(PrimitiveWithInfer): 1673 """ 1674 Compute gradient of `ResizeNearestNeighbor` operator. 1675 1676 Note: 1677 The shape of input parameter `size` must be (height, width). 1678 1679 Args: 1680 align_corners (bool): Whether the centers of the 4 corner pixels of the input 1681 and output tensors are aligned. Default: False. 1682 """ 1683 1684 @prim_attr_register 1685 def __init__(self, align_corners=False): 1686 """Initialize ResizeNearestNeighborGrad""" 1687 self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y']) 1688 1689 def __infer__(self, grads, size): 1690 shp = (grads['shape'][0],) + (grads['shape'][1],) + size['value'] 1691 return {'shape': shp, 1692 'dtype': grads['dtype'], 1693 'value': None} 1694 1695 1696class ROIAlignGrad(PrimitiveWithInfer): 1697 """ 1698 ROIAlignGrad operator. 1699 1700 Args: 1701 xdiff_shape (tuple): The diff shape. 1702 pooled_height (int): The output feature height. 1703 pooled_width (int): The output feature width. 1704 spatial_scale (float): The feature stride. 1705 sample_num (int): Number of sampling points. Default: 2. 1706 """ 1707 1708 @prim_attr_register 1709 def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2): 1710 """Initialize ROIAlignGrad""" 1711 validator.check_value_type("pooled_height", pooled_height, [int], self.name) 1712 validator.check_value_type("pooled_width", pooled_width, [int], self.name) 1713 validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) 1714 validator.check_value_type("sample_num", sample_num, [int], self.name) 1715 validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name) 1716 self.xdiff_shape = xdiff_shape 1717 self.pooled_height = pooled_height 1718 self.pooled_width = pooled_width 1719 self.spatial_scale = spatial_scale 1720 self.sample_num = sample_num 1721 1722 def infer_shape(self, ydiff_shape, rois_shape): 1723 return self.xdiff_shape 1724 1725 def infer_dtype(self, ydiff_type, rois_type): 1726 return ydiff_type 1727 1728 1729class SigmoidGrad(PrimitiveWithInfer): 1730 """Gets the gradient of Sigmoid operation.""" 1731 1732 @prim_attr_register 1733 def __init__(self): 1734 pass 1735 1736 def infer_shape(self, out, dout): 1737 return out 1738 1739 def infer_dtype(self, out, dout): 1740 args = {'out': out, 'dout': dout} 1741 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 1742 return out 1743 1744 1745class _ActivationGrad(PrimitiveWithInfer): 1746 """_ActivationGrad base class.""" 1747 1748 @prim_attr_register 1749 def __init__(self): 1750 self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output']) 1751 1752 def infer_shape(self, y_grad_shape, x_shape): 1753 return x_shape 1754 1755 def infer_dtype(self, y_grad_dtype, x_dtype): 1756 valid_dtypes = (mstype.float16, mstype.float32) 1757 validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) 1758 validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) 1759 return x_dtype 1760 1761 1762class HSwishGrad(_ActivationGrad): 1763 """Gets the gradient of HSwish operation.""" 1764 1765 1766class HSigmoidGrad(_ActivationGrad): 1767 """Gets the gradient of HSigmoid operation.""" 1768 1769 1770class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): 1771 """Computes the gradients of `SigmoidCrossEntropyWithLogits`.""" 1772 1773 @prim_attr_register 1774 def __init__(self): 1775 """Initialize SigmoidCrossEntropyWithLogitsGrad""" 1776 self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad']) 1777 1778 def infer_shape(self, x_shape, y_shape, dout_shape): 1779 validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) 1780 validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) 1781 return x_shape 1782 1783 def infer_dtype(self, x_dtype, y_dtype, dout_dtype): 1784 args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} 1785 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 1786 return dout_dtype 1787 1788 1789class SliceGrad(PrimitiveWithInfer): 1790 """Reverse of slice.""" 1791 1792 @prim_attr_register 1793 def __init__(self): 1794 """Initialize SliceGrad""" 1795 self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx']) 1796 1797 def __infer__(self, dy, x, begin, size): 1798 dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value'] 1799 dy_shape_len = len(dy_shape) 1800 for i in range(dy_shape_len): 1801 validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name) 1802 validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name) 1803 return {'shape': x_shape, 1804 'dtype': x['dtype'], 1805 'value': None} 1806 1807 1808class NLLLossGrad(PrimitiveWithInfer): 1809 """Computes the gradients of `NLLLoss`.""" 1810 1811 @prim_attr_register 1812 def __init__(self, reduction="mean"): 1813 """Initialize NLLLoss""" 1814 self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss']) 1815 self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) 1816 self.add_prim_attr('reduction', self.reduction) 1817 1818 def infer_shape(self, x_shape, y_grad_shape, t_shape, w_shape, tw_shape): 1819 validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name) 1820 validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name) 1821 validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name) 1822 validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name) 1823 if len(x_shape) == 1: 1824 validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name) 1825 else: 1826 validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name) 1827 return x_shape 1828 1829 def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype): 1830 valid_dtypes = (mstype.float16, mstype.float32) 1831 validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name) 1832 validator.check_tensor_dtype_valid("y_grad_dtype", y_grad_dtype, valid_dtypes, self.name) 1833 validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name) 1834 validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name) 1835 validator.check_tensor_dtype_valid("tw_dtype", tw_dtype, valid_dtypes, self.name) 1836 validator.check('tw_shape_dtype', tw_dtype, 'w_shape_dtype', w_dtype, Rel.EQ, self.name) 1837 return x_dtype 1838 1839 1840class SmoothL1LossGrad(PrimitiveWithInfer): 1841 """Computes gradient for prediction on SmoothL1Loss.""" 1842 1843 @prim_attr_register 1844 def __init__(self, beta=1.0): 1845 pass 1846 1847 def infer_shape(self, prediction, target, dloss): 1848 validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) 1849 validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name) 1850 return prediction 1851 1852 def infer_dtype(self, prediction, target, dloss): 1853 args = {"prediction": prediction, "target": target, 'dloss': dloss} 1854 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 1855 return dloss 1856 1857 1858class SoftMarginLossGrad(Primitive): 1859 """Computes gradient for prediction on SoftMarginLoss.""" 1860 1861 @prim_attr_register 1862 def __init__(self, reduction="mean"): 1863 self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient']) 1864 self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name) 1865 1866 1867class StridedSliceGrad(PrimitiveWithInfer): 1868 """ 1869 Performs grad of StridedSlice operation. 1870 1871 Args: 1872 begin_mask (int): Start indexing the slice. Default: 0. 1873 end_mask (int): End indexing the slice. Default: 0. 1874 ellipsis_mask (int): An int32 mask. Default: 0. 1875 new_axis_mask (int): An int32 mask. Default: 0. 1876 shrink_axis_mask (int): An int32 mask. Default: 0. 1877 1878 Returns: 1879 Tensor, has the same shape of input. 1880 """ 1881 1882 @prim_attr_register 1883 def __init__(self, 1884 begin_mask=0, 1885 end_mask=0, 1886 ellipsis_mask=0, 1887 new_axis_mask=0, 1888 shrink_axis_mask=0): 1889 """Initialize StridedSliceGrad""" 1890 validator.check_value_type('begin_mask', begin_mask, [int], self.name) 1891 validator.check_value_type('end_mask', end_mask, [int], self.name) 1892 validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) 1893 validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) 1894 validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) 1895 self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) 1896 1897 def __infer__(self, dy, shapex, begin, end, strides): 1898 validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name) 1899 1900 for idx, item in enumerate(shapex['value']): 1901 validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) 1902 for idx, item in enumerate(begin['value']): 1903 validator.check_value_type("begin[%d]" % idx, item, [int], self.name) 1904 for idx, item in enumerate(end['value']): 1905 validator.check_value_type("end[%d]" % idx, item, [int], self.name) 1906 for idx, item in enumerate(strides['value']): 1907 validator.check_value_type("strides[%d]" % idx, item, [int], self.name) 1908 1909 return {'shape': shapex['value'], 1910 'dtype': dy['dtype'], 1911 'value': None} 1912 1913 1914class SoftplusGrad(PrimitiveWithInfer): 1915 """Computes gradient for the Softplus activation.""" 1916 1917 @prim_attr_register 1918 def __init__(self): 1919 self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output']) 1920 1921 def infer_shape(self, dout_shape, x_shape): 1922 validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) 1923 return x_shape 1924 1925 def infer_dtype(self, dout_dtype, x_dtype): 1926 args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} 1927 validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) 1928 return x_dtype 1929 1930 1931class TanhGrad(PrimitiveWithInfer): 1932 """Computes gradient of hyperbolic tangent of input element-wise.""" 1933 1934 @prim_attr_register 1935 def __init__(self): 1936 pass 1937 1938 def infer_shape(self, out, dout): 1939 return out 1940 1941 def infer_dtype(self, out, dout): 1942 args = {"out": out, "dout": dout} 1943 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 1944 return out 1945 1946 1947class MirrorPadGrad(PrimitiveWithInfer): 1948 """Gradients of MirrorPad operation.""" 1949 1950 @prim_attr_register 1951 def __init__(self, mode="REFLECT"): 1952 """Initialize MirrorPad""" 1953 validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name) 1954 self.mode = mode 1955 1956 def __infer__(self, dout, paddings): 1957 validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name) 1958 validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) 1959 validator.check("paddings rank", len(paddings['shape']), "expected", 2, Rel.EQ, self.name) 1960 validator.check("paddings dim_1", paddings['shape'][1], "expected", 2, Rel.EQ, self.name) 1961 1962 if paddings['value'] is None: 1963 raise ValueError(f"For {self.name}, paddings must be const.") 1964 paddings_value = paddings['value'].asnumpy() 1965 y_shape = () 1966 dout_shape = dout['shape'] 1967 for i, val in enumerate(dout_shape): 1968 y_shape += (val - paddings_value[i][0] - paddings_value[i][1],) 1969 return {'shape': y_shape, 1970 'dtype': dout['dtype'], 1971 'value': None} 1972 1973 1974class EmbeddingLookupCommGrad(PrimitiveWithInfer): 1975 """ 1976 Performs the gradient for the communication part of EmbeddingLookup operator. 1977 1978 This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, 1979 this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host. 1980 """ 1981 1982 @prim_attr_register 1983 def __init__(self): 1984 self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output']) 1985 self.add_prim_attr('primitive_target', 'CPU') 1986 self.tuple_setitem = Primitive('tuple_setitem') 1987 1988 def __infer__(self, dy, split_num): 1989 """ 1990 This primitive is implemented by three steps: 1991 1) Splits the 'dy' along dimension 0 into 'split_num' parts. 1992 2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host. 1993 3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them 1994 along dimension 0. 1995 1996 The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8 1997 """ 1998 dy_shape = tuple(dy['shape']) 1999 split_num_value = split_num['value'] 2000 validator.check_value_type("split_num_value", split_num_value, [int], self.name) 2001 dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8) 2002 return {'shape': dy_shape_all, 2003 'dtype': dy['dtype'], 2004 'value': None} 2005 2006 2007class RefToEmbed(Primitive): 2008 r""" 2009 Make a key from Ref. 2010 2011 The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type, 2012 and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref. 2013 2014 Inputs: 2015 - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref. 2016 2017 Outputs: 2018 symbolic_key, made from the Ref. 2019 2020 Examples: 2021 >>> class Net(nn.Cell): 2022 >>> def __init__(self): 2023 >>> super(Net, self).__init__() 2024 >>> self.weight = mindspore.Parameter(1.0, name='weight') 2025 >>> 2026 >>> def construct(self): 2027 >>> key = RefToEmbed()(self.weight) 2028 >>> return key, self.weight 2029 """ 2030 __mindspore_signature__ = ( 2031 sig.make_sig('variable', sig.sig_rw.RW_REF), 2032 ) 2033 2034 @prim_attr_register 2035 def __init__(self): 2036 pass 2037 2038 2039class AtanGrad(PrimitiveWithInfer): 2040 """ 2041 Computes AtanGrad of input element-wise. 2042 2043 Returns: 2044 Tensor, has the same type as input. 2045 """ 2046 2047 @prim_attr_register 2048 def __init__(self): 2049 """Initialize AtanGrad""" 2050 2051 def infer_shape(self, x, dout): 2052 validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) 2053 return x 2054 2055 def infer_dtype(self, x, dout): 2056 args = {"x": x, "dout": dout} 2057 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 2058 return x 2059 2060 2061class BasicLSTMCellCStateGrad(PrimitiveWithInfer): 2062 """Computes the state gradients of BasicLSTMCell.""" 2063 2064 @prim_attr_register 2065 def __init__(self, forget_bias, activation): 2066 self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) 2067 self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) 2068 2069 def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): 2070 # dhy and dcy should be same shape 2071 validator.check_equal_int(len(c_shape), 2, "c rank", self.name) 2072 validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2073 validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2074 validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2075 validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2076 validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2077 validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2078 validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name) 2079 validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name) 2080 validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name) 2081 validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name) 2082 validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name) 2083 validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name) 2084 validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name) 2085 validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name) 2086 2087 dgate_shape = (c_shape[0], 4 * c_shape[1]) 2088 dct_1_shape = c_shape 2089 2090 return (dgate_shape, dct_1_shape) 2091 2092 def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype): 2093 validator.check_subclass("c", c_dtype, [mstype.tensor], self.name) 2094 validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name) 2095 validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name) 2096 validator.check_subclass("it", it_dtype, [mstype.tensor], self.name) 2097 validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name) 2098 validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name) 2099 validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name) 2100 validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name) 2101 validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name) 2102 validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name) 2103 validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name) 2104 validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name) 2105 validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name) 2106 validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name) 2107 validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name) 2108 validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name) 2109 return (c_dtype, c_dtype) 2110 2111 2112class BasicLSTMCellWeightGrad(PrimitiveWithInfer): 2113 """Computes the weight gradients of BasicLSTM.""" 2114 @prim_attr_register 2115 def __init__(self): 2116 pass 2117 2118 def infer_shape(self, x_shape, h_shape, dgate_shape): 2119 validator.check_equal_int(len(x_shape), 2, "x rank", self.name) 2120 validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name) 2121 validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name) 2122 validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name) 2123 validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name) 2124 validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name) 2125 input_size = x_shape[1] 2126 hidden_size = h_shape[1] 2127 dw_shape = (input_size + hidden_size, 4 * hidden_size) 2128 db_shape = (4 * hidden_size,) 2129 return (dw_shape, db_shape) 2130 2131 def infer_dtype(self, x_dtype, h_dtype, dgate_dtype): 2132 validator.check_subclass("x", x_dtype, mstype.tensor, self.name) 2133 validator.check_subclass("h", h_dtype, mstype.tensor, self.name) 2134 validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name) 2135 validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name) 2136 validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name) 2137 validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name) 2138 return (x_dtype, x_dtype) 2139 2140 2141class BasicLSTMCellInputGrad(PrimitiveWithInfer): 2142 """Computes the input gradients of BasicLSTM.""" 2143 2144 @prim_attr_register 2145 def __init__(self, keep_prob): 2146 self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) 2147 self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) 2148 2149 def infer_shape(self, dgate_shape, w_shape): 2150 validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) 2151 validator.check_equal_int(len(w_shape), 2, "w rank", self.name) 2152 validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) 2153 batch_size = dgate_shape[0] 2154 hidden_size = dgate_shape[1] // 4 2155 input_size = w_shape[0] - hidden_size 2156 dxt_shape = (batch_size, input_size) 2157 dht_shape = (batch_size, hidden_size) 2158 return (dxt_shape, dht_shape) 2159 2160 def infer_dtype(self, dgate_dtype, w_dtype): 2161 validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name) 2162 validator.check_subclass("w", w_dtype, mstype.tensor, self.name) 2163 validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name) 2164 validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name) 2165 return (dgate_dtype, dgate_dtype) 2166 2167 2168class InvGrad(PrimitiveWithInfer): 2169 """Computes gradients for inv operation.""" 2170 2171 @prim_attr_register 2172 def __init__(self): 2173 pass 2174 2175 def infer_shape(self, x, grad): 2176 validator.check("x_shape", x, "grad_shape", grad, Rel.EQ, self.name) 2177 return x 2178 2179 def infer_dtype(self, x, grad): 2180 validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) 2181 validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) 2182 return x 2183 2184 2185class LRNGrad(PrimitiveWithInfer): 2186 """Computes gradients for LRN operation.""" 2187 2188 @prim_attr_register 2189 def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): 2190 self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z']) 2191 validator.check_value_type("depth_radius", depth_radius, [int], self.name) 2192 validator.check_value_type("bias", bias, [float], self.name) 2193 validator.check_value_type("alpha", alpha, [float], self.name) 2194 validator.check_value_type("beta", beta, [float], self.name) 2195 2196 def infer_dtype(self, grads, x, y): 2197 args = {"grads": grads, "x": x, "y": y} 2198 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name) 2199 return x 2200 2201 def infer_shape(self, grads, x, y): 2202 return x 2203 2204 2205class MaskedSelectGrad(PrimitiveWithInfer): 2206 """Computes gradient for MaskedSelect.""" 2207 2208 @prim_attr_register 2209 def __init__(self): 2210 pass 2211 2212 def infer_shape(self, x, mask, grad): 2213 return x 2214 2215 def infer_dtype(self, x, mask, grad): 2216 return x 2217 2218 2219class SoftShrinkGrad(Primitive): 2220 r""" 2221 Gradients for SoftShrink operation. 2222 2223 Args: 2224 lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5. 2225 2226 Inputs: 2227 - **input_grad** (Tensor) - The input gradient. 2228 - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32. 2229 Any number of additional dimensions. 2230 2231 Outputs: 2232 output - Tensor, has the same shape and data type as input_x. 2233 2234 Raises: 2235 TypeError: If lambd is not a float. 2236 TypeError: If dtype of input_x is neither float16 nor float32. 2237 ValueError: If lambd is less than to 0. 2238 2239 Supported Platforms: 2240 ``Ascend`` 2241 """ 2242 2243 @prim_attr_register 2244 def __init__(self, lambd=0.5): 2245 self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output']) 2246 validator.check_value_type("lambd", lambd, [float], self.name) 2247 validator.check_number("lambd", lambd, 0, Rel.GE, self.name) 2248 2249 2250class CdistGrad(Primitive): 2251 """Computes gradient for Cdist.""" 2252 2253 @prim_attr_register 2254 def __init__(self, p=2.0): 2255 validator.check_value_type("p", p, [float], self.name) 2256 self.init_prim_io_names(inputs=['grad', 'input_x', 'input_y', 'cdist'], outputs=['output']) 2257 2258 2259class HShrinkGrad(Primitive): 2260 """ 2261 Computes gradients for HShrinkGrad operation. 2262 2263 Args: 2264 Lambd (float): the λ value for the Hardshrink formulation. Default: 0.5 2265 2266 Inputs: 2267 - **Gradients** (Tensor) - the gradients of loss to output of HShrink function. 2268 Currently gradients data type only support float16 and float32. 2269 - **Features** (Tensor) - Must be the input `input_x` of the forward operator HSHrink. 2270 Currently features data type only support float16 and float32. 2271 2272 Outputs: 2273 backprops - Tensor, with the same shape and data type as `features`. 2274 2275 Rasise: 2276 ValueError: If `lambd` is not a float. 2277 ValueError: If shape of `gradients` is not the same as `features`. 2278 TypeError: If dtype of `gradients` is not the same as `features`. 2279 TypeError: If dtype of `gradients` or `features` is neither float16 nor float32. 2280 2281 Supported Platforms: 2282 ``Ascend`` 2283 """ 2284 2285 @prim_attr_register 2286 def __init__(self, lambd=0.5): 2287 validator.check_value_type("lambd", lambd, [float], self.name) 2288 if lambd < 0.0: 2289 lambd = 0.0 2290 self.add_prim_attr('lambd', lambd) 2291