1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Adamax optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.keras import backend_config 21from tensorflow.python.keras.optimizer_v2 import optimizer_v2 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.training import gen_training_ops 26from tensorflow.python.util.tf_export import keras_export 27 28 29@keras_export('keras.optimizers.Adamax') 30class Adamax(optimizer_v2.OptimizerV2): 31 """Optimizer that implements the Adamax algorithm. 32 33 It is a variant of Adam based on the infinity norm. 34 Default parameters follow those provided in the paper. 35 Adamax is sometimes superior to adam, specially in models with embeddings. 36 37 Initialization: 38 39 ```python 40 m = 0 # Initialize initial 1st moment vector 41 v = 0 # Initialize the exponentially weighted infinity norm 42 t = 0 # Initialize timestep 43 ``` 44 45 The update rule for parameter `w` with gradient `g` is 46 described at the end of section 7.1 of the paper: 47 48 ```python 49 t += 1 50 m = beta1 * m + (1 - beta) * g 51 v = max(beta2 * v, abs(g)) 52 current_lr = learning_rate / (1 - beta1 ** t) 53 w = w - current_lr * m / (v + epsilon) 54 ``` 55 56 Similarly to `Adam`, the epsilon is added for numerical stability 57 (especially to get rid of division by zero when `v_t == 0`). 58 59 In contrast to `Adam`, the sparse implementation of this algorithm 60 (used when the gradient is an IndexedSlices object, typically because of 61 `tf.gather` or an embedding lookup in the forward pass) only updates 62 variable slices and corresponding `m_t`, `v_t` terms when that part of 63 the variable was used in the forward pass. This means that the sparse 64 behavior is contrast to the dense behavior (similar to some momentum 65 implementations which ignore momentum unless a variable slice was actually 66 used). 67 68 Args: 69 learning_rate: A `Tensor`, floating point value, or a schedule that is a 70 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 71 beta_1: A float value or a constant float tensor. The exponential decay 72 rate for the 1st moment estimates. 73 beta_2: A float value or a constant float tensor. The exponential decay 74 rate for the exponentially weighted infinity norm. 75 epsilon: A small constant for numerical stability. 76 name: Optional name for the operations created when applying gradients. 77 Defaults to `"Adamax"`. 78 **kwargs: Keyword arguments. Allowed to be one of 79 `"clipnorm"` or `"clipvalue"`. 80 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 81 gradients by value. 82 83 Reference: 84 - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) 85 """ 86 87 _HAS_AGGREGATE_GRAD = True 88 89 def __init__(self, 90 learning_rate=0.001, 91 beta_1=0.9, 92 beta_2=0.999, 93 epsilon=1e-7, 94 name='Adamax', 95 **kwargs): 96 super(Adamax, self).__init__(name, **kwargs) 97 self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) 98 self._set_hyper('decay', self._initial_decay) 99 self._set_hyper('beta_1', beta_1) 100 self._set_hyper('beta_2', beta_2) 101 self.epsilon = epsilon or backend_config.epsilon() 102 103 def _create_slots(self, var_list): 104 # Separate for-loops to respect the ordering of slot variables from v1. 105 for var in var_list: 106 self.add_slot(var, 'm') # Create slots for the first moments. 107 for var in var_list: 108 self.add_slot(var, 'v') # Create slots for the second moments. 109 110 def _prepare_local(self, var_device, var_dtype, apply_state): 111 super(Adamax, self)._prepare_local(var_device, var_dtype, apply_state) 112 113 local_step = math_ops.cast(self.iterations + 1, var_dtype) 114 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype)) 115 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype)) 116 beta_1_power = math_ops.pow(beta_1_t, local_step) 117 lr_t = apply_state[(var_device, var_dtype)]['lr_t'] 118 119 apply_state[(var_device, var_dtype)].update( 120 dict( 121 neg_scaled_lr=-lr_t / (1 - beta_1_power), 122 epsilon=ops.convert_to_tensor_v2_with_dispatch( 123 self.epsilon, var_dtype), 124 beta_1_t=beta_1_t, 125 beta_1_power=beta_1_power, 126 one_minus_beta_1_t=1 - beta_1_t, 127 beta_2_t=beta_2_t, 128 zero=array_ops.zeros((), dtype=dtypes.int64))) 129 130 def _resource_apply_dense(self, grad, var, apply_state=None): 131 var_device, var_dtype = var.device, var.dtype.base_dtype 132 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 133 or self._fallback_apply_state(var_device, var_dtype)) 134 135 m = self.get_slot(var, 'm') 136 v = self.get_slot(var, 'v') 137 return gen_training_ops.ResourceApplyAdaMax( 138 var=var.handle, 139 m=m.handle, 140 v=v.handle, 141 beta1_power=coefficients['beta_1_power'], 142 lr=coefficients['lr_t'], 143 beta1=coefficients['beta_1_t'], 144 beta2=coefficients['beta_2_t'], 145 epsilon=coefficients['epsilon'], 146 grad=grad, 147 use_locking=self._use_locking) 148 149 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 150 var_device, var_dtype = var.device, var.dtype.base_dtype 151 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 152 or self._fallback_apply_state(var_device, var_dtype)) 153 154 # m_t = beta1 * m + (1 - beta1) * g_t 155 m = self.get_slot(var, 'm') 156 m_slice = array_ops.gather(m, indices, axis=coefficients['zero']) 157 m_t_slice = (m_slice * coefficients['beta_1_t'] + 158 grad * coefficients['one_minus_beta_1_t']) 159 with ops.control_dependencies([m_t_slice]): 160 m_t = self._resource_scatter_update(m, indices, m_t_slice) 161 162 # u_t = max(beta2 * u, abs(g_t)) 163 v = self.get_slot(var, 'v') 164 v_slice = array_ops.gather(v, indices, axis=coefficients['zero']) 165 v_t_slice = math_ops.maximum(v_slice * coefficients['beta_2_t'], 166 math_ops.abs(grad)) 167 with ops.control_dependencies([v_t_slice]): 168 v_t = self._resource_scatter_update(v, indices, v_t_slice) 169 # theta_t = theta - lr / (1 - beta1^t) * m_t / u_t 170 var_slice = coefficients['neg_scaled_lr'] * ( 171 m_t_slice / (v_t_slice + coefficients['epsilon'])) 172 with ops.control_dependencies([var_slice]): 173 var_update = self._resource_scatter_add(var, indices, var_slice) 174 return control_flow_ops.group(*[var_update, m_t, v_t]) 175 176 def get_config(self): 177 config = super(Adamax, self).get_config() 178 config.update({ 179 'learning_rate': self._serialize_hyperparameter('learning_rate'), 180 'decay': self._initial_decay, 181 'beta_1': self._serialize_hyperparameter('beta_1'), 182 'beta_2': self._serialize_hyperparameter('beta_2'), 183 'epsilon': self.epsilon, 184 }) 185 return config 186