• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Defines base class for `ConstrainedOptimizer`s."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import standard_ops
28from tensorflow.python.training import optimizer as train_optimizer
29
30
31@six.add_metaclass(abc.ABCMeta)
32class ConstrainedOptimizer(object):
33  """Base class representing a constrained optimizer.
34
35  A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and
36  applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer,
37  which takes a tensor to minimize as a parameter to its minimize() method, a
38  constrained optimizer instead takes a ConstrainedMinimizationProblem.
39  """
40
41  def __init__(self, optimizer):
42    """Constructs a new `ConstrainedOptimizer`.
43
44    Args:
45      optimizer: tf.train.Optimizer, used to optimize the
46        ConstraintedMinimizationProblem.
47
48    Returns:
49      A new `ConstrainedOptimizer`.
50    """
51    self._optimizer = optimizer
52
53  @property
54  def optimizer(self):
55    """Returns the `tf.train.Optimizer` used for optimization."""
56    return self._optimizer
57
58  @abc.abstractmethod
59  def _minimize_constrained(self,
60                            minimization_problem,
61                            global_step=None,
62                            var_list=None,
63                            gate_gradients=train_optimizer.Optimizer.GATE_OP,
64                            aggregation_method=None,
65                            colocate_gradients_with_ops=False,
66                            name=None,
67                            grad_loss=None):
68    """Version of `minimize_constrained` to be overridden by subclasses.
69
70    Implementations of this method should ignore the `pre_train_ops` property of
71    the `minimization_problem`. The public `minimize_constrained` method will
72    take care of executing these before the returned train_op.
73
74    Args:
75      minimization_problem: ConstrainedMinimizationProblem, the problem to
76        optimize.
77      global_step: as in `tf.train.Optimizer`'s `minimize` method.
78      var_list: as in `tf.train.Optimizer`'s `minimize` method.
79      gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
80      aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
81      colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
82        method.
83      name: as in `tf.train.Optimizer`'s `minimize` method.
84      grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
85
86    Returns:
87      `Operation`, the train_op.
88    """
89    pass
90
91  def minimize_constrained(self,
92                           minimization_problem,
93                           global_step=None,
94                           var_list=None,
95                           gate_gradients=train_optimizer.Optimizer.GATE_OP,
96                           aggregation_method=None,
97                           colocate_gradients_with_ops=False,
98                           name=None,
99                           grad_loss=None):
100    """Returns an `Operation` for minimizing the constrained problem.
101
102    Unlike `minimize_unconstrained`, this function attempts to find a solution
103    that minimizes the `objective` portion of the minimization problem while
104    satisfying the `constraints` portion.
105
106    Args:
107      minimization_problem: ConstrainedMinimizationProblem, the problem to
108        optimize.
109      global_step: as in `tf.train.Optimizer`'s `minimize` method.
110      var_list: as in `tf.train.Optimizer`'s `minimize` method.
111      gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
112      aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
113      colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
114        method.
115      name: as in `tf.train.Optimizer`'s `minimize` method.
116      grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
117
118    Returns:
119      `Operation`, the train_op.
120    """
121
122    def train_op_callback():
123      return self._minimize_constrained(
124          minimization_problem,
125          global_step=global_step,
126          var_list=var_list,
127          gate_gradients=gate_gradients,
128          aggregation_method=aggregation_method,
129          colocate_gradients_with_ops=colocate_gradients_with_ops,
130          name=name,
131          grad_loss=grad_loss)
132
133    # If we have pre_train_ops, use tf.control_dependencies() to ensure that
134    # they execute before the train_op.
135    pre_train_ops = minimization_problem.pre_train_ops
136    if pre_train_ops:
137      with ops.control_dependencies(pre_train_ops):
138        train_op = train_op_callback()
139    else:
140      train_op = train_op_callback()
141
142    return train_op
143
144  def minimize_unconstrained(self,
145                             minimization_problem,
146                             global_step=None,
147                             var_list=None,
148                             gate_gradients=train_optimizer.Optimizer.GATE_OP,
149                             aggregation_method=None,
150                             colocate_gradients_with_ops=False,
151                             name=None,
152                             grad_loss=None):
153    """Returns an `Operation` for minimizing the unconstrained problem.
154
155    Unlike `minimize_constrained`, this function ignores the `constraints` (and
156    `proxy_constraints`) portion of the minimization problem entirely, and only
157    minimizes `objective`.
158
159    Args:
160      minimization_problem: ConstrainedMinimizationProblem, the problem to
161        optimize.
162      global_step: as in `tf.train.Optimizer`'s `minimize` method.
163      var_list: as in `tf.train.Optimizer`'s `minimize` method.
164      gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
165      aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
166      colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
167        method.
168      name: as in `tf.train.Optimizer`'s `minimize` method.
169      grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
170
171    Returns:
172      `Operation`, the train_op.
173    """
174
175    def train_op_callback():
176      return self.optimizer.minimize(
177          minimization_problem.objective,
178          global_step=global_step,
179          var_list=var_list,
180          gate_gradients=gate_gradients,
181          aggregation_method=aggregation_method,
182          colocate_gradients_with_ops=colocate_gradients_with_ops,
183          name=name,
184          grad_loss=grad_loss)
185
186    # If we have pre_train_ops, use tf.control_dependencies() to ensure that
187    # they execute before the train_op.
188    pre_train_ops = minimization_problem.pre_train_ops
189    if pre_train_ops:
190      with ops.control_dependencies(pre_train_ops):
191        train_op = train_op_callback()
192    else:
193      train_op = train_op_callback()
194
195    return train_op
196
197  def minimize(self,
198               minimization_problem,
199               unconstrained_steps=None,
200               global_step=None,
201               var_list=None,
202               gate_gradients=train_optimizer.Optimizer.GATE_OP,
203               aggregation_method=None,
204               colocate_gradients_with_ops=False,
205               name=None,
206               grad_loss=None):
207    """Returns an `Operation` for minimizing the constrained problem.
208
209    This method combines the functionality of `minimize_unconstrained` and
210    `minimize_constrained`. If global_step < unconstrained_steps, it will
211    perform an unconstrained update, and if global_step >= unconstrained_steps,
212    it will perform a constrained update.
213
214    The reason for this functionality is that it may be best to initialize the
215    constrained optimizer with an approximate optimum of the unconstrained
216    problem.
217
218    Args:
219      minimization_problem: ConstrainedMinimizationProblem, the problem to
220        optimize.
221      unconstrained_steps: int, number of steps for which we should perform
222        unconstrained updates, before transitioning to constrained updates.
223      global_step: as in `tf.train.Optimizer`'s `minimize` method.
224      var_list: as in `tf.train.Optimizer`'s `minimize` method.
225      gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
226      aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
227      colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
228        method.
229      name: as in `tf.train.Optimizer`'s `minimize` method.
230      grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
231
232    Returns:
233      `Operation`, the train_op.
234
235    Raises:
236      ValueError: If unconstrained_steps is provided, but global_step is not.
237    """
238
239    def unconstrained_fn():
240      """Returns an `Operation` for minimizing the unconstrained problem."""
241      return self.minimize_unconstrained(
242          minimization_problem=minimization_problem,
243          global_step=global_step,
244          var_list=var_list,
245          gate_gradients=gate_gradients,
246          aggregation_method=aggregation_method,
247          colocate_gradients_with_ops=colocate_gradients_with_ops,
248          name=name,
249          grad_loss=grad_loss)
250
251    def constrained_fn():
252      """Returns an `Operation` for minimizing the constrained problem."""
253      return self.minimize_constrained(
254          minimization_problem=minimization_problem,
255          global_step=global_step,
256          var_list=var_list,
257          gate_gradients=gate_gradients,
258          aggregation_method=aggregation_method,
259          colocate_gradients_with_ops=colocate_gradients_with_ops,
260          name=name,
261          grad_loss=grad_loss)
262
263    if unconstrained_steps is not None:
264      if global_step is None:
265        raise ValueError(
266            "global_step cannot be None if unconstrained_steps is provided")
267      unconstrained_steps_tensor = ops.convert_to_tensor(unconstrained_steps)
268      dtype = unconstrained_steps_tensor.dtype
269      return control_flow_ops.cond(
270          standard_ops.cast(global_step, dtype) < unconstrained_steps_tensor,
271          true_fn=unconstrained_fn,
272          false_fn=constrained_fn)
273    else:
274      return constrained_fn()
275