• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Base class for optimizers."""
17# pylint: disable=g-bad-name
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gradients
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.ops import variables
37from tensorflow.python.training import checkpointable
38from tensorflow.python.training import slot_creator
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import tf_export
41
42
43def _get_variable_for(v):
44  """Returns the ResourceVariable responsible for v, or v if not necessary."""
45  if context.in_eager_mode():
46    return v
47  if v.op.type == "VarHandleOp":
48    for var in variables.trainable_variables():
49      if (isinstance(var, resource_variable_ops.ResourceVariable)
50          and var.handle.op is v.op):
51        return var
52    raise ValueError("Got %s but could not locate source variable." % (str(v)))
53  return v
54
55
56def _deduplicate_indexed_slices(values, indices):
57  """Sums `values` associated with any non-unique `indices`.
58
59  Args:
60    values: A `Tensor` with rank >= 1.
61    indices: A one-dimensional integer `Tensor`, indexing into the first
62      dimension of `values` (as in an IndexedSlices object).
63  Returns:
64    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
65    de-duplicated version of `indices` and `summed_values` contains the sum of
66    `values` slices associated with each unique index.
67  """
68  unique_indices, new_index_positions = array_ops.unique(indices)
69  summed_values = math_ops.unsorted_segment_sum(
70      values, new_index_positions,
71      array_ops.shape(unique_indices)[0])
72  return (summed_values, unique_indices)
73
74
75def _var_key(var):
76  if context.in_eager_mode():
77    return var._shared_name  # pylint: disable=protected-access
78  return (var.op.graph, var.op.name)
79
80
81class _OptimizableVariable(object):
82  """Interface for abstracting over variables in the optimizers."""
83
84  @abc.abstractmethod
85  def target(self):
86    """Returns the optimization target for this variable."""
87    raise NotImplementedError("Calling an abstract method.")
88
89  @abc.abstractmethod
90  def update_op(self, optimizer, g):
91    """Returns the update ops for updating the variable."""
92    raise NotImplementedError("Calling an abstract method.")
93
94
95class _RefVariableProcessor(_OptimizableVariable):
96  """Processor for Variable."""
97
98  def __init__(self, v):
99    self._v = v
100
101  def target(self):
102    return self._v._ref()  # pylint: disable=protected-access
103
104  def update_op(self, optimizer, g):
105    if isinstance(g, ops.Tensor):
106      update_op = optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
107      if self._v.constraint is not None:
108        with ops.control_dependencies([update_op]):
109          return self._v.assign(self._v.constraint(self._v))
110      else:
111        return update_op
112    else:
113      assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
114                                                "tensor nor IndexedSlices.")
115      if self._v.constraint is not None:
116        raise RuntimeError(
117            "Cannot use a constraint function on a sparse variable.")
118      # pylint: disable=protected-access
119      return optimizer._apply_sparse_duplicate_indices(g, self._v)
120
121
122class _DenseReadResourceVariableProcessor(_OptimizableVariable):
123  """Processor for dense ResourceVariables."""
124
125  def __init__(self, v):
126    self._v = v
127
128  def target(self):
129    return self._v
130
131  def update_op(self, optimizer, g):
132    # pylint: disable=protected-access
133    update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
134    if self._v.constraint is not None:
135      with ops.control_dependencies([update_op]):
136        return self._v.assign(self._v.constraint(self._v))
137    else:
138      return update_op
139
140
141class _DenseResourceVariableProcessor(_OptimizableVariable):
142  """Processor for dense ResourceVariables."""
143
144  def __init__(self, v):
145    self._v = v
146
147  def target(self):
148    return self._v
149
150  def update_op(self, optimizer, g):
151    # pylint: disable=protected-access
152    if isinstance(g, ops.IndexedSlices):
153      if self._v.constraint is not None:
154        raise RuntimeError(
155            "Cannot use a constraint function on a sparse variable.")
156      return optimizer._resource_apply_sparse_duplicate_indices(
157          g.values, self._v, g.indices)
158    update_op = optimizer._resource_apply_dense(g, self._v)
159    if self._v.constraint is not None:
160      with ops.control_dependencies([update_op]):
161        return self._v.assign(self._v.constraint(self._v))
162    else:
163      return update_op
164
165
166class _StreamingModelPortProcessor(_OptimizableVariable):
167  """Processor for streaming ModelPorts."""
168
169  def __init__(self, v):
170    self._v = v
171
172  def target(self):
173    return self._v
174
175  def update_op(self, optimizer, g):
176    return g
177
178
179class _TensorProcessor(_OptimizableVariable):
180  """Processor for ordinary Tensors.
181
182  Even though a Tensor can't really be updated, sometimes it is useful to
183  compute the gradients with respect to a Tensor using the optimizer. Updating
184  the Tensor is, of course, unsupported.
185  """
186
187  def __init__(self, v):
188    self._v = v
189
190  def target(self):
191    return self._v
192
193  def update_op(self, optimizer, g):
194    raise NotImplementedError("Trying to update a Tensor ", self._v)
195
196
197def _get_processor(v):
198  """The processor of v."""
199  if context.in_eager_mode():
200    if isinstance(v, ops.Tensor):
201      return _TensorProcessor(v)
202    else:
203      return _DenseResourceVariableProcessor(v)
204  if v.op.type == "VarHandleOp":
205    return _DenseResourceVariableProcessor(v)
206  if isinstance(v, variables.Variable):
207    return _RefVariableProcessor(v)
208  if v.op.type == "SubmodelPort":
209    return _StreamingModelPortProcessor(v)
210  if isinstance(v, ops.Tensor):
211    return _TensorProcessor(v)
212  raise NotImplementedError("Trying to optimize unsupported type ", v)
213
214
215@tf_export("train.Optimizer")
216class Optimizer(checkpointable.Checkpointable):
217  """Base class for optimizers.
218
219  This class defines the API to add Ops to train a model.  You never use this
220  class directly, but instead instantiate one of its subclasses such as
221  `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
222
223  ### Usage
224
225  ```python
226  # Create an optimizer with the desired parameters.
227  opt = GradientDescentOptimizer(learning_rate=0.1)
228  # Add Ops to the graph to minimize a cost by updating a list of variables.
229  # "cost" is a Tensor, and the list of variables contains tf.Variable
230  # objects.
231  opt_op = opt.minimize(cost, var_list=<list of variables>)
232  ```
233
234  In the training program you will just have to run the returned Op.
235
236  ```python
237  # Execute opt_op to do one step of training:
238  opt_op.run()
239  ```
240
241  ### Processing gradients before applying them.
242
243  Calling `minimize()` takes care of both computing the gradients and
244  applying them to the variables.  If you want to process the gradients
245  before applying them you can instead use the optimizer in three steps:
246
247  1.  Compute the gradients with `compute_gradients()`.
248  2.  Process the gradients as you wish.
249  3.  Apply the processed gradients with `apply_gradients()`.
250
251  Example:
252
253  ```python
254  # Create an optimizer.
255  opt = GradientDescentOptimizer(learning_rate=0.1)
256
257  # Compute the gradients for a list of variables.
258  grads_and_vars = opt.compute_gradients(loss, <list of variables>)
259
260  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
261  # need to the 'gradient' part, for example cap them, etc.
262  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
263
264  # Ask the optimizer to apply the capped gradients.
265  opt.apply_gradients(capped_grads_and_vars)
266  ```
267
268  ### Gating Gradients
269
270  Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
271  argument that controls the degree of parallelism during the application of
272  the gradients.
273
274  The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
275
276  <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
277  the maximum parallelism in execution, at the cost of some non-reproducibility
278  in the results.  For example the two gradients of `matmul` depend on the input
279  values: With `GATE_NONE` one of the gradients could be applied to one of the
280  inputs _before_ the other gradient is computed resulting in non-reproducible
281  results.
282
283  <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
284  they are used.  This prevents race conditions for Ops that generate gradients
285  for multiple inputs where the gradients depend on the inputs.
286
287  <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
288  before any one of them is used.  This provides the least parallelism but can
289  be useful if you want to process all gradients before applying any of them.
290
291  ### Slots
292
293  Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
294  allocate and manage additional variables associated with the variables to
295  train.  These are called <i>Slots</i>.  Slots have names and you can ask the
296  optimizer for the names of the slots that it uses.  Once you have a slot name
297  you can ask the optimizer for the variable it created to hold the slot value.
298
299  This can be useful if you want to log debug a training algorithm, report stats
300  about the slots, etc.
301  """
302
303  # Values for gate_gradients.
304  GATE_NONE = 0
305  GATE_OP = 1
306  GATE_GRAPH = 2
307
308  def __init__(self, use_locking, name):
309    """Create a new Optimizer.
310
311    This must be called by the constructors of subclasses.
312
313    Args:
314      use_locking: Bool. If True apply use locks to prevent concurrent updates
315        to variables.
316      name: A non-empty string.  The name to use for accumulators created
317        for the optimizer.
318
319    Raises:
320      ValueError: If name is malformed.
321    """
322    if not name:
323      raise ValueError("Must specify the optimizer name")
324    self._use_locking = use_locking
325    self._name = name
326    # Dictionary of slots.
327    #  {slot_name :
328    #      {_var_key(variable_to_train): slot_for_the_variable, ... },
329    #   ... }
330    self._slots = {}
331    self._non_slot_dict = {}
332    # For implementing Checkpointable. Stores information about how to restore
333    # slot variables which have not yet been created
334    # (checkpointable._CheckpointPosition objects).
335    #  {slot_name :
336    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
337    #   ... }
338    self._deferred_slot_restorations = {}
339
340  def get_name(self):
341    return self._name
342
343  def minimize(self, loss, global_step=None, var_list=None,
344               gate_gradients=GATE_OP, aggregation_method=None,
345               colocate_gradients_with_ops=False, name=None,
346               grad_loss=None):
347    """Add operations to minimize `loss` by updating `var_list`.
348
349    This method simply combines calls `compute_gradients()` and
350    `apply_gradients()`. If you want to process the gradient before applying
351    them call `compute_gradients()` and `apply_gradients()` explicitly instead
352    of using this function.
353
354    Args:
355      loss: A `Tensor` containing the value to minimize.
356      global_step: Optional `Variable` to increment by one after the
357        variables have been updated.
358      var_list: Optional list or tuple of `Variable` objects to update to
359        minimize `loss`.  Defaults to the list of variables collected in
360        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
361      gate_gradients: How to gate the computation of gradients.  Can be
362        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
363      aggregation_method: Specifies the method used to combine gradient terms.
364        Valid values are defined in the class `AggregationMethod`.
365      colocate_gradients_with_ops: If True, try colocating gradients with
366        the corresponding op.
367      name: Optional name for the returned operation.
368      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
369
370    Returns:
371      An Operation that updates the variables in `var_list`.  If `global_step`
372      was not `None`, that operation also increments `global_step`.
373
374    Raises:
375      ValueError: If some of the variables are not `Variable` objects.
376
377    @compatibility(eager)
378    When eager execution is enabled, `loss` should be a Python function that
379    takes elements of `var_list` as arguments and computes the value to be
380    minimized. If `var_list` is None, `loss` should take no arguments.
381    Minimization (and gradient computation) is done with respect to the
382    elements of `var_list` if not None, else with respect to any trainable
383    variables created during the execution of the `loss` function.
384    `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
385    `grad_loss` are ignored when eager execution is enabled.
386    @end_compatibility
387    """
388    grads_and_vars = self.compute_gradients(
389        loss, var_list=var_list, gate_gradients=gate_gradients,
390        aggregation_method=aggregation_method,
391        colocate_gradients_with_ops=colocate_gradients_with_ops,
392        grad_loss=grad_loss)
393
394    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
395    if not vars_with_grad:
396      raise ValueError(
397          "No gradients provided for any variable, check your graph for ops"
398          " that do not support gradients, between variables %s and loss %s." %
399          ([str(v) for _, v in grads_and_vars], loss))
400
401    return self.apply_gradients(grads_and_vars, global_step=global_step,
402                                name=name)
403
404  def compute_gradients(self, loss, var_list=None,
405                        gate_gradients=GATE_OP,
406                        aggregation_method=None,
407                        colocate_gradients_with_ops=False,
408                        grad_loss=None):
409    """Compute gradients of `loss` for the variables in `var_list`.
410
411    This is the first part of `minimize()`.  It returns a list
412    of (gradient, variable) pairs where "gradient" is the gradient
413    for "variable".  Note that "gradient" can be a `Tensor`, an
414    `IndexedSlices`, or `None` if there is no gradient for the
415    given variable.
416
417    Args:
418      loss: A Tensor containing the value to minimize or a callable taking
419        no arguments which returns the value to minimize. When eager execution
420        is enabled it must be a callable.
421      var_list: Optional list or tuple of `tf.Variable` to update to minimize
422        `loss`.  Defaults to the list of variables collected in the graph
423        under the key `GraphKeys.TRAINABLE_VARIABLES`.
424      gate_gradients: How to gate the computation of gradients.  Can be
425        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
426      aggregation_method: Specifies the method used to combine gradient terms.
427        Valid values are defined in the class `AggregationMethod`.
428      colocate_gradients_with_ops: If True, try colocating gradients with
429        the corresponding op.
430      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
431
432    Returns:
433      A list of (gradient, variable) pairs. Variable is always present, but
434      gradient can be `None`.
435
436    Raises:
437      TypeError: If `var_list` contains anything else than `Variable` objects.
438      ValueError: If some arguments are invalid.
439      RuntimeError: If called with eager execution enabled and `loss` is
440        not callable.
441
442    @compatibility(eager)
443    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
444    and `colocate_gradients_with_ops` are ignored.
445    @end_compatibility
446    """
447    if callable(loss):
448      with backprop.GradientTape() as tape:
449        if var_list is not None:
450          tape.watch(var_list)
451        loss_value = loss()
452      if var_list is None:
453        var_list = tape.watched_variables()
454      grads = tape.gradient(loss_value, var_list, grad_loss)
455      return list(zip(grads, var_list))
456    if context.in_eager_mode():
457      raise RuntimeError(
458          "`loss` passed to Optimizer.compute_gradients should "
459          "be a function when eager execution is enabled.")
460    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
461                              Optimizer.GATE_GRAPH]:
462      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
463                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
464                       gate_gradients)
465    self._assert_valid_dtypes([loss])
466    if grad_loss is not None:
467      self._assert_valid_dtypes([grad_loss])
468    if var_list is None:
469      var_list = (
470          variables.trainable_variables() +
471          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
472    else:
473      var_list = nest.flatten(var_list)
474    # pylint: disable=protected-access
475    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
476    # pylint: enable=protected-access
477    processors = [_get_processor(v) for v in var_list]
478    if not var_list:
479      raise ValueError("No variables to optimize.")
480    var_refs = [p.target() for p in processors]
481    grads = gradients.gradients(
482        loss, var_refs, grad_ys=grad_loss,
483        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
484        aggregation_method=aggregation_method,
485        colocate_gradients_with_ops=colocate_gradients_with_ops)
486    if gate_gradients == Optimizer.GATE_GRAPH:
487      grads = control_flow_ops.tuple(grads)
488    grads_and_vars = list(zip(grads, var_list))
489    self._assert_valid_dtypes(
490        [v for g, v in grads_and_vars
491         if g is not None and v.dtype != dtypes.resource])
492    return grads_and_vars
493
494  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
495    """Apply gradients to variables.
496
497    This is the second part of `minimize()`. It returns an `Operation` that
498    applies gradients.
499
500    Args:
501      grads_and_vars: List of (gradient, variable) pairs as returned by
502        `compute_gradients()`.
503      global_step: Optional `Variable` to increment by one after the
504        variables have been updated.
505      name: Optional name for the returned operation.  Default to the
506        name passed to the `Optimizer` constructor.
507
508    Returns:
509      An `Operation` that applies the specified gradients. If `global_step`
510      was not None, that operation also increments `global_step`.
511
512    Raises:
513      TypeError: If `grads_and_vars` is malformed.
514      ValueError: If none of the variables have gradients.
515    """
516    # This is a default implementation of apply_gradients() that can be shared
517    # by most optimizers.  It relies on the subclass implementing the following
518    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
519
520    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
521    if not grads_and_vars:
522      raise ValueError("No variables provided.")
523    converted_grads_and_vars = []
524    for g, v in grads_and_vars:
525      if g is not None:
526        try:
527          # Convert the grad to Tensor or IndexedSlices if necessary.
528          g = ops.convert_to_tensor_or_indexed_slices(g)
529        except TypeError:
530          raise TypeError(
531              "Gradient must be convertible to a Tensor"
532              " or IndexedSlices, or None: %s" % g)
533        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
534          raise TypeError(
535              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
536      p = _get_processor(v)
537      converted_grads_and_vars.append((g, v, p))
538
539    converted_grads_and_vars = tuple(converted_grads_and_vars)
540    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
541    if not var_list:
542      raise ValueError("No gradients provided for any variable: %s." %
543                       ([str(v) for _, _, v in converted_grads_and_vars],))
544    with ops.init_scope():
545      self._create_slots([_get_variable_for(v) for v in var_list])
546    update_ops = []
547    with ops.name_scope(name, self._name) as name:
548      self._prepare()
549      for grad, var, processor in converted_grads_and_vars:
550        if grad is None:
551          continue
552        # We colocate all ops created in _apply_dense or _apply_sparse
553        # on the same device as the variable.
554        # TODO(apassos): figure out how to get the variable name here.
555        scope_name = var.op.name if context.in_graph_mode() else ""
556        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
557          update_ops.append(processor.update_op(self, grad))
558      if global_step is None:
559        apply_updates = self._finish(update_ops, name)
560      else:
561        with ops.control_dependencies([self._finish(update_ops, "update")]):
562          with ops.colocate_with(global_step):
563            if isinstance(global_step, resource_variable_ops.ResourceVariable):
564              # TODO(apassos): the implicit read in assign_add is slow; consider
565              # making it less so.
566              apply_updates = resource_variable_ops.assign_add_variable_op(
567                  global_step.handle,
568                  ops.convert_to_tensor(1, dtype=global_step.dtype),
569                  name=name)
570            else:
571              apply_updates = state_ops.assign_add(global_step, 1, name=name)
572
573      if context.in_graph_mode():
574        if isinstance(apply_updates, ops.Tensor):
575          apply_updates = apply_updates.op
576        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
577        if apply_updates not in train_op:
578          train_op.append(apply_updates)
579
580      return apply_updates
581
582  def get_slot(self, var, name):
583    """Return a slot named `name` created for `var` by the Optimizer.
584
585    Some `Optimizer` subclasses use additional variables.  For example
586    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
587    gives access to these `Variable` objects if for some reason you need them.
588
589    Use `get_slot_names()` to get the list of slot names created by the
590    `Optimizer`.
591
592    Args:
593      var: A variable passed to `minimize()` or `apply_gradients()`.
594      name: A string.
595
596    Returns:
597      The `Variable` for the slot if it was created, `None` otherwise.
598    """
599    named_slots = self._slots.get(name, None)
600    if not named_slots:
601      return None
602    return named_slots.get(_var_key(var), None)
603
604  def get_slot_names(self):
605    """Return a list of the names of slots created by the `Optimizer`.
606
607    See `get_slot()`.
608
609    Returns:
610      A list of strings.
611    """
612    return sorted(self._slots.keys())
613
614  def variables(self):
615    """A list of variables which encode the current state of `Optimizer`.
616
617    Includes slot variables and additional global variables created by the
618    optimizer in the current default graph.
619
620    Returns:
621      A list of variables.
622    """
623    executing_eagerly = context.in_eager_mode()
624    current_graph = ops.get_default_graph()
625
626    def _from_current_graph(variable):
627      if executing_eagerly:
628        # No variable.op in eager mode. We don't expect lots of eager graphs,
629        # but behavior should be consistent with graph mode.
630        return variable._graph_key == current_graph._graph_key  # pylint: disable=protected-access
631      else:
632        return variable.op.graph is current_graph
633
634    optimizer_variables = [v for v in self._non_slot_variables()
635                           if _from_current_graph(v)]
636    for _, variable_dict in self._slots.items():
637      for _, slot_for_variable in variable_dict.items():
638        if _from_current_graph(slot_for_variable):
639          optimizer_variables.append(slot_for_variable)
640    # Sort variables by name so that the return is deterministic.
641    return sorted(optimizer_variables, key=lambda v: v.name)
642
643  def _create_non_slot_variable(self, initial_value, name, colocate_with):
644    """Add an extra variable, not associated with a slot."""
645    if context.in_graph_mode():
646      graph = colocate_with.graph
647    else:
648      graph = None
649
650    key = (name, graph)
651    v = self._non_slot_dict.get(key, None)
652    if v is None:
653      with ops.colocate_with(colocate_with):
654        v = variable_scope.variable(initial_value, name=name, trainable=False)
655      self._non_slot_dict[key] = v
656
657    return v
658
659  def _get_non_slot_variable(self, name, graph=None):
660    return self._non_slot_dict.get((name, graph), None)
661
662  def _non_slot_variables(self):
663    """Additional variables created by the `Optimizer`.
664
665    Returns:
666      A list or tuple of variables.
667    """
668    return self._non_slot_dict.values()
669
670  def _assert_valid_dtypes(self, tensors):
671    """Asserts tensors are all valid types (see `_valid_dtypes`).
672
673    Args:
674      tensors: Tensors to check.
675
676    Raises:
677      ValueError: If any tensor is not a valid type.
678    """
679    valid_dtypes = self._valid_dtypes()
680    for t in tensors:
681      dtype = t.dtype.base_dtype
682      if dtype not in valid_dtypes:
683        raise ValueError(
684            "Invalid type %r for %s, expected: %s." % (
685                dtype, t.name, [v for v in valid_dtypes]))
686
687  # --------------
688  # Methods to be implemented by subclasses if they want to use the
689  # inherited implementation of apply_gradients() or compute_gradients().
690  # --------------
691  def _valid_dtypes(self):
692    """Valid types for loss, variables and gradients.
693
694    Subclasses should override to allow other float types.
695
696    Returns:
697      Valid types for loss, variables and gradients.
698    """
699    return set(
700        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
701
702  def _create_slots(self, var_list):
703    """Create all slots needed by the variables.
704
705    Args:
706      var_list: A list of `Variable` objects.
707    """
708    # No slots needed by default
709    pass
710
711  def _prepare(self):
712    """Create all needed tensors before applying gradients.
713
714    This is called with the name_scope using the "name" that
715    users have chosen for the application of gradients.
716    """
717    pass
718
719  def _apply_dense(self, grad, var):
720    """Add ops to apply dense gradients to `var`.
721
722    Args:
723      grad: A `Tensor`.
724      var: A `Variable` object.
725
726    Returns:
727      An `Operation`.
728    """
729    raise NotImplementedError()
730
731  def _resource_apply_dense(self, grad, handle):
732    """Add ops to apply dense gradients to the variable `handle`.
733
734    Args:
735      grad: a `Tensor` representing the gradient.
736      handle: a `Tensor` of dtype `resource` which points to the variable
737       to be updated.
738
739    Returns:
740      An `Operation` which updates the value of the variable.
741    """
742    raise NotImplementedError()
743
744  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
745    """Add ops to apply sparse gradients to `handle`, with repeated indices.
746
747    Optimizers which override this method must deal with repeated indices. See
748    the docstring of `_apply_sparse_duplicate_indices` for details. By default
749    the correct behavior, to sum non-unique indices and their associated
750    gradients, is enforced by first pre-processing `grad` and `indices` and
751    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
752    with duplicate indices may instead override this method to avoid the
753    overhead of summing.
754
755    Args:
756      grad: a `Tensor` representing the gradient for the affected indices.
757      handle: a `Tensor` of dtype `resource` which points to the variable
758       to be updated.
759      indices: a `Tensor` of integral type representing the indices for
760       which the gradient is nonzero. Indices may be repeated.
761
762    Returns:
763      An `Operation` which updates the value of the variable.
764    """
765    summed_grad, unique_indices = _deduplicate_indexed_slices(
766        values=grad, indices=indices)
767    return self._resource_apply_sparse(summed_grad, handle, unique_indices)
768
769  def _resource_apply_sparse(self, grad, handle, indices):
770    """Add ops to apply sparse gradients to the variable `handle`.
771
772    Similar to `_apply_sparse`, the `indices` argument to this method has been
773    de-duplicated. Optimizers which deal correctly with non-unique indices may
774    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
775    overhead.
776
777    Args:
778      grad: a `Tensor` representing the gradient for the affected indices.
779      handle: a `Tensor` of dtype `resource` which points to the variable
780       to be updated.
781      indices: a `Tensor` of integral type representing the indices for
782       which the gradient is nonzero. Indices are unique.
783
784    Returns:
785      An `Operation` which updates the value of the variable.
786    """
787    raise NotImplementedError()
788
789  def _apply_sparse_duplicate_indices(self, grad, var):
790    """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
791
792    Optimizers which override this method must deal with IndexedSlices objects
793    such as the following:
794
795      IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
796
797    The correct interpretation is:
798
799      IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
800
801    Many optimizers deal incorrectly with repeated indices when updating based
802    on sparse gradients (e.g. summing squares rather than squaring the sum, or
803    applying momentum terms multiple times). Adding first is always the correct
804    behavior, so this is enforced here by reconstructing the IndexedSlices to
805    have only unique indices, then calling _apply_sparse.
806
807    Optimizers which deal correctly with repeated indices may instead override
808    this method to avoid the overhead of summing indices.
809
810    Args:
811      grad: `IndexedSlices`.
812      var: A `Variable` object.
813
814    Returns:
815      An `Operation`.
816    """
817    summed_values, unique_indices = _deduplicate_indexed_slices(
818        values=grad.values, indices=grad.indices)
819    gradient_no_duplicate_indices = ops.IndexedSlices(
820        indices=unique_indices,
821        values=summed_values,
822        dense_shape=grad.dense_shape)
823    return self._apply_sparse(gradient_no_duplicate_indices, var)
824
825  def _apply_sparse(self, grad, var):
826    """Add ops to apply sparse gradients to `var`.
827
828    The IndexedSlices object passed to `grad` in this function is by default
829    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
830    indices (see its docstring for details). Optimizers which can tolerate or
831    have correct special cases for duplicate sparse indices may override
832    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
833    overhead.
834
835    Args:
836      grad: `IndexedSlices`, with no repeated indices.
837      var: A `Variable` object.
838
839    Returns:
840      An `Operation`.
841    """
842    raise NotImplementedError()
843
844  def _finish(self, update_ops, name_scope):
845    """Do what is needed to finish the update.
846
847    This is called with the `name_scope` using the "name" that
848    users have chosen for the application of gradients.
849
850    Args:
851      update_ops: List of `Operation` objects to update variables.  This list
852        contains the values returned by the `_apply_dense()` and
853        `_apply_sparse()` calls.
854      name_scope: String.  Name to use for the returned operation.
855
856    Returns:
857      The operation to apply updates.
858    """
859    return control_flow_ops.group(*update_ops, name=name_scope)
860
861  # --------------
862  # Utility methods for subclasses.
863  # --------------
864
865  def _slot_dict(self, slot_name):
866    """Returns a dict for caching slots created under the given name.
867
868    Args:
869      slot_name: Name for the slot.
870
871    Returns:
872      A dict that maps primary `Variable` objects to the slot created
873      for that variable, under the given slot name.
874    """
875    named_slots = self._slots.get(slot_name, None)
876    if named_slots is None:
877      named_slots = {}
878      self._slots[slot_name] = named_slots
879    return named_slots
880
881  def _get_or_make_slot(self, var, val, slot_name, op_name):
882    """Find or create a slot for a variable.
883
884    Args:
885      var: A `Variable` object.
886      val: A `Tensor`.  The initial value of the slot.
887      slot_name: Name for the slot.
888      op_name: Name to use when scoping the Variable that
889        needs to be created for the slot.
890
891    Returns:
892      A `Variable` object.
893    """
894    named_slots = self._slot_dict(slot_name)
895    if _var_key(var) not in named_slots:
896      new_slot_variable = slot_creator.create_slot(var, val, op_name)
897      self._restore_slot_variable(
898          slot_name=slot_name, variable=var,
899          slot_variable=new_slot_variable)
900      named_slots[_var_key(var)] = new_slot_variable
901    return named_slots[_var_key(var)]
902
903  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
904                                         slot_name, op_name):
905    """Find or create a slot for a variable, using an Initializer.
906
907    Args:
908      var: A `Variable` object.
909      initializer: An `Initializer`.  The initial value of the slot.
910      shape: Shape of the initial value of the slot.
911      dtype: Type of the value of the slot.
912      slot_name: Name for the slot.
913      op_name: Name to use when scoping the Variable that
914        needs to be created for the slot.
915
916    Returns:
917      A `Variable` object.
918    """
919    named_slots = self._slot_dict(slot_name)
920    if _var_key(var) not in named_slots:
921      new_slot_variable = slot_creator.create_slot_with_initializer(
922          var, initializer, shape, dtype, op_name)
923      self._restore_slot_variable(
924          slot_name=slot_name, variable=var,
925          slot_variable=new_slot_variable)
926      named_slots[_var_key(var)] = new_slot_variable
927    return named_slots[_var_key(var)]
928
929  def _zeros_slot(self, var, slot_name, op_name):
930    """Find or create a slot initialized with 0.0.
931
932    Args:
933      var: A `Variable` object.
934      slot_name: Name for the slot.
935      op_name: Name to use when scoping the Variable that
936        needs to be created for the slot.
937
938    Returns:
939      A `Variable` object.
940    """
941    named_slots = self._slot_dict(slot_name)
942    if _var_key(var) not in named_slots:
943      new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
944      self._restore_slot_variable(
945          slot_name=slot_name, variable=var,
946          slot_variable=new_slot_variable)
947      named_slots[_var_key(var)] = new_slot_variable
948    return named_slots[_var_key(var)]
949
950  # --------------
951  # For implementing the Checkpointable interface.
952  # --------------
953
954  def _restore_slot_variable(self, slot_name, variable, slot_variable):
955    """Restore a newly created slot variable's value."""
956    variable_key = _var_key(variable)
957    deferred_restorations = self._deferred_slot_restorations.get(
958        slot_name, {}).pop(variable_key, [])
959    # Iterate over restores, highest restore UID first to minimize the number
960    # of assignments.
961    deferred_restorations.sort(key=lambda position: position.restore_uid,
962                               reverse=True)
963    for checkpoint_position in deferred_restorations:
964      checkpoint_position.restore(slot_variable)
965
966  def _create_or_restore_slot_variable(
967      self, slot_variable_position, slot_name, variable):
968    """Restore a slot variable's value, possibly creating it.
969
970    Called when a variable which has an associated slot variable is created or
971    restored. When executing eagerly, we create the slot variable with a
972    restoring initializer.
973
974    No new variables are created when graph building. Instead,
975    _restore_slot_variable catches these after normal creation and adds restore
976    ops to the graph. This method is nonetheless important when graph building
977    for the case when a slot variable has already been created but `variable`
978    has just been added to a dependency graph (causing us to realize that the
979    slot variable needs to be restored).
980
981    Args:
982      slot_variable_position: A `checkpointable._CheckpointPosition` object
983        indicating the slot variable `Checkpointable` object to be restored.
984      slot_name: The name of this `Optimizer`'s slot to restore into.
985      variable: The variable object this slot is being created for.
986    """
987    named_slots = self._slot_dict(slot_name)
988    variable_key = _var_key(variable)
989    slot_variable = named_slots.get(variable_key, None)
990    if (slot_variable is None
991        and context.in_eager_mode()
992        and slot_variable_position.is_simple_variable()):
993      initializer = checkpointable.CheckpointInitialValue(
994          checkpoint_position=slot_variable_position)
995      slot_variable = self._get_or_make_slot(
996          var=variable,
997          val=initializer,
998          slot_name=slot_name,
999          op_name=self._name)
1000      # Slot variables are not owned by any one object (because we don't want to
1001      # save the slot variable if the optimizer is saved without the non-slot
1002      # variable, or if the non-slot variable is saved without the optimizer;
1003      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1004      # variable, variable)). So we don't _track_ slot variables anywhere, and
1005      # instead special-case this dependency and otherwise pretend it's a normal
1006      # graph.
1007    if slot_variable is not None:
1008      # If we've either made this slot variable, or if we've pulled out an
1009      # existing slot variable, we should restore it.
1010      slot_variable_position.restore(slot_variable)
1011    else:
1012      # We didn't make the slot variable. Defer restoring until it gets created
1013      # normally. We keep a list rather than the one with the highest restore
1014      # UID in case slot variables have their own dependencies, in which case
1015      # those could differ between restores.
1016      self._deferred_slot_restorations.setdefault(
1017          slot_name, {}).setdefault(variable_key, []).append(
1018              slot_variable_position)
1019