1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Strategy and optimizer combinations for combinations.combine().""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.distribute import strategy_combinations as strategy_combinations_base 21from tensorflow.python.framework import test_combinations as combinations 22from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_keras_v2 23from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 24from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 25from tensorflow.python.keras.optimizer_v2 import adamax as adamax_keras_v2 26from tensorflow.python.keras.optimizer_v2 import ftrl as ftrl_keras_v2 27from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2 28from tensorflow.python.keras.optimizer_v2 import nadam as nadam_keras_v2 29from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2 30from tensorflow.python.training import adagrad 31from tensorflow.python.training import adam 32from tensorflow.python.training import ftrl 33from tensorflow.python.training import gradient_descent 34from tensorflow.python.training import rmsprop 35 36 37gradient_descent_optimizer_v1_fn = combinations.NamedObject( 38 "GradientDescentV1", 39 lambda: gradient_descent.GradientDescentOptimizer(0.001)) 40adagrad_optimizer_v1_fn = combinations.NamedObject( 41 "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) 42adam_optimizer_v1_fn = combinations.NamedObject( 43 "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) 44ftrl_optimizer_v1_fn = combinations.NamedObject( 45 "FtrlV1", lambda: ftrl.FtrlOptimizer(0.001)) 46rmsprop_optimizer_v1_fn = combinations.NamedObject( 47 "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001)) 48 49# TODO(shiningsun): consider adding the other v1 optimizers 50optimizers_v1 = [ 51 gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn, 52 ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn 53] 54 55adadelta_optimizer_keras_v2_fn = combinations.NamedObject( 56 "AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001)) 57adagrad_optimizer_keras_v2_fn = combinations.NamedObject( 58 "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001)) 59adam_optimizer_keras_v2_fn = combinations.NamedObject( 60 "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0)) 61adamax_optimizer_keras_v2_fn = combinations.NamedObject( 62 "AdamaxKerasV2", lambda: adamax_keras_v2.Adamax(0.001, epsilon=1.0)) 63nadam_optimizer_keras_v2_fn = combinations.NamedObject( 64 "NadamKerasV2", lambda: nadam_keras_v2.Nadam(0.001, epsilon=1.0)) 65ftrl_optimizer_keras_v2_fn = combinations.NamedObject( 66 "FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001)) 67gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject( 68 "GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001)) 69rmsprop_optimizer_keras_v2_fn = combinations.NamedObject( 70 "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001)) 71 72# TODO(shiningsun): consider adding the other v2 optimizers 73optimizers_v2 = [ 74 gradient_descent_optimizer_keras_v2_fn, adagrad_optimizer_keras_v2_fn 75] 76 77optimizers_v1_and_v2 = optimizers_v1 + optimizers_v2 78 79 80def distributions_and_v1_optimizers(): 81 """A common set of combination with DistributionStrategies and Optimizers.""" 82 return combinations.combine( 83 distribution=[ 84 strategy_combinations_base.one_device_strategy, 85 strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu, 86 strategy_combinations_base.mirrored_strategy_with_two_gpus, 87 ], 88 optimizer_fn=optimizers_v1) 89 90 91def distributions_and_v2_optimizers(): 92 """A common set of combination with DistributionStrategies and Optimizers.""" 93 return combinations.combine( 94 distribution=[ 95 strategy_combinations_base.one_device_strategy, 96 strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu, 97 strategy_combinations_base.mirrored_strategy_with_two_gpus, 98 ], 99 optimizer_fn=optimizers_v2) 100 101 102def distributions_and_v1_and_v2_optimizers(): 103 """A common set of combination with DistributionStrategies and Optimizers.""" 104 return combinations.combine( 105 distribution=[ 106 strategy_combinations_base.one_device_strategy, 107 strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu, 108 strategy_combinations_base.mirrored_strategy_with_two_gpus, 109 ], 110 optimizer_fn=optimizers_v1_and_v2) 111