• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adam for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import ops
21from tensorflow.python.keras import backend_config
22from tensorflow.python.keras.optimizer_v2 import optimizer_v2
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import state_ops
27from tensorflow.python.training import training_ops
28from tensorflow.python.util.tf_export import keras_export
29
30
31@keras_export('keras.optimizers.Adam')
32class Adam(optimizer_v2.OptimizerV2):
33  """Optimizer that implements the Adam algorithm.
34
35  Adam optimization is a stochastic gradient descent method that is based on
36  adaptive estimation of first-order and second-order moments.
37  According to the paper
38  [Adam: A Method for Stochastic Optimization. Kingma et al.,
39  2014](http://arxiv.org/abs/1412.6980),
40   the method is "*computationally efficient, has little memory
41  requirement, invariant to diagonal rescaling of gradients, and is well suited
42  for problems that are large in terms of data/parameters*".
43
44  For AMSGrad see [On The Convergence Of Adam And Beyond.
45  Reddi et al., 5-8](https://openreview.net/pdf?id=ryQu7f-RZ).
46  """
47
48  def __init__(self,
49               learning_rate=0.001,
50               beta_1=0.9,
51               beta_2=0.999,
52               epsilon=1e-7,
53               amsgrad=False,
54               name='Adam',
55               **kwargs):
56    r"""Construct a new Adam optimizer.
57
58    If amsgrad = False:
59      Initialization:
60
61      $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
62      $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
63      $$t := 0 \text{(Initialize timestep)}$$
64
65      The update rule for `variable` with gradient `g` uses an optimization
66      described at the end of section 2 of the paper:
67
68      $$t := t + 1$$
69      $$\text{lr}_t := \mathrm{learning_rate} *
70        \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
71
72      $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
73      $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$
74      $$\text{variable} := \text{variable} -
75        lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
76
77    If amsgrad = True:
78      Initialization:
79
80      $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
81      $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
82      $$\hat{v}_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
83      $$t := 0 \text{(Initialize timestep)}$$
84
85      The update rule for `variable` with gradient `g` uses an optimization
86      described at the end of section 2 of the paper:
87
88      $$t := t + 1$$
89      $$\text{lr}_t := \mathrm{learning_rate} *
90        \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
91
92      $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
93      $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$
94      $$\hat{v}_t := \max(\hat{v}_{t-1}, v_t)$$
95      $$\text{variable} := \text{variable} -
96        \text{lr}_t * m_t / (\sqrt{\hat{v}_t} + \epsilon)$$
97
98    The default value of 1e-7 for epsilon might not be a good default in
99    general. For example, when training an Inception network on ImageNet a
100    current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
101    formulation just before Section 2.1 of the Kingma and Ba paper rather than
102    the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
103    hat" in the paper.
104
105    The sparse implementation of this algorithm (used when the gradient is an
106    IndexedSlices object, typically because of `tf.gather` or an embedding
107    lookup in the forward pass) does apply momentum to variable slices even if
108    they were not used in the forward pass (meaning they have a gradient equal
109    to zero). Momentum decay (beta1) is also applied to the entire momentum
110    accumulator. This means that the sparse behavior is equivalent to the dense
111    behavior (in contrast to some momentum implementations which ignore momentum
112    unless a variable slice was actually used).
113
114    Args:
115      learning_rate: A `Tensor`, floating point value, or a schedule that is a
116        `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
117      beta_1: A float value or a constant float tensor. The exponential decay
118        rate for the 1st moment estimates.
119      beta_2: A float value or a constant float tensor. The exponential decay
120        rate for the 2nd moment estimates.
121      epsilon: A small constant for numerical stability. This epsilon is
122        "epsilon hat" in the Kingma and Ba paper (in the formula just before
123        Section 2.1), not the epsilon in Algorithm 1 of the paper.
124      amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
125        the paper "On the Convergence of Adam and beyond".
126      name: Optional name for the operations created when applying gradients.
127        Defaults to "Adam".
128      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
129        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
130        gradients by value, `decay` is included for backward compatibility to
131        allow time inverse decay of learning rate. `lr` is included for backward
132        compatibility, recommended to use `learning_rate` instead.
133
134    @compatibility(eager)
135    When eager execution is enabled, `learning_rate`, `beta_1`, `beta_2`,
136    and `epsilon` can each be a callable that takes no arguments and
137    returns the actual value to use. This can be useful for changing these
138    values across different invocations of optimizer functions.
139    @end_compatibility
140    """
141
142    super(Adam, self).__init__(name, **kwargs)
143    self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
144    self._set_hyper('decay', self._initial_decay)
145    self._set_hyper('beta_1', beta_1)
146    self._set_hyper('beta_2', beta_2)
147    self.epsilon = epsilon or backend_config.epsilon()
148    self.amsgrad = amsgrad
149
150  def _create_slots(self, var_list):
151    # Create slots for the first and second moments.
152    # Separate for-loops to respect the ordering of slot variables from v1.
153    for var in var_list:
154      self.add_slot(var, 'm')
155    for var in var_list:
156      self.add_slot(var, 'v')
157    if self.amsgrad:
158      for var in var_list:
159        self.add_slot(var, 'vhat')
160
161  def _prepare_local(self, var_device, var_dtype, apply_state):
162    super(Adam, self)._prepare_local(var_device, var_dtype, apply_state)
163
164    local_step = math_ops.cast(self.iterations + 1, var_dtype)
165    beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
166    beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
167    beta_1_power = math_ops.pow(beta_1_t, local_step)
168    beta_2_power = math_ops.pow(beta_2_t, local_step)
169    lr = (apply_state[(var_device, var_dtype)]['lr_t'] *
170          (math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)))
171    apply_state[(var_device, var_dtype)].update(dict(
172        lr=lr,
173        epsilon=ops.convert_to_tensor(self.epsilon, var_dtype),
174        beta_1_t=beta_1_t,
175        beta_1_power=beta_1_power,
176        one_minus_beta_1_t=1 - beta_1_t,
177        beta_2_t=beta_2_t,
178        beta_2_power=beta_2_power,
179        one_minus_beta_2_t=1 - beta_2_t
180    ))
181
182  def set_weights(self, weights):
183    params = self.weights
184    # If the weights are generated by Keras V1 optimizer, it includes vhats
185    # even without amsgrad, i.e, V1 optimizer has 3x + 1 variables, while V2
186    # optimizer has 2x + 1 variables. Filter vhats out for compatibility.
187    num_vars = int((len(params) - 1) / 2)
188    if len(weights) == 3 * num_vars + 1:
189      weights = weights[:len(params)]
190    super(Adam, self).set_weights(weights)
191
192  def _resource_apply_dense(self, grad, var, apply_state=None):
193    var_device, var_dtype = var.device, var.dtype.base_dtype
194    coefficients = ((apply_state or {}).get((var_device, var_dtype))
195                    or self._fallback_apply_state(var_device, var_dtype))
196
197    m = self.get_slot(var, 'm')
198    v = self.get_slot(var, 'v')
199
200    if not self.amsgrad:
201      return training_ops.resource_apply_adam(
202          var.handle,
203          m.handle,
204          v.handle,
205          coefficients['beta_1_power'],
206          coefficients['beta_2_power'],
207          coefficients['lr_t'],
208          coefficients['beta_1_t'],
209          coefficients['beta_2_t'],
210          coefficients['epsilon'],
211          grad,
212          use_locking=self._use_locking)
213    else:
214      vhat = self.get_slot(var, 'vhat')
215      return training_ops.resource_apply_adam_with_amsgrad(
216          var.handle,
217          m.handle,
218          v.handle,
219          vhat.handle,
220          coefficients['beta_1_power'],
221          coefficients['beta_2_power'],
222          coefficients['lr_t'],
223          coefficients['beta_1_t'],
224          coefficients['beta_2_t'],
225          coefficients['epsilon'],
226          grad,
227          use_locking=self._use_locking)
228
229  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
230    var_device, var_dtype = var.device, var.dtype.base_dtype
231    coefficients = ((apply_state or {}).get((var_device, var_dtype))
232                    or self._fallback_apply_state(var_device, var_dtype))
233
234    # m_t = beta1 * m + (1 - beta1) * g_t
235    m = self.get_slot(var, 'm')
236    m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
237    m_t = state_ops.assign(m, m * coefficients['beta_1_t'],
238                           use_locking=self._use_locking)
239    with ops.control_dependencies([m_t]):
240      m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
241
242    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
243    v = self.get_slot(var, 'v')
244    v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
245    v_t = state_ops.assign(v, v * coefficients['beta_2_t'],
246                           use_locking=self._use_locking)
247    with ops.control_dependencies([v_t]):
248      v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
249
250    if not self.amsgrad:
251      v_sqrt = math_ops.sqrt(v_t)
252      var_update = state_ops.assign_sub(
253          var, coefficients['lr'] * m_t / (v_sqrt + coefficients['epsilon']),
254          use_locking=self._use_locking)
255      return control_flow_ops.group(*[var_update, m_t, v_t])
256    else:
257      v_hat = self.get_slot(var, 'vhat')
258      v_hat_t = math_ops.maximum(v_hat, v_t)
259      with ops.control_dependencies([v_hat_t]):
260        v_hat_t = state_ops.assign(
261            v_hat, v_hat_t, use_locking=self._use_locking)
262      v_hat_sqrt = math_ops.sqrt(v_hat_t)
263      var_update = state_ops.assign_sub(
264          var,
265          coefficients['lr'] * m_t / (v_hat_sqrt + coefficients['epsilon']),
266          use_locking=self._use_locking)
267      return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
268
269  def get_config(self):
270    config = super(Adam, self).get_config()
271    config.update({
272        'learning_rate': self._serialize_hyperparameter('learning_rate'),
273        'decay': self._serialize_hyperparameter('decay'),
274        'beta_1': self._serialize_hyperparameter('beta_1'),
275        'beta_2': self._serialize_hyperparameter('beta_2'),
276        'epsilon': self.epsilon,
277        'amsgrad': self.amsgrad,
278    })
279    return config
280