• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Built-in optimizer classes."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22from six.moves import zip  # pylint: disable=redefined-builtin
23
24from tensorflow.python.distribute import distribution_strategy_context
25from tensorflow.python.framework import ops
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
28from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
29from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
30from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
31from tensorflow.python.keras.optimizer_v2 import ftrl
32from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
33from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
34from tensorflow.python.keras.optimizer_v2 import optimizer_v2
35from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
36from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
37from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
38from tensorflow.python.ops import clip_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.training import optimizer as tf_optimizer_module
42from tensorflow.python.training import training_util
43from tensorflow.python.training.tracking import base as trackable
44from tensorflow.python.util.tf_export import keras_export
45
46
47class Optimizer(object):
48  """Abstract optimizer base class.
49
50  Note: this is the parent class of all optimizers, not an actual optimizer
51  that can be used for training models.
52
53  All Keras optimizers support the following keyword arguments:
54
55      clipnorm: float >= 0. Gradients will be clipped
56          when their L2 norm exceeds this value.
57      clipvalue: float >= 0. Gradients will be clipped
58          when their absolute value exceeds this value.
59  """
60
61  def __init__(self, **kwargs):
62    allowed_kwargs = {'clipnorm', 'clipvalue'}
63    for k in kwargs:
64      if k not in allowed_kwargs:
65        raise TypeError('Unexpected keyword argument '
66                        'passed to optimizer: ' + str(k))
67      # checks that clipnorm >= 0 and clipvalue >= 0
68      if kwargs[k] < 0:
69        raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k]))
70    self.__dict__.update(kwargs)
71    self.updates = []
72    self.weights = []
73
74  def get_updates(self, loss, params):
75    raise NotImplementedError
76
77  def get_gradients(self, loss, params):
78    """Returns gradients of `loss` with respect to `params`.
79
80    Arguments:
81        loss: Loss tensor.
82        params: List of variables.
83
84    Returns:
85        List of gradient tensors.
86
87    Raises:
88        ValueError: In case any gradient cannot be computed (e.g. if gradient
89          function not implemented).
90    """
91    grads = K.gradients(loss, params)
92    if any(g is None for g in grads):
93      raise ValueError('An operation has `None` for gradient. '
94                       'Please make sure that all of your ops have a '
95                       'gradient defined (i.e. are differentiable). '
96                       'Common ops without gradient: '
97                       'K.argmax, K.round, K.eval.')
98    if hasattr(self, 'clipnorm'):
99      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
100    if hasattr(self, 'clipvalue'):
101      grads = [
102          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
103          for g in grads
104      ]
105    return grads
106
107  def set_weights(self, weights):
108    """Sets the weights of the optimizer, from Numpy arrays.
109
110    Should only be called after computing the gradients
111    (otherwise the optimizer has no weights).
112
113    Arguments:
114        weights: a list of Numpy arrays. The number of arrays and their shape
115          must match number of the dimensions of the weights of the optimizer
116          (i.e. it should match the output of `get_weights`).
117
118    Raises:
119        ValueError: in case of incompatible weight shapes.
120    """
121    params = self.weights
122    if len(params) != len(weights):
123      raise ValueError('Length of the specified weight list (' +
124                       str(len(weights)) +
125                       ') does not match the number of weights '
126                       'of the optimizer (' + str(len(params)) + ')')
127    weight_value_tuples = []
128    param_values = K.batch_get_value(params)
129    for pv, p, w in zip(param_values, params, weights):
130      if pv.shape != w.shape:
131        raise ValueError('Optimizer weight shape ' + str(pv.shape) +
132                         ' not compatible with '
133                         'provided weight shape ' + str(w.shape))
134      weight_value_tuples.append((p, w))
135    K.batch_set_value(weight_value_tuples)
136
137  def get_weights(self):
138    """Returns the current value of the weights of the optimizer.
139
140    Returns:
141        A list of numpy arrays.
142    """
143    return K.batch_get_value(self.weights)
144
145  def get_config(self):
146    config = {}
147    if hasattr(self, 'clipnorm'):
148      config['clipnorm'] = self.clipnorm
149    if hasattr(self, 'clipvalue'):
150      config['clipvalue'] = self.clipvalue
151    return config
152
153  @classmethod
154  def from_config(cls, config):
155    return cls(**config)
156
157
158class SGD(Optimizer):
159  """Stochastic gradient descent optimizer.
160
161  Includes support for momentum,
162  learning rate decay, and Nesterov momentum.
163
164  Arguments:
165      lr: float >= 0. Learning rate.
166      momentum: float >= 0. Parameter that accelerates SGD in the relevant
167        direction and dampens oscillations.
168      decay: float >= 0. Learning rate decay over each update.
169      nesterov: boolean. Whether to apply Nesterov momentum.
170  """
171
172  def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, **kwargs):
173    super(SGD, self).__init__(**kwargs)
174    with K.name_scope(self.__class__.__name__):
175      self.iterations = K.variable(0, dtype='int64', name='iterations')
176      self.lr = K.variable(lr, name='lr')
177      self.momentum = K.variable(momentum, name='momentum')
178      self.decay = K.variable(decay, name='decay')
179    self.initial_decay = decay
180    self.nesterov = nesterov
181
182  def get_updates(self, loss, params):
183    grads = self.get_gradients(loss, params)
184    self.updates = [state_ops.assign_add(self.iterations, 1)]
185
186    lr = self.lr
187    if self.initial_decay > 0:
188      lr = lr * (  # pylint: disable=g-no-augmented-assignment
189          1. /
190          (1. +
191           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
192    # momentum
193    shapes = [K.int_shape(p) for p in params]
194    moments = [K.zeros(shape) for shape in shapes]
195    self.weights = [self.iterations] + moments
196    for p, g, m in zip(params, grads, moments):
197      v = self.momentum * m - lr * g  # velocity
198      self.updates.append(state_ops.assign(m, v))
199
200      if self.nesterov:
201        new_p = p + self.momentum * v - lr * g
202      else:
203        new_p = p + v
204
205      # Apply constraints.
206      if getattr(p, 'constraint', None) is not None:
207        new_p = p.constraint(new_p)
208
209      self.updates.append(state_ops.assign(p, new_p))
210    return self.updates
211
212  def get_config(self):
213    config = {
214        'lr': float(K.get_value(self.lr)),
215        'momentum': float(K.get_value(self.momentum)),
216        'decay': float(K.get_value(self.decay)),
217        'nesterov': self.nesterov
218    }
219    base_config = super(SGD, self).get_config()
220    return dict(list(base_config.items()) + list(config.items()))
221
222
223class RMSprop(Optimizer):
224  """RMSProp optimizer.
225
226  It is recommended to leave the parameters of this optimizer
227  at their default values
228  (except the learning rate, which can be freely tuned).
229
230  Arguments:
231      lr: float >= 0. Learning rate.
232      rho: float >= 0.
233      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
234      decay: float >= 0. Learning rate decay over each update.
235  """
236
237  def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
238    super(RMSprop, self).__init__(**kwargs)
239    with K.name_scope(self.__class__.__name__):
240      self.lr = K.variable(lr, name='lr')
241      self.rho = K.variable(rho, name='rho')
242      self.decay = K.variable(decay, name='decay')
243      self.iterations = K.variable(0, dtype='int64', name='iterations')
244    if epsilon is None:
245      epsilon = K.epsilon()
246    self.epsilon = epsilon
247    self.initial_decay = decay
248
249  def get_updates(self, loss, params):
250    grads = self.get_gradients(loss, params)
251    accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
252    self.weights = accumulators
253    self.updates = [state_ops.assign_add(self.iterations, 1)]
254
255    lr = self.lr
256    if self.initial_decay > 0:
257      lr = lr * (  # pylint: disable=g-no-augmented-assignment
258          1. /
259          (1. +
260           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
261
262    for p, g, a in zip(params, grads, accumulators):
263      # update accumulator
264      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
265      self.updates.append(state_ops.assign(a, new_a))
266      new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
267
268      # Apply constraints.
269      if getattr(p, 'constraint', None) is not None:
270        new_p = p.constraint(new_p)
271
272      self.updates.append(state_ops.assign(p, new_p))
273    return self.updates
274
275  def get_config(self):
276    config = {
277        'lr': float(K.get_value(self.lr)),
278        'rho': float(K.get_value(self.rho)),
279        'decay': float(K.get_value(self.decay)),
280        'epsilon': self.epsilon
281    }
282    base_config = super(RMSprop, self).get_config()
283    return dict(list(base_config.items()) + list(config.items()))
284
285
286class Adagrad(Optimizer):
287  """Adagrad optimizer.
288
289  Adagrad is an optimizer with parameter-specific learning rates,
290  which are adapted relative to how frequently a parameter gets
291  updated during training. The more updates a parameter receives,
292  the smaller the updates.
293
294  It is recommended to leave the parameters of this optimizer
295  at their default values.
296
297  # Arguments
298      lr: float >= 0. Initial learning rate.
299      epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
300      decay: float >= 0. Learning rate decay over each update.
301
302  # References
303      - [Adaptive Subgradient Methods for Online Learning and Stochastic
304      Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
305  """
306
307  def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
308    super(Adagrad, self).__init__(**kwargs)
309    with K.name_scope(self.__class__.__name__):
310      self.lr = K.variable(lr, name='lr')
311      self.decay = K.variable(decay, name='decay')
312      self.iterations = K.variable(0, dtype='int64', name='iterations')
313    if epsilon is None:
314      epsilon = K.epsilon()
315    self.epsilon = epsilon
316    self.initial_decay = decay
317
318  def get_updates(self, loss, params):
319    grads = self.get_gradients(loss, params)
320    shapes = [K.int_shape(p) for p in params]
321    accumulators = [K.zeros(shape) for shape in shapes]
322    self.weights = accumulators
323    self.updates = [state_ops.assign_add(self.iterations, 1)]
324
325    lr = self.lr
326    if self.initial_decay > 0:
327      lr = lr * (  # pylint: disable=g-no-augmented-assignment
328          1. /
329          (1. +
330           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
331
332    for p, g, a in zip(params, grads, accumulators):
333      new_a = a + math_ops.square(g)  # update accumulator
334      self.updates.append(state_ops.assign(a, new_a))
335      new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
336
337      # Apply constraints.
338      if getattr(p, 'constraint', None) is not None:
339        new_p = p.constraint(new_p)
340
341      self.updates.append(state_ops.assign(p, new_p))
342    return self.updates
343
344  def get_config(self):
345    config = {
346        'lr': float(K.get_value(self.lr)),
347        'decay': float(K.get_value(self.decay)),
348        'epsilon': self.epsilon
349    }
350    base_config = super(Adagrad, self).get_config()
351    return dict(list(base_config.items()) + list(config.items()))
352
353
354class Adadelta(Optimizer):
355  """Adadelta optimizer.
356
357  Adadelta is a more robust extension of Adagrad
358  that adapts learning rates based on a moving window of gradient updates,
359  instead of accumulating all past gradients. This way, Adadelta continues
360  learning even when many updates have been done. Compared to Adagrad, in the
361  original version of Adadelta you don't have to set an initial learning
362  rate. In this version, initial learning rate and decay factor can
363  be set, as in most other Keras optimizers.
364
365  It is recommended to leave the parameters of this optimizer
366  at their default values.
367
368  # Arguments
369      lr: float >= 0. Initial learning rate, defaults to 1.
370          It is recommended to leave it at the default value.
371      rho: float >= 0. Adadelta decay factor, corresponding to fraction of
372          gradient to keep at each time step.
373      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
374      decay: float >= 0. Initial learning rate decay.
375
376  # References
377      - [Adadelta - an adaptive learning rate
378      method](http://arxiv.org/abs/1212.5701)
379  """
380
381  def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
382    super(Adadelta, self).__init__(**kwargs)
383    with K.name_scope(self.__class__.__name__):
384      self.lr = K.variable(lr, name='lr')
385      self.decay = K.variable(decay, name='decay')
386      self.iterations = K.variable(0, dtype='int64', name='iterations')
387    if epsilon is None:
388      epsilon = K.epsilon()
389    self.rho = rho
390    self.epsilon = epsilon
391    self.initial_decay = decay
392
393  def get_updates(self, loss, params):
394    grads = self.get_gradients(loss, params)
395    shapes = [K.int_shape(p) for p in params]
396    accumulators = [K.zeros(shape) for shape in shapes]
397    delta_accumulators = [K.zeros(shape) for shape in shapes]
398    self.weights = accumulators + delta_accumulators
399    self.updates = [state_ops.assign_add(self.iterations, 1)]
400
401    lr = self.lr
402    if self.initial_decay > 0:
403      lr = lr * (  # pylint: disable=g-no-augmented-assignment
404          1. /
405          (1. +
406           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
407
408    for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
409      # update accumulator
410      new_a = self.rho * a + (1. - self.rho) * math_ops.square(g)
411      self.updates.append(state_ops.assign(a, new_a))
412
413      # use the new accumulator and the *old* delta_accumulator
414      update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon)
415      new_p = p - lr * update
416
417      # Apply constraints.
418      if getattr(p, 'constraint', None) is not None:
419        new_p = p.constraint(new_p)
420
421      self.updates.append(state_ops.assign(p, new_p))
422
423      # update delta_accumulator
424      new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update)
425      self.updates.append(state_ops.assign(d_a, new_d_a))
426    return self.updates
427
428  def get_config(self):
429    config = {
430        'lr': float(K.get_value(self.lr)),
431        'rho': self.rho,
432        'decay': float(K.get_value(self.decay)),
433        'epsilon': self.epsilon
434    }
435    base_config = super(Adadelta, self).get_config()
436    return dict(list(base_config.items()) + list(config.items()))
437
438
439class Adam(Optimizer):
440  """Adam optimizer.
441
442  Default parameters follow those provided in the original paper.
443
444  Arguments:
445      lr: float >= 0. Learning rate.
446      beta_1: float, 0 < beta < 1. Generally close to 1.
447      beta_2: float, 0 < beta < 1. Generally close to 1.
448      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
449      decay: float >= 0. Learning rate decay over each update.
450      amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm
451        from the paper "On the Convergence of Adam and Beyond".
452  """
453
454  def __init__(self,
455               lr=0.001,
456               beta_1=0.9,
457               beta_2=0.999,
458               epsilon=None,
459               decay=0.,
460               amsgrad=False,
461               **kwargs):
462    super(Adam, self).__init__(**kwargs)
463    with K.name_scope(self.__class__.__name__):
464      self.iterations = K.variable(0, dtype='int64', name='iterations')
465      self.lr = K.variable(lr, name='lr')
466      self.beta_1 = K.variable(beta_1, name='beta_1')
467      self.beta_2 = K.variable(beta_2, name='beta_2')
468      self.decay = K.variable(decay, name='decay')
469    if epsilon is None:
470      epsilon = K.epsilon()
471    self.epsilon = epsilon
472    self.initial_decay = decay
473    self.amsgrad = amsgrad
474
475  def get_updates(self, loss, params):
476    grads = self.get_gradients(loss, params)
477    self.updates = []
478
479    lr = self.lr
480    if self.initial_decay > 0:
481      lr = lr * (  # pylint: disable=g-no-augmented-assignment
482          1. /
483          (1. +
484           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
485
486    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
487      t = math_ops.cast(self.iterations, K.floatx())
488    lr_t = lr * (
489        K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
490        (1. - math_ops.pow(self.beta_1, t)))
491
492    ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
493    vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
494    if self.amsgrad:
495      vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
496    else:
497      vhats = [K.zeros(1) for _ in params]
498    self.weights = [self.iterations] + ms + vs + vhats
499
500    for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
501      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
502      v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g)
503      if self.amsgrad:
504        vhat_t = math_ops.maximum(vhat, v_t)
505        p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
506        self.updates.append(state_ops.assign(vhat, vhat_t))
507      else:
508        p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
509
510      self.updates.append(state_ops.assign(m, m_t))
511      self.updates.append(state_ops.assign(v, v_t))
512      new_p = p_t
513
514      # Apply constraints.
515      if getattr(p, 'constraint', None) is not None:
516        new_p = p.constraint(new_p)
517
518      self.updates.append(state_ops.assign(p, new_p))
519    return self.updates
520
521  def get_config(self):
522    config = {
523        'lr': float(K.get_value(self.lr)),
524        'beta_1': float(K.get_value(self.beta_1)),
525        'beta_2': float(K.get_value(self.beta_2)),
526        'decay': float(K.get_value(self.decay)),
527        'epsilon': self.epsilon,
528        'amsgrad': self.amsgrad
529    }
530    base_config = super(Adam, self).get_config()
531    return dict(list(base_config.items()) + list(config.items()))
532
533
534class Adamax(Optimizer):
535  """Adamax optimizer from Adam paper's Section 7.
536
537  It is a variant of Adam based on the infinity norm.
538  Default parameters follow those provided in the paper.
539
540  Arguments:
541      lr: float >= 0. Learning rate.
542      beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
543      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
544      decay: float >= 0. Learning rate decay over each update.
545  """
546
547  def __init__(self,
548               lr=0.002,
549               beta_1=0.9,
550               beta_2=0.999,
551               epsilon=None,
552               decay=0.,
553               **kwargs):
554    super(Adamax, self).__init__(**kwargs)
555    with K.name_scope(self.__class__.__name__):
556      self.iterations = K.variable(0, dtype='int64', name='iterations')
557      self.lr = K.variable(lr, name='lr')
558      self.beta_1 = K.variable(beta_1, name='beta_1')
559      self.beta_2 = K.variable(beta_2, name='beta_2')
560      self.decay = K.variable(decay, name='decay')
561    if epsilon is None:
562      epsilon = K.epsilon()
563    self.epsilon = epsilon
564    self.initial_decay = decay
565
566  def get_updates(self, loss, params):
567    grads = self.get_gradients(loss, params)
568    self.updates = []
569
570    lr = self.lr
571    if self.initial_decay > 0:
572      lr = lr * (  # pylint: disable=g-no-augmented-assignment
573          1. /
574          (1. +
575           self.decay * math_ops.cast(self.iterations, K.dtype(self.decay))))
576
577    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
578      t = math_ops.cast(self.iterations, K.floatx())
579    lr_t = lr / (1. - math_ops.pow(self.beta_1, t))
580
581    shapes = [K.int_shape(p) for p in params]
582    # zero init of 1st moment
583    ms = [K.zeros(shape) for shape in shapes]
584    # zero init of exponentially weighted infinity norm
585    us = [K.zeros(shape) for shape in shapes]
586    self.weights = [self.iterations] + ms + us
587
588    for p, g, m, u in zip(params, grads, ms, us):
589
590      m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
591      u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g))
592      p_t = p - lr_t * m_t / (u_t + self.epsilon)
593
594      self.updates.append(state_ops.assign(m, m_t))
595      self.updates.append(state_ops.assign(u, u_t))
596      new_p = p_t
597
598      # Apply constraints.
599      if getattr(p, 'constraint', None) is not None:
600        new_p = p.constraint(new_p)
601
602      self.updates.append(state_ops.assign(p, new_p))
603    return self.updates
604
605  def get_config(self):
606    config = {
607        'lr': float(K.get_value(self.lr)),
608        'beta_1': float(K.get_value(self.beta_1)),
609        'beta_2': float(K.get_value(self.beta_2)),
610        'decay': float(K.get_value(self.decay)),
611        'epsilon': self.epsilon
612    }
613    base_config = super(Adamax, self).get_config()
614    return dict(list(base_config.items()) + list(config.items()))
615
616
617class Nadam(Optimizer):
618  """Nesterov Adam optimizer.
619
620  Much like Adam is essentially RMSprop with momentum,
621  Nadam is Adam RMSprop with Nesterov momentum.
622
623  Default parameters follow those provided in the paper.
624  It is recommended to leave the parameters of this optimizer
625  at their default values.
626
627  Arguments:
628      lr: float >= 0. Learning rate.
629      beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
630      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
631  """
632
633  def __init__(self,
634               lr=0.002,
635               beta_1=0.9,
636               beta_2=0.999,
637               epsilon=None,
638               schedule_decay=0.004,
639               **kwargs):
640    super(Nadam, self).__init__(**kwargs)
641    with K.name_scope(self.__class__.__name__):
642      self.iterations = K.variable(0, dtype='int64', name='iterations')
643      self.m_schedule = K.variable(1., name='m_schedule')
644      self.lr = K.variable(lr, name='lr')
645      self.beta_1 = K.variable(beta_1, name='beta_1')
646      self.beta_2 = K.variable(beta_2, name='beta_2')
647    if epsilon is None:
648      epsilon = K.epsilon()
649    self.epsilon = epsilon
650    self.schedule_decay = schedule_decay
651
652  def get_updates(self, loss, params):
653    grads = self.get_gradients(loss, params)
654    self.updates = []
655
656    with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]):
657      t = math_ops.cast(self.iterations, K.floatx())
658
659    # Due to the recommendations in [2], i.e. warming momentum schedule
660    momentum_cache_t = self.beta_1 * (
661        1. - 0.5 *
662        (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
663    momentum_cache_t_1 = self.beta_1 * (
664        1. - 0.5 *
665        (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
666    m_schedule_new = self.m_schedule * momentum_cache_t
667    m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
668    self.updates.append((self.m_schedule, m_schedule_new))
669
670    shapes = [K.int_shape(p) for p in params]
671    ms = [K.zeros(shape) for shape in shapes]
672    vs = [K.zeros(shape) for shape in shapes]
673
674    self.weights = [self.iterations, self.m_schedule] + ms + vs
675
676    for p, g, m, v in zip(params, grads, ms, vs):
677      # the following equations given in [1]
678      g_prime = g / (1. - m_schedule_new)
679      m_t = self.beta_1 * m + (1. - self.beta_1) * g
680      m_t_prime = m_t / (1. - m_schedule_next)
681      v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
682      v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
683      m_t_bar = (1. -
684                 momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime
685
686      self.updates.append(state_ops.assign(m, m_t))
687      self.updates.append(state_ops.assign(v, v_t))
688
689      p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
690      new_p = p_t
691
692      # Apply constraints.
693      if getattr(p, 'constraint', None) is not None:
694        new_p = p.constraint(new_p)
695
696      self.updates.append(state_ops.assign(p, new_p))
697    return self.updates
698
699  def get_config(self):
700    config = {
701        'lr': float(K.get_value(self.lr)),
702        'beta_1': float(K.get_value(self.beta_1)),
703        'beta_2': float(K.get_value(self.beta_2)),
704        'epsilon': self.epsilon,
705        'schedule_decay': self.schedule_decay
706    }
707    base_config = super(Nadam, self).get_config()
708    return dict(list(base_config.items()) + list(config.items()))
709
710
711class TFOptimizer(Optimizer, trackable.Trackable):
712  """Wrapper class for native TensorFlow optimizers."""
713
714  def __init__(self, optimizer, iterations=None):  # pylint: disable=super-init-not-called
715    self.optimizer = optimizer
716    self._track_trackable(optimizer, name='optimizer')
717    if iterations is None:
718      with K.name_scope(self.__class__.__name__):
719        self.iterations = K.variable(0, dtype='int64', name='iterations')
720    else:
721      self.iterations = iterations
722    self._track_trackable(self.iterations, name='global_step')
723
724  def apply_gradients(self, grads):
725    self.optimizer.apply_gradients(grads, global_step=self.iterations)
726
727  def get_grads(self, loss, params):
728    return self.optimizer.compute_gradients(loss, params)
729
730  def get_updates(self, loss, params):
731    if distribution_strategy_context.has_strategy():
732      self.updates = []
733
734      if not params:
735        # After the model vars have been created, the second call to get_updates
736        # is called with params as an empty list. This ensures that we call
737        # compute_gradients with params=None.
738        grads = self.optimizer.compute_gradients(loss)
739      else:
740        grads = self.optimizer.compute_gradients(loss, params)
741      global_step = training_util.get_global_step()
742      opt_update = self.optimizer.apply_gradients(grads, global_step)
743    else:
744      if not params:
745        self.updates = [state_ops.assign_add(self.iterations, 1)]
746        return self.updates
747
748      # Updates list starts out empty because the iterations variable is
749      # incremented in optimizer.apply_gradients()
750      self.updates = []
751      grads = self.optimizer.compute_gradients(loss, params)
752      opt_update = self.optimizer.apply_gradients(
753          grads, global_step=self.iterations)
754
755    self.updates.append(opt_update)
756    return self.updates
757
758  @property
759  def weights(self):
760    raise NotImplementedError
761
762  def get_config(self):
763    raise NotImplementedError
764
765  def from_config(self, config):
766    raise NotImplementedError
767
768
769# Aliases.
770
771sgd = SGD
772rmsprop = RMSprop
773adagrad = Adagrad
774adadelta = Adadelta
775adam = Adam
776adamax = Adamax
777nadam = Nadam
778
779
780@keras_export('keras.optimizers.serialize')
781def serialize(optimizer):
782  return serialize_keras_object(optimizer)
783
784
785@keras_export('keras.optimizers.deserialize')
786def deserialize(config, custom_objects=None):
787  """Inverse of the `serialize` function.
788
789  Arguments:
790      config: Optimizer configuration dictionary.
791      custom_objects: Optional dictionary mapping names (strings) to custom
792        objects (classes and functions) to be considered during deserialization.
793
794  Returns:
795      A Keras Optimizer instance.
796  """
797  # loss_scale_optimizer has a direct dependency of optimizer, import here
798  # rather than top to avoid the cyclic dependency.
799  from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer  # pylint: disable=g-import-not-at-top
800  all_classes = {
801      'adadelta': adadelta_v2.Adadelta,
802      'adagrad': adagrad_v2.Adagrad,
803      'adam': adam_v2.Adam,
804      'adamax': adamax_v2.Adamax,
805      'nadam': nadam_v2.Nadam,
806      'rmsprop': rmsprop_v2.RMSprop,
807      'sgd': gradient_descent_v2.SGD,
808      'ftrl': ftrl.Ftrl,
809      'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
810  }
811
812  # Make deserialization case-insensitive for built-in optimizers.
813  if config['class_name'].lower() in all_classes:
814    config['class_name'] = config['class_name'].lower()
815  return deserialize_keras_object(
816      config,
817      module_objects=all_classes,
818      custom_objects=custom_objects,
819      printable_module_name='optimizer')
820
821
822@keras_export('keras.optimizers.get')
823def get(identifier):
824  """Retrieves a Keras Optimizer instance.
825
826  Arguments:
827      identifier: Optimizer identifier, one of
828          - String: name of an optimizer
829          - Dictionary: configuration dictionary. - Keras Optimizer instance (it
830            will be returned unchanged). - TensorFlow Optimizer instance (it
831            will be wrapped as a Keras Optimizer).
832
833  Returns:
834      A Keras Optimizer instance.
835
836  Raises:
837      ValueError: If `identifier` cannot be interpreted.
838  """
839  if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)):
840    return identifier
841  # Wrap TF optimizer instances
842  elif isinstance(identifier, tf_optimizer_module.Optimizer):
843    opt = TFOptimizer(identifier)
844    K.track_tf_optimizer(opt)
845    return opt
846  elif isinstance(identifier, dict):
847    return deserialize(identifier)
848  elif isinstance(identifier, six.string_types):
849    config = {'class_name': str(identifier), 'config': {}}
850    return deserialize(config)
851  else:
852    raise ValueError('Could not interpret optimizer identifier:', identifier)
853