• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""inner_ops"""
17
18import numbers
19from ..._checkparam import Validator as validator
20from ..._checkparam import Rel
21from ...common import dtype as mstype
22from ...common.dtype import tensor, dtype_to_pytype
23from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
24from .. import signature as sig
25
26
27class ScalarCast(PrimitiveWithInfer):
28    """
29    Casts the input scalar to another type.
30
31    Inputs:
32        - **input_x** (scalar) - The input scalar. Only constant value is allowed.
33        - **input_y** (mindspore.dtype) - The type to be cast. Only constant value is allowed.
34
35    Outputs:
36        Scalar. The type is the same as the python type corresponding to `input_y`.
37
38    Raises:
39        TypeError: If neither `input_x` nor `input_y` is a constant value.
40
41    Supported Platforms:
42        ``Ascend`` ``GPU`` ``CPU``
43
44    Examples:
45        >>> scalar_cast = ops.ScalarCast()
46        >>> output = scalar_cast(255.0, mindspore.int32)
47        >>> print(output)
48        255
49    """
50
51    @prim_attr_register
52    def __init__(self):
53        pass
54
55    def __infer__(self, x, t):
56        validator.check_equal_int(len(x['shape']), 0, 'x shape', self.name)
57        value, to = x['value'], t['value']
58        if value is not None:
59            validator.check_value_type("value", value, [numbers.Number, bool], self.name)
60            if isinstance(to, type(tensor)):
61                to = to.element_type()
62            np_type = dtype_to_pytype(to)
63            value = np_type(value)
64        out = {'shape': x['shape'],
65               'dtype': t['value'],
66               'value': value}
67        return out
68
69
70class Randperm(PrimitiveWithInfer):
71    """
72    Generates n random samples from 0 to n-1 without repeating. If `max_length` > n,
73    the last `max_length-n` elements will be filled with `pad`.
74
75    Args:
76        max_length (int): Number of items expected to get and the number must be greater than 0. Default: 1.
77        pad (int): The pad value to be filled. Default: -1.
78        dtype (mindspore.dtype): The type of output. Default: mindspore.int32.
79
80    Inputs:
81        - **n** (Tensor[int32]) - The input tensor with shape: (1,) and the number must be in [0, `max_length`].
82
83    Outputs:
84        - **output** (Tensor) - The output Tensor with shape: (`max_length`,) and type: `dtype`.
85
86    Raises:
87        TypeError: If neither `max_length` nor `pad` is an int.
88        TypeError: If `n` is not a Tensor.
89        TypeError: If `n` has non-Int elements.
90        TypeError: If `n` has negative elements.
91
92    Supported Platforms:
93        ``Ascend`` ``GPU``
94
95    Examples:
96        >>> # The result of every execution is different because this operator will generate n random samples.
97        >>> randperm = ops.Randperm(max_length=30, pad=-1)
98        >>> n = Tensor([20], dtype=mindspore.int32)
99        >>> output = randperm(n)
100        >>> print(output)
101        [15 6 11 19 14 16 9 5 13 18 4 10 8 0 17 2 1 12 3 7
102         -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
103    """
104
105    @prim_attr_register
106    def __init__(self, max_length=1, pad=-1, dtype=mstype.int32):
107        """Initialize Randperm"""
108        validator.check_value_type("pad", pad, [int], self.name)
109        validator.check_value_type("max_length", max_length, [int], self.name)
110        validator.check_int(max_length, 1, Rel.GE, "max_length", self.name)
111        self.dtype = dtype
112        self.max_length = max_length
113        self.init_prim_io_names(inputs=[], outputs=['output'])
114
115    def infer_shape(self, n_shape):
116        validator.check_int(len(n_shape), 1, Rel.EQ, "rank_of_n", self.name)
117        validator.check_int(n_shape[0], 1, Rel.EQ, "length_of_n", self.name)
118        return [self.max_length]
119
120    def infer_dtype(self, n_type):
121        validator.check_type_name("n_type", n_type, mstype.int32, self.name)
122
123        valid_values = (mstype.int8, mstype.int16, mstype.int32, mstype.int64,
124                        mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64)
125        validator.check_type_name("dtype", self.dtype, valid_values, self.name)
126        return self.dtype
127
128
129class NoRepeatNGram(PrimitiveWithInfer):
130    """
131    Updates log_probs with repeat n-grams.
132
133    During beam search, if consecutive `ngram_size` words exist in the generated word sequence,
134    the consecutive `ngram_size` words will be avoided during subsequent prediction.
135    For example, when `ngram_size` is 3, the generated word sequence is [1, 2, 3, 2, 3],
136    the next predicted word will not be 2 and the value of `log_probs` will be replaced with -FLOAT_MAX.
137    Because 3 consecutive words [2, 3, 2] do not appear twice in the word sequence.
138
139    Args:
140        ngram_size (int): Size of n-grams, must be greater than 0. Default: 1.
141
142    Inputs:
143        - **state_seq** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, m).
144        - **log_probs** (Tensor) - A 3-D tensor with shape: (batch_size, beam_width, vocab_size).
145          The value of log_probs will be replaced with -FLOAT_MAX when n-grams repeated.
146
147    Outputs:
148        - **log_probs** (Tensor) - The output Tensor with same shape and type as original `log_probs`.
149
150    Raises:
151        TypeError: If `ngram_size` is not an int.
152        TypeError: If neither `state_seq` nor `log_probs` is a Tensor.
153
154    Supported Platforms:
155        ``Ascend``
156
157    Examples:
158        >>> no_repeat_ngram = ops.NoRepeatNGram(ngram_size=3)
159        >>> state_seq = Tensor([[[1, 2, 1, 2, 5, 1, 2],
160        ...                      [9, 3, 9, 5, 4, 1, 5]],
161        ...                     [[4, 8, 6, 4, 5, 6, 4],
162        ...                      [4, 8, 8, 4, 3, 4, 8]]], dtype=mindspore.int32)
163        >>> log_probs = Tensor([[[0.7, 0.8, 0.6, 0.9, 0.2, 0.8, 0.4, 0.6, 0.2, 0.7],
164        ...                      [0.4, 0.5, 0.6, 0.7, 0.8, 0.1, 0.9, 0.8, 0.7, 0.1]],
165        ...                     [[0.9, 0.7, 0.6, 0.3, 0.5, 0.3, 0.5, 0.4, 0.8, 0.6],
166        ...                      [0.5, 0.8, 0.8, 0.7, 0.7, 0.8, 0.2, 0.7, 0.9, 0.7]]], dtype=mindspore.float32)
167        >>> output = no_repeat_ngram(state_seq, log_probs)
168        >>> print(output)
169        [[[ 6.9999999e-01 -3.4028235e+38  6.0000002e-01  8.9999998e-01
170            2.0000000e-01 -3.4028235e+38  4.0000001e-01  6.0000002e-01
171            2.0000000e-01  6.9999999e-01]
172          [ 4.0000001e-01  5.0000000e-01  6.0000002e-01  6.9999999e-01
173            8.0000001e-01  1.0000000e-01  8.9999998e-01  8.0000001e-01
174            6.9999999e-01  1.0000000e-01]]
175         [[ 8.9999998e-01  6.9999999e-01  6.0000002e-01  3.0000001e-01
176            5.0000000e-01 -3.4028235e+38  5.0000000e-01  4.0000001e-01
177            8.0000001e-01  6.0000002e-01]
178          [ 5.0000000e-01  8.0000001e-01  8.0000001e-01  6.9999999e-01
179            6.9999999e-01  8.0000001e-01  2.0000000e-01  6.9999999e-01
180           -3.4028235e+38  6.9999999e-01]]]
181    """
182
183    @prim_attr_register
184    def __init__(self, ngram_size=1):
185        """NoRepeatNGram Randperm"""
186        validator.check_value_type("ngram_size", ngram_size, [int], self.name)
187        validator.check_int(ngram_size, 1, Rel.GE, "ngram_size", self.name)
188        self.ngram_size = ngram_size
189        self.init_prim_io_names(inputs=['state_seq', 'log_probs'], outputs=['log_probs'])
190
191    def infer_shape(self, seq_shape, log_shape):
192        validator.check_int(len(seq_shape), 3, Rel.EQ, "rank of state_seq", self.name)
193        validator.check_int(len(log_shape), 3, Rel.EQ, "rank of log_probs", self.name)
194        validator.check("state_seq shape[0]", seq_shape[0], "log_probs shape[0]", log_shape[0], Rel.EQ, self.name)
195        validator.check("state_seq shape[1]", seq_shape[1], "log_probs shape[1]", log_shape[1], Rel.EQ, self.name)
196        validator.check("ngram_size", self.ngram_size, "state_seq shape[2] + 1", seq_shape[2] + 1, Rel.LE, self.name)
197        return log_shape
198
199    def infer_dtype(self, seq_type, log_type):
200        validator.check_type_name("seq_type", seq_type, mstype.int32, self.name)
201        valid_values = (mstype.float16, mstype.float32, mstype.float64)
202        validator.check_type_name("log_type", log_type, valid_values, self.name)
203        return log_type
204
205
206class LambApplyOptimizerAssign(PrimitiveWithInfer):
207    r"""
208    Updates gradients by LAMB optimizer algorithm. Get the compute ratio.
209
210    The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
211    <https://arxiv.org/abs/1904.00962>`_.
212
213    The updating formulas are as follows,
214
215    .. math::
216        \begin{array}{ll} \\
217            m = \beta_1 * m + (1 - \beta_1) * g \\
218            v = \beta_2 * v + (1 - \beta_2) * g * g \\
219            m = \frac{m}{1 - \beta_1^t} \\
220            v = \frac{v}{1 - \beta_2^t} \\
221            r = \frac{m}{\sqrt{v} + \epsilon} \\
222            w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w))
223        \end{array}
224
225    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
226    `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
227    :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
228    `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents
229    `epsilon`.
230
231    Inputs:
232        - **gradient** (Tensor) - Gradient of parameters, float32/float16.
233        - **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`.
234        - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`.
235        - **var** (Tensor) - Weights to be updated, has the same type as `gradient`.
236        - **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16.
237        - **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`.
238        - **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`.
239        - **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`.
240        - **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`.
241        - **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`.
242        - **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`.
243        - **decay_flag** (Tensor) -Specify whether param update with weight decay, has the same type as `beta1`.
244        - **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`.
245
246    Outputs:
247        Tensor, the compute ratio r.
248        - **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`.
249        - **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace,
250                           has the same type as `gradient`.
251        - **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace,
252                           has the same type as `gradient`.
253
254    Supported Platforms:
255        ``Ascend``
256    """
257    @prim_attr_register
258    def __init__(self):
259        """Initialize LambApplyOptimizerAssign"""
260        self.add_prim_attr('side_effect_mem', True)
261
262    def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape,
263                    beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape):
264        validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
265        validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
266        validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
267        return m_shape, v_shape, m_shape
268
269    def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype,
270                    beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype):
271        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
272        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
273
274        args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype,
275                "eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype,
276                "weight_decay": weight_decay_dtype}
277        validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
278        return m_dtype, v_dtype, v_dtype
279
280
281class LambApplyWeightAssign(PrimitiveWithInfer):
282    r"""
283    Updates gradients by LAMB optimizer algorithm. The weight update part.
284
285    The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
286    <https://arxiv.org/abs/1904.00962>`_.
287
288    The updating formulas are as follows,
289
290    .. math::
291        \begin{array}{ll} \\
292            m = \beta_1 * m + (1 - \beta_1) * g \\
293            v = \beta_2 * v + (1 - \beta_2) * g * g \\
294            m = \frac{m}{1 - \beta_1^t} \\
295            v = \frac{v}{1 - \beta_2^t} \\
296            r = \frac{m}{\sqrt{v} + \epsilon} \\
297            w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w))
298        \end{array}
299
300    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
301    `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
302    :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
303    `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents
304    `epsilon`.
305
306    Inputs:
307        - **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16.
308        - **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`.
309        - **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16.
310        - **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16.
311        - **var** (Tensor) - Weights to be updated, the same shape and type as `update`.
312
313    Outputs:
314        - **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs.
315
316    Supported Platforms:
317        ``Ascend``
318    """
319    @prim_attr_register
320    def __init__(self):
321        """Initialize LambApplyWeightAssign"""
322        self.add_prim_attr('side_effect_mem', True)
323
324    def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape):
325        validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name)
326        return var_shape
327
328    def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype):
329        args = {"var": var_dtype, "update": update_dtype}
330        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
331
332        args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype}
333        validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
334        return var_dtype
335
336
337class MakeRefKey(Primitive):
338    """
339    Makes a RefKey instance by string. RefKey stores the name of Parameter, can be passed through the functions,
340    and used for Assign target.
341
342    Args:
343        tag (str): Parameter name to make the RefKey.
344
345    Inputs:
346        No inputs.
347
348    Outputs:
349        RefKeyType, made from the Parameter name.
350
351    Raises:
352        TypeError: If `tag` is not a str.
353
354    Supported Platforms:
355        ``Ascend`` ``GPU`` ``CPU``
356
357    Examples:
358        >>> import numpy as np
359        >>> from mindspore import Parameter, Tensor
360        >>> from mindspore import dtype as mstype
361        >>> import mindspore.ops as ops
362        >>> class Net(nn.Cell):
363        ...     def __init__(self):
364        ...         super(Net, self).__init__()
365        ...         self.y = Parameter(Tensor(np.ones([2, 3]), mstype.int32), name="y")
366        ...         self.make_ref_key = ops.MakeRefKey("y")
367        ...
368        ...     def construct(self, x):
369        ...         key = self.make_ref_key()
370        ...         ref = ops.make_ref(key, x, self.y)
371        ...         return ref * x
372        ...
373        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.int32)
374        >>> net = Net()
375        >>> output = net(x)
376        >>> print(output)
377        [[ 1  4  9]
378         [16 25 36]]
379    """
380
381    @prim_attr_register
382    def __init__(self, tag):
383        validator.check_value_type('tag', tag, (str,), self.name)
384
385    def __call__(self):
386        pass
387
388
389class FusedWeightScaleApplyMomentum(PrimitiveWithInfer):
390    """
391    Optimizer that implements the Momentum algorithm with weight decay and loss scale.
392
393    Refer to the paper `On the importance of initialization and momentum in deep
394    learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_  for more details.
395
396    Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage.
397
398    Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules
399    to make the data types consistent.
400    If they have different data types, lower priority data type will be converted to
401    relatively highest priority data type.
402    Data type conversion of Parameter is not supported. RuntimeError exception will be thrown.
403
404    Inputs:
405        - **weight_decay** (Tensor) - The weight decay value, must be a scalar tensor with float data type.
406          Default: 0.0.
407        - **loss_scale** (Tensor) - The loss scale value, must be a scalar tensor with float data type.
408          Default: 1.0.
409        - **variable** (Parameter) - Weights to be updated. data type must be float.
410        - **accumulation** (Parameter) - Accumulated gradient value by moment weight.
411          Has the same data type with `variable`.
412        - **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float number or
413          a scalar tensor with float data type.
414        - **gradient** (Tensor) - Gradient, has the same data type as `variable`.
415        - **momentum** (Union[Number, Tensor]) - Momentum, must be a float number or
416          a scalar tensor with float data type.
417
418    Outputs:
419        Tensor, parameters to be updated.
420
421    Supported Platforms:
422        ``GPU``
423    Examples:
424        Please refer to the usage in :class:`mindspore.nn.Momentum`, and add weight_decay and loss_scale as inputs.
425    """
426    __mindspore_signature__ = (
427        sig.make_sig('weight_decay', dtype=sig.sig_dtype.T3),
428        sig.make_sig('loss_scale', dtype=sig.sig_dtype.T3),
429        sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
430        sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
431        sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1),
432        sig.make_sig('gradient', dtype=sig.sig_dtype.T),
433        sig.make_sig('momentum', dtype=sig.sig_dtype.T2)
434    )
435
436    @prim_attr_register
437    def __init__(self):
438        self.init_prim_io_names(inputs=['weight_decay', 'loss_scale', 'variable', 'accumulation', 'learning_rate',
439                                        'gradient', 'momentum'], outputs=['output'])
440
441    def infer_shape(self, d_shape, s_shape, v_shape, a_shape, l_shape, g_shape, m_shape):
442        return v_shape
443
444    def infer_dtype(self, d_dtype, s_dtype, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
445        valid_dtypes = [mstype.float16, mstype.float32]
446        if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
447            validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name)
448            validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name)
449        validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
450        validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
451        validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
452        validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name)
453        validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name)
454        return v_dtype
455
456
457class FusedCastAdamWeightDecay(PrimitiveWithInfer):
458    r"""
459    Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay. This operator
460    incorporates type conversion when parameters are initialized with dtype of float16.
461
462    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
463    The AdamWeightDecay variant was proposed in `Decoupled Weight Decay Regularization
464    <https://arxiv.org/abs/1711.05101>`_.
465
466    The updating formulas are as follows,
467
468    .. math::
469        \begin{array}{ll} \\
470            m = \beta_1 * m + (1 - \beta_1) * g \\
471            v = \beta_2 * v + (1 - \beta_2) * g * g \\
472            update = \frac{m}{\sqrt{v} + eps} \\
473            update =
474            \begin{cases}
475                update + weight\_decay * w
476                    & \text{ if } weight\_decay > 0 \\
477                update
478                    & \text{ otherwise }
479            \end{cases} \\
480            w  = w - lr * update
481        \end{array}
482
483    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
484    `gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
485    :math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
486    :math:`\epsilon` represents `epsilon`.
487
488    Args:
489        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
490            If true, updates of the var, m, and v tensors will be protected by a lock.
491            If false, the result is unpredictable. Default: False.
492
493    Inputs:
494        - **var** (Tensor) - Weights to be updated with the type float16 or float32.
495        - **m** (Tensor) - The 1st moment vector in the updating formula with the type float32.
496        - **v** (Tensor) - the 2nd moment vector in the updating formula with the type float32.
497        - **lr** (float) - :math:`lr` in the updating formula.
498        - **beta1** (float) - The exponential decay rate for the 1st moment estimations.
499        - **beta2** (float) - The exponential decay rate for the 2nd moment estimations.
500        - **epsilon** (float) - Term added to the denominator to improve numerical stability.
501        - **decay** (float) - The weight decay value, must be a scalar tensor with float data type.
502        - **gradient** (Tensor) - Gradient, has the type float16.
503
504    Outputs:
505        Tuple of 3 Tensor, the updated parameters.
506
507        - **var** (Tensor) - The same shape and data type as `var`.
508        - **m** (Tensor) - The same shape and data type as `m`.
509        - **v** (Tensor) - The same shape and data type as `v`.
510
511    Supported Platforms:
512        ``CPU``
513
514    Examples:
515        >>> import numpy as np
516        >>> import mindspore.context as context
517        >>> import mindspore.nn as nn
518        >>> import mindspore.ops as ops
519        >>> from mindspore import Tensor, Parameter
520        >>> from mindspore import dtype as mstype
521        >>> class Net(nn.Cell):
522        ...     def __init__(self):
523        ...         super(Net, self).__init__()
524        ...         self.opt = ops.FusedCastAdamWeightDecay()
525        ...         self.var = Parameter(Tensor(np.ones([2, 2]), mstype.float16), name="var")
526        ...         self.m = Parameter(Tensor(np.ones([2, 2]), mstype.float32), name="m")
527        ...         self.v = Parameter(Tensor(np.ones([2, 2]), mstype.float32), name="v")
528        ...     def construct(self, lr, beta1, beta2, epsilon, decay, grad):
529        ...         out = self.opt(self.var, self.m, self.v, lr, beta1, beta2, epsilon, decay, grad)
530        ...         return out
531        >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
532        >>> net = Net()
533        >>> gradient = Tensor(np.ones([2, 2]), mstype.float16)
534        >>> output = net(0.001, 0.9, 0.999, 1e-8, 0.0, gradient)
535        >>> print(net.var.asnumpy())
536    """
537
538    @prim_attr_register
539    def __init__(self, use_locking=False):
540        self.add_prim_attr('side_effect_mem', True)
541        validator.check_value_type("use_locking", use_locking, [bool], self.name)
542
543    def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
544                    epsilon_shape, decay_shape, grad_shape):
545        validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
546        validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
547        validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
548        return var_shape, m_shape, v_shape
549
550    def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
551                    epsilon_dtype, decay_dtype, grad_dtype):
552        args = {"m": m_dtype, "v": v_dtype}
553        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
554        validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)
555        validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name)
556
557        args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
558                "decay": decay_dtype}
559        validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
560        return var_dtype, m_dtype, v_dtype
561