1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Nadam optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.framework import ops 19from tensorflow.python.keras import backend_config 20from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 21from tensorflow.python.keras.optimizer_v2 import optimizer_v2 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import state_ops 26from tensorflow.python.ops import variables as tf_variables 27from tensorflow.python.util.tf_export import keras_export 28 29 30@keras_export('keras.optimizers.Nadam') 31class Nadam(optimizer_v2.OptimizerV2): 32 r"""Optimizer that implements the NAdam algorithm. 33 Much like Adam is essentially RMSprop with momentum, Nadam is Adam with 34 Nesterov momentum. 35 36 Args: 37 learning_rate: A Tensor or a floating point value. The learning rate. 38 beta_1: A float value or a constant float tensor. The exponential decay 39 rate for the 1st moment estimates. 40 beta_2: A float value or a constant float tensor. The exponential decay 41 rate for the exponentially weighted infinity norm. 42 epsilon: A small constant for numerical stability. 43 name: Optional name for the operations created when applying gradients. 44 Defaults to `"Nadam"`. 45 **kwargs: Keyword arguments. Allowed to be one of 46 `"clipnorm"` or `"clipvalue"`. 47 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 48 gradients by value. 49 50 Usage Example: 51 >>> opt = tf.keras.optimizers.Nadam(learning_rate=0.2) 52 >>> var1 = tf.Variable(10.0) 53 >>> loss = lambda: (var1 ** 2) / 2.0 54 >>> step_count = opt.minimize(loss, [var1]).numpy() 55 >>> "{:.1f}".format(var1.numpy()) 56 9.8 57 58 Reference: 59 - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). 60 """ 61 62 _HAS_AGGREGATE_GRAD = True 63 64 def __init__(self, 65 learning_rate=0.001, 66 beta_1=0.9, 67 beta_2=0.999, 68 epsilon=1e-7, 69 name='Nadam', 70 **kwargs): 71 # Backwards compatibility with keras NAdam optimizer. 72 kwargs['decay'] = kwargs.pop('schedule_decay', 0.004) 73 learning_rate = kwargs.get('lr', learning_rate) 74 if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule): 75 raise ValueError('The Nadam optimizer does not support ' 76 'tf.keras.optimizers.LearningRateSchedules as the ' 77 'learning rate.') 78 79 super(Nadam, self).__init__(name, **kwargs) 80 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 81 self._set_hyper('decay', self._initial_decay) 82 self._set_hyper('beta_1', beta_1) 83 self._set_hyper('beta_2', beta_2) 84 self.epsilon = epsilon or backend_config.epsilon() 85 self._m_cache = None 86 87 def _create_slots(self, var_list): 88 var_dtype = var_list[0].dtype.base_dtype 89 if self._m_cache is None: 90 self._m_cache = self.add_weight( 91 'momentum_cache', 92 shape=[], 93 dtype=var_dtype, 94 initializer='ones', 95 trainable=False, 96 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 97 self._weights.append(self._m_cache) 98 # Separate for-loops to respect the ordering of slot variables from v1. 99 for var in var_list: 100 # Create slots for the first moments. 101 self.add_slot(var, 'm') 102 for var in var_list: 103 # Create slots for the second moments. 104 self.add_slot(var, 'v') 105 106 def _prepare_local(self, var_device, var_dtype, apply_state): 107 lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype)) 108 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 109 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 110 local_step = math_ops.cast(self.iterations + 1, var_dtype) 111 next_step = math_ops.cast(self.iterations + 2, var_dtype) 112 113 decay_base = math_ops.cast(0.96, var_dtype) 114 115 m_t = beta_1_t * (1. - 0.5 * ( 116 math_ops.pow(decay_base, self._initial_decay * local_step))) 117 m_t_1 = beta_1_t * (1. - 0.5 * ( 118 math_ops.pow(decay_base, self._initial_decay * next_step))) 119 120 m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t 121 if var_dtype is self._m_cache.dtype: 122 m_schedule_new = array_ops.identity(state_ops.assign( 123 self._m_cache, m_schedule_new, use_locking=self._use_locking)) 124 m_schedule_next = m_schedule_new * m_t_1 125 126 apply_state[(var_device, var_dtype)] = dict( 127 lr_t=lr_t, 128 neg_lr_t=-lr_t, # pylint: disable=invalid-unary-operand-type 129 epsilon=ops.convert_to_tensor_v2_with_dispatch(self.epsilon, var_dtype), 130 beta_1_t=beta_1_t, 131 beta_2_t=beta_2_t, 132 m_t=m_t, 133 m_t_1=m_t_1, 134 one_minus_beta_1_t=1 - beta_1_t, 135 one_minus_beta_2_t=1 - beta_2_t, 136 one_minus_m_t=1. - m_t, 137 one_minus_m_schedule_new=1. - m_schedule_new, 138 one_minus_m_schedule_next=1. - m_schedule_next, 139 v_t_prime_denominator=1. - math_ops.pow(beta_2_t, local_step), 140 ) 141 142 def _prepare(self, var_list): 143 # Get the value of the momentum cache before starting to apply gradients. 144 self._m_cache_read = array_ops.identity(self._m_cache) 145 return super(Nadam, self)._prepare(var_list) 146 147 def _resource_apply_dense(self, grad, var, apply_state=None): 148 var_device, var_dtype = var.device, var.dtype.base_dtype 149 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 150 or self._fallback_apply_state(var_device, var_dtype)) 151 152 m = self.get_slot(var, 'm') 153 v = self.get_slot(var, 'v') 154 155 g_prime = grad / coefficients['one_minus_m_schedule_new'] 156 m_t = (coefficients['beta_1_t'] * m + 157 coefficients['one_minus_beta_1_t'] * grad) 158 m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) 159 m_t_prime = m_t / coefficients['one_minus_m_schedule_next'] 160 v_t = (coefficients['beta_2_t'] * v + 161 coefficients['one_minus_beta_2_t'] * math_ops.square(grad)) 162 v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) 163 v_t_prime = v_t / coefficients['v_t_prime_denominator'] 164 m_t_bar = (coefficients['one_minus_m_t'] * g_prime + 165 coefficients['m_t_1'] * m_t_prime) 166 var_t = var - coefficients['lr_t'] * m_t_bar / ( 167 math_ops.sqrt(v_t_prime) + coefficients['epsilon']) 168 return state_ops.assign(var, var_t, use_locking=self._use_locking).op 169 170 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 171 var_device, var_dtype = var.device, var.dtype.base_dtype 172 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 173 or self._fallback_apply_state(var_device, var_dtype)) 174 175 m = self.get_slot(var, 'm') 176 v = self.get_slot(var, 'v') 177 178 g_prime = grad / coefficients['one_minus_m_schedule_new'] 179 180 # m_t = beta1 * m + (1 - beta1) * g_t 181 m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] 182 m_t = state_ops.assign(m, m * coefficients['beta_1_t'], 183 use_locking=self._use_locking) 184 185 with ops.control_dependencies([m_t]): 186 m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) 187 m_t_slice = array_ops.gather(m_t, indices) 188 189 m_t_prime = m_t_slice / coefficients['one_minus_m_schedule_next'] 190 m_t_bar = (coefficients['one_minus_m_t'] * g_prime + 191 coefficients['m_t_1'] * m_t_prime) 192 193 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 194 v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t'] 195 v_t = state_ops.assign(v, v * coefficients['beta_2_t'], 196 use_locking=self._use_locking) 197 198 with ops.control_dependencies([v_t]): 199 v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) 200 v_t_slice = array_ops.gather(v_t, indices) 201 202 v_t_prime = v_t_slice / coefficients['v_t_prime_denominator'] 203 v_prime_sqrt_plus_eps = math_ops.sqrt(v_t_prime) + coefficients['epsilon'] 204 205 var_update = self._resource_scatter_add( 206 var, indices, 207 coefficients['neg_lr_t'] * m_t_bar / v_prime_sqrt_plus_eps) 208 return control_flow_ops.group(*[var_update, m_t_bar, v_t]) 209 210 def get_config(self): 211 config = super(Nadam, self).get_config() 212 config.update({ 213 'learning_rate': self._serialize_hyperparameter('learning_rate'), 214 'decay': self._initial_decay, 215 'beta_1': self._serialize_hyperparameter('beta_1'), 216 'beta_2': self._serialize_hyperparameter('beta_2'), 217 'epsilon': self.epsilon, 218 }) 219 return config 220