• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""lars optimizer"""
16from __future__ import absolute_import
17
18from mindspore.ops import operations as P
19from mindspore.ops import composite as C
20from mindspore.ops import functional as F
21from mindspore import _checkparam as validator
22from mindspore.common import Tensor, Parameter, dtype as mstype
23from mindspore.common.api import jit
24from mindspore.nn.optim.optimizer import _grad_scale, Optimizer
25from mindspore.nn.optim.optimizer import opt_init_args_register
26
27_lars_opt = C.MultitypeFuncGraph("lars_opt")
28
29
30@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
31def _tensor_run_opt(lars, loss_scale, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag):
32    """Apply lars optimizer to the weight parameter."""
33    if lars_flag:
34        op_reduce_sum = P.SquareSumAll()
35        w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient)
36        if decay_flag:
37            grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay / loss_scale, learning_rate)
38        else:
39            num_zero = 0.0
40            grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate)
41        return grad_t
42
43    return gradient
44
45
46def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name):
47    validator.check_value_type("optimizer", optimizer, Optimizer, prim_name)
48    validator.check_value_type("epsilon", epsilon, [float], prim_name)
49    validator.check_value_type("coefficient", coefficient, [float], prim_name)
50    validator.check_value_type("use_clip", use_clip, [bool], prim_name)
51
52
53class LARS(Optimizer):
54    r"""
55    Implements the LARS algorithm.
56
57    LARS is an optimization algorithm employing a large batch optimization technique. Refer to paper `LARGE BATCH
58    TRAINING OF CONVOLUTIONAL NETWORKS <https://arxiv.org/abs/1708.03888>`_.
59
60    The updating formulas are as follows,
61
62    .. math::
63        \begin{array}{ll} \\
64            &\newline
65            &\hline \\
66            &\textbf{Parameters}: \text{base learning rate } \gamma_{0} , \text{ momentum  m}, \text{ weight decay }
67             \lambda , \\
68            &\hspace{5mm}\text{ LARS coefficient } \eta , \text{ number of steps } T \\
69            &\textbf{Init}: \text{ t=0, v=0, init weight }  w_{0}^{l}  \text{ for each layer } l \\[-1.ex]
70            &\newline
71            &\hline \\
72            &\textbf{while} \text{ t<T  for each layer } l \textbf{ do} \\
73            &\hspace{5mm}g_{t}^{l} \leftarrow \nabla L\left(w_{t}^{l}\right) \\
74            &\hspace{5mm}\gamma_{t} \leftarrow \gamma_{0} *\left(1-\frac{t}{T}\right)^{2} \\
75            &\hspace{5mm}\gamma^{l} \leftarrow \eta *\frac{\left\|w_{t}^{l}\right\|}{\left\|g_{t}^{l}\right\|+
76             \lambda\left\|w_{t}^{l}\right\|} \text{(compute the local LR } \gamma^{ l)} \\
77            &\hspace{5mm}v_{t+1}^{l} \leftarrow m v_{t}^{l}+\gamma_{t+1} * \gamma^{l} *\left(g_{t}^{l}+\lambda
78             w_{t}^{l}\right) \\
79            &\hspace{5mm}w_{t+1}^{l} \leftarrow w_{t}^{l}-v_{t+1}^{l} \\
80            &\textbf{ end while } \\[-1.ex]
81            &\newline
82            &\hline \\[-1.ex]
83        \end{array}
84
85    :math:`w` represents the network parameters, :math:`g` represents `gradients`,
86    :math:`t` represents the current step, :math:`\lambda` represents `weight_decay` in `optimizer`,
87    :math:`\gamma` represents `learning_rate` in `optimizer`, :math:`\eta` represents `coefficient`.
88
89    Args:
90        optimizer (:class:`mindspore.nn.Optimizer`): MindSpore optimizer for which to wrap and modify gradients.
91        epsilon (float): Term added to the denominator to improve numerical stability. Default: ``1e-05`` .
92        coefficient (float): Trust coefficient for calculating the local learning rate. Default: ``0.001`` .
93        use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: ``False`` .
94        lars_filter (Function): A function to determine which of the network parameters to use LARS algorithm. Default:
95                                ``lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name``.
96
97    Inputs:
98        - **gradients** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the
99          as same as the `params` in the optimizer.
100
101    Outputs:
102        Union[Tensor[bool], tuple[Parameter]], it depends on the output of `optimizer`.
103
104    Supported Platforms:
105        ``Ascend``
106
107    Examples:
108        >>> import mindspore as ms
109        >>> from mindspore import nn
110        >>>
111        >>> # Define the network structure of LeNet5. Refer to
112        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
113        >>> net = LeNet5()
114        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
115        >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
116        >>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)
117        >>> model = ms.train.Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)
118    """
119
120    @opt_init_args_register
121    def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False,
122                 lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
123        super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
124        _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name)
125        self.opt = optimizer
126        self.dynamic_decay_flags = optimizer.dynamic_decay_flags
127        self.dynamic_weight_decay = optimizer.dynamic_weight_decay
128        self.weight_decay = optimizer.weight_decay
129        self.global_step = optimizer.global_step
130        self.parameters = optimizer.parameters
131        if optimizer._use_flattened_params:  # pylint: disable=W0212
132            self.opt._use_flattened_params = False  # pylint: disable=W0212
133        self._user_parameters += [param.name for param in self.parameters]
134        self.use_clip = use_clip
135        self.lars_flag = tuple(lars_filter(x) for x in self.parameters)
136        self.is_group = optimizer.is_group
137        self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
138        self.decay_flags = optimizer.decay_flags
139        self.reciprocal_scale = optimizer.reciprocal_scale
140        self.need_scale = optimizer.need_scale
141        self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
142        self.cast = P.Cast()
143        self.loss_scale = optimizer.loss_scale
144
145        if use_clip:
146            self.is_group_lr = optimizer.is_group_lr
147            self.dynamic_lr = optimizer.dynamic_lr
148            self.origin_learning_rate = optimizer.learning_rate
149            if self.is_group_lr and self.dynamic_lr:
150                raise ValueError("For 'LARS', if the argument 'use_clip' is set to True, then the dynamic "
151                                 "learning rate and group learning rate cannot both be true.")
152
153        if self.is_group:
154            optimizer.dynamic_decay_flags = tuple(map(lambda x: False, self.dynamic_decay_flags))
155        else:
156            optimizer.dynamic_decay_flags = False
157        optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags))
158        optimizer.dynamic_weight_decay = False
159        optimizer.reciprocal_scale = 1.0
160        optimizer.exec_weight_decay = False
161
162    def _get_lr(self):
163        """Get the learning rate of current step."""
164        lr = self.origin_learning_rate
165        if self.dynamic_lr:
166            if self.is_group_lr:
167                lr = ()
168                for learning_rate in self.origin_learning_rate:
169                    current_dynamic_lr = learning_rate(self.global_step)
170                    lr += (current_dynamic_lr,)
171            else:
172                lr = self.origin_learning_rate(self.global_step)
173
174        return lr
175
176    @jit
177    def construct(self, gradients):
178        params = self.parameters
179        gradients = self.flatten_gradients(gradients)
180        if self.use_clip:
181            lr = self._get_lr()
182        else:
183            lr = self.learning_rate
184        weight_decay = self.get_weight_decay()
185
186        if self.need_scale:
187            gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)
188
189        if self.is_group:
190            if self.is_group_lr:
191                gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale), lr, weight_decay,
192                                           gradients, params, self.decay_flags, self.lars_flag)
193            else:
194                gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr), weight_decay,
195                                           gradients, params, self.decay_flags, self.lars_flag)
196        else:
197            gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr, weight_decay),
198                                       gradients, params, self.decay_flags, self.lars_flag)
199        success = self.opt(gradients)
200        if self._is_dynamic_lr_or_weight_decay() and not self.opt.dynamic_lr and not self.opt.dynamic_weight_decay:
201            self.assignadd(self.global_step, self.global_step_increase_tensor)
202        return success
203