• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Defines `{Additive,Multiplicative}SwapRegretOptimizer`s.
16
17These optimizers minimize a `ConstrainedMinimizationProblem` by using a
18swap-regret minimizing algorithm (either SGD or multiplicative weights) to learn
19what weights should be associated with the objective function and constraints.
20These algorithms do *not* use Lagrange multipliers, but the idea is similar.
21The main differences between the formulation used here, and the standard
22Lagrangian formulation, are that (i) the objective function is weighted, in
23addition to the constraints, and (ii) we learn a matrix of weights, instead of a
24vector.
25
26For the purposes of constrained optimization, at least in theory,
27external-regret minimization suffices if the `ConstrainedMinimizationProblem`
28we're optimizing doesn't have any `proxy_constraints`, while swap-regret
29minimization should be used if `proxy_constraints` are present.
30
31For more specifics, please refer to:
32
33> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
34> Constrained Optimization".
35> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
36
37The formulation used by both of the SwapRegretOptimizers can be found in
38Definition 2, and is discussed in Section 4. The
39`MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4,
40with the difference being that it uses `tf.train.Optimizer`s, instead of SGD,
41for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in
42that it performs additive (instead of multiplicative) updates of the stochastic
43matrix.
44"""
45
46from __future__ import absolute_import
47from __future__ import division
48from __future__ import print_function
49
50import abc
51import math
52
53import six
54
55from tensorflow.contrib.constrained_optimization.python import constrained_optimizer
56
57from tensorflow.python.framework import dtypes
58from tensorflow.python.framework import ops
59from tensorflow.python.ops import control_flow_ops
60from tensorflow.python.ops import standard_ops
61from tensorflow.python.ops import state_ops
62from tensorflow.python.training import optimizer as train_optimizer
63
64
65def _maximal_eigenvector_power_method(matrix,
66                                      epsilon=1e-6,
67                                      maximum_iterations=100):
68  """Returns the maximal right-eigenvector of `matrix` using the power method.
69
70  Args:
71    matrix: 2D Tensor, the matrix of which we will find the maximal
72      right-eigenvector.
73    epsilon: nonnegative float, if two iterations of the power method differ (in
74      L2 norm) by no more than epsilon, we will terminate.
75    maximum_iterations: nonnegative int, if we perform this many iterations, we
76      will terminate.
77
78  Result:
79    The maximal right-eigenvector of `matrix`.
80
81  Raises:
82    ValueError: If the `matrix` tensor is not floating-point, or if the
83      `epsilon` or `maximum_iterations` parameters violate their bounds.
84  """
85  if not matrix.dtype.is_floating:
86    raise ValueError("multipliers must have a floating-point dtype")
87  if epsilon <= 0.0:
88    raise ValueError("epsilon must be strictly positive")
89  if maximum_iterations <= 0:
90    raise ValueError("maximum_iterations must be strictly positive")
91
92  def while_loop_condition(iteration, eigenvector, old_eigenvector):
93    """Returns false if the while loop should terminate."""
94    not_done = (iteration < maximum_iterations)
95    not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon)
96    return standard_ops.logical_and(not_done, not_converged)
97
98  def while_loop_body(iteration, eigenvector, old_eigenvector):
99    """Performs one iteration of the power method."""
100    del old_eigenvector  # Needed by the condition, but not the body.
101    iteration += 1
102    # We need to use tf.matmul() and tf.expand_dims(), instead of
103    # tf.tensordot(), since the former will infer the shape of the result, while
104    # the latter will not (tf.while_loop() needs the shapes).
105    new_eigenvector = standard_ops.matmul(
106        matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0]
107    new_eigenvector /= standard_ops.norm(new_eigenvector)
108    return (iteration, new_eigenvector, eigenvector)
109
110  iteration = standard_ops.constant(0)
111  eigenvector = standard_ops.ones_like(matrix[:, 0])
112  eigenvector /= standard_ops.norm(eigenvector)
113
114  # We actually want a do-while loop, so we explicitly call while_loop_body()
115  # once before tf.while_loop().
116  iteration, eigenvector, old_eigenvector = while_loop_body(
117      iteration, eigenvector, eigenvector)
118  iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop(
119      while_loop_condition,
120      while_loop_body,
121      loop_vars=(iteration, eigenvector, old_eigenvector),
122      name="power_method")
123
124  return eigenvector
125
126
127def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
128  """Projects its argument onto the set of left-stochastic matrices.
129
130  This algorithm is O(n^3) at worst, where `matrix` is n*n. It can be done in
131  O(n^2 * log(n)) time by sorting each column (and maybe better with a different
132  algorithm), but the algorithm implemented here is easier to implement in
133  TensorFlow.
134
135  Args:
136    matrix: 2d square tensor, the matrix to project.
137
138  Returns:
139    The 2d square tensor that results from projecting `matrix` onto the set of
140      left-stochastic matrices w.r.t. the Euclidean norm applied column-wise
141      (i.e. the Frobenius norm).
142
143  Raises:
144    ValueError: if the `matrix` tensor is not floating-point, does not have a
145      fully-known shape, or is not two-dimensional and square.
146  """
147  if not matrix.dtype.is_floating:
148    raise ValueError("multipliers must have a floating-point dtype")
149  matrix_shape = matrix.get_shape()
150  if matrix_shape.ndims is None:
151    raise ValueError("matrix must have known shape")
152  if matrix_shape.ndims != 2:
153    raise ValueError(
154        "matrix must be two dimensional (instead is %d-dimensional)" %
155        matrix_shape.ndims)
156  if matrix_shape[0] != matrix_shape[1]:
157    raise ValueError("matrix must be square (instead has shape (%d,%d))" %
158                     (matrix_shape[0], matrix_shape[1]))
159  dimension = matrix_shape.dims[0].value
160  if dimension is None:
161    raise ValueError("matrix must have fully-known shape")
162
163  def while_loop_condition(iteration, matrix, inactive, old_inactive):
164    """Returns false if the while loop should terminate."""
165    del matrix  # Needed by the body, but not the condition.
166    not_done = (iteration < dimension)
167    not_converged = standard_ops.reduce_any(
168        standard_ops.not_equal(inactive, old_inactive))
169    return standard_ops.logical_and(not_done, not_converged)
170
171  def while_loop_body(iteration, matrix, inactive, old_inactive):
172    """Performs one iteration of the projection."""
173    del old_inactive  # Needed by the condition, but not the body.
174    iteration += 1
175    scale = (1.0 - standard_ops.reduce_sum(
176        matrix, axis=0, keepdims=True)) / standard_ops.maximum(
177            1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
178    matrix = matrix + (scale * inactive)
179    new_inactive = standard_ops.cast(matrix > 0, matrix.dtype)
180    matrix = matrix * new_inactive
181    return (iteration, matrix, new_inactive, inactive)
182
183  iteration = standard_ops.constant(0)
184  inactive = standard_ops.ones_like(matrix, dtype=matrix.dtype)
185
186  # We actually want a do-while loop, so we explicitly call while_loop_body()
187  # once before tf.while_loop().
188  iteration, matrix, inactive, old_inactive = while_loop_body(
189      iteration, matrix, inactive, inactive)
190  iteration, matrix, inactive, old_inactive = control_flow_ops.while_loop(
191      while_loop_condition,
192      while_loop_body,
193      loop_vars=(iteration, matrix, inactive, old_inactive),
194      name="euclidean_projection")
195
196  return matrix
197
198
199def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
200  """Projects its argument onto the set of log-left-stochastic matrices.
201
202  Args:
203    log_matrix: 2d square tensor, the element-wise logarithm of the matrix to
204      project.
205
206  Returns:
207    The 2d square tensor that results from projecting exp(`matrix`) onto the set
208      of left-stochastic matrices w.r.t. the KL-divergence applied column-wise.
209  """
210
211  # For numerical reasons, make sure that the largest matrix element is zero
212  # before exponentiating.
213  log_matrix = log_matrix - standard_ops.reduce_max(
214      log_matrix, axis=0, keepdims=True)
215  log_matrix = log_matrix - standard_ops.log(
216      standard_ops.reduce_sum(
217          standard_ops.exp(log_matrix), axis=0, keepdims=True))
218  return log_matrix
219
220
221@six.add_metaclass(abc.ABCMeta)
222class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
223  """Base class representing a `_SwapRegretOptimizer`.
224
225  This class contains most of the logic for performing constrained optimization,
226  minimizing swap regret for the constraints player. What it *doesn't* do is
227  keep track of the internal state (the stochastic matrix).  Instead, the state
228  is accessed via the _initial_state(), _stochastic_matrix(),
229  _constraint_grad_and_var() and _projection_op() methods.
230
231  The reason for this is that we want to make it easy to implement different
232  representations of the internal state. For example, for additive updates, it's
233  most natural to store the stochastic matrix directly, whereas for
234  multiplicative updates, it's most natural to store its element-wise logarithm.
235
236  For more specifics, please refer to:
237
238  > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
239  > Constrained Optimization".
240  > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
241
242  The formulation used by `_SwapRegretOptimizer`s can be found in Definition 2,
243  and is discussed in Section 4. Such optimizers are most similar to Algorithm
244  2 in Section 4. Most notably, the internal state is a left-stochastic matrix
245  of shape (m+1,m+1), where m is the number of constraints.
246  """
247
248  def __init__(self, optimizer, constraint_optimizer=None):
249    """Constructs a new `_SwapRegretOptimizer`.
250
251    The difference between `optimizer` and `constraint_optimizer` (if the latter
252    is provided) is that the former is used for learning the model parameters,
253    while the latter us used for the update to the constraint/objective weight
254    matrix (the analogue of Lagrange multipliers). If no `constraint_optimizer`
255    is provided, then `optimizer` is used for both.
256
257    Args:
258      optimizer: tf.train.Optimizer, used to optimize the objective and
259        proxy_constraints portion of ConstrainedMinimizationProblem. If
260        constraint_optimizer is not provided, this will also be used to optimize
261        the Lagrange multiplier analogues.
262      constraint_optimizer: optional tf.train.Optimizer, used to optimize the
263        Lagrange multiplier analogues.
264
265    Returns:
266      A new `_SwapRegretOptimizer`.
267    """
268    super(_SwapRegretOptimizer, self).__init__(optimizer=optimizer)
269    self._constraint_optimizer = constraint_optimizer
270
271  @property
272  def constraint_optimizer(self):
273    """Returns the `tf.train.Optimizer` used for the matrix."""
274    return self._constraint_optimizer
275
276  @abc.abstractmethod
277  def _initial_state(self, num_constraints):
278    pass
279
280  @abc.abstractmethod
281  def _stochastic_matrix(self, state):
282    pass
283
284  def _distribution(self, state):
285    distribution = _maximal_eigenvector_power_method(
286        self._stochastic_matrix(state))
287    distribution = standard_ops.abs(distribution)
288    distribution /= standard_ops.reduce_sum(distribution)
289    return distribution
290
291  @abc.abstractmethod
292  def _constraint_grad_and_var(self, state, gradient):
293    pass
294
295  @abc.abstractmethod
296  def _projection_op(self, state, name=None):
297    pass
298
299  def _minimize_constrained(self,
300                            minimization_problem,
301                            global_step=None,
302                            var_list=None,
303                            gate_gradients=train_optimizer.Optimizer.GATE_OP,
304                            aggregation_method=None,
305                            colocate_gradients_with_ops=False,
306                            name=None,
307                            grad_loss=None):
308    """Returns an `Operation` for minimizing the constrained problem.
309
310    The `optimizer` constructor parameter will be used to update the model
311    parameters, while the constraint/objective weight matrix (the analogue of
312    Lagrange multipliers) will be updated using `constrained_optimizer` (if
313    provided) or `optimizer` (if not). Whether the matrix updates are additive
314    or multiplicative depends on the derived class.
315
316    Args:
317      minimization_problem: ConstrainedMinimizationProblem, the problem to
318        optimize.
319      global_step: as in `tf.train.Optimizer`'s `minimize` method.
320      var_list: as in `tf.train.Optimizer`'s `minimize` method.
321      gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
322      aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
323      colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
324        method.
325      name: as in `tf.train.Optimizer`'s `minimize` method.
326      grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
327
328    Raises:
329      ValueError: If the minimization_problem tensors have different dtypes.
330
331    Returns:
332      `Operation`, the train_op.
333    """
334    objective = minimization_problem.objective
335
336    constraints = minimization_problem.constraints
337    proxy_constraints = minimization_problem.proxy_constraints
338    if proxy_constraints is None:
339      proxy_constraints = constraints
340
341    # Make sure that the objective, constraints and proxy constraints all have
342    # the same dtype.
343    if (objective.dtype.base_dtype != constraints.dtype.base_dtype or
344        objective.dtype.base_dtype != proxy_constraints.dtype.base_dtype):
345      raise ValueError("objective, constraints and proxy_constraints must "
346                       "have the same dtype")
347
348    # Flatten both constraints tensors to 1d.
349    num_constraints = minimization_problem.num_constraints
350    constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
351    proxy_constraints = standard_ops.reshape(
352        proxy_constraints, shape=(num_constraints,))
353
354    # We use a lambda to initialize the state so that, if this function call is
355    # inside the scope of a tf.control_dependencies() block, the dependencies
356    # will not be applied to the initializer.
357    state = standard_ops.Variable(
358        lambda: self._initial_state(num_constraints),
359        trainable=False,
360        name="swap_regret_optimizer_state")
361
362    zero_and_constraints = standard_ops.concat(
363        (standard_ops.zeros((1,), dtype=constraints.dtype), constraints),
364        axis=0)
365    objective_and_proxy_constraints = standard_ops.concat(
366        (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0)
367
368    distribution = self._distribution(state)
369    loss = standard_ops.tensordot(
370        standard_ops.cast(distribution, objective_and_proxy_constraints.dtype),
371        objective_and_proxy_constraints, 1)
372    matrix_gradient = standard_ops.matmul(
373        standard_ops.expand_dims(
374            standard_ops.cast(zero_and_constraints, distribution.dtype), 1),
375        standard_ops.expand_dims(distribution, 0))
376
377    update_ops = []
378    if self.constraint_optimizer is None:
379      # If we don't have a separate constraint_optimizer, then we use
380      # self._optimizer for both the update of the model parameters, and that of
381      # the internal state.
382      grads_and_vars = self.optimizer.compute_gradients(
383          loss,
384          var_list=var_list,
385          gate_gradients=gate_gradients,
386          aggregation_method=aggregation_method,
387          colocate_gradients_with_ops=colocate_gradients_with_ops,
388          grad_loss=grad_loss)
389      grads_and_vars.append(
390          self._constraint_grad_and_var(state, matrix_gradient))
391      update_ops.append(
392          self.optimizer.apply_gradients(grads_and_vars, name="update"))
393    else:
394      # If we have a separate constraint_optimizer, then we use self._optimizer
395      # for the update of the model parameters, and self._constraint_optimizer
396      # for that of the internal state.
397      grads_and_vars = self.optimizer.compute_gradients(
398          loss,
399          var_list=var_list,
400          gate_gradients=gate_gradients,
401          aggregation_method=aggregation_method,
402          colocate_gradients_with_ops=colocate_gradients_with_ops,
403          grad_loss=grad_loss)
404      matrix_grads_and_vars = [
405          self._constraint_grad_and_var(state, matrix_gradient)
406      ]
407
408      gradients = [
409          gradient for gradient, _ in grads_and_vars + matrix_grads_and_vars
410          if gradient is not None
411      ]
412      with ops.control_dependencies(gradients):
413        update_ops.append(
414            self.optimizer.apply_gradients(grads_and_vars, name="update"))
415        update_ops.append(
416            self.constraint_optimizer.apply_gradients(
417                matrix_grads_and_vars, name="optimizer_state_update"))
418
419    with ops.control_dependencies(update_ops):
420      if global_step is None:
421        # If we don't have a global step, just project, and we're done.
422        return self._projection_op(state, name=name)
423      else:
424        # If we have a global step, then we need to increment it in addition to
425        # projecting.
426        projection_op = self._projection_op(state, name="project")
427        with ops.colocate_with(global_step):
428          global_step_op = state_ops.assign_add(
429              global_step, 1, name="global_step_increment")
430        return control_flow_ops.group(projection_op, global_step_op, name=name)
431
432
433class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer):
434  """A `ConstrainedOptimizer` based on swap-regret minimization.
435
436  This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
437  minimize over the model parameters, and maximize over constraint/objective
438  weight matrix (the analogue of Lagrange multipliers), with the latter
439  maximization using additive updates and an algorithm that minimizes swap
440  regret.
441
442  For more specifics, please refer to:
443
444  > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
445  > Constrained Optimization".
446  > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
447
448  The formulation used by this optimizer can be found in Definition 2, and is
449  discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with
450  the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for
451  the "inner" updates, and performs additive (instead of multiplicative) updates
452  of the stochastic matrix.
453  """
454
455  def __init__(self, optimizer, constraint_optimizer=None):
456    """Constructs a new `AdditiveSwapRegretOptimizer`.
457
458    Args:
459      optimizer: tf.train.Optimizer, used to optimize the objective and
460        proxy_constraints portion of ConstrainedMinimizationProblem. If
461        constraint_optimizer is not provided, this will also be used to optimize
462        the Lagrange multiplier analogues.
463      constraint_optimizer: optional tf.train.Optimizer, used to optimize the
464        Lagrange multiplier analogues.
465
466    Returns:
467      A new `AdditiveSwapRegretOptimizer`.
468    """
469    # TODO(acotter): add a parameter determining the initial values of the
470    # matrix elements (like initial_multiplier_radius in
471    # MultiplicativeSwapRegretOptimizer).
472    super(AdditiveSwapRegretOptimizer, self).__init__(
473        optimizer=optimizer, constraint_optimizer=constraint_optimizer)
474
475  def _initial_state(self, num_constraints):
476    # For an AdditiveSwapRegretOptimizer, the internal state is a tensor of
477    # shape (m+1,m+1), where m is the number of constraints, representing a
478    # left-stochastic matrix.
479    dimension = num_constraints + 1
480    # Initialize by putting all weight on the objective, and none on the
481    # constraints.
482    return standard_ops.concat(
483        (standard_ops.ones(
484            (1, dimension)), standard_ops.zeros((dimension - 1, dimension))),
485        axis=0)
486
487  def _stochastic_matrix(self, state):
488    return state
489
490  def _constraint_grad_and_var(self, state, gradient):
491    # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
492    return (-gradient, state)
493
494  def _projection_op(self, state, name=None):
495    with ops.colocate_with(state):
496      return state_ops.assign(
497          state,
498          _project_stochastic_matrix_wrt_euclidean_norm(state),
499          name=name)
500
501
502class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer):
503  """A `ConstrainedOptimizer` based on swap-regret minimization.
504
505  This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
506  minimize over the model parameters, and maximize over constraint/objective
507  weight matrix (the analogue of Lagrange multipliers), with the latter
508  maximization using multiplicative updates and an algorithm that minimizes swap
509  regret.
510
511  For more specifics, please refer to:
512
513  > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
514  > Constrained Optimization".
515  > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
516
517  The formulation used by this optimizer can be found in Definition 2, and is
518  discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with
519  the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for
520  the "inner" updates.
521  """
522
523  def __init__(self,
524               optimizer,
525               constraint_optimizer=None,
526               minimum_multiplier_radius=1e-3,
527               initial_multiplier_radius=None):
528    """Constructs a new `MultiplicativeSwapRegretOptimizer`.
529
530    Args:
531      optimizer: tf.train.Optimizer, used to optimize the objective and
532        proxy_constraints portion of ConstrainedMinimizationProblem. If
533        constraint_optimizer is not provided, this will also be used to optimize
534        the Lagrange multiplier analogues.
535      constraint_optimizer: optional tf.train.Optimizer, used to optimize the
536        Lagrange multiplier analogues.
537      minimum_multiplier_radius: float, each element of the matrix will be lower
538        bounded by `minimum_multiplier_radius` divided by one plus the number of
539        constraints.
540      initial_multiplier_radius: float, the initial value of each element of the
541        matrix associated with a constraint (i.e. excluding those elements
542        associated with the objective) will be `initial_multiplier_radius`
543        divided by one plus the number of constraints. Defaults to the value of
544        `minimum_multiplier_radius`.
545
546    Returns:
547      A new `MultiplicativeSwapRegretOptimizer`.
548
549    Raises:
550      ValueError: If the two radius parameters are inconsistent.
551    """
552    super(MultiplicativeSwapRegretOptimizer, self).__init__(
553        optimizer=optimizer, constraint_optimizer=constraint_optimizer)
554
555    if (minimum_multiplier_radius <= 0.0) or (minimum_multiplier_radius >= 1.0):
556      raise ValueError("minimum_multiplier_radius must be in the range (0,1)")
557    if initial_multiplier_radius is None:
558      initial_multiplier_radius = minimum_multiplier_radius
559    elif (initial_multiplier_radius <
560          minimum_multiplier_radius) or (minimum_multiplier_radius > 1.0):
561      raise ValueError("initial_multiplier_radius must be in the range "
562                       "[minimum_multiplier_radius,1]")
563
564    self._minimum_multiplier_radius = minimum_multiplier_radius
565    self._initial_multiplier_radius = initial_multiplier_radius
566
567  def _initial_state(self, num_constraints):
568    # For a MultiplicativeSwapRegretOptimizer, the internal state is a tensor of
569    # shape (m+1,m+1), where m is the number of constraints, representing the
570    # element-wise logarithm of a left-stochastic matrix.
571    dimension = num_constraints + 1
572    # Initialize by putting as much weight as possible on the objective, and as
573    # little as possible on the constraints.
574    log_initial_one = math.log(1.0 - (self._initial_multiplier_radius *
575                                      (dimension - 1) / (dimension)))
576    log_initial_zero = math.log(self._initial_multiplier_radius / dimension)
577    # FUTURE WORK: make the dtype a parameter.
578    return standard_ops.concat(
579        (standard_ops.constant(
580            log_initial_one, dtype=dtypes.float32, shape=(1, dimension)),
581         standard_ops.constant(
582             log_initial_zero,
583             dtype=dtypes.float32,
584             shape=(dimension - 1, dimension))),
585        axis=0)
586
587  def _stochastic_matrix(self, state):
588    return standard_ops.exp(state)
589
590  def _constraint_grad_and_var(self, state, gradient):
591    # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
592    return (-gradient, state)
593
594  def _projection_op(self, state, name=None):
595    with ops.colocate_with(state):
596      # Gets the dimension of the state (num_constraints + 1)--all of these
597      # assertions are of things that should be impossible, since the state
598      # passed into this method will have the same shape as that returned by
599      # _initial_state().
600      state_shape = state.get_shape()
601      assert state_shape is not None
602      assert state_shape.ndims == 2
603      assert state_shape[0] == state_shape[1]
604      dimension = state_shape.dims[0].value
605      assert dimension is not None
606
607      minimum_log_multiplier = standard_ops.log(
608          self._minimum_multiplier_radius / standard_ops.to_float(dimension))
609
610      return state_ops.assign(
611          state,
612          standard_ops.maximum(
613              _project_log_stochastic_matrix_wrt_kl_divergence(state),
614              minimum_log_multiplier),
615          name=name)
616