• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Base class for optimizers."""
17# pylint: disable=g-bad-name
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24
25import six
26
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.distribute import distribute_utils
29from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
30from tensorflow.python.distribute import reduce_util as ds_reduce_util
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gradients
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables
43from tensorflow.python.training import slot_creator
44from tensorflow.python.training.tracking import base as trackable
45from tensorflow.python.util import nest
46from tensorflow.python.util.tf_export import tf_export
47
48
49def get_filtered_grad_fn(grad_fn):
50  # `distributed_context.join()` requires that its arguments are parallel
51  # across threads, and in particular that `grads_and_vars` has the same
52  # variables in the same order.
53
54  # When computing gradients in eager mode with multiple threads, you
55  # can get extra variables with a gradient of `None`. This happens when
56  # those variables are accessed in another thread during the gradient
57  # computation. To get a consistent set of variables, we filter out
58  # those with `None` gradients.
59  def filtered_grad_fn(*args, **kwargs):
60    return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None]
61
62  return filtered_grad_fn
63
64
65def _deduplicate_indexed_slices(values, indices):
66  """Sums `values` associated with any non-unique `indices`.
67
68  Args:
69    values: A `Tensor` with rank >= 1.
70    indices: A one-dimensional integer `Tensor`, indexing into the first
71      dimension of `values` (as in an IndexedSlices object).
72  Returns:
73    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
74    de-duplicated version of `indices` and `summed_values` contains the sum of
75    `values` slices associated with each unique index.
76  """
77  unique_indices, new_index_positions = array_ops.unique(indices)
78  summed_values = math_ops.unsorted_segment_sum(
79      values, new_index_positions,
80      array_ops.shape(unique_indices)[0])
81  return (summed_values, unique_indices)
82
83
84def _var_key(var):
85  """Returns slot key for `var`."""
86  # pylint: disable=protected-access
87  if hasattr(var, "_distributed_container"):
88    var = var._distributed_container()
89  if (distribute_utils.is_distributed_variable(var) and
90      not ops.executing_eagerly_outside_functions()):
91    return (var.graph, var._shared_name)
92  if hasattr(var, "op"):
93    return (var.op.graph, var.op.name)
94  return var._unique_id
95  # pylint: enable=protected-access
96
97
98@six.add_metaclass(abc.ABCMeta)
99class _OptimizableVariable(object):
100  """Interface for abstracting over variables in the optimizers."""
101
102  @abc.abstractmethod
103  def target(self):
104    """Returns the optimization target for this variable."""
105    raise NotImplementedError("Calling an abstract method.")
106
107  @abc.abstractmethod
108  def update_op(self, optimizer, g):
109    """Returns the update ops for updating the variable."""
110    raise NotImplementedError("Calling an abstract method.")
111
112
113class _RefVariableProcessor(_OptimizableVariable):
114  """Processor for Variable."""
115
116  def __init__(self, v):
117    self._v = v
118
119  def __str__(self):
120    return "<_RefVariableProcessor(%s)>" % self._v
121
122  def target(self):
123    return self._v._ref()  # pylint: disable=protected-access
124
125  def update_op(self, optimizer, g):
126    if isinstance(g, ops.Tensor):
127      update_op = optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
128      if self._v.constraint is not None:
129        with ops.control_dependencies([update_op]):
130          return self._v.assign(self._v.constraint(self._v))
131      else:
132        return update_op
133    else:
134      assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
135                                                "tensor nor IndexedSlices.")
136      if self._v.constraint is not None:
137        raise RuntimeError(
138            "Cannot use a constraint function on a sparse variable.")
139      # pylint: disable=protected-access
140      return optimizer._apply_sparse_duplicate_indices(g, self._v)
141
142
143class _DenseReadResourceVariableProcessor(_OptimizableVariable):
144  """Processor for dense ResourceVariables."""
145
146  def __init__(self, v):
147    self._v = v
148
149  def target(self):
150    return self._v
151
152  def update_op(self, optimizer, g):
153    # pylint: disable=protected-access
154    update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
155    if self._v.constraint is not None:
156      with ops.control_dependencies([update_op]):
157        return self._v.assign(self._v.constraint(self._v))
158    else:
159      return update_op
160
161
162class _DenseResourceVariableProcessor(_OptimizableVariable):
163  """Processor for dense ResourceVariables."""
164
165  def __init__(self, v):
166    self._v = v
167
168  def target(self):
169    return self._v
170
171  def update_op(self, optimizer, g):
172    # pylint: disable=protected-access
173    if isinstance(g, ops.IndexedSlices):
174      if self._v.constraint is not None:
175        raise RuntimeError(
176            "Cannot use a constraint function on a sparse variable.")
177      return optimizer._resource_apply_sparse_duplicate_indices(
178          g.values, self._v, g.indices)
179    update_op = optimizer._resource_apply_dense(g, self._v)
180    if self._v.constraint is not None:
181      with ops.control_dependencies([update_op]):
182        return self._v.assign(self._v.constraint(self._v))
183    else:
184      return update_op
185
186
187class _TensorProcessor(_OptimizableVariable):
188  """Processor for ordinary Tensors.
189
190  Even though a Tensor can't really be updated, sometimes it is useful to
191  compute the gradients with respect to a Tensor using the optimizer. Updating
192  the Tensor is, of course, unsupported.
193  """
194
195  def __init__(self, v):
196    self._v = v
197
198  def target(self):
199    return self._v
200
201  def update_op(self, optimizer, g):
202    raise NotImplementedError("Trying to update a Tensor ", self._v)
203
204
205def _get_processor(v):
206  """The processor of v."""
207  if context.executing_eagerly():
208    if isinstance(v, ops.Tensor):
209      return _TensorProcessor(v)
210    else:
211      return _DenseResourceVariableProcessor(v)
212  if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode:  # pylint: disable=protected-access
213    # True if and only if `v` was initialized eagerly.
214    return _DenseResourceVariableProcessor(v)
215  if v.op.type == "VarHandleOp":
216    return _DenseResourceVariableProcessor(v)
217  if isinstance(v, variables.Variable):
218    return _RefVariableProcessor(v)
219  if isinstance(v, ops.Tensor):
220    return _TensorProcessor(v)
221  raise NotImplementedError("Trying to optimize unsupported type ", v)
222
223
224@tf_export(v1=["train.Optimizer"])
225class Optimizer(
226    # Optimizers inherit from Trackable rather than AutoTrackable
227    # since they do most of their dependency management themselves (slot
228    # variables are special-cased, and non-slot variables are keyed to graphs).
229    trackable.Trackable):
230  """Base class for optimizers.
231
232  This class defines the API to add Ops to train a model.  You never use this
233  class directly, but instead instantiate one of its subclasses such as
234  `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
235
236  ### Usage
237
238  ```python
239  # Create an optimizer with the desired parameters.
240  opt = GradientDescentOptimizer(learning_rate=0.1)
241  # Add Ops to the graph to minimize a cost by updating a list of variables.
242  # "cost" is a Tensor, and the list of variables contains tf.Variable
243  # objects.
244  opt_op = opt.minimize(cost, var_list=<list of variables>)
245  ```
246
247  In the training program you will just have to run the returned Op.
248
249  ```python
250  # Execute opt_op to do one step of training:
251  opt_op.run()
252  ```
253
254  ### Processing gradients before applying them.
255
256  Calling `minimize()` takes care of both computing the gradients and
257  applying them to the variables.  If you want to process the gradients
258  before applying them you can instead use the optimizer in three steps:
259
260  1.  Compute the gradients with `compute_gradients()`.
261  2.  Process the gradients as you wish.
262  3.  Apply the processed gradients with `apply_gradients()`.
263
264  Example:
265
266  ```python
267  # Create an optimizer.
268  opt = GradientDescentOptimizer(learning_rate=0.1)
269
270  # Compute the gradients for a list of variables.
271  grads_and_vars = opt.compute_gradients(loss, <list of variables>)
272
273  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
274  # need to the 'gradient' part, for example cap them, etc.
275  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
276
277  # Ask the optimizer to apply the capped gradients.
278  opt.apply_gradients(capped_grads_and_vars)
279  ```
280
281  ### Gating Gradients
282
283  Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
284  argument that controls the degree of parallelism during the application of
285  the gradients.
286
287  The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
288
289  <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
290  the maximum parallelism in execution, at the cost of some non-reproducibility
291  in the results.  For example the two gradients of `matmul` depend on the input
292  values: With `GATE_NONE` one of the gradients could be applied to one of the
293  inputs _before_ the other gradient is computed resulting in non-reproducible
294  results.
295
296  <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
297  they are used.  This prevents race conditions for Ops that generate gradients
298  for multiple inputs where the gradients depend on the inputs.
299
300  <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
301  before any one of them is used.  This provides the least parallelism but can
302  be useful if you want to process all gradients before applying any of them.
303
304  ### Slots
305
306  Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
307  allocate and manage additional variables associated with the variables to
308  train.  These are called <i>Slots</i>.  Slots have names and you can ask the
309  optimizer for the names of the slots that it uses.  Once you have a slot name
310  you can ask the optimizer for the variable it created to hold the slot value.
311
312  This can be useful if you want to log debug a training algorithm, report stats
313  about the slots, etc.
314  """
315
316  # Values for gate_gradients.
317  GATE_NONE = 0
318  GATE_OP = 1
319  GATE_GRAPH = 2
320
321  def __init__(self, use_locking, name):
322    """Create a new Optimizer.
323
324    This must be called by the constructors of subclasses.
325
326    Args:
327      use_locking: Bool. If True apply use locks to prevent concurrent updates
328        to variables.
329      name: A non-empty string.  The name to use for accumulators created
330        for the optimizer.
331
332    Raises:
333      ValueError: If name is malformed.
334    """
335    if not name:
336      raise ValueError("Must specify the optimizer name")
337    self._use_locking = use_locking
338    self._name = name
339    # Dictionary of slots.
340    #  {slot_name :
341    #      {_var_key(variable_to_train): slot_for_the_variable, ... },
342    #   ... }
343    self._slots = {}
344    self._non_slot_dict = {}
345    # For implementing Trackable. Stores information about how to restore
346    # slot variables which have not yet been created
347    # (trackable._CheckpointPosition objects).
348    #  {slot_name :
349    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
350    #   ... }
351    self._deferred_slot_restorations = {}
352
353    # TODO(isaprykin): When using a DistributionStrategy, and when an
354    # optimizer is created in each replica, it might be dangerous to
355    # rely on some Optimizer methods.  When such methods are called on a
356    # per-replica optimizer, an exception needs to be thrown.  We do
357    # allow creation per-replica optimizers however, because the
358    # compute_gradients()->apply_gradients() sequence is safe.
359
360  def get_name(self):
361    return self._name
362
363  def minimize(self, loss, global_step=None, var_list=None,
364               gate_gradients=GATE_OP, aggregation_method=None,
365               colocate_gradients_with_ops=False, name=None,
366               grad_loss=None):
367    """Add operations to minimize `loss` by updating `var_list`.
368
369    This method simply combines calls `compute_gradients()` and
370    `apply_gradients()`. If you want to process the gradient before applying
371    them call `compute_gradients()` and `apply_gradients()` explicitly instead
372    of using this function.
373
374    Args:
375      loss: A `Tensor` containing the value to minimize.
376      global_step: Optional `Variable` to increment by one after the
377        variables have been updated.
378      var_list: Optional list or tuple of `Variable` objects to update to
379        minimize `loss`.  Defaults to the list of variables collected in
380        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
381      gate_gradients: How to gate the computation of gradients.  Can be
382        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
383      aggregation_method: Specifies the method used to combine gradient terms.
384        Valid values are defined in the class `AggregationMethod`.
385      colocate_gradients_with_ops: If True, try colocating gradients with
386        the corresponding op.
387      name: Optional name for the returned operation.
388      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
389
390    Returns:
391      An Operation that updates the variables in `var_list`.  If `global_step`
392      was not `None`, that operation also increments `global_step`.
393
394    Raises:
395      ValueError: If some of the variables are not `Variable` objects.
396
397    @compatibility(eager)
398    When eager execution is enabled, `loss` should be a Python function that
399    takes no arguments and computes the value to be minimized. Minimization (and
400    gradient computation) is done with respect to the elements of `var_list` if
401    not None, else with respect to any trainable variables created during the
402    execution of the `loss` function. `gate_gradients`, `aggregation_method`,
403    `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
404    execution is enabled.
405    @end_compatibility
406    """
407    grads_and_vars = self.compute_gradients(
408        loss, var_list=var_list, gate_gradients=gate_gradients,
409        aggregation_method=aggregation_method,
410        colocate_gradients_with_ops=colocate_gradients_with_ops,
411        grad_loss=grad_loss)
412
413    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
414    if not vars_with_grad:
415      raise ValueError(
416          "No gradients provided for any variable, check your graph for ops"
417          " that do not support gradients, between variables %s and loss %s." %
418          ([str(v) for _, v in grads_and_vars], loss))
419
420    return self.apply_gradients(grads_and_vars, global_step=global_step,
421                                name=name)
422
423  def compute_gradients(self, loss, var_list=None,
424                        gate_gradients=GATE_OP,
425                        aggregation_method=None,
426                        colocate_gradients_with_ops=False,
427                        grad_loss=None):
428    """Compute gradients of `loss` for the variables in `var_list`.
429
430    This is the first part of `minimize()`.  It returns a list
431    of (gradient, variable) pairs where "gradient" is the gradient
432    for "variable".  Note that "gradient" can be a `Tensor`, an
433    `IndexedSlices`, or `None` if there is no gradient for the
434    given variable.
435
436    Args:
437      loss: A Tensor containing the value to minimize or a callable taking
438        no arguments which returns the value to minimize. When eager execution
439        is enabled it must be a callable.
440      var_list: Optional list or tuple of `tf.Variable` to update to minimize
441        `loss`.  Defaults to the list of variables collected in the graph
442        under the key `GraphKeys.TRAINABLE_VARIABLES`.
443      gate_gradients: How to gate the computation of gradients.  Can be
444        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
445      aggregation_method: Specifies the method used to combine gradient terms.
446        Valid values are defined in the class `AggregationMethod`.
447      colocate_gradients_with_ops: If True, try colocating gradients with
448        the corresponding op.
449      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
450
451    Returns:
452      A list of (gradient, variable) pairs. Variable is always present, but
453      gradient can be `None`.
454
455    Raises:
456      TypeError: If `var_list` contains anything else than `Variable` objects.
457      ValueError: If some arguments are invalid.
458      RuntimeError: If called with eager execution enabled and `loss` is
459        not callable.
460
461    @compatibility(eager)
462    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
463    and `colocate_gradients_with_ops` are ignored.
464    @end_compatibility
465    """
466    if callable(loss):
467      with backprop.GradientTape() as tape:
468        if var_list is not None:
469          tape.watch(var_list)
470        loss_value = loss()
471
472        # Scale loss if using a "mean" loss reduction and multiple replicas.
473        # Have to be careful to call distribute_lib.get_loss_reduction()
474        # *after* loss() is evaluated, so we know what loss reduction it uses.
475        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
476        loss_value = self._scale_loss(loss_value)
477
478      if var_list is None:
479        var_list = tape.watched_variables()
480      # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
481      # to be executed.
482      with ops.control_dependencies([loss_value]):
483        grads = tape.gradient(loss_value, var_list, grad_loss)
484      return list(zip(grads, var_list))
485
486    # Non-callable/Tensor loss case
487    if context.executing_eagerly():
488      raise RuntimeError(
489          "`loss` passed to Optimizer.compute_gradients should "
490          "be a function when eager execution is enabled.")
491
492    # Scale loss if using a "mean" loss reduction and multiple replicas.
493    loss = self._scale_loss(loss)
494
495    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
496                              Optimizer.GATE_GRAPH]:
497      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
498                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
499                       gate_gradients)
500    self._assert_valid_dtypes([loss])
501    if grad_loss is not None:
502      self._assert_valid_dtypes([grad_loss])
503    if var_list is None:
504      var_list = (
505          variables.trainable_variables() +
506          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
507    else:
508      var_list = nest.flatten(var_list)
509    # pylint: disable=protected-access
510    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
511    # pylint: enable=protected-access
512    processors = [_get_processor(v) for v in var_list]
513    if not var_list:
514      raise ValueError("No variables to optimize.")
515    var_refs = [p.target() for p in processors]
516    grads = gradients.gradients(
517        loss, var_refs, grad_ys=grad_loss,
518        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
519        aggregation_method=aggregation_method,
520        colocate_gradients_with_ops=colocate_gradients_with_ops)
521    if gate_gradients == Optimizer.GATE_GRAPH:
522      grads = control_flow_ops.tuple(grads)
523    grads_and_vars = list(zip(grads, var_list))
524    self._assert_valid_dtypes(
525        [v for g, v in grads_and_vars
526         if g is not None and v.dtype != dtypes.resource])
527    return grads_and_vars
528
529  @staticmethod
530  def _scale_loss(loss_value):
531    ops.get_default_graph()._is_loss_scaled_by_optimizer = False  # pylint: disable=protected-access
532    if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
533      num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
534      if num_replicas > 1:
535        loss_value *= (1. / num_replicas)
536        ops.get_default_graph()._is_loss_scaled_by_optimizer = True  # pylint: disable=protected-access
537    return loss_value
538
539  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
540    """Apply gradients to variables.
541
542    This is the second part of `minimize()`. It returns an `Operation` that
543    applies gradients.
544
545    Args:
546      grads_and_vars: List of (gradient, variable) pairs as returned by
547        `compute_gradients()`.
548      global_step: Optional `Variable` to increment by one after the
549        variables have been updated.
550      name: Optional name for the returned operation.  Default to the
551        name passed to the `Optimizer` constructor.
552
553    Returns:
554      An `Operation` that applies the specified gradients. If `global_step`
555      was not None, that operation also increments `global_step`.
556
557    Raises:
558      TypeError: If `grads_and_vars` is malformed.
559      ValueError: If none of the variables have gradients.
560      RuntimeError: If you should use `_distributed_apply()` instead.
561    """
562    # This is a default implementation of apply_gradients() that can be shared
563    # by most optimizers.  It relies on the subclass implementing the following
564    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
565
566    # TODO(isaprykin): Get rid of `has_strategy()` check by
567    # always calling _distributed_apply(), using the default distribution
568    # as needed.
569    if distribute_ctx.has_strategy():
570      # Handle DistributionStrategy case.
571      if distribute_ctx.in_cross_replica_context():
572        raise RuntimeError("Use `_distributed_apply()` instead of "
573                           "`apply_gradients()` in a cross-replica context.")
574
575      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
576      return distribute_ctx.get_replica_context().merge_call(
577          self._distributed_apply, args=(grads_and_vars, global_step, name))
578
579    # No DistributionStrategy case.
580    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
581    if not grads_and_vars:
582      raise ValueError("No variables provided.")
583    converted_grads_and_vars = []
584    for g, v in grads_and_vars:
585      if g is not None:
586        try:
587          # Convert the grad to Tensor or IndexedSlices if necessary.
588          g = ops.convert_to_tensor_or_indexed_slices(g)
589        except TypeError:
590          raise TypeError(
591              "Gradient must be convertible to a Tensor"
592              " or IndexedSlices, or None: %s" % g)
593        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
594          raise TypeError(
595              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
596      p = _get_processor(v)
597      converted_grads_and_vars.append((g, v, p))
598
599    converted_grads_and_vars = tuple(converted_grads_and_vars)
600    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
601    if not var_list:
602      raise ValueError("No gradients provided for any variable: %s." %
603                       ([str(v) for _, v, _ in converted_grads_and_vars],))
604    with ops.init_scope():
605      self._create_slots(var_list)
606    update_ops = []
607    with ops.name_scope(name, self._name, skip_on_eager=False) as name:
608      self._prepare()
609      for grad, var, processor in converted_grads_and_vars:
610        if grad is None:
611          continue
612        # We colocate all ops created in _apply_dense or _apply_sparse
613        # on the same device as the variable.
614        # TODO(apassos): figure out how to get the variable name here.
615        if (context.executing_eagerly() or
616            resource_variable_ops.is_resource_variable(var)
617            and not var._in_graph_mode):  # pylint: disable=protected-access
618          scope_name = ""
619        else:
620          scope_name = var.op.name
621        with ops.name_scope(
622            "update_" + scope_name,
623            skip_on_eager=False), ops.colocate_with(var):
624          update_ops.append(processor.update_op(self, grad))
625      if global_step is None:
626        apply_updates = self._finish(update_ops, name)
627      else:
628        with ops.control_dependencies([self._finish(update_ops, "update")]):
629          with ops.colocate_with(global_step):
630            if isinstance(
631                global_step, resource_variable_ops.BaseResourceVariable):
632              # TODO(apassos): the implicit read in assign_add is slow; consider
633              # making it less so.
634              apply_updates = resource_variable_ops.assign_add_variable_op(
635                  global_step.handle,
636                  ops.convert_to_tensor(1, dtype=global_step.dtype),
637                  name=name)
638            else:
639              apply_updates = state_ops.assign_add(global_step, 1, name=name)
640
641      if not context.executing_eagerly():
642        if isinstance(apply_updates, ops.Tensor):
643          apply_updates = apply_updates.op
644        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
645        if apply_updates not in train_op:
646          train_op.append(apply_updates)
647
648      return apply_updates
649
650  def _distributed_apply(self,
651                         distribution,
652                         grads_and_vars,
653                         global_step=None,
654                         name=None):
655    """A version of `apply_gradients` for cross-replica context.
656
657    This is a version of `apply_gradients()` for when you are using a
658    `DistributionStrategy` and are in a cross-replica context. If in a
659    replica context, use `apply_gradients()` as normal.
660
661    Args:
662      distribution: A `DistributionStrategy` object.
663      grads_and_vars: List of (gradient, variable) pairs as returned by
664        `compute_gradients()`, and then aggregated across replicas.
665      global_step: Optional (mirrored) `Variable` to increment by one
666        after the variables have been updated.
667      name: Optional name for the returned operation.  Default to the
668        name passed to the `Optimizer` constructor.
669
670    Returns:
671      An `Operation` that applies the specified gradients across all
672      replicas. If `global_step` was not None, that operation also
673      increments `global_step`
674    """
675    reduced_grads = distribution.extended.batch_reduce_to(
676        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
677    var_list = [v for _, v in grads_and_vars]
678    grads_and_vars = zip(reduced_grads, var_list)
679
680    # Note that this is called in a cross-replica context.
681    with ops.init_scope():
682      self._create_slots(var_list)
683
684    def update(v, g):
685      """Apply gradients to a replica variable."""
686      assert v is not None
687
688      try:
689        # Convert the grad to Tensor or IndexedSlices if necessary.
690        g = ops.convert_to_tensor_or_indexed_slices(g)
691      except TypeError:
692        raise TypeError("Gradient must be convertible to a Tensor"
693                        " or IndexedSlices, or None: %s" % g)
694      if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
695        raise TypeError(
696            "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
697      p = _get_processor(v)
698
699      if context.executing_eagerly() or (
700          resource_variable_ops.is_resource_variable(v) and
701          not v._in_graph_mode):  # pylint: disable=protected-access
702        scope_name = v.name.split(":")[0]
703      else:
704        scope_name = v.op.name
705
706      # device_policy is set because non-mirrored tensors will be read in
707      # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
708      # is an example.
709      with ops.name_scope("update_" + scope_name):
710        return p.update_op(self, g)
711
712    with ops.name_scope(name, self._name) as name:
713      self._prepare()
714
715      update_ops = [
716          op
717          for grad, var in grads_and_vars
718          for op in distribution.extended.update(
719              var, update, args=(grad,), group=False)
720      ]
721
722      def finish(self, update_ops):
723        return self._finish(update_ops, "update")
724
725      non_slot_devices = distribution.extended.non_slot_devices(var_list)
726      finish_updates = distribution.extended.update_non_slot(
727          non_slot_devices, finish, args=(self, update_ops), group=False)
728      if global_step is None:
729        apply_updates = distribution.group(finish_updates, name=name)
730      else:
731        with ops.control_dependencies(finish_updates):
732          apply_updates = distribution.extended.update(
733              global_step, state_ops.assign_add, args=(1,),
734              kwargs={"name": name})
735
736      if not context.executing_eagerly():
737        if isinstance(apply_updates, ops.Tensor):
738          apply_updates = apply_updates.op
739        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
740        if apply_updates not in train_op:
741          train_op.append(apply_updates)
742
743      return apply_updates
744
745  def get_slot(self, var, name):
746    """Return a slot named `name` created for `var` by the Optimizer.
747
748    Some `Optimizer` subclasses use additional variables.  For example
749    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
750    gives access to these `Variable` objects if for some reason you need them.
751
752    Use `get_slot_names()` to get the list of slot names created by the
753    `Optimizer`.
754
755    Args:
756      var: A variable passed to `minimize()` or `apply_gradients()`.
757      name: A string.
758
759    Returns:
760      The `Variable` for the slot if it was created, `None` otherwise.
761    """
762    named_slots = self._slots.get(name, None)
763    if not named_slots:
764      return None
765    slot = named_slots.get(_var_key(var), None)
766    if (distribute_utils.is_distributed_variable(slot) and
767        not distribute_utils.is_distributed_variable(var)):
768      # Make sure var and slot are either both DistributedVariable, or both
769      # per replica variables.
770      slot = slot._get_on_device_or_primary()  # pylint: disable=protected-access
771    return slot
772
773  def get_slot_names(self):
774    """Return a list of the names of slots created by the `Optimizer`.
775
776    See `get_slot()`.
777
778    Returns:
779      A list of strings.
780    """
781    return sorted(self._slots.keys())
782
783  def variables(self):
784    """A list of variables which encode the current state of `Optimizer`.
785
786    Includes slot variables and additional global variables created by the
787    optimizer in the current default graph.
788
789    Returns:
790      A list of variables.
791    """
792    current_graph = ops.get_default_graph()
793
794    def _from_current_graph(variable):
795      if variable._in_graph_mode:  # pylint: disable=protected-access
796        return variable.op.graph is current_graph
797      else:
798        # No variable.op in eager mode. We don't expect lots of eager graphs,
799        # but behavior should be consistent with graph mode.
800        return variable._graph_key == current_graph._graph_key  # pylint: disable=protected-access
801
802    optimizer_variables = [v for v in self._non_slot_variables()
803                           if _from_current_graph(v)]
804    for _, variable_dict in self._slots.items():
805      for _, slot_for_variable in variable_dict.items():
806        if _from_current_graph(slot_for_variable):
807          optimizer_variables.append(slot_for_variable)
808    # Sort variables by name so that the return is deterministic.
809    return sorted(optimizer_variables, key=lambda v: v.name)
810
811  def _create_non_slot_variable(self, initial_value, name, colocate_with):
812    """Add an extra variable, not associated with a slot."""
813    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
814    eager = context.executing_eagerly()
815    graph = None if eager else colocate_with.graph
816
817    key = (name, graph)
818    v = self._non_slot_dict.get(key, None)
819    if v is None:
820      self._maybe_initialize_trackable()
821      distribution_strategy = distribute_ctx.get_strategy()
822      with distribution_strategy.extended.colocate_vars_with(colocate_with):
823        if eager:
824          restored_initial_value = self._preload_simple_restoration(
825              name=name)
826          if restored_initial_value is not None:
827            initial_value = restored_initial_value
828        v = variable_scope.variable(
829            initial_value, name=name, trainable=False,
830            use_resource=resource_variable_ops.is_resource_variable(
831                colocate_with))
832      # Restore this variable by name if necessary, but don't add a
833      # Trackable dependency. Optimizers return the current graph's
834      # non-slot variables from _checkpoint_dependencies explicitly rather
835      # than unconditionally adding dependencies (since there may be multiple
836      # non-slot variables with the same name in different graphs, trying to
837      # save all of them would result in errors).
838      self._handle_deferred_dependencies(name=name, trackable=v)
839      self._non_slot_dict[key] = v
840
841    return v
842
843  @property
844  def _checkpoint_dependencies(self):
845    """From Trackable. Gather graph-specific non-slot variables to save."""
846    current_graph_non_slot_variables = []
847    current_graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
848    for (name, _), variable_object in sorted(self._non_slot_dict.items(),
849                                             # Avoid comparing graphs
850                                             key=lambda item: item[0][0]):
851      if variable_object._graph_key == current_graph_key:  # pylint: disable=protected-access
852        current_graph_non_slot_variables.append(
853            trackable.TrackableReference(
854                name=name, ref=variable_object))
855    return (super(Optimizer, self)._checkpoint_dependencies
856            + current_graph_non_slot_variables)
857
858  def _lookup_dependency(self, name):
859    """From Trackable. Find a non-slot variable in the current graph."""
860    unconditional = super(Optimizer, self)._lookup_dependency(name)
861    if unconditional is not None:
862      return unconditional
863    graph = None if context.executing_eagerly() else ops.get_default_graph()
864    return self._get_non_slot_variable(name, graph=graph)
865
866  def _get_non_slot_variable(self, name, graph=None):
867    non_slot = self._non_slot_dict.get((name, graph), None)
868    if hasattr(non_slot, "_distributed_container"):
869      # This is a mirrored non-slot.  In order to enable code like `_finish`
870      # to assign to a non-slot, return the current context replica.
871      return non_slot.get()
872    else:
873      return non_slot
874
875  def _non_slot_variables(self):
876    """Additional variables created by the `Optimizer`.
877
878    Returns:
879      A list or tuple of variables.
880    """
881    return self._non_slot_dict.values()
882
883  def _assert_valid_dtypes(self, tensors):
884    """Asserts tensors are all valid types (see `_valid_dtypes`).
885
886    Args:
887      tensors: Tensors to check.
888
889    Raises:
890      ValueError: If any tensor is not a valid type.
891    """
892    valid_dtypes = self._valid_dtypes()
893    for t in tensors:
894      dtype = t.dtype.base_dtype
895      if dtype not in valid_dtypes:
896        raise ValueError(
897            "Invalid type %r for %s, expected: %s." % (
898                dtype, t.name, [v for v in valid_dtypes]))
899
900  # --------------
901  # Methods to be implemented by subclasses if they want to use the
902  # inherited implementation of apply_gradients() or compute_gradients().
903  # --------------
904  def _valid_dtypes(self):
905    """Valid types for loss, variables and gradients.
906
907    Subclasses should override to allow other float types.
908
909    Returns:
910      Valid types for loss, variables and gradients.
911    """
912    return set(
913        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
914
915  def _create_slots(self, var_list):
916    """Create all slots needed by the variables.
917
918    Args:
919      var_list: A list of `Variable` objects.
920    """
921    # No slots needed by default
922    pass
923
924  def _prepare(self):
925    """Create all needed tensors before applying gradients.
926
927    This is called with the name_scope using the "name" that
928    users have chosen for the application of gradients.
929    """
930    pass
931
932  def _apply_dense(self, grad, var):
933    """Add ops to apply dense gradients to `var`.
934
935    Args:
936      grad: A `Tensor`.
937      var: A `Variable` object.
938
939    Returns:
940      An `Operation`.
941    """
942    raise NotImplementedError()
943
944  def _resource_apply_dense(self, grad, handle):
945    """Add ops to apply dense gradients to the variable `handle`.
946
947    Args:
948      grad: a `Tensor` representing the gradient.
949      handle: a `Tensor` of dtype `resource` which points to the variable
950       to be updated.
951
952    Returns:
953      An `Operation` which updates the value of the variable.
954    """
955    raise NotImplementedError()
956
957  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
958    """Add ops to apply sparse gradients to `handle`, with repeated indices.
959
960    Optimizers which override this method must deal with repeated indices. See
961    the docstring of `_apply_sparse_duplicate_indices` for details. By default
962    the correct behavior, to sum non-unique indices and their associated
963    gradients, is enforced by first pre-processing `grad` and `indices` and
964    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
965    with duplicate indices may instead override this method to avoid the
966    overhead of summing.
967
968    Args:
969      grad: a `Tensor` representing the gradient for the affected indices.
970      handle: a `Tensor` of dtype `resource` which points to the variable
971       to be updated.
972      indices: a `Tensor` of integral type representing the indices for
973       which the gradient is nonzero. Indices may be repeated.
974
975    Returns:
976      An `Operation` which updates the value of the variable.
977    """
978    summed_grad, unique_indices = _deduplicate_indexed_slices(
979        values=grad, indices=indices)
980    return self._resource_apply_sparse(summed_grad, handle, unique_indices)
981
982  def _resource_apply_sparse(self, grad, handle, indices):
983    """Add ops to apply sparse gradients to the variable `handle`.
984
985    Similar to `_apply_sparse`, the `indices` argument to this method has been
986    de-duplicated. Optimizers which deal correctly with non-unique indices may
987    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
988    overhead.
989
990    Args:
991      grad: a `Tensor` representing the gradient for the affected indices.
992      handle: a `Tensor` of dtype `resource` which points to the variable
993       to be updated.
994      indices: a `Tensor` of integral type representing the indices for
995       which the gradient is nonzero. Indices are unique.
996
997    Returns:
998      An `Operation` which updates the value of the variable.
999    """
1000    raise NotImplementedError()
1001
1002  def _apply_sparse_duplicate_indices(self, grad, var):
1003    """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
1004
1005    Optimizers which override this method must deal with IndexedSlices objects
1006    such as the following:
1007
1008      IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
1009
1010    The correct interpretation is:
1011
1012      IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
1013
1014    Many optimizers deal incorrectly with repeated indices when updating based
1015    on sparse gradients (e.g. summing squares rather than squaring the sum, or
1016    applying momentum terms multiple times). Adding first is always the correct
1017    behavior, so this is enforced here by reconstructing the IndexedSlices to
1018    have only unique indices, then calling _apply_sparse.
1019
1020    Optimizers which deal correctly with repeated indices may instead override
1021    this method to avoid the overhead of summing indices.
1022
1023    Args:
1024      grad: `IndexedSlices`.
1025      var: A `Variable` object.
1026
1027    Returns:
1028      An `Operation`.
1029    """
1030    summed_values, unique_indices = _deduplicate_indexed_slices(
1031        values=grad.values, indices=grad.indices)
1032    gradient_no_duplicate_indices = ops.IndexedSlices(
1033        indices=unique_indices,
1034        values=summed_values,
1035        dense_shape=grad.dense_shape)
1036    return self._apply_sparse(gradient_no_duplicate_indices, var)
1037
1038  def _apply_sparse(self, grad, var):
1039    """Add ops to apply sparse gradients to `var`.
1040
1041    The IndexedSlices object passed to `grad` in this function is by default
1042    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
1043    indices (see its docstring for details). Optimizers which can tolerate or
1044    have correct special cases for duplicate sparse indices may override
1045    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
1046    overhead.
1047
1048    Args:
1049      grad: `IndexedSlices`, with no repeated indices.
1050      var: A `Variable` object.
1051
1052    Returns:
1053      An `Operation`.
1054    """
1055    raise NotImplementedError()
1056
1057  def _finish(self, update_ops, name_scope):
1058    """Do what is needed to finish the update.
1059
1060    This is called with the `name_scope` using the "name" that
1061    users have chosen for the application of gradients.
1062
1063    Args:
1064      update_ops: List of `Operation` objects to update variables.  This list
1065        contains the values returned by the `_apply_dense()` and
1066        `_apply_sparse()` calls.
1067      name_scope: String.  Name to use for the returned operation.
1068
1069    Returns:
1070      The operation to apply updates.
1071    """
1072    return control_flow_ops.group(*update_ops, name=name_scope)
1073
1074  # --------------
1075  # Utility methods for subclasses.
1076  # --------------
1077
1078  def _slot_dict(self, slot_name):
1079    """Returns a dict for caching slots created under the given name.
1080
1081    Args:
1082      slot_name: Name for the slot.
1083
1084    Returns:
1085      A dict that maps primary `Variable` objects to the slot created
1086      for that variable, under the given slot name.
1087    """
1088    named_slots = self._slots.get(slot_name, None)
1089    if named_slots is None:
1090      named_slots = {}
1091      self._slots[slot_name] = named_slots
1092    return named_slots
1093
1094  def _get_or_make_slot(self, var, val, slot_name, op_name):
1095    """Find or create a slot for a variable.
1096
1097    Args:
1098      var: A `Variable` object.
1099      val: A `Tensor`.  The initial value of the slot.
1100      slot_name: Name for the slot.
1101      op_name: Name to use when scoping the Variable that
1102        needs to be created for the slot.
1103
1104    Returns:
1105      A `Variable` object.
1106    """
1107    named_slots = self._slot_dict(slot_name)
1108    if _var_key(var) not in named_slots:
1109      new_slot_variable = slot_creator.create_slot(var, val, op_name)
1110      self._restore_slot_variable(
1111          slot_name=slot_name, variable=var,
1112          slot_variable=new_slot_variable)
1113      named_slots[_var_key(var)] = new_slot_variable
1114    return named_slots[_var_key(var)]
1115
1116  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
1117                                         slot_name, op_name):
1118    """Find or create a slot for a variable, using an Initializer.
1119
1120    Args:
1121      var: A `Variable` object.
1122      initializer: An `Initializer`.  The initial value of the slot.
1123      shape: Shape of the initial value of the slot.
1124      dtype: Type of the value of the slot.
1125      slot_name: Name for the slot.
1126      op_name: Name to use when scoping the Variable that
1127        needs to be created for the slot.
1128
1129    Returns:
1130      A `Variable` object.
1131    """
1132    named_slots = self._slot_dict(slot_name)
1133    if _var_key(var) not in named_slots:
1134      new_slot_variable = slot_creator.create_slot_with_initializer(
1135          var, initializer, shape, dtype, op_name)
1136      self._restore_slot_variable(
1137          slot_name=slot_name, variable=var,
1138          slot_variable=new_slot_variable)
1139      named_slots[_var_key(var)] = new_slot_variable
1140    return named_slots[_var_key(var)]
1141
1142  def _zeros_slot(self, var, slot_name, op_name):
1143    """Find or create a slot initialized with 0.0.
1144
1145    Args:
1146      var: A `Variable` object.
1147      slot_name: Name for the slot.
1148      op_name: Name to use when scoping the Variable that
1149        needs to be created for the slot.
1150
1151    Returns:
1152      A `Variable` object.
1153    """
1154    named_slots = self._slot_dict(slot_name)
1155    if _var_key(var) not in named_slots:
1156      new_slot_variable = slot_creator.create_zeros_slot(
1157          var, op_name, copy_xla_sharding=True)
1158      self._restore_slot_variable(
1159          slot_name=slot_name, variable=var,
1160          slot_variable=new_slot_variable)
1161      named_slots[_var_key(var)] = new_slot_variable
1162    return named_slots[_var_key(var)]
1163
1164  # --------------
1165  # For implementing the Trackable interface.
1166  # --------------
1167
1168  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1169    """Restore a newly created slot variable's value."""
1170    variable_key = _var_key(variable)
1171    deferred_restorations = self._deferred_slot_restorations.get(
1172        slot_name, {}).pop(variable_key, [])
1173    # Iterate over restores, highest restore UID first to minimize the number
1174    # of assignments.
1175    deferred_restorations.sort(key=lambda position: position.restore_uid,
1176                               reverse=True)
1177    for checkpoint_position in deferred_restorations:
1178      checkpoint_position.restore(slot_variable)
1179
1180  def _create_or_restore_slot_variable(
1181      self, slot_variable_position, slot_name, variable):
1182    """Restore a slot variable's value, possibly creating it.
1183
1184    Called when a variable which has an associated slot variable is created or
1185    restored. When executing eagerly, we create the slot variable with a
1186    restoring initializer.
1187
1188    No new variables are created when graph building. Instead,
1189    _restore_slot_variable catches these after normal creation and adds restore
1190    ops to the graph. This method is nonetheless important when graph building
1191    for the case when a slot variable has already been created but `variable`
1192    has just been added to a dependency graph (causing us to realize that the
1193    slot variable needs to be restored).
1194
1195    Args:
1196      slot_variable_position: A `trackable._CheckpointPosition` object
1197        indicating the slot variable `Trackable` object to be restored.
1198      slot_name: The name of this `Optimizer`'s slot to restore into.
1199      variable: The variable object this slot is being created for.
1200    """
1201    named_slots = self._slot_dict(slot_name)
1202    variable_key = _var_key(variable)
1203    slot_variable = named_slots.get(variable_key, None)
1204    if (slot_variable is None and context.executing_eagerly() and
1205        slot_variable_position.is_simple_variable()
1206        # Defer slot variable creation if there is an active variable creator
1207        # scope. Generally we'd like to eagerly create/restore slot variables
1208        # when possible, but this may mean that scopes intended to catch
1209        # `variable` also catch its eagerly created slot variable
1210        # unintentionally (specifically make_template would add a dependency on
1211        # a slot variable if not for this case). Deferring is mostly harmless
1212        # (aside from double initialization), and makes variable creator scopes
1213        # behave the same way they do when graph building.
1214        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
1215      initializer = trackable.CheckpointInitialValueCallable(
1216          checkpoint_position=slot_variable_position)
1217      # CheckpointInitialValueCallable will ignore the shape and dtype
1218      # parameters but they must be passed.
1219      slot_variable = self._get_or_make_slot_with_initializer(
1220          var=variable,
1221          initializer=initializer,
1222          shape=variable.shape,
1223          dtype=variable.dtype,
1224          slot_name=slot_name,
1225          op_name=self._name)
1226      # Slot variables are not owned by any one object (because we don't want to
1227      # save the slot variable if the optimizer is saved without the non-slot
1228      # variable, or if the non-slot variable is saved without the optimizer;
1229      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1230      # variable, variable)). So we don't _track_ slot variables anywhere, and
1231      # instead special-case this dependency and otherwise pretend it's a normal
1232      # graph.
1233    if slot_variable is not None:
1234      # If we've either made this slot variable, or if we've pulled out an
1235      # existing slot variable, we should restore it.
1236      slot_variable_position.restore(slot_variable)
1237    else:
1238      # We didn't make the slot variable. Defer restoring until it gets created
1239      # normally. We keep a list rather than the one with the highest restore
1240      # UID in case slot variables have their own dependencies, in which case
1241      # those could differ between restores.
1242      self._deferred_slot_restorations.setdefault(
1243          slot_name, {}).setdefault(variable_key, []).append(
1244              slot_variable_position)
1245
1246  def _call_if_callable(self, param):
1247    """Call the function if param is callable."""
1248    return param() if callable(param) else param
1249