1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""inner_ops""" 17 18import numbers 19from ..._checkparam import Validator as validator 20from ..._checkparam import Rel 21from ...common import dtype as mstype 22from ...common.dtype import tensor, dtype_to_pytype 23from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer 24from .. import signature as sig 25 26 27class ScalarCast(PrimitiveWithInfer): 28 """ 29 Casts the input scalar to another type. 30 31 Inputs: 32 - **input_x** (scalar) - The input scalar. Only constant value is allowed. 33 - **input_y** (mindspore.dtype) - The type to be cast. Only constant value is allowed. 34 35 Outputs: 36 Scalar. The type is the same as the python type corresponding to `input_y`. 37 38 Raises: 39 TypeError: If neither `input_x` nor `input_y` is a constant value. 40 41 Supported Platforms: 42 ``Ascend`` ``GPU`` ``CPU`` 43 44 Examples: 45 >>> scalar_cast = ops.ScalarCast() 46 >>> output = scalar_cast(255.0, mindspore.int32) 47 >>> print(output) 48 255 49 """ 50 51 @prim_attr_register 52 def __init__(self): 53 pass 54 55 def __infer__(self, x, t): 56 validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name) 57 value, to = x['value'], t['value'] 58 if value is not None: 59 validator.check_value_type("value", value, [numbers.Number, bool], self.name) 60 if isinstance(to, type(tensor)): 61 to = to.element_type() 62 np_type = dtype_to_pytype(to) 63 value = np_type(value) 64 out = {'shape': x['shape'], 65 'dtype': t['value'], 66 'value': value} 67 return out 68 69 70class Randperm(PrimitiveWithInfer): 71 """ 72 Generates n random samples from 0 to n-1 without repeating. If `max_length` > n, 73 the last `max_length-n` elements will be filled with `pad`. 74 75 Args: 76 max_length (int): Number of items expected to get and the number must be greater than 0. Default: 1. 77 pad (int): The pad value to be filled. Default: -1. 78 dtype (mindspore.dtype): The type of output. Default: mindspore.int32. 79 80 Inputs: 81 - **n** (Tensor[int32]) - The input tensor with shape: (1,) and the number must be in [0, `max_length`]. 82 83 Outputs: 84 - **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`. 85 86 Raises: 87 TypeError: If neither `max_length` nor `pad` is an int. 88 TypeError: If `n` is not a Tensor. 89 TypeError: If `n` has non-Int elements. 90 TypeError: If `n` has negative elements. 91 92 Supported Platforms: 93 ``Ascend`` ``GPU`` 94 95 Examples: 96 >>> # The result of every execution is different because this operator will generate n random samples. 97 >>> randperm = ops.Randperm(max_length=30, pad=-1) 98 >>> n = Tensor([20], dtype=mindspore.int32) 99 >>> output = randperm(n) 100 >>> print(output) 101 [15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 1 12 3 7 102 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1] 103 """ 104 105 @prim_attr_register 106 def __init__(self, max_length=1, pad=-1, dtype=mstype.int32): 107 """Initialize Randperm""" 108 validator.check_value_type("pad", pad, [int], self.name) 109 validator.check_value_type("max_length", max_length, [int], self.name) 110 validator.check_int(max_length, 1, Rel.GE, "max_length", self.name) 111 self.dtype = dtype 112 self.max_length = max_length 113 self.init_prim_io_names(inputs=[], outputs=['output']) 114 115 def infer_shape(self, n_shape): 116 validator.check_int(len(n_shape), 1, Rel.EQ, "rank_of_n", self.name) 117 validator.check_int(n_shape[0], 1, Rel.EQ, "length_of_n", self.name) 118 return [self.max_length] 119 120 def infer_dtype(self, n_type): 121 validator.check_type_name("n_type", n_type, mstype.int32, self.name) 122 123 valid_values = (mstype.int8, mstype.int16, mstype.int32, mstype.int64, 124 mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64) 125 validator.check_type_name("dtype", self.dtype, valid_values, self.name) 126 return self.dtype 127 128 129class NoRepeatNGram(PrimitiveWithInfer): 130 """ 131 Updates log_probs with repeat n-grams. 132 133 During beam search, if consecutive `ngram_size` words exist in the generated word sequence, 134 the consecutive `ngram_size` words will be avoided during subsequent prediction. 135 For example, when `ngram_size` is 3, the generated word sequence is [1, 2, 3, 2, 3], 136 the next predicted word will not be 2 and the value of `log_probs` will be replaced with -FLOAT_MAX. 137 Because 3 consecutive words [2, 3, 2] do not appear twice in the word sequence. 138 139 Args: 140 ngram_size (int): Size of n-grams, must be greater than 0. Default: 1. 141 142 Inputs: 143 - **state_seq** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, m). 144 - **log_probs** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, vocab_size). 145 The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated. 146 147 Outputs: 148 - **log_probs** (Tensor) - The output Tensor with same shape and type as original `log_probs`. 149 150 Raises: 151 TypeError: If `ngram_size` is not an int. 152 TypeError: If neither `state_seq` nor `log_probs` is a Tensor. 153 154 Supported Platforms: 155 ``Ascend`` 156 157 Examples: 158 >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3) 159 >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2], 160 ... [9, 3, 9, 5, 4, 1, 5]], 161 ... [[4, 8, 6, 4, 5, 6, 4], 162 ... [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32) 163 >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7], 164 ... [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]], 165 ... [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6], 166 ... [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32) 167 >>> output = no_repeat_ngram(state_seq, log_probs) 168 >>> print(output) 169 [[[ 6.9999999e-01 -3.4028235e+38 6.0000002e-01 8.9999998e-01 170 2.0000000e-01 -3.4028235e+38 4.0000001e-01 6.0000002e-01 171 2.0000000e-01 6.9999999e-01] 172 [ 4.0000001e-01 5.0000000e-01 6.0000002e-01 6.9999999e-01 173 8.0000001e-01 1.0000000e-01 8.9999998e-01 8.0000001e-01 174 6.9999999e-01 1.0000000e-01]] 175 [[ 8.9999998e-01 6.9999999e-01 6.0000002e-01 3.0000001e-01 176 5.0000000e-01 -3.4028235e+38 5.0000000e-01 4.0000001e-01 177 8.0000001e-01 6.0000002e-01] 178 [ 5.0000000e-01 8.0000001e-01 8.0000001e-01 6.9999999e-01 179 6.9999999e-01 8.0000001e-01 2.0000000e-01 6.9999999e-01 180 -3.4028235e+38 6.9999999e-01]]] 181 """ 182 183 @prim_attr_register 184 def __init__(self, ngram_size=1): 185 """NoRepeatNGram Randperm""" 186 validator.check_value_type("ngram_size", ngram_size, [int], self.name) 187 validator.check_int(ngram_size, 1, Rel.GE, "ngram_size", self.name) 188 self.ngram_size = ngram_size 189 self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs']) 190 191 def infer_shape(self, seq_shape, log_shape): 192 validator.check_int(len(seq_shape), 3, Rel.EQ, "rank of state_seq", self.name) 193 validator.check_int(len(log_shape), 3, Rel.EQ, "rank of log_probs", self.name) 194 validator.check("state_seq shape[0]", seq_shape[0], "log_probs shape[0]", log_shape[0], Rel.EQ, self.name) 195 validator.check("state_seq shape[1]", seq_shape[1], "log_probs shape[1]", log_shape[1], Rel.EQ, self.name) 196 validator.check("ngram_size", self.ngram_size, "state_seq shape[2] + 1", seq_shape[2] + 1, Rel.LE, self.name) 197 return log_shape 198 199 def infer_dtype(self, seq_type, log_type): 200 validator.check_type_name("seq_type", seq_type, mstype.int32, self.name) 201 valid_values = (mstype.float16, mstype.float32, mstype.float64) 202 validator.check_type_name("log_type", log_type, valid_values, self.name) 203 return log_type 204 205 206class LambApplyOptimizerAssign(PrimitiveWithInfer): 207 r""" 208 Updates gradients by LAMB optimizer algorithm. Get the compute ratio. 209 210 The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes 211 <https://arxiv.org/abs/1904.00962>`_. 212 213 The updating formulas are as follows, 214 215 .. math:: 216 \begin{array}{ll} \\ 217 m = \beta_1 * m + (1 - \beta_1) * g \\ 218 v = \beta_2 * v + (1 - \beta_2) * g * g \\ 219 m = \frac{m}{1 - \beta_1^t} \\ 220 v = \frac{v}{1 - \beta_2^t} \\ 221 r = \frac{m}{\sqrt{v} + \epsilon} \\ 222 w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) 223 \end{array} 224 225 :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents 226 `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, 227 :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and 228 `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents 229 `epsilon`. 230 231 Inputs: 232 - **gradient** (Tensor) - Gradient of parameters, float32/float16. 233 - **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`. 234 - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`. 235 - **var** (Tensor) - Weights to be updated, has the same type as `gradient`. 236 - **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16. 237 - **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`. 238 - **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`. 239 - **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`. 240 - **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`. 241 - **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`. 242 - **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`. 243 - **decay_flag** (Tensor) -Specify whether param update with weight decay, has the same type as `beta1`. 244 - **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`. 245 246 Outputs: 247 Tensor, the compute ratio r. 248 - **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`. 249 - **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace, 250 has the same type as `gradient`. 251 - **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace, 252 has the same type as `gradient`. 253 254 Supported Platforms: 255 ``Ascend`` 256 """ 257 @prim_attr_register 258 def __init__(self): 259 """Initialize LambApplyOptimizerAssign""" 260 self.add_prim_attr('side_effect_mem', True) 261 262 def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape, 263 beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape): 264 validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) 265 validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) 266 validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) 267 return m_shape, v_shape, m_shape 268 269 def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype, 270 beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype): 271 args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} 272 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 273 274 args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype, 275 "eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype, 276 "weight_decay": weight_decay_dtype} 277 validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) 278 return m_dtype, v_dtype, v_dtype 279 280 281class LambApplyWeightAssign(PrimitiveWithInfer): 282 r""" 283 Updates gradients by LAMB optimizer algorithm. The weight update part. 284 285 The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes 286 <https://arxiv.org/abs/1904.00962>`_. 287 288 The updating formulas are as follows, 289 290 .. math:: 291 \begin{array}{ll} \\ 292 m = \beta_1 * m + (1 - \beta_1) * g \\ 293 v = \beta_2 * v + (1 - \beta_2) * g * g \\ 294 m = \frac{m}{1 - \beta_1^t} \\ 295 v = \frac{v}{1 - \beta_2^t} \\ 296 r = \frac{m}{\sqrt{v} + \epsilon} \\ 297 w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) 298 \end{array} 299 300 :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents 301 `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, 302 :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and 303 `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents 304 `epsilon`. 305 306 Inputs: 307 - **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16. 308 - **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`. 309 - **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16. 310 - **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16. 311 - **var** (Tensor) - Weights to be updated, the same shape and type as `update`. 312 313 Outputs: 314 - **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs. 315 316 Supported Platforms: 317 ``Ascend`` 318 """ 319 @prim_attr_register 320 def __init__(self): 321 """Initialize LambApplyWeightAssign""" 322 self.add_prim_attr('side_effect_mem', True) 323 324 def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape): 325 validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name) 326 return var_shape 327 328 def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype): 329 args = {"var": var_dtype, "update": update_dtype} 330 validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) 331 332 args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype} 333 validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) 334 return var_dtype 335 336 337class MakeRefKey(Primitive): 338 """ 339 Makes a RefKey instance by string. RefKey stores the name of Parameter, can be passed through the functions, 340 and used for Assign target. 341 342 Args: 343 tag (str): Parameter name to make the RefKey. 344 345 Inputs: 346 No inputs. 347 348 Outputs: 349 RefKeyType, made from the Parameter name. 350 351 Raises: 352 TypeError: If `tag` is not a str. 353 354 Supported Platforms: 355 ``Ascend`` ``GPU`` ``CPU`` 356 357 Examples: 358 >>> import numpy as np 359 >>> from mindspore import Parameter, Tensor 360 >>> from mindspore import dtype as mstype 361 >>> import mindspore.ops as ops 362 >>> class Net(nn.Cell): 363 ... def __init__(self): 364 ... super(Net, self).__init__() 365 ... self.y = Parameter(Tensor(np.ones([2, 3]), mstype.int32), name="y") 366 ... self.make_ref_key = ops.MakeRefKey("y") 367 ... 368 ... def construct(self, x): 369 ... key = self.make_ref_key() 370 ... ref = ops.make_ref(key, x, self.y) 371 ... return ref * x 372 ... 373 >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.int32) 374 >>> net = Net() 375 >>> output = net(x) 376 >>> print(output) 377 [[ 1 4 9] 378 [16 25 36]] 379 """ 380 381 @prim_attr_register 382 def __init__(self, tag): 383 validator.check_value_type('tag', tag, (str,), self.name) 384 385 def __call__(self): 386 pass 387 388 389class FusedWeightScaleApplyMomentum(PrimitiveWithInfer): 390 """ 391 Optimizer that implements the Momentum algorithm with weight decay and loss scale. 392 393 Refer to the paper `On the importance of initialization and momentum in deep 394 learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details. 395 396 Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage. 397 398 Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules 399 to make the data types consistent. 400 If they have different data types, lower priority data type will be converted to 401 relatively highest priority data type. 402 Data type conversion of Parameter is not supported. RuntimeError exception will be thrown. 403 404 Inputs: 405 - **weight_decay** (Tensor) - The weight decay value, must be a scalar tensor with float data type. 406 Default: 0.0. 407 - **loss_scale** (Tensor) - The loss scale value, must be a scalar tensor with float data type. 408 Default: 1.0. 409 - **variable** (Parameter) - Weights to be updated. data type must be float. 410 - **accumulation** (Parameter) - Accumulated gradient value by moment weight. 411 Has the same data type with `variable`. 412 - **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float number or 413 a scalar tensor with float data type. 414 - **gradient** (Tensor) - Gradient, has the same data type as `variable`. 415 - **momentum** (Union[Number, Tensor]) - Momentum, must be a float number or 416 a scalar tensor with float data type. 417 418 Outputs: 419 Tensor, parameters to be updated. 420 421 Supported Platforms: 422 ``GPU`` 423 Examples: 424 Please refer to the usage in :class:`mindspore.nn.Momentum`, and add weight_decay and loss_scale as inputs. 425 """ 426 __mindspore_signature__ = ( 427 sig.make_sig('weight_decay', dtype=sig.sig_dtype.T3), 428 sig.make_sig('loss_scale', dtype=sig.sig_dtype.T3), 429 sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), 430 sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), 431 sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1), 432 sig.make_sig('gradient', dtype=sig.sig_dtype.T), 433 sig.make_sig('momentum', dtype=sig.sig_dtype.T2) 434 ) 435 436 @prim_attr_register 437 def __init__(self): 438 self.init_prim_io_names(inputs=['weight_decay', 'loss_scale', 'variable', 'accumulation', 'learning_rate', 439 'gradient', 'momentum'], outputs=['output']) 440 441 def infer_shape(self, d_shape, s_shape, v_shape, a_shape, l_shape, g_shape, m_shape): 442 return v_shape 443 444 def infer_dtype(self, d_dtype, s_dtype, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): 445 valid_dtypes = [mstype.float16, mstype.float32] 446 if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: 447 validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name) 448 validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name) 449 validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) 450 validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) 451 validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) 452 validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name) 453 validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name) 454 return v_dtype 455 456 457class FusedCastAdamWeightDecay(PrimitiveWithInfer): 458 r""" 459 Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. This operator 460 incorporates type conversion when parameters are initialized with dtype of float16. 461 462 The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_. 463 The AdamWeightDecay variant was proposed in `Decoupled Weight Decay Regularization 464 <https://arxiv.org/abs/1711.05101>`_. 465 466 The updating formulas are as follows, 467 468 .. math:: 469 \begin{array}{ll} \\ 470 m = \beta_1 * m + (1 - \beta_1) * g \\ 471 v = \beta_2 * v + (1 - \beta_2) * g * g \\ 472 update = \frac{m}{\sqrt{v} + eps} \\ 473 update = 474 \begin{cases} 475 update + weight\_decay * w 476 & \text{ if } weight\_decay > 0 \\ 477 update 478 & \text{ otherwise } 479 \end{cases} \\ 480 w = w - lr * update 481 \end{array} 482 483 :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents 484 `gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, 485 :math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`, 486 :math:`\epsilon` represents `epsilon`. 487 488 Args: 489 use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. 490 If true, updates of the var, m, and v tensors will be protected by a lock. 491 If false, the result is unpredictable. Default: False. 492 493 Inputs: 494 - **var** (Tensor) - Weights to be updated with the type float16 or float32. 495 - **m** (Tensor) - The 1st moment vector in the updating formula with the type float32. 496 - **v** (Tensor) - the 2nd moment vector in the updating formula with the type float32. 497 - **lr** (float) - :math:`lr` in the updating formula. 498 - **beta1** (float) - The exponential decay rate for the 1st moment estimations. 499 - **beta2** (float) - The exponential decay rate for the 2nd moment estimations. 500 - **epsilon** (float) - Term added to the denominator to improve numerical stability. 501 - **decay** (float) - The weight decay value, must be a scalar tensor with float data type. 502 - **gradient** (Tensor) - Gradient, has the type float16. 503 504 Outputs: 505 Tuple of 3 Tensor, the updated parameters. 506 507 - **var** (Tensor) - The same shape and data type as `var`. 508 - **m** (Tensor) - The same shape and data type as `m`. 509 - **v** (Tensor) - The same shape and data type as `v`. 510 511 Supported Platforms: 512 ``CPU`` 513 514 Examples: 515 >>> import numpy as np 516 >>> import mindspore.context as context 517 >>> import mindspore.nn as nn 518 >>> import mindspore.ops as ops 519 >>> from mindspore import Tensor, Parameter 520 >>> from mindspore import dtype as mstype 521 >>> class Net(nn.Cell): 522 ... def __init__(self): 523 ... super(Net, self).__init__() 524 ... self.opt = ops.FusedCastAdamWeightDecay() 525 ... self.var = Parameter(Tensor(np.ones([2, 2]), mstype.float16), name="var") 526 ... self.m = Parameter(Tensor(np.ones([2, 2]), mstype.float32), name="m") 527 ... self.v = Parameter(Tensor(np.ones([2, 2]), mstype.float32), name="v") 528 ... def construct(self, lr, beta1, beta2, epsilon, decay, grad): 529 ... out = self.opt(self.var, self.m, self.v, lr, beta1, beta2, epsilon, decay, grad) 530 ... return out 531 >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 532 >>> net = Net() 533 >>> gradient = Tensor(np.ones([2, 2]), mstype.float16) 534 >>> output = net(0.001, 0.9, 0.999, 1e-8, 0.0, gradient) 535 >>> print(net.var.asnumpy()) 536 """ 537 538 @prim_attr_register 539 def __init__(self, use_locking=False): 540 self.add_prim_attr('side_effect_mem', True) 541 validator.check_value_type("use_locking", use_locking, [bool], self.name) 542 543 def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, 544 epsilon_shape, decay_shape, grad_shape): 545 validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) 546 validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) 547 validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) 548 return var_shape, m_shape, v_shape 549 550 def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, 551 epsilon_dtype, decay_dtype, grad_dtype): 552 args = {"m": m_dtype, "v": v_dtype} 553 validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) 554 validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name) 555 validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name) 556 557 args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, 558 "decay": decay_dtype} 559 validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) 560 return var_dtype, m_dtype, v_dtype 561