• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=g-classes-have-attributes
17"""Legacy v1 optimizer classes.
18
19For more examples see the base class `tf.compat.v1.keras.optimizers.Optimizer`.
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from six.moves import zip  # pylint: disable=redefined-builtin
26
27from tensorflow.python.distribute import distribution_strategy_context
28from tensorflow.python.eager import backprop
29from tensorflow.python.framework import ops
30from tensorflow.python.keras import backend as K
31from tensorflow.python.ops import clip_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.training import training_util
35from tensorflow.python.training.tracking import base as trackable
36from tensorflow.python.util import nest
37
38
39class Optimizer(object):
40  """Abstract optimizer base class.
41
42  Note: this is the parent class of all optimizers, not an actual optimizer
43  that can be used for training models.
44
45  All Keras optimizers support the following keyword arguments:
46
47      clipnorm: float >= 0. Gradients will be clipped
48          when their L2 norm exceeds this value.
49      clipvalue: float >= 0. Gradients will be clipped
50          when their absolute value exceeds this value.
51  """
52
53  def __init__(self, **kwargs):
54    allowed_kwargs = {'clipnorm', 'clipvalue'}
55    for k in kwargs:
56      if k not in allowed_kwargs:
57        raise TypeError('Unexpected keyword argument '
58                        'passed to optimizer: ' + str(k))
59      # checks that clipnorm >= 0 and clipvalue >= 0
60      if kwargs[k] < 0:
61        raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k]))
62    self.__dict__.update(kwargs)
63    self.updates = []
64    self.weights = []
65
66  # Set this to False, indicating `apply_gradients` does not take the
67  # `experimental_aggregate_gradients` argument.
68  _HAS_AGGREGATE_GRAD = False
69
70  def _create_all_weights(self, params):
71    """Creates and sets all optimizer weights.
72
73    Args:
74      params: list or tuple of `Variable` objects that will be minimized
75        using this optimizer.
76
77    Returns:
78      Specific weight values that are used in `get_updates`
79    """
80    raise NotImplementedError
81
82  def get_updates(self, loss, params):
83    raise NotImplementedError
84
85  def get_gradients(self, loss, params):
86    """Returns gradients of `loss` with respect to `params`.
87
88    Args:
89        loss: Loss tensor.
90        params: List of variables.
91
92    Returns:
93        List of gradient tensors.
94
95    Raises:
96        ValueError: In case any gradient cannot be computed (e.g. if gradient
97          function not implemented).
98    """
99    grads = K.gradients(loss, params)
100    if any(g is None for g in grads):
101      raise ValueError('An operation has `None` for gradient. '
102                       'Please make sure that all of your ops have a '
103                       'gradient defined (i.e. are differentiable). '
104                       'Common ops without gradient: '
105                       'K.argmax, K.round, K.eval.')
106    if hasattr(self, 'clipnorm'):
107      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
108    if hasattr(self, 'clipvalue'):
109      grads = [
110          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
111          for g in grads
112      ]
113    return grads
114
115  def set_weights(self, weights):
116    """Sets the weights of the optimizer, from Numpy arrays.
117
118    Should only be called after computing the gradients
119    (otherwise the optimizer has no weights).
120
121    Args:
122        weights: a list of Numpy arrays. The number of arrays and their shape
123          must match number of the dimensions of the weights of the optimizer
124          (i.e. it should match the output of `get_weights`).
125
126    Raises:
127        ValueError: in case of incompatible weight shapes.
128    """
129    params = self.weights
130    if len(params) != len(weights):
131      raise ValueError('Length of the specified weight list (' +
132                       str(len(weights)) +
133                       ') does not match the number of weights '
134                       'of the optimizer (' + str(len(params)) + ')')
135    weight_value_tuples = []
136    param_values = K.batch_get_value(params)
137    for pv, p, w in zip(param_values, params, weights):
138      if pv.shape != w.shape:
139        raise ValueError('Optimizer weight shape ' + str(pv.shape) +
140                         ' not compatible with '
141                         'provided weight shape ' + str(w.shape))
142      weight_value_tuples.append((p, w))
143    K.batch_set_value(weight_value_tuples)
144
145  def get_weights(self):
146    """Returns the current value of the weights of the optimizer.
147
148    Returns:
149        A list of numpy arrays.
150    """
151    return K.batch_get_value(self.weights)
152
153  def get_config(self):
154    config = {}
155    if hasattr(self, 'clipnorm'):
156      config['clipnorm'] = self.clipnorm
157    if hasattr(self, 'clipvalue'):
158      config['clipvalue'] = self.clipvalue
159    return config
160
161  @classmethod
162  def from_config(cls, config):
163    return cls(**config)
164
165
166class SGD(Optimizer):
167  """Stochastic gradient descent optimizer.
168
169  Includes support for momentum,
170  learning rate decay, and Nesterov momentum.
171
172  Args:
173      lr: float >= 0. Learning rate.
174      momentum: float >= 0. Parameter that accelerates SGD in the relevant
175        direction and dampens oscillations.
176      decay: float >= 0. Learning rate decay over each update.
177      nesterov: boolean. Whether to apply Nesterov momentum.
178  """
179
180  def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, **kwargs):
181    super(SGD, self).__init__(**kwargs)
182    with K.name_scope(self.__class__.__name__):
183      self.iterations = K.variable(0, dtype='int64', name='iterations')
184      self.lr = K.variable(lr, name='lr')
185      self.momentum = K.variable(momentum, name='momentum')
186      self.decay = K.variable(decay, name='decay')
187    self.initial_decay = decay
188    self.nesterov = nesterov
189
190  def _create_all_weights(self, params):
191    shapes = [K.int_shape(p) for p in params]
192    moments = [K.zeros(shape) for shape in shapes]
193    self.weights = [self.iterations] + moments
194    return moments
195
196  def get_updates(self, loss, params):
197    grads = self.get_gradients(loss, params)
198    self.updates = [state_ops.assign_add(self.iterations, 1)]
199
200    lr = self.lr
201    if self.initial_decay > 0:
202      lr = lr * (  # pylint: disable=g-no-augmented-assignment
203          1. /
204          (1. +
205           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
206    # momentum
207    moments = self._create_all_weights(params)
208    for p, g, m in zip(params, grads, moments):
209      v = self.momentum * m - lr * g  # velocity
210      self.updates.append(state_ops.assign(m, v))
211
212      if self.nesterov:
213        new_p = p + self.momentum * v - lr * g
214      else:
215        new_p = p + v
216
217      # Apply constraints.
218      if getattr(p, 'constraint', None) is not None:
219        new_p = p.constraint(new_p)
220
221      self.updates.append(state_ops.assign(p, new_p))
222    return self.updates
223
224  def get_config(self):
225    config = {
226        'lr': float(K.get_value(self.lr)),
227        'momentum': float(K.get_value(self.momentum)),
228        'decay': float(K.get_value(self.decay)),
229        'nesterov': self.nesterov
230    }
231    base_config = super(SGD, self).get_config()
232    return dict(list(base_config.items()) + list(config.items()))
233
234
235class RMSprop(Optimizer):
236  """RMSProp optimizer.
237
238  It is recommended to leave the parameters of this optimizer
239  at their default values
240  (except the learning rate, which can be freely tuned).
241
242  Args:
243      lr: float >= 0. Learning rate.
244      rho: float >= 0.
245      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
246      decay: float >= 0. Learning rate decay over each update.
247  """
248
249  def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
250    super(RMSprop, self).__init__(**kwargs)
251    with K.name_scope(self.__class__.__name__):
252      self.lr = K.variable(lr, name='lr')
253      self.rho = K.variable(rho, name='rho')
254      self.decay = K.variable(decay, name='decay')
255      self.iterations = K.variable(0, dtype='int64', name='iterations')
256    if epsilon is None:
257      epsilon = K.epsilon()
258    self.epsilon = epsilon
259    self.initial_decay = decay
260
261  def _create_all_weights(self, params):
262    accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
263    self.weights = accumulators
264    return accumulators
265
266  def get_updates(self, loss, params):
267    grads = self.get_gradients(loss, params)
268    accumulators = self._create_all_weights(params)
269    self.updates = [state_ops.assign_add(self.iterations, 1)]
270
271    lr = self.lr
272    if self.initial_decay > 0:
273      lr = lr * (  # pylint: disable=g-no-augmented-assignment
274          1. /
275          (1. +
276           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
277
278    for p, g, a in zip(params, grads, accumulators):
279      # update accumulator
280      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
281      self.updates.append(state_ops.assign(a, new_a))
282      new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
283
284      # Apply constraints.
285      if getattr(p, 'constraint', None) is not None:
286        new_p = p.constraint(new_p)
287
288      self.updates.append(state_ops.assign(p, new_p))
289    return self.updates
290
291  def get_config(self):
292    config = {
293        'lr': float(K.get_value(self.lr)),
294        'rho': float(K.get_value(self.rho)),
295        'decay': float(K.get_value(self.decay)),
296        'epsilon': self.epsilon
297    }
298    base_config = super(RMSprop, self).get_config()
299    return dict(list(base_config.items()) + list(config.items()))
300
301
302class Adagrad(Optimizer):
303  """Adagrad optimizer.
304
305  Adagrad is an optimizer with parameter-specific learning rates,
306  which are adapted relative to how frequently a parameter gets
307  updated during training. The more updates a parameter receives,
308  the smaller the updates.
309
310  It is recommended to leave the parameters of this optimizer
311  at their default values.
312
313  # Arguments
314      lr: float >= 0. Initial learning rate.
315      epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
316      decay: float >= 0. Learning rate decay over each update.
317
318  # References
319      - [Adaptive Subgradient Methods for Online Learning and Stochastic
320      Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
321  """
322
323  def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
324    super(Adagrad, self).__init__(**kwargs)
325    with K.name_scope(self.__class__.__name__):
326      self.lr = K.variable(lr, name='lr')
327      self.decay = K.variable(decay, name='decay')
328      self.iterations = K.variable(0, dtype='int64', name='iterations')
329    if epsilon is None:
330      epsilon = K.epsilon()
331    self.epsilon = epsilon
332    self.initial_decay = decay
333
334  def _create_all_weights(self, params):
335    shapes = [K.int_shape(p) for p in params]
336    accumulators = [K.zeros(shape) for shape in shapes]
337    self.weights = accumulators
338    return accumulators
339
340  def get_updates(self, loss, params):
341    grads = self.get_gradients(loss, params)
342    accumulators = self._create_all_weights(params)
343
344    self.updates = [state_ops.assign_add(self.iterations, 1)]
345
346    lr = self.lr
347    if self.initial_decay > 0:
348      lr = lr * (  # pylint: disable=g-no-augmented-assignment
349          1. /
350          (1. +
351           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
352
353    for p, g, a in zip(params, grads, accumulators):
354      new_a = a + math_ops.square(g)  # update accumulator
355      self.updates.append(state_ops.assign(a, new_a))
356      new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
357
358      # Apply constraints.
359      if getattr(p, 'constraint', None) is not None:
360        new_p = p.constraint(new_p)
361
362      self.updates.append(state_ops.assign(p, new_p))
363    return self.updates
364
365  def get_config(self):
366    config = {
367        'lr': float(K.get_value(self.lr)),
368        'decay': float(K.get_value(self.decay)),
369        'epsilon': self.epsilon
370    }
371    base_config = super(Adagrad, self).get_config()
372    return dict(list(base_config.items()) + list(config.items()))
373
374
375class Adadelta(Optimizer):
376  """Adadelta optimizer.
377
378  Adadelta is a more robust extension of Adagrad
379  that adapts learning rates based on a moving window of gradient updates,
380  instead of accumulating all past gradients. This way, Adadelta continues
381  learning even when many updates have been done. Compared to Adagrad, in the
382  original version of Adadelta you don't have to set an initial learning
383  rate. In this version, initial learning rate and decay factor can
384  be set, as in most other Keras optimizers.
385
386  It is recommended to leave the parameters of this optimizer
387  at their default values.
388
389  # Arguments
390      lr: float >= 0. Initial learning rate, defaults to 1.
391          It is recommended to leave it at the default value.
392      rho: float >= 0. Adadelta decay factor, corresponding to fraction of
393          gradient to keep at each time step.
394      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
395      decay: float >= 0. Initial learning rate decay.
396
397  # References
398      - [Adadelta - an adaptive learning rate
399      method](http://arxiv.org/abs/1212.5701)
400  """
401
402  def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
403    super(Adadelta, self).__init__(**kwargs)
404    with K.name_scope(self.__class__.__name__):
405      self.lr = K.variable(lr, name='lr')
406      self.decay = K.variable(decay, name='decay')
407      self.iterations = K.variable(0, dtype='int64', name='iterations')
408    if epsilon is None:
409      epsilon = K.epsilon()
410    self.rho = rho
411    self.epsilon = epsilon
412    self.initial_decay = decay
413
414  def _create_all_weights(self, params):
415    shapes = [K.int_shape(p) for p in params]
416    accumulators = [K.zeros(shape) for shape in shapes]
417    delta_accumulators = [K.zeros(shape) for shape in shapes]
418    self.weights = accumulators + delta_accumulators
419    return accumulators, delta_accumulators
420
421  def get_updates(self, loss, params):
422    grads = self.get_gradients(loss, params)
423    self.updates = [state_ops.assign_add(self.iterations, 1)]
424    accumulators, delta_accumulators = self._create_all_weights(params)
425
426    lr = self.lr
427    if self.initial_decay > 0:
428      lr = lr * (  # pylint: disable=g-no-augmented-assignment
429          1. /
430          (1. +
431           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
432
433    for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
434      # update accumulator
435      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
436      self.updates.append(state_ops.assign(a, new_a))
437
438      # use the new accumulator and the *old* delta_accumulator
439      update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon)
440      new_p = p - lr * update
441
442      # Apply constraints.
443      if getattr(p, 'constraint', None) is not None:
444        new_p = p.constraint(new_p)
445
446      self.updates.append(state_ops.assign(p, new_p))
447
448      # update delta_accumulator
449      new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update)
450      self.updates.append(state_ops.assign(d_a, new_d_a))
451    return self.updates
452
453  def get_config(self):
454    config = {
455        'lr': float(K.get_value(self.lr)),
456        'rho': self.rho,
457        'decay': float(K.get_value(self.decay)),
458        'epsilon': self.epsilon
459    }
460    base_config = super(Adadelta, self).get_config()
461    return dict(list(base_config.items()) + list(config.items()))
462
463
464class Adam(Optimizer):
465  """Adam optimizer.
466
467  Default parameters follow those provided in the original paper.
468
469  Args:
470      lr: float >= 0. Learning rate.
471      beta_1: float, 0 < beta < 1. Generally close to 1.
472      beta_2: float, 0 < beta < 1. Generally close to 1.
473      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
474      decay: float >= 0. Learning rate decay over each update.
475      amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm
476        from the paper "On the Convergence of Adam and Beyond".
477  """
478
479  def __init__(self,
480               lr=0.001,
481               beta_1=0.9,
482               beta_2=0.999,
483               epsilon=None,
484               decay=0.,
485               amsgrad=False,
486               **kwargs):
487    super(Adam, self).__init__(**kwargs)
488    with K.name_scope(self.__class__.__name__):
489      self.iterations = K.variable(0, dtype='int64', name='iterations')
490      self.lr = K.variable(lr, name='lr')
491      self.beta_1 = K.variable(beta_1, name='beta_1')
492      self.beta_2 = K.variable(beta_2, name='beta_2')
493      self.decay = K.variable(decay, name='decay')
494    if epsilon is None:
495      epsilon = K.epsilon()
496    self.epsilon = epsilon
497    self.initial_decay = decay
498    self.amsgrad = amsgrad
499
500  def _create_all_weights(self, params):
501    ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
502    vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
503    if self.amsgrad:
504      vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
505    else:
506      vhats = [K.zeros(1) for _ in params]
507    self.weights = [self.iterations] + ms + vs + vhats
508    return ms, vs, vhats
509
510  def get_updates(self, loss, params):
511    grads = self.get_gradients(loss, params)
512    self.updates = []
513
514    lr = self.lr
515    if self.initial_decay > 0:
516      lr = lr * (  # pylint: disable=g-no-augmented-assignment
517          1. /
518          (1. +
519           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
520
521    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
522      t = math_ops.cast(self.iterations, K.floatx())
523    lr_t = lr * (
524        K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
525        (1. - math_ops.pow(self.beta_1, t)))
526
527    ms, vs, vhats = self._create_all_weights(params)
528    for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
529      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
530      v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
531      if self.amsgrad:
532        vhat_t = math_ops.maximum(vhat, v_t)
533        p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
534        self.updates.append(state_ops.assign(vhat, vhat_t))
535      else:
536        p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
537
538      self.updates.append(state_ops.assign(m, m_t))
539      self.updates.append(state_ops.assign(v, v_t))
540      new_p = p_t
541
542      # Apply constraints.
543      if getattr(p, 'constraint', None) is not None:
544        new_p = p.constraint(new_p)
545
546      self.updates.append(state_ops.assign(p, new_p))
547    return self.updates
548
549  def get_config(self):
550    config = {
551        'lr': float(K.get_value(self.lr)),
552        'beta_1': float(K.get_value(self.beta_1)),
553        'beta_2': float(K.get_value(self.beta_2)),
554        'decay': float(K.get_value(self.decay)),
555        'epsilon': self.epsilon,
556        'amsgrad': self.amsgrad
557    }
558    base_config = super(Adam, self).get_config()
559    return dict(list(base_config.items()) + list(config.items()))
560
561
562class Adamax(Optimizer):
563  """Adamax optimizer from Adam paper's Section 7.
564
565  It is a variant of Adam based on the infinity norm.
566  Default parameters follow those provided in the paper.
567
568  Args:
569      lr: float >= 0. Learning rate.
570      beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
571      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
572      decay: float >= 0. Learning rate decay over each update.
573  """
574
575  def __init__(self,
576               lr=0.002,
577               beta_1=0.9,
578               beta_2=0.999,
579               epsilon=None,
580               decay=0.,
581               **kwargs):
582    super(Adamax, self).__init__(**kwargs)
583    with K.name_scope(self.__class__.__name__):
584      self.iterations = K.variable(0, dtype='int64', name='iterations')
585      self.lr = K.variable(lr, name='lr')
586      self.beta_1 = K.variable(beta_1, name='beta_1')
587      self.beta_2 = K.variable(beta_2, name='beta_2')
588      self.decay = K.variable(decay, name='decay')
589    if epsilon is None:
590      epsilon = K.epsilon()
591    self.epsilon = epsilon
592    self.initial_decay = decay
593
594  def _create_all_weights(self, params):
595
596    shapes = [K.int_shape(p) for p in params]
597    # zero init of 1st moment
598    ms = [K.zeros(shape) for shape in shapes]
599    # zero init of exponentially weighted infinity norm
600    us = [K.zeros(shape) for shape in shapes]
601    self.weights = [self.iterations] + ms + us
602    return ms, us
603
604  def get_updates(self, loss, params):
605    grads = self.get_gradients(loss, params)
606    self.updates = []
607
608    lr = self.lr
609    if self.initial_decay > 0:
610      lr = lr * (  # pylint: disable=g-no-augmented-assignment
611          1. /
612          (1. +
613           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
614
615    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
616      t = math_ops.cast(self.iterations, K.floatx())
617    lr_t = lr / (1. - math_ops.pow(self.beta_1, t))
618
619    ms, us = self._create_all_weights(params)
620
621    for p, g, m, u in zip(params, grads, ms, us):
622
623      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
624      u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g))
625      p_t = p - lr_t * m_t / (u_t + self.epsilon)
626
627      self.updates.append(state_ops.assign(m, m_t))
628      self.updates.append(state_ops.assign(u, u_t))
629      new_p = p_t
630
631      # Apply constraints.
632      if getattr(p, 'constraint', None) is not None:
633        new_p = p.constraint(new_p)
634
635      self.updates.append(state_ops.assign(p, new_p))
636    return self.updates
637
638  def get_config(self):
639    config = {
640        'lr': float(K.get_value(self.lr)),
641        'beta_1': float(K.get_value(self.beta_1)),
642        'beta_2': float(K.get_value(self.beta_2)),
643        'decay': float(K.get_value(self.decay)),
644        'epsilon': self.epsilon
645    }
646    base_config = super(Adamax, self).get_config()
647    return dict(list(base_config.items()) + list(config.items()))
648
649
650class Nadam(Optimizer):
651  """Nesterov Adam optimizer.
652
653  Much like Adam is essentially RMSprop with momentum,
654  Nadam is Adam RMSprop with Nesterov momentum.
655
656  Default parameters follow those provided in the paper.
657  It is recommended to leave the parameters of this optimizer
658  at their default values.
659
660  Args:
661      lr: float >= 0. Learning rate.
662      beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
663      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
664  """
665
666  def __init__(self,
667               lr=0.002,
668               beta_1=0.9,
669               beta_2=0.999,
670               epsilon=None,
671               schedule_decay=0.004,
672               **kwargs):
673    super(Nadam, self).__init__(**kwargs)
674    with K.name_scope(self.__class__.__name__):
675      self.iterations = K.variable(0, dtype='int64', name='iterations')
676      self.m_schedule = K.variable(1., name='m_schedule')
677      self.lr = K.variable(lr, name='lr')
678      self.beta_1 = K.variable(beta_1, name='beta_1')
679      self.beta_2 = K.variable(beta_2, name='beta_2')
680    if epsilon is None:
681      epsilon = K.epsilon()
682    self.epsilon = epsilon
683    self.schedule_decay = schedule_decay
684
685  def _create_all_weights(self, params):
686    shapes = [K.int_shape(p) for p in params]
687    ms = [K.zeros(shape) for shape in shapes]
688    vs = [K.zeros(shape) for shape in shapes]
689
690    self.weights = [self.iterations, self.m_schedule] + ms + vs
691    return ms, vs
692
693  def get_updates(self, loss, params):
694    grads = self.get_gradients(loss, params)
695    self.updates = []
696
697    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
698      t = math_ops.cast(self.iterations, K.floatx())
699
700    # Due to the recommendations in [2], i.e. warming momentum schedule
701    momentum_cache_t = self.beta_1 * (
702        1. - 0.5 *
703        (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
704    momentum_cache_t_1 = self.beta_1 * (
705        1. - 0.5 *
706        (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
707    m_schedule_new = self.m_schedule * momentum_cache_t
708    m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
709    self.updates.append((self.m_schedule, m_schedule_new))
710
711    ms, vs = self._create_all_weights(params)
712
713    for p, g, m, v in zip(params, grads, ms, vs):
714      # the following equations given in [1]
715      g_prime = g / (1. - m_schedule_new)
716      m_t = self.beta_1 * m + (1. - self.beta_1) * g
717      m_t_prime = m_t / (1. - m_schedule_next)
718      v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
719      v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
720      m_t_bar = (1. -
721                 momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
722
723      self.updates.append(state_ops.assign(m, m_t))
724      self.updates.append(state_ops.assign(v, v_t))
725
726      p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
727      new_p = p_t
728
729      # Apply constraints.
730      if getattr(p, 'constraint', None) is not None:
731        new_p = p.constraint(new_p)
732
733      self.updates.append(state_ops.assign(p, new_p))
734    return self.updates
735
736  def get_config(self):
737    config = {
738        'lr': float(K.get_value(self.lr)),
739        'beta_1': float(K.get_value(self.beta_1)),
740        'beta_2': float(K.get_value(self.beta_2)),
741        'epsilon': self.epsilon,
742        'schedule_decay': self.schedule_decay
743    }
744    base_config = super(Nadam, self).get_config()
745    return dict(list(base_config.items()) + list(config.items()))
746
747
748class TFOptimizer(Optimizer, trackable.Trackable):
749  """Wrapper class for native TensorFlow optimizers."""
750
751  def __init__(self, optimizer, iterations=None):  # pylint: disable=super-init-not-called
752    self.optimizer = optimizer
753    self._track_trackable(optimizer, name='optimizer')
754    if iterations is None:
755      with K.name_scope(self.__class__.__name__):
756        self.iterations = K.variable(0, dtype='int64', name='iterations')
757    else:
758      self.iterations = iterations
759    self._track_trackable(self.iterations, name='global_step')
760
761  def _clip_gradients(self, grads):
762    """Clip gradients according to the clipnorm and clipvalue attributes."""
763    # TFOptimizer wrapper has no gradient clipping options.
764    return grads
765
766  def minimize(self, loss, var_list, grad_loss=None, tape=None):
767    """Mimics the `OptimizerV2.minimize` API."""
768    if not callable(loss) and tape is None:
769      raise ValueError('`tape` is required when a `Tensor` loss is passed.')
770    tape = tape if tape is not None else backprop.GradientTape()
771
772    if callable(loss):
773      with tape:
774        if not callable(var_list):
775          tape.watch(var_list)
776        loss = loss()
777        if callable(var_list):
778          var_list = var_list()
779
780    var_list = nest.flatten(var_list)
781    if var_list:
782      grads = tape.gradient(loss, var_list, grad_loss)
783      grads_and_vars = list(zip(grads, var_list))
784      self.apply_gradients(grads_and_vars)
785
786  def apply_gradients(self, grads_and_vars):
787    self.optimizer.apply_gradients(grads_and_vars, global_step=self.iterations)
788
789  def get_grads(self, loss, params):
790    return self.optimizer.compute_gradients(loss, params)
791
792  def get_updates(self, loss, params):
793    if distribution_strategy_context.has_strategy():
794      self.updates = []
795
796      if not params:
797        # After the model vars have been created, the second call to get_updates
798        # is called with params as an empty list. This ensures that we call
799        # compute_gradients with params=None.
800        grads = self.optimizer.compute_gradients(loss)
801      else:
802        grads = self.optimizer.compute_gradients(loss, params)
803      global_step = training_util.get_global_step()
804      opt_update = self.optimizer.apply_gradients(grads, global_step)
805    else:
806      if not params:
807        self.updates = [state_ops.assign_add(self.iterations, 1)]
808        return self.updates
809
810      # Updates list starts out empty because the iterations variable is
811      # incremented in optimizer.apply_gradients()
812      self.updates = []
813      grads = self.optimizer.compute_gradients(loss, params)
814      opt_update = self.optimizer.apply_gradients(
815          grads, global_step=self.iterations)
816
817    self.updates.append(opt_update)
818    return self.updates
819
820  @property
821  def weights(self):
822    raise NotImplementedError
823
824  def get_config(self):
825    raise NotImplementedError
826
827  def from_config(self, config):
828    raise NotImplementedError
829
830
831# Aliases.
832
833sgd = SGD
834rmsprop = RMSprop
835adagrad = Adagrad
836adadelta = Adadelta
837adam = Adam
838adamax = Adamax
839nadam = Nadam
840