• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Built-in optimizer classes.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23from six.moves import zip  # pylint: disable=redefined-builtin
24
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.framework import ops
27from tensorflow.python.keras import backend as K
28from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
29from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
30from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
31from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
32from tensorflow.python.keras.optimizer_v2 import ftrl
33from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
34from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
35from tensorflow.python.keras.optimizer_v2 import optimizer_v2
36from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
37from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
38from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
39from tensorflow.python.ops import clip_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import state_ops
42from tensorflow.python.training import optimizer as tf_optimizer_module
43from tensorflow.python.training import training_util
44from tensorflow.python.training.tracking import base as trackable
45from tensorflow.python.util.tf_export import keras_export
46
47
48class Optimizer(object):
49  """Abstract optimizer base class.
50
51  Note: this is the parent class of all optimizers, not an actual optimizer
52  that can be used for training models.
53
54  All Keras optimizers support the following keyword arguments:
55
56      clipnorm: float >= 0. Gradients will be clipped
57          when their L2 norm exceeds this value.
58      clipvalue: float >= 0. Gradients will be clipped
59          when their absolute value exceeds this value.
60  """
61
62  def __init__(self, **kwargs):
63    allowed_kwargs = {'clipnorm', 'clipvalue'}
64    for k in kwargs:
65      if k not in allowed_kwargs:
66        raise TypeError('Unexpected keyword argument '
67                        'passed to optimizer: ' + str(k))
68      # checks that clipnorm >= 0 and clipvalue >= 0
69      if kwargs[k] < 0:
70        raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k]))
71    self.__dict__.update(kwargs)
72    self.updates = []
73    self.weights = []
74
75  def get_updates(self, loss, params):
76    raise NotImplementedError
77
78  def get_gradients(self, loss, params):
79    """Returns gradients of `loss` with respect to `params`.
80
81    Arguments:
82        loss: Loss tensor.
83        params: List of variables.
84
85    Returns:
86        List of gradient tensors.
87
88    Raises:
89        ValueError: In case any gradient cannot be computed (e.g. if gradient
90          function not implemented).
91    """
92    grads = K.gradients(loss, params)
93    if None in grads:
94      raise ValueError('An operation has `None` for gradient. '
95                       'Please make sure that all of your ops have a '
96                       'gradient defined (i.e. are differentiable). '
97                       'Common ops without gradient: '
98                       'K.argmax, K.round, K.eval.')
99    if hasattr(self, 'clipnorm'):
100      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
101    if hasattr(self, 'clipvalue'):
102      grads = [
103          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
104          for g in grads
105      ]
106    return grads
107
108  def set_weights(self, weights):
109    """Sets the weights of the optimizer, from Numpy arrays.
110
111    Should only be called after computing the gradients
112    (otherwise the optimizer has no weights).
113
114    Arguments:
115        weights: a list of Numpy arrays. The number
116            of arrays and their shape must match
117            number of the dimensions of the weights
118            of the optimizer (i.e. it should match the
119            output of `get_weights`).
120
121    Raises:
122        ValueError: in case of incompatible weight shapes.
123    """
124    params = self.weights
125    if len(params) != len(weights):
126      raise ValueError(
127          'Length of the specified weight list (' + str(len(weights)) +
128          ') does not match the number of weights '
129          'of the optimizer (' + str(len(params)) + ')')
130    weight_value_tuples = []
131    param_values = K.batch_get_value(params)
132    for pv, p, w in zip(param_values, params, weights):
133      if pv.shape != w.shape:
134        raise ValueError(
135            'Optimizer weight shape ' + str(pv.shape) + ' not compatible with '
136            'provided weight shape ' + str(w.shape))
137      weight_value_tuples.append((p, w))
138    K.batch_set_value(weight_value_tuples)
139
140  def get_weights(self):
141    """Returns the current value of the weights of the optimizer.
142
143    Returns:
144        A list of numpy arrays.
145    """
146    return K.batch_get_value(self.weights)
147
148  def get_config(self):
149    config = {}
150    if hasattr(self, 'clipnorm'):
151      config['clipnorm'] = self.clipnorm
152    if hasattr(self, 'clipvalue'):
153      config['clipvalue'] = self.clipvalue
154    return config
155
156  @classmethod
157  def from_config(cls, config):
158    return cls(**config)
159
160
161class SGD(Optimizer):
162  """Stochastic gradient descent optimizer.
163
164  Includes support for momentum,
165  learning rate decay, and Nesterov momentum.
166
167  Arguments:
168      lr: float >= 0. Learning rate.
169      momentum: float >= 0. Parameter that accelerates SGD
170          in the relevant direction and dampens oscillations.
171      decay: float >= 0. Learning rate decay over each update.
172      nesterov: boolean. Whether to apply Nesterov momentum.
173  """
174
175  def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, **kwargs):
176    super(SGD, self).__init__(**kwargs)
177    with K.name_scope(self.__class__.__name__):
178      self.iterations = K.variable(0, dtype='int64', name='iterations')
179      self.lr = K.variable(lr, name='lr')
180      self.momentum = K.variable(momentum, name='momentum')
181      self.decay = K.variable(decay, name='decay')
182    self.initial_decay = decay
183    self.nesterov = nesterov
184
185  def get_updates(self, loss, params):
186    grads = self.get_gradients(loss, params)
187    self.updates = [state_ops.assign_add(self.iterations, 1)]
188
189    lr = self.lr
190    if self.initial_decay > 0:
191      lr = lr * (  # pylint: disable=g-no-augmented-assignment
192          1. / (1. + self.decay * math_ops.cast(self.iterations,
193                                                K.dtype(self.decay))))
194    # momentum
195    shapes = [K.int_shape(p) for p in params]
196    moments = [K.zeros(shape) for shape in shapes]
197    self.weights = [self.iterations] + moments
198    for p, g, m in zip(params, grads, moments):
199      v = self.momentum * m - lr * g  # velocity
200      self.updates.append(state_ops.assign(m, v))
201
202      if self.nesterov:
203        new_p = p + self.momentum * v - lr * g
204      else:
205        new_p = p + v
206
207      # Apply constraints.
208      if getattr(p, 'constraint', None) is not None:
209        new_p = p.constraint(new_p)
210
211      self.updates.append(state_ops.assign(p, new_p))
212    return self.updates
213
214  def get_config(self):
215    config = {
216        'lr': float(K.get_value(self.lr)),
217        'momentum': float(K.get_value(self.momentum)),
218        'decay': float(K.get_value(self.decay)),
219        'nesterov': self.nesterov
220    }
221    base_config = super(SGD, self).get_config()
222    return dict(list(base_config.items()) + list(config.items()))
223
224
225class RMSprop(Optimizer):
226  """RMSProp optimizer.
227
228  It is recommended to leave the parameters of this optimizer
229  at their default values
230  (except the learning rate, which can be freely tuned).
231
232  This optimizer is usually a good choice for recurrent
233  neural networks.
234
235  Arguments:
236      lr: float >= 0. Learning rate.
237      rho: float >= 0.
238      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
239      decay: float >= 0. Learning rate decay over each update.
240
241  """
242
243  def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
244    super(RMSprop, self).__init__(**kwargs)
245    with K.name_scope(self.__class__.__name__):
246      self.lr = K.variable(lr, name='lr')
247      self.rho = K.variable(rho, name='rho')
248      self.decay = K.variable(decay, name='decay')
249      self.iterations = K.variable(0, dtype='int64', name='iterations')
250    if epsilon is None:
251      epsilon = K.epsilon()
252    self.epsilon = epsilon
253    self.initial_decay = decay
254
255  def get_updates(self, loss, params):
256    grads = self.get_gradients(loss, params)
257    accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
258    self.weights = accumulators
259    self.updates = [state_ops.assign_add(self.iterations, 1)]
260
261    lr = self.lr
262    if self.initial_decay > 0:
263      lr = lr * (  # pylint: disable=g-no-augmented-assignment
264          1. / (1. + self.decay * math_ops.cast(self.iterations,
265                                                K.dtype(self.decay))))
266
267    for p, g, a in zip(params, grads, accumulators):
268      # update accumulator
269      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
270      self.updates.append(state_ops.assign(a, new_a))
271      new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
272
273      # Apply constraints.
274      if getattr(p, 'constraint', None) is not None:
275        new_p = p.constraint(new_p)
276
277      self.updates.append(state_ops.assign(p, new_p))
278    return self.updates
279
280  def get_config(self):
281    config = {
282        'lr': float(K.get_value(self.lr)),
283        'rho': float(K.get_value(self.rho)),
284        'decay': float(K.get_value(self.decay)),
285        'epsilon': self.epsilon
286    }
287    base_config = super(RMSprop, self).get_config()
288    return dict(list(base_config.items()) + list(config.items()))
289
290
291class Adagrad(Optimizer):
292  """Adagrad optimizer.
293
294  Adagrad is an optimizer with parameter-specific learning rates,
295  which are adapted relative to how frequently a parameter gets
296  updated during training. The more updates a parameter receives,
297  the smaller the updates.
298
299  It is recommended to leave the parameters of this optimizer
300  at their default values.
301
302  # Arguments
303      lr: float >= 0. Initial learning rate.
304      epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
305      decay: float >= 0. Learning rate decay over each update.
306
307  # References
308      - [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
309  """
310
311  def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
312    super(Adagrad, self).__init__(**kwargs)
313    with K.name_scope(self.__class__.__name__):
314      self.lr = K.variable(lr, name='lr')
315      self.decay = K.variable(decay, name='decay')
316      self.iterations = K.variable(0, dtype='int64', name='iterations')
317    if epsilon is None:
318      epsilon = K.epsilon()
319    self.epsilon = epsilon
320    self.initial_decay = decay
321
322  def get_updates(self, loss, params):
323    grads = self.get_gradients(loss, params)
324    shapes = [K.int_shape(p) for p in params]
325    accumulators = [K.zeros(shape) for shape in shapes]
326    self.weights = accumulators
327    self.updates = [state_ops.assign_add(self.iterations, 1)]
328
329    lr = self.lr
330    if self.initial_decay > 0:
331      lr = lr * (  # pylint: disable=g-no-augmented-assignment
332          1. / (1. + self.decay * math_ops.cast(self.iterations,
333                                                K.dtype(self.decay))))
334
335    for p, g, a in zip(params, grads, accumulators):
336      new_a = a + math_ops.square(g)  # update accumulator
337      self.updates.append(state_ops.assign(a, new_a))
338      new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
339
340      # Apply constraints.
341      if getattr(p, 'constraint', None) is not None:
342        new_p = p.constraint(new_p)
343
344      self.updates.append(state_ops.assign(p, new_p))
345    return self.updates
346
347  def get_config(self):
348    config = {
349        'lr': float(K.get_value(self.lr)),
350        'decay': float(K.get_value(self.decay)),
351        'epsilon': self.epsilon
352    }
353    base_config = super(Adagrad, self).get_config()
354    return dict(list(base_config.items()) + list(config.items()))
355
356
357class Adadelta(Optimizer):
358  """Adadelta optimizer.
359
360  Adadelta is a more robust extension of Adagrad
361  that adapts learning rates based on a moving window of gradient updates,
362  instead of accumulating all past gradients. This way, Adadelta continues
363  learning even when many updates have been done. Compared to Adagrad, in the
364  original version of Adadelta you don't have to set an initial learning
365  rate. In this version, initial learning rate and decay factor can
366  be set, as in most other Keras optimizers.
367
368  It is recommended to leave the parameters of this optimizer
369  at their default values.
370
371  # Arguments
372      lr: float >= 0. Initial learning rate, defaults to 1.
373          It is recommended to leave it at the default value.
374      rho: float >= 0. Adadelta decay factor, corresponding to fraction of
375          gradient to keep at each time step.
376      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
377      decay: float >= 0. Initial learning rate decay.
378
379  # References
380      - [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701)
381  """
382
383  def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
384    super(Adadelta, self).__init__(**kwargs)
385    with K.name_scope(self.__class__.__name__):
386      self.lr = K.variable(lr, name='lr')
387      self.decay = K.variable(decay, name='decay')
388      self.iterations = K.variable(0, dtype='int64', name='iterations')
389    if epsilon is None:
390      epsilon = K.epsilon()
391    self.rho = rho
392    self.epsilon = epsilon
393    self.initial_decay = decay
394
395  def get_updates(self, loss, params):
396    grads = self.get_gradients(loss, params)
397    shapes = [K.int_shape(p) for p in params]
398    accumulators = [K.zeros(shape) for shape in shapes]
399    delta_accumulators = [K.zeros(shape) for shape in shapes]
400    self.weights = accumulators + delta_accumulators
401    self.updates = [state_ops.assign_add(self.iterations, 1)]
402
403    lr = self.lr
404    if self.initial_decay > 0:
405      lr = lr * (  # pylint: disable=g-no-augmented-assignment
406          1. / (1. + self.decay * math_ops.cast(self.iterations,
407                                                K.dtype(self.decay))))
408
409    for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
410      # update accumulator
411      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
412      self.updates.append(state_ops.assign(a, new_a))
413
414      # use the new accumulator and the *old* delta_accumulator
415      update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon)
416      new_p = p - lr * update
417
418      # Apply constraints.
419      if getattr(p, 'constraint', None) is not None:
420        new_p = p.constraint(new_p)
421
422      self.updates.append(state_ops.assign(p, new_p))
423
424      # update delta_accumulator
425      new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update)
426      self.updates.append(state_ops.assign(d_a, new_d_a))
427    return self.updates
428
429  def get_config(self):
430    config = {
431        'lr': float(K.get_value(self.lr)),
432        'rho': self.rho,
433        'decay': float(K.get_value(self.decay)),
434        'epsilon': self.epsilon
435    }
436    base_config = super(Adadelta, self).get_config()
437    return dict(list(base_config.items()) + list(config.items()))
438
439
440class Adam(Optimizer):
441  """Adam optimizer.
442
443  Default parameters follow those provided in the original paper.
444
445  Arguments:
446      lr: float >= 0. Learning rate.
447      beta_1: float, 0 < beta < 1. Generally close to 1.
448      beta_2: float, 0 < beta < 1. Generally close to 1.
449      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
450      decay: float >= 0. Learning rate decay over each update.
451      amsgrad: boolean. Whether to apply the AMSGrad variant of this
452          algorithm from the paper "On the Convergence of Adam and
453          Beyond".
454
455  """
456
457  def __init__(self,
458               lr=0.001,
459               beta_1=0.9,
460               beta_2=0.999,
461               epsilon=None,
462               decay=0.,
463               amsgrad=False,
464               **kwargs):
465    super(Adam, self).__init__(**kwargs)
466    with K.name_scope(self.__class__.__name__):
467      self.iterations = K.variable(0, dtype='int64', name='iterations')
468      self.lr = K.variable(lr, name='lr')
469      self.beta_1 = K.variable(beta_1, name='beta_1')
470      self.beta_2 = K.variable(beta_2, name='beta_2')
471      self.decay = K.variable(decay, name='decay')
472    if epsilon is None:
473      epsilon = K.epsilon()
474    self.epsilon = epsilon
475    self.initial_decay = decay
476    self.amsgrad = amsgrad
477
478  def get_updates(self, loss, params):
479    grads = self.get_gradients(loss, params)
480    self.updates = []
481
482    lr = self.lr
483    if self.initial_decay > 0:
484      lr = lr * (  # pylint: disable=g-no-augmented-assignment
485          1. / (1. + self.decay * math_ops.cast(self.iterations,
486                                                K.dtype(self.decay))))
487
488    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
489      t = math_ops.cast(self.iterations, K.floatx())
490    lr_t = lr * (
491        K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
492        (1. - math_ops.pow(self.beta_1, t)))
493
494    ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
495    vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
496    if self.amsgrad:
497      vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
498    else:
499      vhats = [K.zeros(1) for _ in params]
500    self.weights = [self.iterations] + ms + vs + vhats
501
502    for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
503      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
504      v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
505      if self.amsgrad:
506        vhat_t = math_ops.maximum(vhat, v_t)
507        p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
508        self.updates.append(state_ops.assign(vhat, vhat_t))
509      else:
510        p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
511
512      self.updates.append(state_ops.assign(m, m_t))
513      self.updates.append(state_ops.assign(v, v_t))
514      new_p = p_t
515
516      # Apply constraints.
517      if getattr(p, 'constraint', None) is not None:
518        new_p = p.constraint(new_p)
519
520      self.updates.append(state_ops.assign(p, new_p))
521    return self.updates
522
523  def get_config(self):
524    config = {
525        'lr': float(K.get_value(self.lr)),
526        'beta_1': float(K.get_value(self.beta_1)),
527        'beta_2': float(K.get_value(self.beta_2)),
528        'decay': float(K.get_value(self.decay)),
529        'epsilon': self.epsilon,
530        'amsgrad': self.amsgrad
531    }
532    base_config = super(Adam, self).get_config()
533    return dict(list(base_config.items()) + list(config.items()))
534
535
536class Adamax(Optimizer):
537  """Adamax optimizer from Adam paper's Section 7.
538
539  It is a variant of Adam based on the infinity norm.
540  Default parameters follow those provided in the paper.
541
542  Arguments:
543      lr: float >= 0. Learning rate.
544      beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
545      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
546      decay: float >= 0. Learning rate decay over each update.
547
548  """
549
550  def __init__(self,
551               lr=0.002,
552               beta_1=0.9,
553               beta_2=0.999,
554               epsilon=None,
555               decay=0.,
556               **kwargs):
557    super(Adamax, self).__init__(**kwargs)
558    with K.name_scope(self.__class__.__name__):
559      self.iterations = K.variable(0, dtype='int64', name='iterations')
560      self.lr = K.variable(lr, name='lr')
561      self.beta_1 = K.variable(beta_1, name='beta_1')
562      self.beta_2 = K.variable(beta_2, name='beta_2')
563      self.decay = K.variable(decay, name='decay')
564    if epsilon is None:
565      epsilon = K.epsilon()
566    self.epsilon = epsilon
567    self.initial_decay = decay
568
569  def get_updates(self, loss, params):
570    grads = self.get_gradients(loss, params)
571    self.updates = []
572
573    lr = self.lr
574    if self.initial_decay > 0:
575      lr = lr * (  # pylint: disable=g-no-augmented-assignment
576          1. / (1. + self.decay * math_ops.cast(self.iterations,
577                                                K.dtype(self.decay))))
578
579    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
580      t = math_ops.cast(self.iterations, K.floatx())
581    lr_t = lr / (1. - math_ops.pow(self.beta_1, t))
582
583    shapes = [K.int_shape(p) for p in params]
584    # zero init of 1st moment
585    ms = [K.zeros(shape) for shape in shapes]
586    # zero init of exponentially weighted infinity norm
587    us = [K.zeros(shape) for shape in shapes]
588    self.weights = [self.iterations] + ms + us
589
590    for p, g, m, u in zip(params, grads, ms, us):
591
592      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
593      u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g))
594      p_t = p - lr_t * m_t / (u_t + self.epsilon)
595
596      self.updates.append(state_ops.assign(m, m_t))
597      self.updates.append(state_ops.assign(u, u_t))
598      new_p = p_t
599
600      # Apply constraints.
601      if getattr(p, 'constraint', None) is not None:
602        new_p = p.constraint(new_p)
603
604      self.updates.append(state_ops.assign(p, new_p))
605    return self.updates
606
607  def get_config(self):
608    config = {
609        'lr': float(K.get_value(self.lr)),
610        'beta_1': float(K.get_value(self.beta_1)),
611        'beta_2': float(K.get_value(self.beta_2)),
612        'decay': float(K.get_value(self.decay)),
613        'epsilon': self.epsilon
614    }
615    base_config = super(Adamax, self).get_config()
616    return dict(list(base_config.items()) + list(config.items()))
617
618
619class Nadam(Optimizer):
620  """Nesterov Adam optimizer.
621
622  Much like Adam is essentially RMSprop with momentum,
623  Nadam is Adam RMSprop with Nesterov momentum.
624
625  Default parameters follow those provided in the paper.
626  It is recommended to leave the parameters of this optimizer
627  at their default values.
628
629  Arguments:
630      lr: float >= 0. Learning rate.
631      beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
632      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
633
634  """
635
636  def __init__(self,
637               lr=0.002,
638               beta_1=0.9,
639               beta_2=0.999,
640               epsilon=None,
641               schedule_decay=0.004,
642               **kwargs):
643    super(Nadam, self).__init__(**kwargs)
644    with K.name_scope(self.__class__.__name__):
645      self.iterations = K.variable(0, dtype='int64', name='iterations')
646      self.m_schedule = K.variable(1., name='m_schedule')
647      self.lr = K.variable(lr, name='lr')
648      self.beta_1 = K.variable(beta_1, name='beta_1')
649      self.beta_2 = K.variable(beta_2, name='beta_2')
650    if epsilon is None:
651      epsilon = K.epsilon()
652    self.epsilon = epsilon
653    self.schedule_decay = schedule_decay
654
655  def get_updates(self, loss, params):
656    grads = self.get_gradients(loss, params)
657    self.updates = []
658
659    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
660      t = math_ops.cast(self.iterations, K.floatx())
661
662    # Due to the recommendations in [2], i.e. warming momentum schedule
663    momentum_cache_t = self.beta_1 * (
664        1. - 0.5 *
665        (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
666    momentum_cache_t_1 = self.beta_1 * (
667        1. - 0.5 *
668        (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
669    m_schedule_new = self.m_schedule * momentum_cache_t
670    m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
671    self.updates.append((self.m_schedule, m_schedule_new))
672
673    shapes = [K.int_shape(p) for p in params]
674    ms = [K.zeros(shape) for shape in shapes]
675    vs = [K.zeros(shape) for shape in shapes]
676
677    self.weights = [self.iterations, self.m_schedule] + ms + vs
678
679    for p, g, m, v in zip(params, grads, ms, vs):
680      # the following equations given in [1]
681      g_prime = g / (1. - m_schedule_new)
682      m_t = self.beta_1 * m + (1. - self.beta_1) * g
683      m_t_prime = m_t / (1. - m_schedule_next)
684      v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
685      v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
686      m_t_bar = (
687          1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
688
689      self.updates.append(state_ops.assign(m, m_t))
690      self.updates.append(state_ops.assign(v, v_t))
691
692      p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
693      new_p = p_t
694
695      # Apply constraints.
696      if getattr(p, 'constraint', None) is not None:
697        new_p = p.constraint(new_p)
698
699      self.updates.append(state_ops.assign(p, new_p))
700    return self.updates
701
702  def get_config(self):
703    config = {
704        'lr': float(K.get_value(self.lr)),
705        'beta_1': float(K.get_value(self.beta_1)),
706        'beta_2': float(K.get_value(self.beta_2)),
707        'epsilon': self.epsilon,
708        'schedule_decay': self.schedule_decay
709    }
710    base_config = super(Nadam, self).get_config()
711    return dict(list(base_config.items()) + list(config.items()))
712
713
714class TFOptimizer(Optimizer, trackable.Trackable):
715  """Wrapper class for native TensorFlow optimizers.
716  """
717
718  def __init__(self, optimizer, iterations=None):  # pylint: disable=super-init-not-called
719    self.optimizer = optimizer
720    self._track_trackable(optimizer, name='optimizer')
721    if iterations is None:
722      with K.name_scope(self.__class__.__name__):
723        self.iterations = K.variable(0, dtype='int64', name='iterations')
724    else:
725      self.iterations = iterations
726    self._track_trackable(self.iterations, name='global_step')
727
728  def apply_gradients(self, grads):
729    self.optimizer.apply_gradients(grads, global_step=self.iterations)
730
731  def get_grads(self, loss, params):
732    return self.optimizer.compute_gradients(loss, params)
733
734  def get_updates(self, loss, params):
735    if distribution_strategy_context.has_strategy():
736      self.updates = []
737
738      if not params:
739        # After the model vars have been created, the second call to get_updates
740        # is called with params as an empty list. This ensures that we call
741        # compute_gradients with params=None.
742        grads = self.optimizer.compute_gradients(loss)
743      else:
744        grads = self.optimizer.compute_gradients(loss, params)
745      global_step = training_util.get_global_step()
746      opt_update = self.optimizer.apply_gradients(grads, global_step)
747    else:
748      if not params:
749        self.updates = [state_ops.assign_add(self.iterations, 1)]
750        return self.updates
751
752      # Updates list starts out empty because the iterations variable is
753      # incremented in optimizer.apply_gradients()
754      self.updates = []
755      grads = self.optimizer.compute_gradients(loss, params)
756      opt_update = self.optimizer.apply_gradients(
757          grads, global_step=self.iterations)
758
759    self.updates.append(opt_update)
760    return self.updates
761
762  @property
763  def weights(self):
764    raise NotImplementedError
765
766  def get_config(self):
767    raise NotImplementedError
768
769  def from_config(self, config):
770    raise NotImplementedError
771
772
773# Aliases.
774
775sgd = SGD
776rmsprop = RMSprop
777adagrad = Adagrad
778adadelta = Adadelta
779adam = Adam
780adamax = Adamax
781nadam = Nadam
782
783
784@keras_export('keras.optimizers.serialize')
785def serialize(optimizer):
786  return serialize_keras_object(optimizer)
787
788
789@keras_export('keras.optimizers.deserialize')
790def deserialize(config, custom_objects=None):
791  """Inverse of the `serialize` function.
792
793  Arguments:
794      config: Optimizer configuration dictionary.
795      custom_objects: Optional dictionary mapping
796          names (strings) to custom objects
797          (classes and functions)
798          to be considered during deserialization.
799
800  Returns:
801      A Keras Optimizer instance.
802  """
803  all_classes = {
804      'adadelta': adadelta_v2.Adadelta,
805      'adagrad': adagrad_v2.Adagrad,
806      'adam': adam_v2.Adam,
807      'adamax': adamax_v2.Adamax,
808      'nadam': nadam_v2.Nadam,
809      'rmsprop': rmsprop_v2.RMSprop,
810      'sgd': gradient_descent_v2.SGD,
811      'ftrl': ftrl.Ftrl
812  }
813
814  # Make deserialization case-insensitive for built-in optimizers.
815  if config['class_name'].lower() in all_classes:
816    config['class_name'] = config['class_name'].lower()
817  return deserialize_keras_object(
818      config,
819      module_objects=all_classes,
820      custom_objects=custom_objects,
821      printable_module_name='optimizer')
822
823
824@keras_export('keras.optimizers.get')
825def get(identifier):
826  """Retrieves a Keras Optimizer instance.
827
828  Arguments:
829      identifier: Optimizer identifier, one of
830          - String: name of an optimizer
831          - Dictionary: configuration dictionary.
832          - Keras Optimizer instance (it will be returned unchanged).
833          - TensorFlow Optimizer instance
834              (it will be wrapped as a Keras Optimizer).
835
836  Returns:
837      A Keras Optimizer instance.
838
839  Raises:
840      ValueError: If `identifier` cannot be interpreted.
841  """
842  if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)):
843    return identifier
844  # Wrap TF optimizer instances
845  elif isinstance(identifier, tf_optimizer_module.Optimizer):
846    opt = TFOptimizer(identifier)
847    K.track_tf_optimizer(opt)
848    return opt
849  elif isinstance(identifier, dict):
850    return deserialize(identifier)
851  elif isinstance(identifier, six.string_types):
852    config = {'class_name': str(identifier), 'config': {}}
853    return deserialize(config)
854  else:
855    raise ValueError('Could not interpret optimizer identifier:', identifier)
856