• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""optimizer"""
16import inspect
17from typing import Iterable
18
19import numpy as np
20
21import mindspore
22from mindspore.ops import functional as F, composite as C, operations as P
23from mindspore.ops.operations import _inner_ops as inner
24from mindspore.nn.cell import Cell
25from mindspore.nn.layer.container import CellList
26from mindspore.common.parameter import Parameter, ParameterTuple
27from mindspore.common.initializer import initializer
28from mindspore.common.tensor import Tensor, RowTensor
29import mindspore.common.dtype as mstype
30from mindspore._checkparam import Validator as validator
31from mindspore import log as logger
32from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
33from mindspore.context import ParallelMode
34from mindspore import context
35from mindspore.nn.learning_rate_schedule import LearningRateSchedule
36
37__all__ = ['Optimizer', 'opt_init_args_register']
38
39
40def opt_init_args_register(fn):
41    """Register optimizer init args."""
42    def deco(self, *args, **kwargs):
43        bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
44        bound_args.apply_defaults()
45        arguments = bound_args.arguments
46        arguments.pop('self')
47        if 'params' in arguments.keys():
48            setattr(self, 'init_params', dict({"params": arguments['params']}))
49            arguments.pop('params')
50        if 'optimizer' in arguments.keys():
51            setattr(self, 'init_params', dict({"params": arguments['optimizer'].init_params["params"]}))
52            arguments.pop('optimizer')
53        setattr(self, 'init_args', arguments)
54        fn(self, *args, **kwargs)
55    return deco
56
57
58class Optimizer(Cell):
59    """
60    Base class for all optimizers.
61
62    Note:
63        This class defines the API to add Ops to train a model. Never use
64        this class directly, but instead instantiate one of its subclasses.
65
66        Different parameter groups can set different `learning_rate`, `weight_decay` and `grad_centralization`.
67
68        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
69        weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will
70        be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
71
72        When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
73        but the gradient centralization can only be applied to the parameters of the convolution layer.
74        If the parameters of the non convolution layer are set to True, an error will be reported.
75
76        To improve parameter groups performance, the customized order of parameters can be supported.
77
78    Args:
79        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning
80            rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
81            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
82            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
83            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
84            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
85            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
86        parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be
87            updated, the element in `parameters` must be class `Parameter`. When the `parameters` is a list of `dict`,
88            the "params", "lr", "weight_decay" and "order_params" are the keys can be parsed.
89
90            - params: Required. The value must be a list of `Parameter`.
91
92            - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
93              If not, the `learning_rate` in the API will be used.
94
95            - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
96              will be used. If not, the `weight_decay` in the API will be used.
97
98            - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
99              the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
100              in the value of 'order_params' must be in one of group parameters.
101
102            - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
103              is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
104              This parameter only works on the convolution layer.
105
106        weight_decay (Union[float, int]): An int or a floating point value for the weight decay.
107            It must be equal to or greater than 0.
108            If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0.
109        loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the
110            type of `loss_scale` input is int, it will be converted to float. In general, use the default value. Only
111            when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
112            `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
113            `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
114            Default: 1.0.
115
116    Raises:
117        TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
118        TypeError: If element of `parameters` is neither Parameter nor dict.
119        TypeError: If `loss_scale` is not a float.
120        TypeError: If `weight_decay` is neither float nor int.
121        ValueError: If `loss_scale` is less than or equal to 0.
122        ValueError: If `weight_decay` is less than 0.
123        ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
124
125    Supported Platforms:
126        ``Ascend`` ``GPU`` ``CPU``
127    """
128
129    def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
130        super(Optimizer, self).__init__(auto_prefix=False)
131        parameters = self._parameters_base_check(parameters, "parameters")
132        if not all(isinstance(x, Parameter) for x in parameters) and not all(isinstance(x, dict) for x in parameters):
133            raise TypeError("All elements of the optimizer parameters must be of type `Parameter` or `dict`.")
134
135        if isinstance(loss_scale, int):
136            loss_scale = float(loss_scale)
137        validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
138        validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
139        self.loss_scale = loss_scale
140
141        weight_decay = self._preprocess_weight_decay(weight_decay)
142        self.grad_centralization = False
143
144        self._unique = True
145        self._target = context.get_context("device_target")
146        self.dynamic_lr = False
147        self.assignadd = None
148        self.global_step = None
149        self.is_group = False
150        self.is_group_lr = False
151        self.is_group_params_ordered = False
152        learning_rate = self._preprocess_single_lr(learning_rate)
153        if isinstance(parameters[0], dict):
154            self.is_group = True
155            self.group_params = []
156            self.group_lr = []
157            self.group_weight_decay = []
158            self.group_grad_centralization = []
159            self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization)
160
161        # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
162        if self.dynamic_lr:
163            self.assignadd = P.AssignAdd()
164            self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step')
165
166        if self.is_group_lr:
167            self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
168                else ParameterTuple(self.group_lr)
169        else:
170            self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')
171
172        if self.is_group:
173            self.parameters = ParameterTuple(self.group_params)
174            self.weight_decay = tuple(self.group_weight_decay)
175            self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
176            decay_filter = lambda x: x > 0
177            self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
178            self.exec_weight_decay = any(self.decay_flags)
179            self.grad_centralization_flags = tuple(self.group_grad_centralization)
180        else:
181            self.parameters = ParameterTuple(parameters)
182            self.weight_decay = weight_decay * loss_scale
183            self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
184            decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
185            self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
186            self.exec_weight_decay = self.weight_decay > 0
187        # when a parameter has been unique, there is no need do another unique in optimizer.
188        for param in self.parameters:
189            if param.unique:
190                self._unique = False
191                break
192        ps_filter = lambda x: x.is_param_ps
193        self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
194        cache_filter = lambda x: x.cache_enable
195        self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
196        self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
197        self.need_scale = loss_scale != 1.0
198        self.global_step_increase_tensor = Tensor(1, mstype.int32)
199        self.param_length = len(self.parameters)
200        self.map_ = C.Map()
201        self.map_reverse = C.Map(None, True)
202        self.hyper_map = C.HyperMap()
203        self.hyper_map_reverse = C.HyperMap(None, True)
204        self._use_parallel_optimizer()
205
206    def _use_parallel_optimizer(self):
207        """Indicates whether to use automatic parallelism."""
208        if context.get_auto_parallel_context("enable_parallel_optimizer"):
209            if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend":
210                self.use_parallel = True
211            elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
212                    and context.get_context("device_target") != "Ascend":
213                raise RuntimeError("Parallel optimizer only supports Ascend in data parallel mode.")
214            elif _get_parallel_mode() in (ParallelMode.STAND_ALONE, ParallelMode.HYBRID_PARALLEL):
215                raise RuntimeError("Parallel optimizer is not supported in {}.".format(_get_parallel_mode()))
216            else:
217                self.use_parallel = False
218        else:
219            self.use_parallel = False
220        if self.use_parallel:
221            if self.cls_name not in ["Lamb", "AdamWeightDecay", "AdaFactor"]:
222                raise RuntimeError("Parallel optimizer does not support optimizer {}".format(self.cls_name))
223            self.dev_num = _get_device_num()
224            if self.dev_num > self.param_length:
225                raise RuntimeError("Parallel optimizer can not be applied when the number of parameters {} is"
226                                   " less than the number of devices {}".format(self.param_length, self.dev_num))
227            self.param_rank = self._get_parameter_group_id()
228            self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
229            self.param_names = []
230            for param in self.parameters:
231                self.param_names.append(param.name)
232        else:
233            self.optim_filter = (True,) * self.param_length
234
235    @property
236    def unique(self):
237        """The method is to see whether to make unique. The input type is bool. The method is read-only."""
238        return self._unique
239
240    @unique.setter
241    def unique(self, value):
242        """Set whether the input value is unique."""
243        if not isinstance(value, bool):
244            raise TypeError("The value type must be bool, but got value type is {}".format(type(value)))
245        self._unique = value
246
247    @property
248    def target(self):
249        """
250        The method is used to determine whether the parameter is updated on host or device. The input type is str
251        and can only be 'CPU', 'Ascend' or 'GPU'.
252        """
253        return self._target
254
255    @target.setter
256    def target(self, value):
257        """
258        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
259        optimizer operation.
260        """
261        raise NotImplementedError
262
263    def _set_base_target(self, value):
264        """
265        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
266        optimizer operation.
267        """
268        if not isinstance(value, str):
269            raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
270
271        if value not in ('CPU', 'Ascend', 'GPU'):
272            raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
273
274        if self._target == "CPU" and value in ('Ascend', 'GPU'):
275            raise ValueError("In the CPU environment, target cannot be set to 'GPU' or 'Ascend'.")
276
277        if self._target == "Ascend" and value == 'GPU':
278            raise ValueError("In the Ascend environment, target cannot be set to 'GPU'.")
279
280        if self._target == "GPU" and value == 'Ascend':
281            raise ValueError("In the GPU environment, target cannot be set to 'Ascend'.")
282
283        self._is_device = (value != 'CPU')
284        self._target = value
285
286    def decay_weight(self, gradients):
287        """
288        Weight decay.
289
290        An approach to reduce the overfitting of a deep learning neural network model.
291
292        Args:
293            gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
294                `self.parameters`.
295
296        Returns:
297            tuple[Tensor], The gradients after weight decay.
298        """
299        if self.exec_weight_decay:
300            params = self.parameters
301            if self.is_group:
302                gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
303                                      params, gradients)
304            else:
305                gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
306                                      params, gradients)
307
308        return gradients
309
310    def gradients_centralization(self, gradients):
311        """
312        Gradients centralization.
313
314        A method for optimizing convolutional layer parameters to impore the training speed of a deep learning neural
315        network model.
316
317        Args:
318            gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
319                `self.parameters`.
320
321        Returns:
322            tuple[Tensor], The gradients after gradients centralization.
323        """
324        if self.is_group:
325            gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients)
326
327        return gradients
328
329    def scale_grad(self, gradients):
330        """
331        Loss scale for mixed precision.
332
333        An approach of mixed precision training to improve the speed and energy efficiency of training deep neural
334        network.
335
336        Args:
337            gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as
338                `self.parameters`.
339
340        Returns:
341            tuple[Tensor], The gradients after loss scale.
342
343        """
344        if self.need_scale:
345            gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)
346
347        return gradients
348
349    def _grad_sparse_indices_deduplicate(self, gradients):
350        """ In the case of using big operators, deduplicate the 'indexes' in gradients."""
351        if self._target != 'CPU' and self._unique:
352            gradients = self.map_(F.partial(_indices_deduplicate), gradients)
353        return gradients
354
355    def _preprocess_weight_decay(self, weight_decay):
356        """Check weight decay, and convert int to float."""
357        if isinstance(weight_decay, (float, int)):
358            weight_decay = float(weight_decay)
359            validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
360            return weight_decay
361        raise TypeError("Weight decay should be int or float.")
362
363    def _preprocess_grad_centralization(self, grad_centralization):
364        if not isinstance(grad_centralization, bool):
365            raise TypeError("The gradients centralization should be bool")
366        return grad_centralization
367
368    def _preprocess_single_lr(self, learning_rate):
369        """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule."""
370        if isinstance(learning_rate, (float, int)):
371            learning_rate = float(learning_rate)
372            validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
373            return learning_rate
374        if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
375            return learning_rate
376
377        self.dynamic_lr = True
378        if isinstance(learning_rate, Iterable):
379            return Tensor(np.array(list(learning_rate)).astype(np.float32))
380        if isinstance(learning_rate, Tensor):
381            if learning_rate.ndim > 1:
382                raise ValueError("The dim of `Tensor` type Learning rate should be a 0 or 1,"
383                                 f"but got {learning_rate.ndim}.")
384            if learning_rate.ndim == 1 and learning_rate.size < 2:
385                logger.warning("If use `Tensor` type dynamic learning rate, please make sure that the number"
386                               "of elements in the tensor is greater than 1.")
387            return learning_rate
388        if isinstance(learning_rate, LearningRateSchedule):
389            return learning_rate
390        raise TypeError("Learning rate should be int, float, Tensor, Iterable or LearningRateSchedule.")
391
392    def _build_single_lr(self, learning_rate, name):
393        """Build learning rate value, convert learning rate to a Parameter or a LearningRateSchedule."""
394        if isinstance(learning_rate, float):
395            learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name)
396            if self.is_group_lr and self.dynamic_lr:
397                learning_rate = _ConvertToCell(learning_rate)
398            return learning_rate
399        if isinstance(learning_rate, Tensor) and learning_rate.ndim == 0:
400            learning_rate = Parameter(learning_rate, name)
401            if self.is_group_lr and self.dynamic_lr:
402                learning_rate = _ConvertToCell(learning_rate)
403            return learning_rate
404        if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
405            return _IteratorLearningRate(learning_rate, name)
406        return learning_rate
407
408    def _parameters_base_check(self, parameters, param_info):
409        if parameters is None:
410            raise ValueError(f"Optimizer {param_info} can not be None.")
411        if not isinstance(parameters, Iterable):
412            raise TypeError(f"Optimizer {param_info} must be Iterable.")
413        parameters = list(parameters)
414
415        if not parameters:
416            raise ValueError(f"Optimizer got an empty {param_info} list.")
417        return parameters
418
419    def _check_group_params(self, parameters):
420        """Check group params."""
421        parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization']
422        for group_param in parameters:
423            invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys()))
424            if invalid_key:
425                raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.')
426
427            if 'order_params' in group_param.keys():
428                if len(group_param.keys()) > 1:
429                    raise ValueError("The order params dict in group parameters should "
430                                     "only include the 'order_params' key.")
431                if not isinstance(group_param['order_params'], Iterable):
432                    raise TypeError("The value of 'order_params' should be an Iterable type.")
433                continue
434
435            parameters = self._parameters_base_check(group_param['params'], "group `params`")
436            if not all(isinstance(x, Parameter) for x in parameters):
437                raise TypeError("The group `params` should be an iterator of Parameter type.")
438
439    def _parse_group_params(self, parameters, learning_rate):
440        """Parse group params."""
441        self._check_group_params(parameters)
442        if isinstance(learning_rate, Tensor) and learning_rate.ndim == 1:
443            tensor_lr_length = learning_rate.size
444        else:
445            tensor_lr_length = 0
446
447        for group_param in parameters:
448            if 'order_params' in group_param.keys():
449                if len(group_param.keys()) > 1:
450                    raise ValueError("The order params dict in group parameters should "
451                                     "only include the 'order_params' key.")
452                if not isinstance(group_param['order_params'], Iterable):
453                    raise TypeError("The value of 'order_params' should be an Iterable type.")
454                self.is_group_params_ordered = True
455                continue
456
457            if 'lr' in group_param.keys():
458                self.is_group_lr = True
459                group_lr = self._preprocess_single_lr(group_param['lr'])
460
461                if isinstance(group_lr, Tensor) and group_lr.ndim == 1:
462                    group_lr_length = group_lr.size
463                    if tensor_lr_length == 0:
464                        tensor_lr_length = group_lr_length
465                    elif group_lr_length != tensor_lr_length:
466                        raise ValueError("The Tensor type dynamic learning rate in group should be the same size.")
467
468    def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization):
469        """Initialize learning rate, weight decay or grad centralization in group params."""
470        self._parse_group_params(parameters, learning_rate)
471        default_lr = self._build_single_lr(learning_rate, 'learning_rate')
472
473        params_store = []
474        for group_num, group_param in enumerate(parameters):
475            if 'order_params' in group_param.keys():
476                ordered_parameters = group_param['order_params']
477                continue
478
479            self.group_params += group_param['params']
480
481            if 'lr' in group_param.keys():
482                lr_param_name = 'learning_rate_group_' + str(group_num)
483                lr = self._preprocess_single_lr(group_param['lr'])
484                lr = self._build_single_lr(lr, lr_param_name)
485            else:
486                lr = default_lr
487
488            if 'weight_decay' in group_param.keys():
489                cur_weight_decay = self._preprocess_weight_decay(group_param['weight_decay'])
490                weight_decay_ = cur_weight_decay * self.loss_scale
491            else:
492                weight_decay_ = weight_decay * self.loss_scale
493
494            if 'grad_centralization' in group_param.keys():
495                self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
496                for param in group_param['params']:
497                    validator.check_value_type("parameter", param, [Parameter], self.cls_name)
498                    grad_centralization_ = self.grad_centralization
499            else:
500                grad_centralization_ = grad_centralization
501
502            for key in group_param.keys():
503                if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'):
504                    logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.")
505
506            for param in group_param['params']:
507                validator.check_value_type("parameter", param, [Parameter], self.cls_name)
508                if param.name in params_store:
509                    raise RuntimeError(f"The {param.name} parameter already exists in parameter groups, "
510                                       f"duplicate parameters are not supported.")
511
512                params_store.append(param.name)
513                self.group_lr.append(lr)
514                self.group_weight_decay.append(weight_decay_)
515                self.group_grad_centralization.append(grad_centralization_)
516
517        if self.is_group_params_ordered:
518            self._order_and_adjust_group_params(ordered_parameters)
519
520    def _order_and_adjust_group_params(self, ordered_parameters):
521        """
522        Order group parameter, learning rate, weight decay and grad centralization in group params.
523        """
524        params_length = len(self.group_params)
525        if len(ordered_parameters) != len(self.group_params):
526            raise ValueError(f"The value of 'order_params' should be same with all group parameters.")
527
528        ordered_params = [None] * params_length
529        ordered_learning_rate = [None] * params_length
530        ordered_weight_decay = [None] * params_length
531        ordered_grad_centralization = [None] * params_length
532        params_name = [param.name for param in ordered_parameters]
533
534        for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay,
535                                     self.group_grad_centralization):
536            index = params_name.index(param.name)
537            ordered_params[index] = param
538            ordered_learning_rate[index] = lr
539            ordered_weight_decay[index] = wd
540            ordered_grad_centralization[index] = gc
541
542        self.group_params = ordered_params
543        self.group_lr = ordered_learning_rate
544        self.group_weight_decay = ordered_weight_decay
545        self.group_grad_centralization = ordered_grad_centralization
546
547    def get_lr(self):
548        """
549        Get the learning rate of current step.
550
551        Returns:
552            float, the learning rate of current step.
553        """
554        lr = self.learning_rate
555        if self.dynamic_lr:
556            if self.is_group_lr:
557                lr = ()
558                for learning_rate in self.learning_rate:
559                    current_dynamic_lr = learning_rate(self.global_step)
560                    lr += (current_dynamic_lr,)
561            else:
562                lr = self.learning_rate(self.global_step)
563
564            self.assignadd(self.global_step, self.global_step_increase_tensor)
565        return lr
566
567    def get_lr_parameter(self, param):
568        """
569        Get the learning rate of parameter.
570
571        Args:
572            param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`.
573
574        Returns:
575            Parameter, single `Parameter` or `list[Parameter]` according to the input type.
576        """
577        def get_lr_value(learning_rate):
578            if isinstance(learning_rate, (_ConvertToCell, _IteratorLearningRate)):
579                return learning_rate.learning_rate
580
581            return learning_rate
582
583        if isinstance(param, Parameter):
584            param_list = [param]
585        elif isinstance(param, list):
586            param_list = param
587        else:
588            raise TypeError(f"The parameter only support 'Parameter' or 'list' type.")
589
590        lr = []
591        ids = [id(p) for p in self.parameters]
592        for p in param_list:
593            validator.check_value_type("parameter", p, [Parameter], self.cls_name)
594            if id(p) not in ids:
595                raise ValueError(f"The parameter {p.name} is not in optimizer.")
596            if self.is_group_lr:
597                index = ids.index(id(p))
598                lr.append(get_lr_value(self.learning_rate[index]))
599            else:
600                lr.append(get_lr_value(self.learning_rate))
601
602        return lr if isinstance(param, list) else lr[0]
603
604    def _get_parameter_group_id(self):
605        """
606        Get the parameter partition group id, which is less than the number of devices.
607
608        Returns:
609            tuple, the group id tuple of parameters.
610        """
611        rank_list = ()
612        count = 0
613        for _ in range(self.param_length):
614            rank_list = rank_list + (count,)
615            count = count + 1
616            if count == self.dev_num:
617                count = 0
618        return rank_list
619
620    def broadcast_params(self, optim_result):
621        """
622        Apply Broadcast operations in the sequential order of parameter groups.
623
624        Returns:
625             bool, the status flag.
626        """
627        param_group = []
628        key_group = []
629        for _ in range(self.dev_num):
630            param_group.append(F.make_tuple())
631            key_group.append(F.make_tuple())
632        for i in range(self.param_length):
633            param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (self.parameters[i],)
634            key = P.MakeRefKey(self.param_names[i])()
635            key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
636        new_param_group = []
637        for root in range(self.dev_num):
638            ops = P.Broadcast(root)
639            if root > 0:
640                param_group[root] = F.depend(param_group[root], new_param_group[root-1])
641            else:
642                param_group[root] = F.depend(param_group[root], optim_result)
643            next_params = ops(param_group[root])
644            new_param_group.append(next_params)
645            for i in range(F.tuple_len(next_params)):
646                F.assign(key_group[root][i], next_params[i])
647        return new_param_group
648
649    def construct(self, *hyper_params):
650        raise NotImplementedError
651
652
653op_add = P.AddN()
654op_gather = P.Gather()
655op_mul = P.Mul()
656op_gc = inner.Centralization()
657
658_apply_decay = C.MultitypeFuncGraph("apply_decay")
659_apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization")
660
661
662@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
663def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
664    """Get grad with weight_decay."""
665    if if_apply:
666        indices = gradient.indices
667        values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
668        shape = gradient.dense_shape
669        return RowTensor(indices, values, shape)
670    return gradient
671
672
673@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
674def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
675    """Get grad with weight_decay."""
676    if if_apply:
677        return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
678    return gradient
679
680
681@_apply_grad_centralization.register("Bool", "RowTensor")
682def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
683    """Get grad with grad_centralization."""
684    if if_apply:
685        indices = gradient.indices
686        shape = gradient.dense_shape
687        grad_shape = F.shape(gradient)
688        axis = []
689        for i in range(1, len(grad_shape)):
690            axis.append(i)
691        if len(axis) >= 1:
692            if grad_shape[1] % 16 != 0:
693                return gradient
694            values = op_gc(gradient.values, axis)
695            return RowTensor(indices, values, shape)
696    return gradient
697
698
699@_apply_grad_centralization.register("Bool", "Tensor")
700def _tensor_apply_grad_centralization(if_apply, gradient):
701    """Get grad with grad_centralization."""
702    if if_apply:
703        axis = []
704        grad_shape = F.shape(gradient)
705        for i in range(1, len(grad_shape)):
706            axis.append(i)
707        if len(axis) >= 1:
708            if grad_shape[1] % 16 != 0:
709                return gradient
710            return op_gc(gradient, axis)
711    return gradient
712
713
714_grad_scale = C.MultitypeFuncGraph("grad_scale")
715_indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate")
716
717
718@_grad_scale.register("Number", "Tensor")
719def tensor_grad_scale(scale, grad):
720    """Get grad with scale."""
721    if scale == 1.0:
722        return grad
723    return op_mul(grad, F.cast(scale, F.dtype(grad)))
724
725
726@_grad_scale.register("Tensor", "Tensor")
727def tensor_grad_scale_with_tensor(scale, grad):
728    """Get grad with scale."""
729    return op_mul(grad, F.cast(scale, F.dtype(grad)))
730
731
732@_grad_scale.register("Tensor", "RowTensor")
733def tensor_grad_scale_with_sparse(scale, grad):
734    """Get grad with scale."""
735    return RowTensor(grad.indices, grad.values * F.cast(scale, F.dtype(grad.values)), grad.dense_shape)
736
737
738@_indices_deduplicate.register("RowTensor")
739def rowtensor_deduplicate_indices_slices(grad):
740    """Unique the indices and sums the 'values' corresponding to the duplicate indices."""
741    indices = grad.indices
742    values = grad.values
743
744    unique_indices, index_position = P.Unique()(indices)
745    summed_values = P.UnsortedSegmentSum()(values, index_position, P.DynamicShape()(unique_indices)[0])
746
747    return RowTensor(unique_indices, summed_values, grad.dense_shape)
748
749
750@_indices_deduplicate.register("Tensor")
751def tensor_deduplicate_indice_slices(grad):
752    """Return the input gradient directly in the dense sences."""
753    return grad
754
755
756class _ConvertToCell(LearningRateSchedule):
757    """Inner api, convert learning rate of scalar to LearningRateSchedule."""
758    def __init__(self, learning_rate):
759        super(_ConvertToCell, self).__init__()
760        if not isinstance(learning_rate, Parameter):
761            raise TypeError('Learning rate must be Parameter.')
762        self.learning_rate = learning_rate
763
764    def construct(self, global_step):
765        return self.learning_rate + 1.0 - 1.0
766
767
768class _IteratorLearningRate(LearningRateSchedule):
769    """Inner api, convert learning rate of Tensor(list) to LearningRateSchedule."""
770    def __init__(self, learning_rate, name):
771        super(_IteratorLearningRate, self).__init__()
772        if isinstance(learning_rate, Tensor):
773            if learning_rate.ndim != 1:
774                raise ValueError("The dim of `Tensor` type dynamic learning rate should be 1, "
775                                 f"but got {learning_rate.ndim}.")
776        else:
777            raise TypeError("Learning rate should be Tensor.")
778
779        self.learning_rate = Parameter(learning_rate, name)
780        self.gather = P.Gather()
781
782    def construct(self, global_step):
783        return self.gather(self.learning_rate, global_step, 0)
784