1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Built-in optimizer classes. 17 18For more examples see the base class `tf.keras.optimizers.Optimizer`. 19""" 20 21from tensorflow.python.keras import backend 22from tensorflow.python.keras.optimizer_v1 import Optimizer 23from tensorflow.python.keras.optimizer_v1 import TFOptimizer 24from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 25from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 26from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 27from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 28from tensorflow.python.keras.optimizer_v2 import ftrl 29from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 30from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 31from tensorflow.python.keras.optimizer_v2 import optimizer_v2 32from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 33from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 34from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 35from tensorflow.python.training import optimizer as tf_optimizer_module 36from tensorflow.python.util.tf_export import keras_export 37 38 39@keras_export('keras.optimizers.serialize') 40def serialize(optimizer): 41 """Serialize the optimizer configuration to JSON compatible python dict. 42 43 The configuration can be used for persistence and reconstruct the `Optimizer` 44 instance again. 45 46 >>> tf.keras.optimizers.serialize(tf.keras.optimizers.SGD()) 47 {'class_name': 'SGD', 'config': {'name': 'SGD', 'learning_rate': 0.01, 48 'decay': 0.0, 'momentum': 0.0, 49 'nesterov': False}} 50 51 Args: 52 optimizer: An `Optimizer` instance to serialize. 53 54 Returns: 55 Python dict which contains the configuration of the input optimizer. 56 """ 57 return serialize_keras_object(optimizer) 58 59 60@keras_export('keras.optimizers.deserialize') 61def deserialize(config, custom_objects=None): 62 """Inverse of the `serialize` function. 63 64 Args: 65 config: Optimizer configuration dictionary. 66 custom_objects: Optional dictionary mapping names (strings) to custom 67 objects (classes and functions) to be considered during deserialization. 68 69 Returns: 70 A Keras Optimizer instance. 71 """ 72 # loss_scale_optimizer has a direct dependency of optimizer, import here 73 # rather than top to avoid the cyclic dependency. 74 from tensorflow.python.keras.mixed_precision import loss_scale_optimizer # pylint: disable=g-import-not-at-top 75 all_classes = { 76 'adadelta': adadelta_v2.Adadelta, 77 'adagrad': adagrad_v2.Adagrad, 78 'adam': adam_v2.Adam, 79 'adamax': adamax_v2.Adamax, 80 'nadam': nadam_v2.Nadam, 81 'rmsprop': rmsprop_v2.RMSprop, 82 'sgd': gradient_descent_v2.SGD, 83 'ftrl': ftrl.Ftrl, 84 'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer, 85 # LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as 86 # LossScaleOptimizerV1 will be removed soon but deserializing it will 87 # still be supported. 88 'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer, 89 } 90 91 # Make deserialization case-insensitive for built-in optimizers. 92 if config['class_name'].lower() in all_classes: 93 config['class_name'] = config['class_name'].lower() 94 return deserialize_keras_object( 95 config, 96 module_objects=all_classes, 97 custom_objects=custom_objects, 98 printable_module_name='optimizer') 99 100 101@keras_export('keras.optimizers.get') 102def get(identifier): 103 """Retrieves a Keras Optimizer instance. 104 105 Args: 106 identifier: Optimizer identifier, one of 107 - String: name of an optimizer 108 - Dictionary: configuration dictionary. - Keras Optimizer instance (it 109 will be returned unchanged). - TensorFlow Optimizer instance (it 110 will be wrapped as a Keras Optimizer). 111 112 Returns: 113 A Keras Optimizer instance. 114 115 Raises: 116 ValueError: If `identifier` cannot be interpreted. 117 """ 118 if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)): 119 return identifier 120 # Wrap legacy TF optimizer instances 121 elif isinstance(identifier, tf_optimizer_module.Optimizer): 122 opt = TFOptimizer(identifier) 123 backend.track_tf_optimizer(opt) 124 return opt 125 elif isinstance(identifier, dict): 126 return deserialize(identifier) 127 elif isinstance(identifier, str): 128 config = {'class_name': str(identifier), 'config': {}} 129 return deserialize(config) 130 else: 131 raise ValueError( 132 'Could not interpret optimizer identifier: {}'.format(identifier)) 133