• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Strategy and optimizer combinations for combinations.combine()."""
16
17from tensorflow.python.distribute import strategy_combinations as strategy_combinations_base
18from tensorflow.python.framework import test_combinations as combinations
19from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_keras_v2
20from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2
21from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2
22from tensorflow.python.keras.optimizer_v2 import adamax as adamax_keras_v2
23from tensorflow.python.keras.optimizer_v2 import ftrl as ftrl_keras_v2
24from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2
25from tensorflow.python.keras.optimizer_v2 import nadam as nadam_keras_v2
26from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2
27from tensorflow.python.training import adagrad
28from tensorflow.python.training import adam
29from tensorflow.python.training import ftrl
30from tensorflow.python.training import gradient_descent
31from tensorflow.python.training import rmsprop
32
33
34gradient_descent_optimizer_v1_fn = combinations.NamedObject(
35    "GradientDescentV1",
36    lambda: gradient_descent.GradientDescentOptimizer(0.001))
37adagrad_optimizer_v1_fn = combinations.NamedObject(
38    "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
39adam_optimizer_v1_fn = combinations.NamedObject(
40    "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
41ftrl_optimizer_v1_fn = combinations.NamedObject(
42    "FtrlV1", lambda: ftrl.FtrlOptimizer(0.001))
43rmsprop_optimizer_v1_fn = combinations.NamedObject(
44    "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
45
46# TODO(shiningsun): consider adding the other v1 optimizers
47optimizers_v1 = [
48    gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn,
49    ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn
50]
51
52adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
53    "AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))
54adagrad_optimizer_keras_v2_fn = combinations.NamedObject(
55    "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001))
56adam_optimizer_keras_v2_fn = combinations.NamedObject(
57    "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0))
58adamax_optimizer_keras_v2_fn = combinations.NamedObject(
59    "AdamaxKerasV2", lambda: adamax_keras_v2.Adamax(0.001, epsilon=1.0))
60nadam_optimizer_keras_v2_fn = combinations.NamedObject(
61    "NadamKerasV2", lambda: nadam_keras_v2.Nadam(0.001, epsilon=1.0))
62ftrl_optimizer_keras_v2_fn = combinations.NamedObject(
63    "FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001))
64gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject(
65    "GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001))
66rmsprop_optimizer_keras_v2_fn = combinations.NamedObject(
67    "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001))
68
69# TODO(shiningsun): consider adding the other v2 optimizers
70optimizers_v2 = [
71    gradient_descent_optimizer_keras_v2_fn, adagrad_optimizer_keras_v2_fn
72]
73
74optimizers_v1_and_v2 = optimizers_v1 + optimizers_v2
75
76
77def distributions_and_v1_optimizers():
78  """A common set of combination with DistributionStrategies and Optimizers."""
79  return combinations.combine(
80      distribution=[
81          strategy_combinations_base.one_device_strategy,
82          strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu,
83          strategy_combinations_base.mirrored_strategy_with_two_gpus,
84          strategy_combinations_base
85          .mirrored_strategy_with_two_gpus_no_merge_call,
86      ],
87      optimizer_fn=optimizers_v1)
88
89
90def distributions_and_v2_optimizers():
91  """A common set of combination with DistributionStrategies and Optimizers."""
92  return combinations.combine(
93      distribution=[
94          strategy_combinations_base.one_device_strategy,
95          strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu,
96          strategy_combinations_base.mirrored_strategy_with_two_gpus,
97          strategy_combinations_base
98          .mirrored_strategy_with_two_gpus_no_merge_call,
99      ],
100      optimizer_fn=optimizers_v2)
101
102
103def distributions_and_v1_and_v2_optimizers():
104  """A common set of combination with DistributionStrategies and Optimizers."""
105  return combinations.combine(
106      distribution=[
107          strategy_combinations_base.one_device_strategy,
108          strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu,
109          strategy_combinations_base.mirrored_strategy_with_two_gpus,
110          strategy_combinations_base
111          .mirrored_strategy_with_two_gpus_no_merge_call,
112      ],
113      optimizer_fn=optimizers_v1_and_v2)
114