1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Strategy and optimizer combinations for combinations.combine().""" 16 17from tensorflow.python.distribute import strategy_combinations as strategy_combinations_base 18from tensorflow.python.framework import test_combinations as combinations 19from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_keras_v2 20from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 21from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 22from tensorflow.python.keras.optimizer_v2 import adamax as adamax_keras_v2 23from tensorflow.python.keras.optimizer_v2 import ftrl as ftrl_keras_v2 24from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2 25from tensorflow.python.keras.optimizer_v2 import nadam as nadam_keras_v2 26from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2 27from tensorflow.python.training import adagrad 28from tensorflow.python.training import adam 29from tensorflow.python.training import ftrl 30from tensorflow.python.training import gradient_descent 31from tensorflow.python.training import rmsprop 32 33 34gradient_descent_optimizer_v1_fn = combinations.NamedObject( 35 "GradientDescentV1", 36 lambda: gradient_descent.GradientDescentOptimizer(0.001)) 37adagrad_optimizer_v1_fn = combinations.NamedObject( 38 "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) 39adam_optimizer_v1_fn = combinations.NamedObject( 40 "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) 41ftrl_optimizer_v1_fn = combinations.NamedObject( 42 "FtrlV1", lambda: ftrl.FtrlOptimizer(0.001)) 43rmsprop_optimizer_v1_fn = combinations.NamedObject( 44 "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001)) 45 46# TODO(shiningsun): consider adding the other v1 optimizers 47optimizers_v1 = [ 48 gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn, 49 ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn 50] 51 52adadelta_optimizer_keras_v2_fn = combinations.NamedObject( 53 "AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001)) 54adagrad_optimizer_keras_v2_fn = combinations.NamedObject( 55 "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001)) 56adam_optimizer_keras_v2_fn = combinations.NamedObject( 57 "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0)) 58adamax_optimizer_keras_v2_fn = combinations.NamedObject( 59 "AdamaxKerasV2", lambda: adamax_keras_v2.Adamax(0.001, epsilon=1.0)) 60nadam_optimizer_keras_v2_fn = combinations.NamedObject( 61 "NadamKerasV2", lambda: nadam_keras_v2.Nadam(0.001, epsilon=1.0)) 62ftrl_optimizer_keras_v2_fn = combinations.NamedObject( 63 "FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001)) 64gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject( 65 "GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001)) 66rmsprop_optimizer_keras_v2_fn = combinations.NamedObject( 67 "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001)) 68 69# TODO(shiningsun): consider adding the other v2 optimizers 70optimizers_v2 = [ 71 gradient_descent_optimizer_keras_v2_fn, adagrad_optimizer_keras_v2_fn 72] 73 74optimizers_v1_and_v2 = optimizers_v1 + optimizers_v2 75 76 77def distributions_and_v1_optimizers(): 78 """A common set of combination with DistributionStrategies and Optimizers.""" 79 return combinations.combine( 80 distribution=[ 81 strategy_combinations_base.one_device_strategy, 82 strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu, 83 strategy_combinations_base.mirrored_strategy_with_two_gpus, 84 strategy_combinations_base 85 .mirrored_strategy_with_two_gpus_no_merge_call, 86 ], 87 optimizer_fn=optimizers_v1) 88 89 90def distributions_and_v2_optimizers(): 91 """A common set of combination with DistributionStrategies and Optimizers.""" 92 return combinations.combine( 93 distribution=[ 94 strategy_combinations_base.one_device_strategy, 95 strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu, 96 strategy_combinations_base.mirrored_strategy_with_two_gpus, 97 strategy_combinations_base 98 .mirrored_strategy_with_two_gpus_no_merge_call, 99 ], 100 optimizer_fn=optimizers_v2) 101 102 103def distributions_and_v1_and_v2_optimizers(): 104 """A common set of combination with DistributionStrategies and Optimizers.""" 105 return combinations.combine( 106 distribution=[ 107 strategy_combinations_base.one_device_strategy, 108 strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu, 109 strategy_combinations_base.mirrored_strategy_with_two_gpus, 110 strategy_combinations_base 111 .mirrored_strategy_with_two_gpus_no_merge_call, 112 ], 113 optimizer_fn=optimizers_v1_and_v2) 114