1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adam for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import ops 21from tensorflow.python.keras import backend_config 22from tensorflow.python.keras.optimizer_v2 import optimizer_v2 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import state_ops 27from tensorflow.python.training import training_ops 28from tensorflow.python.util.tf_export import keras_export 29 30 31@keras_export('keras.optimizers.Adam') 32class Adam(optimizer_v2.OptimizerV2): 33 """Optimizer that implements the Adam algorithm. 34 35 Adam optimization is a stochastic gradient descent method that is based on 36 adaptive estimation of first-order and second-order moments. 37 According to the paper 38 [Adam: A Method for Stochastic Optimization. Kingma et al., 39 2014](http://arxiv.org/abs/1412.6980), 40 the method is "*computationally efficient, has little memory 41 requirement, invariant to diagonal rescaling of gradients, and is well suited 42 for problems that are large in terms of data/parameters*". 43 44 For AMSGrad see [On The Convergence Of Adam And Beyond. 45 Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ). 46 """ 47 48 def __init__(self, 49 learning_rate=0.001, 50 beta_1=0.9, 51 beta_2=0.999, 52 epsilon=1e-7, 53 amsgrad=False, 54 name='Adam', 55 **kwargs): 56 r"""Construct a new Adam optimizer. 57 58 If amsgrad = False: 59 Initialization: 60 61 $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ 62 $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ 63 $$t := 0 \text{(Initialize timestep)}$$ 64 65 The update rule for `variable` with gradient `g` uses an optimization 66 described at the end of section 2 of the paper: 67 68 $$t := t + 1$$ 69 $$\text{lr}_t := \mathrm{learning_rate} * 70 \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ 71 72 $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ 73 $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$ 74 $$\text{variable} := \text{variable} - 75 lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ 76 77 If amsgrad = True: 78 Initialization: 79 80 $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ 81 $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ 82 $$\hat{v}_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ 83 $$t := 0 \text{(Initialize timestep)}$$ 84 85 The update rule for `variable` with gradient `g` uses an optimization 86 described at the end of section 2 of the paper: 87 88 $$t := t + 1$$ 89 $$\text{lr}_t := \mathrm{learning_rate} * 90 \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ 91 92 $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ 93 $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$ 94 $$\hat{v}_t := \max(\hat{v}_{t-1}, v_t)$$ 95 $$\text{variable} := \text{variable} - 96 \text{lr}_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$ 97 98 The default value of 1e-7 for epsilon might not be a good default in 99 general. For example, when training an Inception network on ImageNet a 100 current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the 101 formulation just before Section 2.1 of the Kingma and Ba paper rather than 102 the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon 103 hat" in the paper. 104 105 The sparse implementation of this algorithm (used when the gradient is an 106 IndexedSlices object, typically because of `tf.gather` or an embedding 107 lookup in the forward pass) does apply momentum to variable slices even if 108 they were not used in the forward pass (meaning they have a gradient equal 109 to zero). Momentum decay (beta1) is also applied to the entire momentum 110 accumulator. This means that the sparse behavior is equivalent to the dense 111 behavior (in contrast to some momentum implementations which ignore momentum 112 unless a variable slice was actually used). 113 114 Args: 115 learning_rate: A `Tensor`, floating point value, or a schedule that is a 116 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 117 beta_1: A float value or a constant float tensor. The exponential decay 118 rate for the 1st moment estimates. 119 beta_2: A float value or a constant float tensor. The exponential decay 120 rate for the 2nd moment estimates. 121 epsilon: A small constant for numerical stability. This epsilon is 122 "epsilon hat" in the Kingma and Ba paper (in the formula just before 123 Section 2.1), not the epsilon in Algorithm 1 of the paper. 124 amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from 125 the paper "On the Convergence of Adam and beyond". 126 name: Optional name for the operations created when applying gradients. 127 Defaults to "Adam". 128 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 129 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 130 gradients by value, `decay` is included for backward compatibility to 131 allow time inverse decay of learning rate. `lr` is included for backward 132 compatibility, recommended to use `learning_rate` instead. 133 134 @compatibility(eager) 135 When eager execution is enabled, `learning_rate`, `beta_1`, `beta_2`, 136 and `epsilon` can each be a callable that takes no arguments and 137 returns the actual value to use. This can be useful for changing these 138 values across different invocations of optimizer functions. 139 @end_compatibility 140 """ 141 142 super(Adam, self).__init__(name, **kwargs) 143 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 144 self._set_hyper('decay', self._initial_decay) 145 self._set_hyper('beta_1', beta_1) 146 self._set_hyper('beta_2', beta_2) 147 self.epsilon = epsilon or backend_config.epsilon() 148 self.amsgrad = amsgrad 149 150 def _create_slots(self, var_list): 151 # Create slots for the first and second moments. 152 # Separate for-loops to respect the ordering of slot variables from v1. 153 for var in var_list: 154 self.add_slot(var, 'm') 155 for var in var_list: 156 self.add_slot(var, 'v') 157 if self.amsgrad: 158 for var in var_list: 159 self.add_slot(var, 'vhat') 160 161 def _prepare_local(self, var_device, var_dtype, apply_state): 162 super(Adam, self)._prepare_local(var_device, var_dtype, apply_state) 163 164 local_step = math_ops.cast(self.iterations + 1, var_dtype) 165 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 166 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 167 beta_1_power = math_ops.pow(beta_1_t, local_step) 168 beta_2_power = math_ops.pow(beta_2_t, local_step) 169 lr = (apply_state[(var_device, var_dtype)]['lr_t'] * 170 (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))) 171 apply_state[(var_device, var_dtype)].update(dict( 172 lr=lr, 173 epsilon=ops.convert_to_tensor(self.epsilon, var_dtype), 174 beta_1_t=beta_1_t, 175 beta_1_power=beta_1_power, 176 one_minus_beta_1_t=1 - beta_1_t, 177 beta_2_t=beta_2_t, 178 beta_2_power=beta_2_power, 179 one_minus_beta_2_t=1 - beta_2_t 180 )) 181 182 def set_weights(self, weights): 183 params = self.weights 184 # If the weights are generated by Keras V1 optimizer, it includes vhats 185 # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2 186 # optimizer has 2x + 1 variables. Filter vhats out for compatibility. 187 num_vars = int((len(params) - 1) / 2) 188 if len(weights) == 3 * num_vars + 1: 189 weights = weights[:len(params)] 190 super(Adam, self).set_weights(weights) 191 192 def _resource_apply_dense(self, grad, var, apply_state=None): 193 var_device, var_dtype = var.device, var.dtype.base_dtype 194 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 195 or self._fallback_apply_state(var_device, var_dtype)) 196 197 m = self.get_slot(var, 'm') 198 v = self.get_slot(var, 'v') 199 200 if not self.amsgrad: 201 return training_ops.resource_apply_adam( 202 var.handle, 203 m.handle, 204 v.handle, 205 coefficients['beta_1_power'], 206 coefficients['beta_2_power'], 207 coefficients['lr_t'], 208 coefficients['beta_1_t'], 209 coefficients['beta_2_t'], 210 coefficients['epsilon'], 211 grad, 212 use_locking=self._use_locking) 213 else: 214 vhat = self.get_slot(var, 'vhat') 215 return training_ops.resource_apply_adam_with_amsgrad( 216 var.handle, 217 m.handle, 218 v.handle, 219 vhat.handle, 220 coefficients['beta_1_power'], 221 coefficients['beta_2_power'], 222 coefficients['lr_t'], 223 coefficients['beta_1_t'], 224 coefficients['beta_2_t'], 225 coefficients['epsilon'], 226 grad, 227 use_locking=self._use_locking) 228 229 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 230 var_device, var_dtype = var.device, var.dtype.base_dtype 231 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 232 or self._fallback_apply_state(var_device, var_dtype)) 233 234 # m_t = beta1 * m + (1 - beta1) * g_t 235 m = self.get_slot(var, 'm') 236 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 237 m_t = state_ops.assign(m, m * coefficients['beta_1_t'], 238 use_locking=self._use_locking) 239 with ops.control_dependencies([m_t]): 240 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 241 242 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 243 v = self.get_slot(var, 'v') 244 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 245 v_t = state_ops.assign(v, v * coefficients['beta_2_t'], 246 use_locking=self._use_locking) 247 with ops.control_dependencies([v_t]): 248 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 249 250 if not self.amsgrad: 251 v_sqrt = math_ops.sqrt(v_t) 252 var_update = state_ops.assign_sub( 253 var, coefficients['lr'] * m_t / (v_sqrt + coefficients['epsilon']), 254 use_locking=self._use_locking) 255 return control_flow_ops.group(*[var_update, m_t, v_t]) 256 else: 257 v_hat = self.get_slot(var, 'vhat') 258 v_hat_t = math_ops.maximum(v_hat, v_t) 259 with ops.control_dependencies([v_hat_t]): 260 v_hat_t = state_ops.assign( 261 v_hat, v_hat_t, use_locking=self._use_locking) 262 v_hat_sqrt = math_ops.sqrt(v_hat_t) 263 var_update = state_ops.assign_sub( 264 var, 265 coefficients['lr'] * m_t / (v_hat_sqrt + coefficients['epsilon']), 266 use_locking=self._use_locking) 267 return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t]) 268 269 def get_config(self): 270 config = super(Adam, self).get_config() 271 config.update({ 272 'learning_rate': self._serialize_hyperparameter('learning_rate'), 273 'decay': self._serialize_hyperparameter('decay'), 274 'beta_1': self._serialize_hyperparameter('beta_1'), 275 'beta_2': self._serialize_hyperparameter('beta_2'), 276 'epsilon': self.epsilon, 277 'amsgrad': self.amsgrad, 278 }) 279 return config 280