• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Built-in optimizer classes.
17
18For more examples see the base class `tf.keras.optimizers.Optimizer`.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import six
25
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras.optimizer_v1 import Optimizer
28from tensorflow.python.keras.optimizer_v1 import TFOptimizer
29from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
30from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
31from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
32from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
33from tensorflow.python.keras.optimizer_v2 import ftrl
34from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
35from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
36from tensorflow.python.keras.optimizer_v2 import optimizer_v2
37from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
38from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
39from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
40from tensorflow.python.training import optimizer as tf_optimizer_module
41from tensorflow.python.util.tf_export import keras_export
42
43
44@keras_export('keras.optimizers.serialize')
45def serialize(optimizer):
46  return serialize_keras_object(optimizer)
47
48
49@keras_export('keras.optimizers.deserialize')
50def deserialize(config, custom_objects=None):
51  """Inverse of the `serialize` function.
52
53  Args:
54      config: Optimizer configuration dictionary.
55      custom_objects: Optional dictionary mapping names (strings) to custom
56        objects (classes and functions) to be considered during deserialization.
57
58  Returns:
59      A Keras Optimizer instance.
60  """
61  # loss_scale_optimizer has a direct dependency of optimizer, import here
62  # rather than top to avoid the cyclic dependency.
63  from tensorflow.python.keras.mixed_precision import loss_scale_optimizer  # pylint: disable=g-import-not-at-top
64  all_classes = {
65      'adadelta': adadelta_v2.Adadelta,
66      'adagrad': adagrad_v2.Adagrad,
67      'adam': adam_v2.Adam,
68      'adamax': adamax_v2.Adamax,
69      'nadam': nadam_v2.Nadam,
70      'rmsprop': rmsprop_v2.RMSprop,
71      'sgd': gradient_descent_v2.SGD,
72      'ftrl': ftrl.Ftrl,
73      'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
74      # LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as
75      # LossScaleOptimizerV1 will be removed soon but deserializing it will
76      # still be supported.
77      'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer,
78  }
79
80  # Make deserialization case-insensitive for built-in optimizers.
81  if config['class_name'].lower() in all_classes:
82    config['class_name'] = config['class_name'].lower()
83  return deserialize_keras_object(
84      config,
85      module_objects=all_classes,
86      custom_objects=custom_objects,
87      printable_module_name='optimizer')
88
89
90@keras_export('keras.optimizers.get')
91def get(identifier):
92  """Retrieves a Keras Optimizer instance.
93
94  Args:
95      identifier: Optimizer identifier, one of
96          - String: name of an optimizer
97          - Dictionary: configuration dictionary. - Keras Optimizer instance (it
98            will be returned unchanged). - TensorFlow Optimizer instance (it
99            will be wrapped as a Keras Optimizer).
100
101  Returns:
102      A Keras Optimizer instance.
103
104  Raises:
105      ValueError: If `identifier` cannot be interpreted.
106  """
107  if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)):
108    return identifier
109  # Wrap TF optimizer instances
110  elif isinstance(identifier, tf_optimizer_module.Optimizer):
111    opt = TFOptimizer(identifier)
112    K.track_tf_optimizer(opt)
113    return opt
114  elif isinstance(identifier, dict):
115    return deserialize(identifier)
116  elif isinstance(identifier, six.string_types):
117    config = {'class_name': str(identifier), 'config': {}}
118    return deserialize(config)
119  else:
120    raise ValueError(
121        'Could not interpret optimizer identifier: {}'.format(identifier))
122