1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""lars optimizer""" 16from __future__ import absolute_import 17 18from mindspore.ops import operations as P 19from mindspore.ops import composite as C 20from mindspore.ops import functional as F 21from mindspore import _checkparam as validator 22from mindspore.common import Tensor, Parameter, dtype as mstype 23from mindspore.common.api import jit 24from mindspore.nn.optim.optimizer import _grad_scale, Optimizer 25from mindspore.nn.optim.optimizer import opt_init_args_register 26 27_lars_opt = C.MultitypeFuncGraph("lars_opt") 28 29 30@_lars_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 31def _tensor_run_opt(lars, loss_scale, learning_rate, weight_decay, gradient, weight, decay_flag, lars_flag): 32 """Apply lars optimizer to the weight parameter.""" 33 if lars_flag: 34 op_reduce_sum = P.SquareSumAll() 35 w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient) 36 if decay_flag: 37 grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay / loss_scale, learning_rate) 38 else: 39 num_zero = 0.0 40 grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate) 41 return grad_t 42 43 return gradient 44 45 46def _check_param_value(optimizer, epsilon, coefficient, use_clip, prim_name): 47 validator.check_value_type("optimizer", optimizer, Optimizer, prim_name) 48 validator.check_value_type("epsilon", epsilon, [float], prim_name) 49 validator.check_value_type("coefficient", coefficient, [float], prim_name) 50 validator.check_value_type("use_clip", use_clip, [bool], prim_name) 51 52 53class LARS(Optimizer): 54 r""" 55 Implements the LARS algorithm. 56 57 LARS is an optimization algorithm employing a large batch optimization technique. Refer to paper `LARGE BATCH 58 TRAINING OF CONVOLUTIONAL NETWORKS <https://arxiv.org/abs/1708.03888>`_. 59 60 The updating formulas are as follows, 61 62 .. math:: 63 \begin{array}{ll} \\ 64 &\newline 65 &\hline \\ 66 &\textbf{Parameters}: \text{base learning rate } \gamma_{0} , \text{ momentum m}, \text{ weight decay } 67 \lambda , \\ 68 &\hspace{5mm}\text{ LARS coefficient } \eta , \text{ number of steps } T \\ 69 &\textbf{Init}: \text{ t=0, v=0, init weight } w_{0}^{l} \text{ for each layer } l \\[-1.ex] 70 &\newline 71 &\hline \\ 72 &\textbf{while} \text{ t<T for each layer } l \textbf{ do} \\ 73 &\hspace{5mm}g_{t}^{l} \leftarrow \nabla L\left(w_{t}^{l}\right) \\ 74 &\hspace{5mm}\gamma_{t} \leftarrow \gamma_{0} *\left(1-\frac{t}{T}\right)^{2} \\ 75 &\hspace{5mm}\gamma^{l} \leftarrow \eta *\frac{\left\|w_{t}^{l}\right\|}{\left\|g_{t}^{l}\right\|+ 76 \lambda\left\|w_{t}^{l}\right\|} \text{(compute the local LR } \gamma^{ l)} \\ 77 &\hspace{5mm}v_{t+1}^{l} \leftarrow m v_{t}^{l}+\gamma_{t+1} * \gamma^{l} *\left(g_{t}^{l}+\lambda 78 w_{t}^{l}\right) \\ 79 &\hspace{5mm}w_{t+1}^{l} \leftarrow w_{t}^{l}-v_{t+1}^{l} \\ 80 &\textbf{ end while } \\[-1.ex] 81 &\newline 82 &\hline \\[-1.ex] 83 \end{array} 84 85 :math:`w` represents the network parameters, :math:`g` represents `gradients`, 86 :math:`t` represents the current step, :math:`\lambda` represents `weight_decay` in `optimizer`, 87 :math:`\gamma` represents `learning_rate` in `optimizer`, :math:`\eta` represents `coefficient`. 88 89 Args: 90 optimizer (:class:`mindspore.nn.Optimizer`): MindSpore optimizer for which to wrap and modify gradients. 91 epsilon (float): Term added to the denominator to improve numerical stability. Default: ``1e-05`` . 92 coefficient (float): Trust coefficient for calculating the local learning rate. Default: ``0.001`` . 93 use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: ``False`` . 94 lars_filter (Function): A function to determine which of the network parameters to use LARS algorithm. Default: 95 ``lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name``. 96 97 Inputs: 98 - **gradients** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the 99 as same as the `params` in the optimizer. 100 101 Outputs: 102 Union[Tensor[bool], tuple[Parameter]], it depends on the output of `optimizer`. 103 104 Supported Platforms: 105 ``Ascend`` 106 107 Examples: 108 >>> import mindspore as ms 109 >>> from mindspore import nn 110 >>> 111 >>> # Define the network structure of LeNet5. Refer to 112 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 113 >>> net = LeNet5() 114 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 115 >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9) 116 >>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02) 117 >>> model = ms.train.Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None) 118 """ 119 120 @opt_init_args_register 121 def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, 122 lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): 123 super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) 124 _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) 125 self.opt = optimizer 126 self.dynamic_decay_flags = optimizer.dynamic_decay_flags 127 self.dynamic_weight_decay = optimizer.dynamic_weight_decay 128 self.weight_decay = optimizer.weight_decay 129 self.global_step = optimizer.global_step 130 self.parameters = optimizer.parameters 131 if optimizer._use_flattened_params: # pylint: disable=W0212 132 self.opt._use_flattened_params = False # pylint: disable=W0212 133 self._user_parameters += [param.name for param in self.parameters] 134 self.use_clip = use_clip 135 self.lars_flag = tuple(lars_filter(x) for x in self.parameters) 136 self.is_group = optimizer.is_group 137 self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") 138 self.decay_flags = optimizer.decay_flags 139 self.reciprocal_scale = optimizer.reciprocal_scale 140 self.need_scale = optimizer.need_scale 141 self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) 142 self.cast = P.Cast() 143 self.loss_scale = optimizer.loss_scale 144 145 if use_clip: 146 self.is_group_lr = optimizer.is_group_lr 147 self.dynamic_lr = optimizer.dynamic_lr 148 self.origin_learning_rate = optimizer.learning_rate 149 if self.is_group_lr and self.dynamic_lr: 150 raise ValueError("For 'LARS', if the argument 'use_clip' is set to True, then the dynamic " 151 "learning rate and group learning rate cannot both be true.") 152 153 if self.is_group: 154 optimizer.dynamic_decay_flags = tuple(map(lambda x: False, self.dynamic_decay_flags)) 155 else: 156 optimizer.dynamic_decay_flags = False 157 optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags)) 158 optimizer.dynamic_weight_decay = False 159 optimizer.reciprocal_scale = 1.0 160 optimizer.exec_weight_decay = False 161 162 def _get_lr(self): 163 """Get the learning rate of current step.""" 164 lr = self.origin_learning_rate 165 if self.dynamic_lr: 166 if self.is_group_lr: 167 lr = () 168 for learning_rate in self.origin_learning_rate: 169 current_dynamic_lr = learning_rate(self.global_step) 170 lr += (current_dynamic_lr,) 171 else: 172 lr = self.origin_learning_rate(self.global_step) 173 174 return lr 175 176 @jit 177 def construct(self, gradients): 178 params = self.parameters 179 gradients = self.flatten_gradients(gradients) 180 if self.use_clip: 181 lr = self._get_lr() 182 else: 183 lr = self.learning_rate 184 weight_decay = self.get_weight_decay() 185 186 if self.need_scale: 187 gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) 188 189 if self.is_group: 190 if self.is_group_lr: 191 gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale), lr, weight_decay, 192 gradients, params, self.decay_flags, self.lars_flag) 193 else: 194 gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr), weight_decay, 195 gradients, params, self.decay_flags, self.lars_flag) 196 else: 197 gradients = self.hyper_map(F.partial(_lars_opt, self.lars, self.loss_scale, lr, weight_decay), 198 gradients, params, self.decay_flags, self.lars_flag) 199 success = self.opt(gradients) 200 if self._is_dynamic_lr_or_weight_decay() and not self.opt.dynamic_lr and not self.opt.dynamic_weight_decay: 201 self.assignadd(self.global_step, self.global_step_increase_tensor) 202 return success 203