• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow interface for third-party optimizers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gradients
26from tensorflow.python.ops import variables
27from tensorflow.python.platform import tf_logging as logging
28
29__all__ = ['ExternalOptimizerInterface', 'ScipyOptimizerInterface']
30
31
32class ExternalOptimizerInterface(object):
33  """Base class for interfaces with external optimization algorithms.
34
35  Subclass this and implement `_minimize` in order to wrap a new optimization
36  algorithm.
37
38  `ExternalOptimizerInterface` should not be instantiated directly; instead use
39  e.g. `ScipyOptimizerInterface`.
40
41  @@__init__
42
43  @@minimize
44  """
45
46  def __init__(self,
47               loss,
48               var_list=None,
49               equalities=None,
50               inequalities=None,
51               var_to_bounds=None,
52               **optimizer_kwargs):
53    """Initialize a new interface instance.
54
55    Args:
56      loss: A scalar `Tensor` to be minimized.
57      var_list: Optional `list` of `Variable` objects to update to minimize
58        `loss`.  Defaults to the list of variables collected in the graph
59        under the key `GraphKeys.TRAINABLE_VARIABLES`.
60      equalities: Optional `list` of equality constraint scalar `Tensor`s to be
61        held equal to zero.
62      inequalities: Optional `list` of inequality constraint scalar `Tensor`s
63        to be held nonnegative.
64      var_to_bounds: Optional `dict` where each key is an optimization
65        `Variable` and each corresponding value is a length-2 tuple of
66        `(low, high)` bounds. Although enforcing this kind of simple constraint
67        could be accomplished with the `inequalities` arg, not all optimization
68        algorithms support general inequality constraints, e.g. L-BFGS-B. Both
69        `low` and `high` can either be numbers or anything convertible to a
70        NumPy array that can be broadcast to the shape of `var` (using
71        `np.broadcast_to`). To indicate that there is no bound, use `None` (or
72        `+/- np.infty`). For example, if `var` is a 2x3 matrix, then any of
73        the following corresponding `bounds` could be supplied:
74        * `(0, np.infty)`: Each element of `var` held positive.
75        * `(-np.infty, [1, 2])`: First column less than 1, second column less
76          than 2.
77        * `(-np.infty, [[1], [2], [3]])`: First row less than 1, second row less
78          than 2, etc.
79        * `(-np.infty, [[1, 2, 3], [4, 5, 6]])`: Entry `var[0, 0]` less than 1,
80          `var[0, 1]` less than 2, etc.
81      **optimizer_kwargs: Other subclass-specific keyword arguments.
82    """
83    self._loss = loss
84    self._equalities = equalities or []
85    self._inequalities = inequalities or []
86
87    if var_list is None:
88      self._vars = variables.trainable_variables()
89    else:
90      self._vars = list(var_list)
91
92    packed_bounds = None
93    if var_to_bounds is not None:
94      left_packed_bounds = []
95      right_packed_bounds = []
96      for var in self._vars:
97        shape = var.get_shape().as_list()
98        bounds = (-np.infty, np.infty)
99        if var in var_to_bounds:
100          bounds = var_to_bounds[var]
101        left_packed_bounds.extend(list(np.broadcast_to(bounds[0], shape).flat))
102        right_packed_bounds.extend(list(np.broadcast_to(bounds[1], shape).flat))
103      packed_bounds = list(zip(left_packed_bounds, right_packed_bounds))
104    self._packed_bounds = packed_bounds
105
106    self._update_placeholders = [
107        array_ops.placeholder(var.dtype) for var in self._vars
108    ]
109    self._var_updates = [
110        var.assign(array_ops.reshape(placeholder, _get_shape_tuple(var)))
111        for var, placeholder in zip(self._vars, self._update_placeholders)
112    ]
113
114    loss_grads = _compute_gradients(loss, self._vars)
115    equalities_grads = [
116        _compute_gradients(equality, self._vars)
117        for equality in self._equalities
118    ]
119    inequalities_grads = [
120        _compute_gradients(inequality, self._vars)
121        for inequality in self._inequalities
122    ]
123
124    self.optimizer_kwargs = optimizer_kwargs
125
126    self._packed_var = self._pack(self._vars)
127    self._packed_loss_grad = self._pack(loss_grads)
128    self._packed_equality_grads = [
129        self._pack(equality_grads) for equality_grads in equalities_grads
130    ]
131    self._packed_inequality_grads = [
132        self._pack(inequality_grads) for inequality_grads in inequalities_grads
133    ]
134
135    dims = [_prod(_get_shape_tuple(var)) for var in self._vars]
136    accumulated_dims = list(_accumulate(dims))
137    self._packing_slices = [
138        slice(start, end)
139        for start, end in zip(accumulated_dims[:-1], accumulated_dims[1:])
140    ]
141
142  def minimize(self,
143               session=None,
144               feed_dict=None,
145               fetches=None,
146               step_callback=None,
147               loss_callback=None,
148               **run_kwargs):
149    """Minimize a scalar `Tensor`.
150
151    Variables subject to optimization are updated in-place at the end of
152    optimization.
153
154    Note that this method does *not* just return a minimization `Op`, unlike
155    `Optimizer.minimize()`; instead it actually performs minimization by
156    executing commands to control a `Session`.
157
158    Args:
159      session: A `Session` instance.
160      feed_dict: A feed dict to be passed to calls to `session.run`.
161      fetches: A list of `Tensor`s to fetch and supply to `loss_callback`
162        as positional arguments.
163      step_callback: A function to be called at each optimization step;
164        arguments are the current values of all optimization variables
165        flattened into a single vector.
166      loss_callback: A function to be called every time the loss and gradients
167        are computed, with evaluated fetches supplied as positional arguments.
168      **run_kwargs: kwargs to pass to `session.run`.
169    """
170    session = session or ops.get_default_session()
171    feed_dict = feed_dict or {}
172    fetches = fetches or []
173
174    loss_callback = loss_callback or (lambda *fetches: None)
175    step_callback = step_callback or (lambda xk: None)
176
177    # Construct loss function and associated gradient.
178    loss_grad_func = self._make_eval_func([self._loss,
179                                           self._packed_loss_grad], session,
180                                          feed_dict, fetches, loss_callback)
181
182    # Construct equality constraint functions and associated gradients.
183    equality_funcs = self._make_eval_funcs(self._equalities, session, feed_dict,
184                                           fetches)
185    equality_grad_funcs = self._make_eval_funcs(self._packed_equality_grads,
186                                                session, feed_dict, fetches)
187
188    # Construct inequality constraint functions and associated gradients.
189    inequality_funcs = self._make_eval_funcs(self._inequalities, session,
190                                             feed_dict, fetches)
191    inequality_grad_funcs = self._make_eval_funcs(self._packed_inequality_grads,
192                                                  session, feed_dict, fetches)
193
194    # Get initial value from TF session.
195    initial_packed_var_val = session.run(self._packed_var)
196
197    # Perform minimization.
198    packed_var_val = self._minimize(
199        initial_val=initial_packed_var_val,
200        loss_grad_func=loss_grad_func,
201        equality_funcs=equality_funcs,
202        equality_grad_funcs=equality_grad_funcs,
203        inequality_funcs=inequality_funcs,
204        inequality_grad_funcs=inequality_grad_funcs,
205        packed_bounds=self._packed_bounds,
206        step_callback=step_callback,
207        optimizer_kwargs=self.optimizer_kwargs)
208    var_vals = [
209        packed_var_val[packing_slice] for packing_slice in self._packing_slices
210    ]
211
212    # Set optimization variables to their new values.
213    session.run(
214        self._var_updates,
215        feed_dict=dict(zip(self._update_placeholders, var_vals)),
216        **run_kwargs)
217
218  def _minimize(self, initial_val, loss_grad_func, equality_funcs,
219                equality_grad_funcs, inequality_funcs, inequality_grad_funcs,
220                packed_bounds, step_callback, optimizer_kwargs):
221    """Wrapper for a particular optimization algorithm implementation.
222
223    It would be appropriate for a subclass implementation of this method to
224    raise `NotImplementedError` if unsupported arguments are passed: e.g. if an
225    algorithm does not support constraints but `len(equality_funcs) > 0`.
226
227    Args:
228      initial_val: A NumPy vector of initial values.
229      loss_grad_func: A function accepting a NumPy packed variable vector and
230        returning two outputs, a loss value and the gradient of that loss with
231        respect to the packed variable vector.
232      equality_funcs: A list of functions each of which specifies a scalar
233        quantity that an optimizer should hold exactly zero.
234      equality_grad_funcs: A list of gradients of equality_funcs.
235      inequality_funcs: A list of functions each of which specifies a scalar
236        quantity that an optimizer should hold >= 0.
237      inequality_grad_funcs: A list of gradients of inequality_funcs.
238      packed_bounds: A list of bounds for each index, or `None`.
239      step_callback: A callback function to execute at each optimization step,
240        supplied with the current value of the packed variable vector.
241      optimizer_kwargs: Other key-value arguments available to the optimizer.
242
243    Returns:
244      The optimal variable vector as a NumPy vector.
245    """
246    raise NotImplementedError(
247        'To use ExternalOptimizerInterface, subclass from it and implement '
248        'the _minimize() method.')
249
250  @classmethod
251  def _pack(cls, tensors):
252    """Pack a list of `Tensor`s into a single, flattened, rank-1 `Tensor`."""
253    if not tensors:
254      return None
255    elif len(tensors) == 1:
256      return array_ops.reshape(tensors[0], [-1])
257    else:
258      flattened = [array_ops.reshape(tensor, [-1]) for tensor in tensors]
259      return array_ops.concat(flattened, 0)
260
261  def _make_eval_func(self, tensors, session, feed_dict, fetches,
262                      callback=None):
263    """Construct a function that evaluates a `Tensor` or list of `Tensor`s."""
264    if not isinstance(tensors, list):
265      tensors = [tensors]
266    num_tensors = len(tensors)
267
268    def eval_func(x):
269      """Function to evaluate a `Tensor`."""
270      augmented_feed_dict = {
271          var: x[packing_slice].reshape(_get_shape_tuple(var))
272          for var, packing_slice in zip(self._vars, self._packing_slices)
273      }
274      augmented_feed_dict.update(feed_dict)
275      augmented_fetches = tensors + fetches
276
277      augmented_fetch_vals = session.run(
278          augmented_fetches, feed_dict=augmented_feed_dict)
279
280      if callable(callback):
281        callback(*augmented_fetch_vals[num_tensors:])
282
283      return augmented_fetch_vals[:num_tensors]
284
285    return eval_func
286
287  def _make_eval_funcs(self,
288                       tensors,
289                       session,
290                       feed_dict,
291                       fetches,
292                       callback=None):
293    return [
294        self._make_eval_func(tensor, session, feed_dict, fetches, callback)
295        for tensor in tensors
296    ]
297
298
299class ScipyOptimizerInterface(ExternalOptimizerInterface):
300  """Wrapper allowing `scipy.optimize.minimize` to operate a `tf.Session`.
301
302  Example:
303
304  ```python
305  vector = tf.Variable([7., 7.], 'vector')
306
307  # Make vector norm as small as possible.
308  loss = tf.reduce_sum(tf.square(vector))
309
310  optimizer = ScipyOptimizerInterface(loss, options={'maxiter': 100})
311
312  with tf.Session() as session:
313    optimizer.minimize(session)
314
315  # The value of vector should now be [0., 0.].
316  ```
317
318  Example with simple bound constraints:
319
320  ```python
321  vector = tf.Variable([7., 7.], 'vector')
322
323  # Make vector norm as small as possible.
324  loss = tf.reduce_sum(tf.square(vector))
325
326  optimizer = ScipyOptimizerInterface(
327      loss, var_to_bounds={vector: ([1, 2], np.infty)})
328
329  with tf.Session() as session:
330    optimizer.minimize(session)
331
332  # The value of vector should now be [1., 2.].
333  ```
334
335  Example with more complicated constraints:
336
337  ```python
338  vector = tf.Variable([7., 7.], 'vector')
339
340  # Make vector norm as small as possible.
341  loss = tf.reduce_sum(tf.square(vector))
342  # Ensure the vector's y component is = 1.
343  equalities = [vector[1] - 1.]
344  # Ensure the vector's x component is >= 1.
345  inequalities = [vector[0] - 1.]
346
347  # Our default SciPy optimization algorithm, L-BFGS-B, does not support
348  # general constraints. Thus we use SLSQP instead.
349  optimizer = ScipyOptimizerInterface(
350      loss, equalities=equalities, inequalities=inequalities, method='SLSQP')
351
352  with tf.Session() as session:
353    optimizer.minimize(session)
354
355  # The value of vector should now be [1., 1.].
356  ```
357  """
358
359  _DEFAULT_METHOD = 'L-BFGS-B'
360
361  def _minimize(self, initial_val, loss_grad_func, equality_funcs,
362                equality_grad_funcs, inequality_funcs, inequality_grad_funcs,
363                packed_bounds, step_callback, optimizer_kwargs):
364
365    def loss_grad_func_wrapper(x):
366      # SciPy's L-BFGS-B Fortran implementation requires gradients as doubles.
367      loss, gradient = loss_grad_func(x)
368      return loss, gradient.astype('float64')
369
370    optimizer_kwargs = dict(optimizer_kwargs.items())
371    method = optimizer_kwargs.pop('method', self._DEFAULT_METHOD)
372
373    constraints = []
374    for func, grad_func in zip(equality_funcs, equality_grad_funcs):
375      constraints.append({'type': 'eq', 'fun': func, 'jac': grad_func})
376    for func, grad_func in zip(inequality_funcs, inequality_grad_funcs):
377      constraints.append({'type': 'ineq', 'fun': func, 'jac': grad_func})
378
379    minimize_args = [loss_grad_func_wrapper, initial_val]
380    minimize_kwargs = {
381        'jac': True,
382        'callback': step_callback,
383        'method': method,
384        'constraints': constraints,
385        'bounds': packed_bounds,
386    }
387
388    for kwarg in minimize_kwargs:
389      if kwarg in optimizer_kwargs:
390        if kwarg == 'bounds':
391          # Special handling for 'bounds' kwarg since ability to specify bounds
392          # was added after this module was already publicly released.
393          raise ValueError(
394              'Bounds must be set using the var_to_bounds argument')
395        raise ValueError(
396            'Optimizer keyword arg \'{}\' is set '
397            'automatically and cannot be injected manually'.format(kwarg))
398
399    minimize_kwargs.update(optimizer_kwargs)
400
401    import scipy.optimize  # pylint: disable=g-import-not-at-top
402    result = scipy.optimize.minimize(*minimize_args, **minimize_kwargs)
403
404    message_lines = [
405        'Optimization terminated with:',
406        '  Message: %s',
407        '  Objective function value: %f',
408    ]
409    message_args = [result.message, result.fun]
410    if hasattr(result, 'nit'):
411      # Some optimization methods might not provide information such as nit and
412      # nfev in the return. Logs only available information.
413      message_lines.append('  Number of iterations: %d')
414      message_args.append(result.nit)
415    if hasattr(result, 'nfev'):
416      message_lines.append('  Number of functions evaluations: %d')
417      message_args.append(result.nfev)
418    logging.info('\n'.join(message_lines), *message_args)
419
420    return result['x']
421
422
423def _accumulate(list_):
424  total = 0
425  yield total
426  for x in list_:
427    total += x
428    yield total
429
430
431def _get_shape_tuple(tensor):
432  return tuple(tensor.get_shape().as_list())
433
434
435def _prod(array):
436  prod = 1
437  for value in array:
438    prod *= value
439  return prod
440
441
442def _compute_gradients(tensor, var_list):
443  grads = gradients.gradients(tensor, var_list)
444  # tf.gradients sometimes returns `None` when it should return 0.
445  return [
446      grad if grad is not None else array_ops.zeros_like(var)
447      for var, grad in zip(var_list, grads)
448  ]
449