1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Built-in optimizer classes. 17 18For more examples see the base class `tf.keras.optimizers.Optimizer`. 19""" 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import six 25 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras.optimizer_v1 import Optimizer 28from tensorflow.python.keras.optimizer_v1 import TFOptimizer 29from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 30from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 31from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 32from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 33from tensorflow.python.keras.optimizer_v2 import ftrl 34from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 35from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 36from tensorflow.python.keras.optimizer_v2 import optimizer_v2 37from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 38from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 39from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 40from tensorflow.python.training import optimizer as tf_optimizer_module 41from tensorflow.python.util.tf_export import keras_export 42 43 44@keras_export('keras.optimizers.serialize') 45def serialize(optimizer): 46 return serialize_keras_object(optimizer) 47 48 49@keras_export('keras.optimizers.deserialize') 50def deserialize(config, custom_objects=None): 51 """Inverse of the `serialize` function. 52 53 Args: 54 config: Optimizer configuration dictionary. 55 custom_objects: Optional dictionary mapping names (strings) to custom 56 objects (classes and functions) to be considered during deserialization. 57 58 Returns: 59 A Keras Optimizer instance. 60 """ 61 # loss_scale_optimizer has a direct dependency of optimizer, import here 62 # rather than top to avoid the cyclic dependency. 63 from tensorflow.python.keras.mixed_precision import loss_scale_optimizer # pylint: disable=g-import-not-at-top 64 all_classes = { 65 'adadelta': adadelta_v2.Adadelta, 66 'adagrad': adagrad_v2.Adagrad, 67 'adam': adam_v2.Adam, 68 'adamax': adamax_v2.Adamax, 69 'nadam': nadam_v2.Nadam, 70 'rmsprop': rmsprop_v2.RMSprop, 71 'sgd': gradient_descent_v2.SGD, 72 'ftrl': ftrl.Ftrl, 73 'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer, 74 # LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as 75 # LossScaleOptimizerV1 will be removed soon but deserializing it will 76 # still be supported. 77 'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer, 78 } 79 80 # Make deserialization case-insensitive for built-in optimizers. 81 if config['class_name'].lower() in all_classes: 82 config['class_name'] = config['class_name'].lower() 83 return deserialize_keras_object( 84 config, 85 module_objects=all_classes, 86 custom_objects=custom_objects, 87 printable_module_name='optimizer') 88 89 90@keras_export('keras.optimizers.get') 91def get(identifier): 92 """Retrieves a Keras Optimizer instance. 93 94 Args: 95 identifier: Optimizer identifier, one of 96 - String: name of an optimizer 97 - Dictionary: configuration dictionary. - Keras Optimizer instance (it 98 will be returned unchanged). - TensorFlow Optimizer instance (it 99 will be wrapped as a Keras Optimizer). 100 101 Returns: 102 A Keras Optimizer instance. 103 104 Raises: 105 ValueError: If `identifier` cannot be interpreted. 106 """ 107 if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)): 108 return identifier 109 # Wrap TF optimizer instances 110 elif isinstance(identifier, tf_optimizer_module.Optimizer): 111 opt = TFOptimizer(identifier) 112 K.track_tf_optimizer(opt) 113 return opt 114 elif isinstance(identifier, dict): 115 return deserialize(identifier) 116 elif isinstance(identifier, six.string_types): 117 config = {'class_name': str(identifier), 'config': {}} 118 return deserialize(config) 119 else: 120 raise ValueError( 121 'Could not interpret optimizer identifier: {}'.format(identifier)) 122