• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Adam for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.ops import state_ops
26from tensorflow.python.training import optimizer
27from tensorflow.python.training import training_ops
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export(v1=["train.AdamOptimizer"])
32class AdamOptimizer(optimizer.Optimizer):
33  """Optimizer that implements the Adam algorithm.
34
35  References:
36    Adam - A Method for Stochastic Optimization:
37      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
38      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
39  """
40
41  def __init__(self,
42               learning_rate=0.001,
43               beta1=0.9,
44               beta2=0.999,
45               epsilon=1e-8,
46               use_locking=False,
47               name="Adam"):
48    r"""Construct a new Adam optimizer.
49
50    Initialization:
51
52    $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
53    $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
54    $$t := 0 \text{(Initialize timestep)}$$
55
56    The update rule for `variable` with gradient `g` uses an optimization
57    described at the end of section 2 of the paper:
58
59    $$t := t + 1$$
60    $$\text{lr}_t := \mathrm{learning_rate} *
61      \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
62
63    $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
64    $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$
65    $$\text{variable} := \text{variable} -
66      \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$
67
68    The default value of 1e-8 for epsilon might not be a good default in
69    general. For example, when training an Inception network on ImageNet a
70    current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
71    formulation just before Section 2.1 of the Kingma and Ba paper rather than
72    the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
73    hat" in the paper.
74
75    The sparse implementation of this algorithm (used when the gradient is an
76    IndexedSlices object, typically because of `tf.gather` or an embedding
77    lookup in the forward pass) does apply momentum to variable slices even if
78    they were not used in the forward pass (meaning they have a gradient equal
79    to zero). Momentum decay (beta1) is also applied to the entire momentum
80    accumulator. This means that the sparse behavior is equivalent to the dense
81    behavior (in contrast to some momentum implementations which ignore momentum
82    unless a variable slice was actually used).
83
84    Args:
85      learning_rate: A Tensor or a floating point value.  The learning rate.
86      beta1: A float value or a constant float tensor. The exponential decay
87        rate for the 1st moment estimates.
88      beta2: A float value or a constant float tensor. The exponential decay
89        rate for the 2nd moment estimates.
90      epsilon: A small constant for numerical stability. This epsilon is
91        "epsilon hat" in the Kingma and Ba paper (in the formula just before
92        Section 2.1), not the epsilon in Algorithm 1 of the paper.
93      use_locking: If True use locks for update operations.
94      name: Optional name for the operations created when applying gradients.
95        Defaults to "Adam".
96
97    @compatibility(eager)
98    When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
99    `epsilon` can each be a callable that takes no arguments and returns the
100    actual value to use. This can be useful for changing these values across
101    different invocations of optimizer functions.
102    @end_compatibility
103    """
104    super(AdamOptimizer, self).__init__(use_locking, name)
105    self._lr = learning_rate
106    self._beta1 = beta1
107    self._beta2 = beta2
108    self._epsilon = epsilon
109
110    # Tensor versions of the constructor arguments, created in _prepare().
111    self._lr_t = None
112    self._beta1_t = None
113    self._beta2_t = None
114    self._epsilon_t = None
115
116  def _get_beta_accumulators(self):
117    with ops.init_scope():
118      if context.executing_eagerly():
119        graph = None
120      else:
121        graph = ops.get_default_graph()
122      return (self._get_non_slot_variable("beta1_power", graph=graph),
123              self._get_non_slot_variable("beta2_power", graph=graph))
124
125  def _create_slots(self, var_list):
126    # Create the beta1 and beta2 accumulators on the same device as the first
127    # variable. Sort the var_list to make sure this device is consistent across
128    # workers (these need to go on the same PS, otherwise some updates are
129    # silently ignored).
130    first_var = min(var_list, key=lambda x: x.name)
131    self._create_non_slot_variable(
132        initial_value=self._beta1, name="beta1_power", colocate_with=first_var)
133    self._create_non_slot_variable(
134        initial_value=self._beta2, name="beta2_power", colocate_with=first_var)
135
136    # Create slots for the first and second moments.
137    for v in var_list:
138      self._zeros_slot(v, "m", self._name)
139      self._zeros_slot(v, "v", self._name)
140
141  def _prepare(self):
142    lr = self._call_if_callable(self._lr)
143    beta1 = self._call_if_callable(self._beta1)
144    beta2 = self._call_if_callable(self._beta2)
145    epsilon = self._call_if_callable(self._epsilon)
146
147    self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
148    self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
149    self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
150    self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
151
152  def _apply_dense(self, grad, var):
153    m = self.get_slot(var, "m")
154    v = self.get_slot(var, "v")
155    beta1_power, beta2_power = self._get_beta_accumulators()
156    return training_ops.apply_adam(
157        var,
158        m,
159        v,
160        math_ops.cast(beta1_power, var.dtype.base_dtype),
161        math_ops.cast(beta2_power, var.dtype.base_dtype),
162        math_ops.cast(self._lr_t, var.dtype.base_dtype),
163        math_ops.cast(self._beta1_t, var.dtype.base_dtype),
164        math_ops.cast(self._beta2_t, var.dtype.base_dtype),
165        math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
166        grad,
167        use_locking=self._use_locking).op
168
169  def _resource_apply_dense(self, grad, var):
170    m = self.get_slot(var, "m")
171    v = self.get_slot(var, "v")
172    beta1_power, beta2_power = self._get_beta_accumulators()
173    return training_ops.resource_apply_adam(
174        var.handle,
175        m.handle,
176        v.handle,
177        math_ops.cast(beta1_power, grad.dtype.base_dtype),
178        math_ops.cast(beta2_power, grad.dtype.base_dtype),
179        math_ops.cast(self._lr_t, grad.dtype.base_dtype),
180        math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
181        math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
182        math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
183        grad,
184        use_locking=self._use_locking)
185
186  def _apply_sparse_shared(self, grad, var, indices, scatter_add):
187    beta1_power, beta2_power = self._get_beta_accumulators()
188    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
189    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
190    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
191    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
192    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
193    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
194    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
195    # m_t = beta1 * m + (1 - beta1) * g_t
196    m = self.get_slot(var, "m")
197    m_scaled_g_values = grad * (1 - beta1_t)
198    m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking)
199    with ops.control_dependencies([m_t]):
200      m_t = scatter_add(m, indices, m_scaled_g_values)
201    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
202    v = self.get_slot(var, "v")
203    v_scaled_g_values = (grad * grad) * (1 - beta2_t)
204    v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
205    with ops.control_dependencies([v_t]):
206      v_t = scatter_add(v, indices, v_scaled_g_values)
207    v_sqrt = math_ops.sqrt(v_t)
208    var_update = state_ops.assign_sub(
209        var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
210    return control_flow_ops.group(*[var_update, m_t, v_t])
211
212  def _apply_sparse(self, grad, var):
213    return self._apply_sparse_shared(
214        grad.values,
215        var,
216        grad.indices,
217        lambda x, i, v: state_ops.scatter_add(  # pylint: disable=g-long-lambda
218            x,
219            i,
220            v,
221            use_locking=self._use_locking))
222
223  def _resource_scatter_add(self, x, i, v):
224    with ops.control_dependencies(
225        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
226      return x.value()
227
228  def _resource_apply_sparse(self, grad, var, indices):
229    return self._apply_sparse_shared(grad, var, indices,
230                                     self._resource_scatter_add)
231
232  def _finish(self, update_ops, name_scope):
233    # Update the power accumulators.
234    with ops.control_dependencies(update_ops):
235      beta1_power, beta2_power = self._get_beta_accumulators()
236      with ops.colocate_with(beta1_power):
237        update_beta1 = beta1_power.assign(
238            beta1_power * self._beta1_t, use_locking=self._use_locking)
239        update_beta2 = beta2_power.assign(
240            beta2_power * self._beta2_t, use_locking=self._use_locking)
241    return control_flow_ops.group(
242        *update_ops + [update_beta1, update_beta2], name=name_scope)
243