• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=g-classes-have-attributes
17"""Legacy v1 optimizer classes.
18
19For more examples see the base class `tf.compat.v1.keras.optimizers.Optimizer`.
20"""
21
22from tensorflow.python.distribute import distribution_strategy_context
23from tensorflow.python.eager import backprop
24from tensorflow.python.framework import ops
25from tensorflow.python.keras import backend
26from tensorflow.python.ops import clip_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.trackable import base as trackable
30from tensorflow.python.training import training_util
31from tensorflow.python.util import nest
32
33
34class Optimizer(object):
35  """Abstract optimizer base class.
36
37  Note: this is the parent class of all optimizers, not an actual optimizer
38  that can be used for training models.
39
40  All Keras optimizers support the following keyword arguments:
41
42      clipnorm: float >= 0. Gradients will be clipped
43          when their L2 norm exceeds this value.
44      clipvalue: float >= 0. Gradients will be clipped
45          when their absolute value exceeds this value.
46  """
47
48  def __init__(self, **kwargs):
49    allowed_kwargs = {'clipnorm', 'clipvalue'}
50    for k in kwargs:
51      if k not in allowed_kwargs:
52        raise TypeError('Unexpected keyword argument '
53                        'passed to optimizer: ' + str(k))
54      # checks that clipnorm >= 0 and clipvalue >= 0
55      if kwargs[k] < 0:
56        raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k]))
57    self.__dict__.update(kwargs)
58    self.updates = []
59    self.weights = []
60
61  # Set this to False, indicating `apply_gradients` does not take the
62  # `experimental_aggregate_gradients` argument.
63  _HAS_AGGREGATE_GRAD = False
64
65  def _create_all_weights(self, params):
66    """Creates and sets all optimizer weights.
67
68    Args:
69      params: list or tuple of `Variable` objects that will be minimized
70        using this optimizer.
71
72    Returns:
73      Specific weight values that are used in `get_updates`
74    """
75    raise NotImplementedError
76
77  def get_updates(self, loss, params):
78    raise NotImplementedError
79
80  def get_gradients(self, loss, params):
81    """Returns gradients of `loss` with respect to `params`.
82
83    Args:
84        loss: Loss tensor.
85        params: List of variables.
86
87    Returns:
88        List of gradient tensors.
89
90    Raises:
91        ValueError: In case any gradient cannot be computed (e.g. if gradient
92          function not implemented).
93    """
94    grads = backend.gradients(loss, params)
95    if any(g is None for g in grads):
96      raise ValueError('An operation has `None` for gradient. '
97                       'Please make sure that all of your ops have a '
98                       'gradient defined (i.e. are differentiable). '
99                       'Common ops without gradient: '
100                       'backend.argmax, backend.round, backend.eval.')
101    if hasattr(self, 'clipnorm'):
102      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
103    if hasattr(self, 'clipvalue'):
104      grads = [
105          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
106          for g in grads
107      ]
108    return grads
109
110  def set_weights(self, weights):
111    """Sets the weights of the optimizer, from Numpy arrays.
112
113    Should only be called after computing the gradients
114    (otherwise the optimizer has no weights).
115
116    Args:
117        weights: a list of Numpy arrays. The number of arrays and their shape
118          must match number of the dimensions of the weights of the optimizer
119          (i.e. it should match the output of `get_weights`).
120
121    Raises:
122        ValueError: in case of incompatible weight shapes.
123    """
124    params = self.weights
125    if len(params) != len(weights):
126      raise ValueError('Length of the specified weight list (' +
127                       str(len(weights)) +
128                       ') does not match the number of weights '
129                       'of the optimizer (' + str(len(params)) + ')')
130    weight_value_tuples = []
131    param_values = backend.batch_get_value(params)
132    for pv, p, w in zip(param_values, params, weights):
133      if pv.shape != w.shape:
134        raise ValueError('Optimizer weight shape ' + str(pv.shape) +
135                         ' not compatible with '
136                         'provided weight shape ' + str(w.shape))
137      weight_value_tuples.append((p, w))
138    backend.batch_set_value(weight_value_tuples)
139
140  def get_weights(self):
141    """Returns the current value of the weights of the optimizer.
142
143    Returns:
144        A list of numpy arrays.
145    """
146    return backend.batch_get_value(self.weights)
147
148  def get_config(self):
149    config = {}
150    if hasattr(self, 'clipnorm'):
151      config['clipnorm'] = self.clipnorm
152    if hasattr(self, 'clipvalue'):
153      config['clipvalue'] = self.clipvalue
154    return config
155
156  @classmethod
157  def from_config(cls, config):
158    return cls(**config)
159
160
161class SGD(Optimizer):
162  """Stochastic gradient descent optimizer.
163
164  Includes support for momentum,
165  learning rate decay, and Nesterov momentum.
166
167  Args:
168      lr: float >= 0. Learning rate.
169      momentum: float >= 0. Parameter that accelerates SGD in the relevant
170        direction and dampens oscillations.
171      decay: float >= 0. Learning rate decay over each update.
172      nesterov: boolean. Whether to apply Nesterov momentum.
173  """
174
175  def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, **kwargs):
176    super(SGD, self).__init__(**kwargs)
177    with backend.name_scope(self.__class__.__name__):
178      self.iterations = backend.variable(0, dtype='int64', name='iterations')
179      self.lr = backend.variable(lr, name='lr')
180      self.momentum = backend.variable(momentum, name='momentum')
181      self.decay = backend.variable(decay, name='decay')
182    self.initial_decay = decay
183    self.nesterov = nesterov
184
185  def _create_all_weights(self, params):
186    shapes = [backend.int_shape(p) for p in params]
187    moments = [backend.zeros(shape) for shape in shapes]
188    self.weights = [self.iterations] + moments
189    return moments
190
191  def get_updates(self, loss, params):
192    grads = self.get_gradients(loss, params)
193    self.updates = [state_ops.assign_add(self.iterations, 1)]
194
195    lr = self.lr
196    if self.initial_decay > 0:
197      lr = lr * (
198          1. /
199          (1. +
200           self.decay * math_ops.cast(self.iterations,
201                                      backend.dtype(self.decay))))
202    # momentum
203    moments = self._create_all_weights(params)
204    for p, g, m in zip(params, grads, moments):
205      v = self.momentum * m - lr * g  # velocity
206      self.updates.append(state_ops.assign(m, v))
207
208      if self.nesterov:
209        new_p = p + self.momentum * v - lr * g
210      else:
211        new_p = p + v
212
213      # Apply constraints.
214      if getattr(p, 'constraint', None) is not None:
215        new_p = p.constraint(new_p)
216
217      self.updates.append(state_ops.assign(p, new_p))
218    return self.updates
219
220  def get_config(self):
221    config = {
222        'lr': float(backend.get_value(self.lr)),
223        'momentum': float(backend.get_value(self.momentum)),
224        'decay': float(backend.get_value(self.decay)),
225        'nesterov': self.nesterov
226    }
227    base_config = super(SGD, self).get_config()
228    return dict(list(base_config.items()) + list(config.items()))
229
230
231class RMSprop(Optimizer):
232  """RMSProp optimizer.
233
234  It is recommended to leave the parameters of this optimizer
235  at their default values
236  (except the learning rate, which can be freely tuned).
237
238  Args:
239    lr: float >= 0. Learning rate.
240    rho: float >= 0.
241    epsilon: float >= 0. Fuzz factor.
242      If `None`, defaults to `backend.epsilon()`.
243    decay: float >= 0. Learning rate decay over each update.
244  """
245
246  def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
247    super(RMSprop, self).__init__(**kwargs)
248    with backend.name_scope(self.__class__.__name__):
249      self.lr = backend.variable(lr, name='lr')
250      self.rho = backend.variable(rho, name='rho')
251      self.decay = backend.variable(decay, name='decay')
252      self.iterations = backend.variable(0, dtype='int64', name='iterations')
253    if epsilon is None:
254      epsilon = backend.epsilon()
255    self.epsilon = epsilon
256    self.initial_decay = decay
257
258  def _create_all_weights(self, params):
259    accumulators = [
260        backend.zeros(backend.int_shape(p), dtype=backend.dtype(p))
261        for p in params]
262    self.weights = accumulators
263    return accumulators
264
265  def get_updates(self, loss, params):
266    grads = self.get_gradients(loss, params)
267    accumulators = self._create_all_weights(params)
268    self.updates = [state_ops.assign_add(self.iterations, 1)]
269
270    lr = self.lr
271    if self.initial_decay > 0:
272      lr = lr * (
273          1. /
274          (1. +
275           self.decay * math_ops.cast(self.iterations,
276                                      backend.dtype(self.decay))))
277
278    for p, g, a in zip(params, grads, accumulators):
279      # update accumulator
280      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
281      self.updates.append(state_ops.assign(a, new_a))
282      new_p = p - lr * g / (backend.sqrt(new_a) + self.epsilon)
283
284      # Apply constraints.
285      if getattr(p, 'constraint', None) is not None:
286        new_p = p.constraint(new_p)
287
288      self.updates.append(state_ops.assign(p, new_p))
289    return self.updates
290
291  def get_config(self):
292    config = {
293        'lr': float(backend.get_value(self.lr)),
294        'rho': float(backend.get_value(self.rho)),
295        'decay': float(backend.get_value(self.decay)),
296        'epsilon': self.epsilon
297    }
298    base_config = super(RMSprop, self).get_config()
299    return dict(list(base_config.items()) + list(config.items()))
300
301
302class Adagrad(Optimizer):
303  """Adagrad optimizer.
304
305  Adagrad is an optimizer with parameter-specific learning rates,
306  which are adapted relative to how frequently a parameter gets
307  updated during training. The more updates a parameter receives,
308  the smaller the updates.
309
310  It is recommended to leave the parameters of this optimizer
311  at their default values.
312
313  # Arguments
314      lr: float >= 0. Initial learning rate.
315      epsilon: float >= 0. If `None`, defaults to `backend.epsilon()`.
316      decay: float >= 0. Learning rate decay over each update.
317
318  # References
319      - [Adaptive Subgradient Methods for Online Learning and Stochastic
320      Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
321  """
322
323  def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
324    super(Adagrad, self).__init__(**kwargs)
325    with backend.name_scope(self.__class__.__name__):
326      self.lr = backend.variable(lr, name='lr')
327      self.decay = backend.variable(decay, name='decay')
328      self.iterations = backend.variable(0, dtype='int64', name='iterations')
329    if epsilon is None:
330      epsilon = backend.epsilon()
331    self.epsilon = epsilon
332    self.initial_decay = decay
333
334  def _create_all_weights(self, params):
335    shapes = [backend.int_shape(p) for p in params]
336    accumulators = [backend.zeros(shape) for shape in shapes]
337    self.weights = accumulators
338    return accumulators
339
340  def get_updates(self, loss, params):
341    grads = self.get_gradients(loss, params)
342    accumulators = self._create_all_weights(params)
343
344    self.updates = [state_ops.assign_add(self.iterations, 1)]
345
346    lr = self.lr
347    if self.initial_decay > 0:
348      lr = lr * (
349          1. /
350          (1. +
351           self.decay * math_ops.cast(self.iterations,
352                                      backend.dtype(self.decay))))
353
354    for p, g, a in zip(params, grads, accumulators):
355      new_a = a + math_ops.square(g)  # update accumulator
356      self.updates.append(state_ops.assign(a, new_a))
357      new_p = p - lr * g / (backend.sqrt(new_a) + self.epsilon)
358
359      # Apply constraints.
360      if getattr(p, 'constraint', None) is not None:
361        new_p = p.constraint(new_p)
362
363      self.updates.append(state_ops.assign(p, new_p))
364    return self.updates
365
366  def get_config(self):
367    config = {
368        'lr': float(backend.get_value(self.lr)),
369        'decay': float(backend.get_value(self.decay)),
370        'epsilon': self.epsilon
371    }
372    base_config = super(Adagrad, self).get_config()
373    return dict(list(base_config.items()) + list(config.items()))
374
375
376class Adadelta(Optimizer):
377  """Adadelta optimizer.
378
379  Adadelta is a more robust extension of Adagrad
380  that adapts learning rates based on a moving window of gradient updates,
381  instead of accumulating all past gradients. This way, Adadelta continues
382  learning even when many updates have been done. Compared to Adagrad, in the
383  original version of Adadelta you don't have to set an initial learning
384  rate. In this version, initial learning rate and decay factor can
385  be set, as in most other Keras optimizers.
386
387  It is recommended to leave the parameters of this optimizer
388  at their default values.
389
390  Arguments:
391    lr: float >= 0. Initial learning rate, defaults to 1.
392        It is recommended to leave it at the default value.
393    rho: float >= 0. Adadelta decay factor, corresponding to fraction of
394        gradient to keep at each time step.
395    epsilon: float >= 0. Fuzz factor.
396      If `None`, defaults to `backend.epsilon()`.
397    decay: float >= 0. Initial learning rate decay.
398
399  References:
400      - [Adadelta - an adaptive learning rate
401      method](http://arxiv.org/abs/1212.5701)
402  """
403
404  def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
405    super(Adadelta, self).__init__(**kwargs)
406    with backend.name_scope(self.__class__.__name__):
407      self.lr = backend.variable(lr, name='lr')
408      self.decay = backend.variable(decay, name='decay')
409      self.iterations = backend.variable(0, dtype='int64', name='iterations')
410    if epsilon is None:
411      epsilon = backend.epsilon()
412    self.rho = rho
413    self.epsilon = epsilon
414    self.initial_decay = decay
415
416  def _create_all_weights(self, params):
417    shapes = [backend.int_shape(p) for p in params]
418    accumulators = [backend.zeros(shape) for shape in shapes]
419    delta_accumulators = [backend.zeros(shape) for shape in shapes]
420    self.weights = accumulators + delta_accumulators
421    return accumulators, delta_accumulators
422
423  def get_updates(self, loss, params):
424    grads = self.get_gradients(loss, params)
425    self.updates = [state_ops.assign_add(self.iterations, 1)]
426    accumulators, delta_accumulators = self._create_all_weights(params)
427
428    lr = self.lr
429    if self.initial_decay > 0:
430      lr = lr * (
431          1. /
432          (1. +
433           self.decay * math_ops.cast(self.iterations,
434                                      backend.dtype(self.decay))))
435
436    for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
437      # update accumulator
438      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
439      self.updates.append(state_ops.assign(a, new_a))
440
441      # use the new accumulator and the *old* delta_accumulator
442      update = g * backend.sqrt(d_a + self.epsilon) / backend.sqrt(
443          new_a + self.epsilon)
444      new_p = p - lr * update
445
446      # Apply constraints.
447      if getattr(p, 'constraint', None) is not None:
448        new_p = p.constraint(new_p)
449
450      self.updates.append(state_ops.assign(p, new_p))
451
452      # update delta_accumulator
453      new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update)
454      self.updates.append(state_ops.assign(d_a, new_d_a))
455    return self.updates
456
457  def get_config(self):
458    config = {
459        'lr': float(backend.get_value(self.lr)),
460        'rho': self.rho,
461        'decay': float(backend.get_value(self.decay)),
462        'epsilon': self.epsilon
463    }
464    base_config = super(Adadelta, self).get_config()
465    return dict(list(base_config.items()) + list(config.items()))
466
467
468class Adam(Optimizer):
469  """Adam optimizer.
470
471  Default parameters follow those provided in the original paper.
472
473  Args:
474    lr: float >= 0. Learning rate.
475    beta_1: float, 0 < beta < 1. Generally close to 1.
476    beta_2: float, 0 < beta < 1. Generally close to 1.
477    epsilon: float >= 0. Fuzz factor.
478      If `None`, defaults to `backend.epsilon()`.
479    decay: float >= 0. Learning rate decay over each update.
480    amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm
481      from the paper "On the Convergence of Adam and Beyond".
482  """
483
484  def __init__(self,
485               lr=0.001,
486               beta_1=0.9,
487               beta_2=0.999,
488               epsilon=None,
489               decay=0.,
490               amsgrad=False,
491               **kwargs):
492    super(Adam, self).__init__(**kwargs)
493    with backend.name_scope(self.__class__.__name__):
494      self.iterations = backend.variable(0, dtype='int64', name='iterations')
495      self.lr = backend.variable(lr, name='lr')
496      self.beta_1 = backend.variable(beta_1, name='beta_1')
497      self.beta_2 = backend.variable(beta_2, name='beta_2')
498      self.decay = backend.variable(decay, name='decay')
499    if epsilon is None:
500      epsilon = backend.epsilon()
501    self.epsilon = epsilon
502    self.initial_decay = decay
503    self.amsgrad = amsgrad
504
505  def _create_all_weights(self, params):
506    ms = [
507        backend.zeros(backend.int_shape(p), dtype=backend.dtype(p))
508        for p in params]
509    vs = [
510        backend.zeros(backend.int_shape(p), dtype=backend.dtype(p))
511        for p in params]
512    if self.amsgrad:
513      vhats = [
514          backend.zeros(backend.int_shape(p), dtype=backend.dtype(p))
515          for p in params]
516    else:
517      vhats = [backend.zeros(1) for _ in params]
518    self.weights = [self.iterations] + ms + vs + vhats
519    return ms, vs, vhats
520
521  def get_updates(self, loss, params):
522    grads = self.get_gradients(loss, params)
523    self.updates = []
524
525    lr = self.lr
526    if self.initial_decay > 0:
527      lr = lr * (
528          1. /
529          (1. +
530           self.decay * math_ops.cast(self.iterations,
531                                      backend.dtype(self.decay))))
532
533    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
534      t = math_ops.cast(self.iterations, backend.floatx())
535    lr_t = lr * (
536        backend.sqrt(1. - math_ops.pow(self.beta_2, t)) /
537        (1. - math_ops.pow(self.beta_1, t)))
538
539    ms, vs, vhats = self._create_all_weights(params)
540    for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
541      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
542      v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
543      if self.amsgrad:
544        vhat_t = math_ops.maximum(vhat, v_t)
545        p_t = p - lr_t * m_t / (backend.sqrt(vhat_t) + self.epsilon)
546        self.updates.append(state_ops.assign(vhat, vhat_t))
547      else:
548        p_t = p - lr_t * m_t / (backend.sqrt(v_t) + self.epsilon)
549
550      self.updates.append(state_ops.assign(m, m_t))
551      self.updates.append(state_ops.assign(v, v_t))
552      new_p = p_t
553
554      # Apply constraints.
555      if getattr(p, 'constraint', None) is not None:
556        new_p = p.constraint(new_p)
557
558      self.updates.append(state_ops.assign(p, new_p))
559    return self.updates
560
561  def get_config(self):
562    config = {
563        'lr': float(backend.get_value(self.lr)),
564        'beta_1': float(backend.get_value(self.beta_1)),
565        'beta_2': float(backend.get_value(self.beta_2)),
566        'decay': float(backend.get_value(self.decay)),
567        'epsilon': self.epsilon,
568        'amsgrad': self.amsgrad
569    }
570    base_config = super(Adam, self).get_config()
571    return dict(list(base_config.items()) + list(config.items()))
572
573
574class Adamax(Optimizer):
575  """Adamax optimizer from Adam paper's Section 7.
576
577  It is a variant of Adam based on the infinity norm.
578  Default parameters follow those provided in the paper.
579
580  Args:
581    lr: float >= 0. Learning rate.
582    beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
583    epsilon: float >= 0. Fuzz factor.
584      If `None`, defaults to `backend.epsilon()`.
585    decay: float >= 0. Learning rate decay over each update.
586  """
587
588  def __init__(self,
589               lr=0.002,
590               beta_1=0.9,
591               beta_2=0.999,
592               epsilon=None,
593               decay=0.,
594               **kwargs):
595    super(Adamax, self).__init__(**kwargs)
596    with backend.name_scope(self.__class__.__name__):
597      self.iterations = backend.variable(0, dtype='int64', name='iterations')
598      self.lr = backend.variable(lr, name='lr')
599      self.beta_1 = backend.variable(beta_1, name='beta_1')
600      self.beta_2 = backend.variable(beta_2, name='beta_2')
601      self.decay = backend.variable(decay, name='decay')
602    if epsilon is None:
603      epsilon = backend.epsilon()
604    self.epsilon = epsilon
605    self.initial_decay = decay
606
607  def _create_all_weights(self, params):
608
609    shapes = [backend.int_shape(p) for p in params]
610    # zero init of 1st moment
611    ms = [backend.zeros(shape) for shape in shapes]
612    # zero init of exponentially weighted infinity norm
613    us = [backend.zeros(shape) for shape in shapes]
614    self.weights = [self.iterations] + ms + us
615    return ms, us
616
617  def get_updates(self, loss, params):
618    grads = self.get_gradients(loss, params)
619    self.updates = []
620
621    lr = self.lr
622    if self.initial_decay > 0:
623      lr = lr * (
624          1. /
625          (1. +
626           self.decay * math_ops.cast(self.iterations,
627                                      backend.dtype(self.decay))))
628
629    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
630      t = math_ops.cast(self.iterations, backend.floatx())
631    lr_t = lr / (1. - math_ops.pow(self.beta_1, t))
632
633    ms, us = self._create_all_weights(params)
634
635    for p, g, m, u in zip(params, grads, ms, us):
636
637      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
638      u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g))
639      p_t = p - lr_t * m_t / (u_t + self.epsilon)
640
641      self.updates.append(state_ops.assign(m, m_t))
642      self.updates.append(state_ops.assign(u, u_t))
643      new_p = p_t
644
645      # Apply constraints.
646      if getattr(p, 'constraint', None) is not None:
647        new_p = p.constraint(new_p)
648
649      self.updates.append(state_ops.assign(p, new_p))
650    return self.updates
651
652  def get_config(self):
653    config = {
654        'lr': float(backend.get_value(self.lr)),
655        'beta_1': float(backend.get_value(self.beta_1)),
656        'beta_2': float(backend.get_value(self.beta_2)),
657        'decay': float(backend.get_value(self.decay)),
658        'epsilon': self.epsilon
659    }
660    base_config = super(Adamax, self).get_config()
661    return dict(list(base_config.items()) + list(config.items()))
662
663
664class Nadam(Optimizer):
665  """Nesterov Adam optimizer.
666
667  Much like Adam is essentially RMSprop with momentum,
668  Nadam is Adam RMSprop with Nesterov momentum.
669
670  Default parameters follow those provided in the paper.
671  It is recommended to leave the parameters of this optimizer
672  at their default values.
673
674  Args:
675    lr: float >= 0. Learning rate.
676    beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
677    epsilon: float >= 0. Fuzz factor.
678      If `None`, defaults to `backend.epsilon()`.
679  """
680
681  def __init__(self,
682               lr=0.002,
683               beta_1=0.9,
684               beta_2=0.999,
685               epsilon=None,
686               schedule_decay=0.004,
687               **kwargs):
688    super(Nadam, self).__init__(**kwargs)
689    with backend.name_scope(self.__class__.__name__):
690      self.iterations = backend.variable(0, dtype='int64', name='iterations')
691      self.m_schedule = backend.variable(1., name='m_schedule')
692      self.lr = backend.variable(lr, name='lr')
693      self.beta_1 = backend.variable(beta_1, name='beta_1')
694      self.beta_2 = backend.variable(beta_2, name='beta_2')
695    if epsilon is None:
696      epsilon = backend.epsilon()
697    self.epsilon = epsilon
698    self.schedule_decay = schedule_decay
699
700  def _create_all_weights(self, params):
701    shapes = [backend.int_shape(p) for p in params]
702    ms = [backend.zeros(shape) for shape in shapes]
703    vs = [backend.zeros(shape) for shape in shapes]
704
705    self.weights = [self.iterations, self.m_schedule] + ms + vs
706    return ms, vs
707
708  def get_updates(self, loss, params):
709    grads = self.get_gradients(loss, params)
710    self.updates = []
711
712    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
713      t = math_ops.cast(self.iterations, backend.floatx())
714
715    # Due to the recommendations in [2], i.e. warming momentum schedule
716    momentum_cache_t = self.beta_1 * (
717        1. - 0.5 *
718        (math_ops.pow(backend.cast_to_floatx(0.96), t * self.schedule_decay)))
719    momentum_cache_t_1 = self.beta_1 * (
720        1. - 0.5 *
721        (math_ops.pow(backend.cast_to_floatx(0.96),
722                      (t + 1) * self.schedule_decay)))
723    m_schedule_new = self.m_schedule * momentum_cache_t
724    m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
725    self.updates.append((self.m_schedule, m_schedule_new))
726
727    ms, vs = self._create_all_weights(params)
728
729    for p, g, m, v in zip(params, grads, ms, vs):
730      # the following equations given in [1]
731      g_prime = g / (1. - m_schedule_new)
732      m_t = self.beta_1 * m + (1. - self.beta_1) * g
733      m_t_prime = m_t / (1. - m_schedule_next)
734      v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
735      v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
736      m_t_bar = (1. -
737                 momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
738
739      self.updates.append(state_ops.assign(m, m_t))
740      self.updates.append(state_ops.assign(v, v_t))
741
742      p_t = p - self.lr * m_t_bar / (backend.sqrt(v_t_prime) + self.epsilon)
743      new_p = p_t
744
745      # Apply constraints.
746      if getattr(p, 'constraint', None) is not None:
747        new_p = p.constraint(new_p)
748
749      self.updates.append(state_ops.assign(p, new_p))
750    return self.updates
751
752  def get_config(self):
753    config = {
754        'lr': float(backend.get_value(self.lr)),
755        'beta_1': float(backend.get_value(self.beta_1)),
756        'beta_2': float(backend.get_value(self.beta_2)),
757        'epsilon': self.epsilon,
758        'schedule_decay': self.schedule_decay
759    }
760    base_config = super(Nadam, self).get_config()
761    return dict(list(base_config.items()) + list(config.items()))
762
763
764class TFOptimizer(Optimizer, trackable.Trackable):
765  """Wrapper class for native TensorFlow optimizers."""
766
767  def __init__(self, optimizer, iterations=None):  # pylint: disable=super-init-not-called
768    self.optimizer = optimizer
769    self._track_trackable(optimizer, name='optimizer')
770    if iterations is None:
771      with backend.name_scope(self.__class__.__name__):
772        self.iterations = backend.variable(0, dtype='int64', name='iterations')
773    else:
774      self.iterations = iterations
775    self._track_trackable(self.iterations, name='global_step')
776
777  def _clip_gradients(self, grads):
778    """Clip gradients according to the clipnorm and clipvalue attributes."""
779    # TFOptimizer wrapper has no gradient clipping options.
780    return grads
781
782  def minimize(self, loss, var_list, grad_loss=None, tape=None):
783    """Mimics the `OptimizerV2.minimize` API."""
784    if not callable(loss) and tape is None:
785      raise ValueError('`tape` is required when a `Tensor` loss is passed.')
786    tape = tape if tape is not None else backprop.GradientTape()
787
788    if callable(loss):
789      with tape:
790        if not callable(var_list):
791          tape.watch(var_list)
792        loss = loss()
793        if callable(var_list):
794          var_list = var_list()
795
796    var_list = nest.flatten(var_list)
797    if var_list:
798      grads = tape.gradient(loss, var_list, grad_loss)
799      grads_and_vars = list(zip(grads, var_list))
800      self.apply_gradients(grads_and_vars)
801
802  def apply_gradients(self, grads_and_vars):
803    self.optimizer.apply_gradients(grads_and_vars, global_step=self.iterations)
804
805  def get_grads(self, loss, params):
806    return self.optimizer.compute_gradients(loss, params)
807
808  def get_updates(self, loss, params):
809    if distribution_strategy_context.has_strategy():
810      self.updates = []
811
812      if not params:
813        # After the model vars have been created, the second call to get_updates
814        # is called with params as an empty list. This ensures that we call
815        # compute_gradients with params=None.
816        grads = self.optimizer.compute_gradients(loss)
817      else:
818        grads = self.optimizer.compute_gradients(loss, params)
819      global_step = training_util.get_global_step()
820      opt_update = self.optimizer.apply_gradients(grads, global_step)
821    else:
822      if not params:
823        self.updates = [state_ops.assign_add(self.iterations, 1)]
824        return self.updates
825
826      # Updates list starts out empty because the iterations variable is
827      # incremented in optimizer.apply_gradients()
828      self.updates = []
829      grads = self.optimizer.compute_gradients(loss, params)
830      opt_update = self.optimizer.apply_gradients(
831          grads, global_step=self.iterations)
832
833    self.updates.append(opt_update)
834    return self.updates
835
836  @property
837  def weights(self):
838    raise NotImplementedError
839
840  def get_config(self):
841    raise NotImplementedError
842
843  def from_config(self, config):
844    raise NotImplementedError
845
846
847# Aliases.
848
849sgd = SGD
850rmsprop = RMSprop
851adagrad = Adagrad
852adadelta = Adadelta
853adam = Adam
854adamax = Adamax
855nadam = Nadam
856