• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Strategy and optimizer combinations for combinations.combine()."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.distribute import strategy_combinations as strategy_combinations_base
21from tensorflow.python.framework import test_combinations as combinations
22from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_keras_v2
23from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2
24from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2
25from tensorflow.python.keras.optimizer_v2 import adamax as adamax_keras_v2
26from tensorflow.python.keras.optimizer_v2 import ftrl as ftrl_keras_v2
27from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2
28from tensorflow.python.keras.optimizer_v2 import nadam as nadam_keras_v2
29from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2
30from tensorflow.python.training import adagrad
31from tensorflow.python.training import adam
32from tensorflow.python.training import ftrl
33from tensorflow.python.training import gradient_descent
34from tensorflow.python.training import rmsprop
35
36
37gradient_descent_optimizer_v1_fn = combinations.NamedObject(
38    "GradientDescentV1",
39    lambda: gradient_descent.GradientDescentOptimizer(0.001))
40adagrad_optimizer_v1_fn = combinations.NamedObject(
41    "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
42adam_optimizer_v1_fn = combinations.NamedObject(
43    "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
44ftrl_optimizer_v1_fn = combinations.NamedObject(
45    "FtrlV1", lambda: ftrl.FtrlOptimizer(0.001))
46rmsprop_optimizer_v1_fn = combinations.NamedObject(
47    "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
48
49# TODO(shiningsun): consider adding the other v1 optimizers
50optimizers_v1 = [
51    gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn,
52    ftrl_optimizer_v1_fn, rmsprop_optimizer_v1_fn
53]
54
55adadelta_optimizer_keras_v2_fn = combinations.NamedObject(
56    "AdadeltaKerasV2", lambda: adadelta_keras_v2.Adadelta(0.001))
57adagrad_optimizer_keras_v2_fn = combinations.NamedObject(
58    "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001))
59adam_optimizer_keras_v2_fn = combinations.NamedObject(
60    "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0))
61adamax_optimizer_keras_v2_fn = combinations.NamedObject(
62    "AdamaxKerasV2", lambda: adamax_keras_v2.Adamax(0.001, epsilon=1.0))
63nadam_optimizer_keras_v2_fn = combinations.NamedObject(
64    "NadamKerasV2", lambda: nadam_keras_v2.Nadam(0.001, epsilon=1.0))
65ftrl_optimizer_keras_v2_fn = combinations.NamedObject(
66    "FtrlKerasV2", lambda: ftrl_keras_v2.Ftrl(0.001))
67gradient_descent_optimizer_keras_v2_fn = combinations.NamedObject(
68    "GradientDescentKerasV2", lambda: gradient_descent_keras_v2.SGD(0.001))
69rmsprop_optimizer_keras_v2_fn = combinations.NamedObject(
70    "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001))
71
72# TODO(shiningsun): consider adding the other v2 optimizers
73optimizers_v2 = [
74    gradient_descent_optimizer_keras_v2_fn, adagrad_optimizer_keras_v2_fn
75]
76
77optimizers_v1_and_v2 = optimizers_v1 + optimizers_v2
78
79
80def distributions_and_v1_optimizers():
81  """A common set of combination with DistributionStrategies and Optimizers."""
82  return combinations.combine(
83      distribution=[
84          strategy_combinations_base.one_device_strategy,
85          strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu,
86          strategy_combinations_base.mirrored_strategy_with_two_gpus,
87      ],
88      optimizer_fn=optimizers_v1)
89
90
91def distributions_and_v2_optimizers():
92  """A common set of combination with DistributionStrategies and Optimizers."""
93  return combinations.combine(
94      distribution=[
95          strategy_combinations_base.one_device_strategy,
96          strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu,
97          strategy_combinations_base.mirrored_strategy_with_two_gpus,
98      ],
99      optimizer_fn=optimizers_v2)
100
101
102def distributions_and_v1_and_v2_optimizers():
103  """A common set of combination with DistributionStrategies and Optimizers."""
104  return combinations.combine(
105      distribution=[
106          strategy_combinations_base.one_device_strategy,
107          strategy_combinations_base.mirrored_strategy_with_gpu_and_cpu,
108          strategy_combinations_base.mirrored_strategy_with_two_gpus,
109      ],
110      optimizer_fn=optimizers_v1_and_v2)
111