• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Version 2 of class Optimizer."""
17# pylint: disable=g-bad-name
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24
25import six
26
27from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
28from tensorflow.python.distribute import reduce_util as ds_reduce_util
29from tensorflow.python.eager import backprop
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import gradients
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variable_scope
38from tensorflow.python.ops import variables
39from tensorflow.python.training import optimizer as optimizer_v1
40from tensorflow.python.training import slot_creator
41from tensorflow.python.training.tracking import base as trackable
42from tensorflow.python.util import nest
43
44
45@six.add_metaclass(abc.ABCMeta)
46class _OptimizableVariable(object):
47  """Interface for abstracting over variables in the optimizers."""
48
49  @abc.abstractmethod
50  def target(self):
51    """Returns the optimization target for this variable."""
52    raise NotImplementedError("Calling an abstract method.")
53
54  @abc.abstractmethod
55  def update_op(self, optimizer, g, *args):
56    """Returns the update ops for updating the variable."""
57    raise NotImplementedError("Calling an abstract method.")
58
59
60class _RefVariableProcessor(_OptimizableVariable):
61  """Processor for Variable."""
62
63  def __init__(self, v):
64    self._v = v
65
66  def target(self):
67    return self._v._ref()  # pylint: disable=protected-access
68
69  def update_op(self, optimizer, g, *args):
70    if isinstance(g, ops.Tensor):
71      update_op = optimizer._apply_dense(g, self._v, *args)  # pylint: disable=protected-access
72      if self._v.constraint is not None:
73        with ops.control_dependencies([update_op]):
74          return self._v.assign(self._v.constraint(self._v))
75      else:
76        return update_op
77    else:
78      assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
79                                                "tensor nor IndexedSlices.")
80      if self._v.constraint is not None:
81        raise RuntimeError(
82            "Cannot use a constraint function on a sparse variable.")
83      # pylint: disable=protected-access
84      return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
85
86
87class _DenseReadResourceVariableProcessor(_OptimizableVariable):
88  """Processor for dense ResourceVariables."""
89
90  def __init__(self, v):
91    self._v = v
92
93  def target(self):
94    return self._v
95
96  def update_op(self, optimizer, g, *args):
97    # pylint: disable=protected-access
98    update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
99    if self._v.constraint is not None:
100      with ops.control_dependencies([update_op]):
101        return self._v.assign(self._v.constraint(self._v))
102    else:
103      return update_op
104
105
106class _DenseResourceVariableProcessor(_OptimizableVariable):
107  """Processor for dense ResourceVariables."""
108
109  def __init__(self, v):
110    self._v = v
111
112  def target(self):
113    return self._v
114
115  def update_op(self, optimizer, g, *args):
116    # pylint: disable=protected-access
117    if isinstance(g, ops.IndexedSlices):
118      if self._v.constraint is not None:
119        raise RuntimeError(
120            "Cannot use a constraint function on a sparse variable.")
121      return optimizer._resource_apply_sparse_duplicate_indices(
122          g.values, self._v, g.indices, *args)
123    update_op = optimizer._resource_apply_dense(g, self._v, *args)
124    if self._v.constraint is not None:
125      with ops.control_dependencies([update_op]):
126        return self._v.assign(self._v.constraint(self._v))
127    else:
128      return update_op
129
130
131class _TensorProcessor(_OptimizableVariable):
132  """Processor for ordinary Tensors.
133
134  Even though a Tensor can't really be updated, sometimes it is useful to
135  compute the gradients with respect to a Tensor using the optimizer. Updating
136  the Tensor is, of course, unsupported.
137  """
138
139  def __init__(self, v):
140    self._v = v
141
142  def target(self):
143    return self._v
144
145  def update_op(self, optimizer, g, *args):
146    raise NotImplementedError("Trying to update a Tensor ", self._v)
147
148
149def _get_processor(v):
150  """The processor of v."""
151  if context.executing_eagerly():
152    if isinstance(v, ops.Tensor):
153      return _TensorProcessor(v)
154    else:
155      return _DenseResourceVariableProcessor(v)
156  if v.op.type == "VarHandleOp":
157    return _DenseResourceVariableProcessor(v)
158  if isinstance(v, variables.Variable):
159    return _RefVariableProcessor(v)
160  if isinstance(v, ops.Tensor):
161    return _TensorProcessor(v)
162  raise NotImplementedError("Trying to optimize unsupported type ", v)
163
164
165def _var_key_v2(var):
166  """Key for representing a primary variable, for looking up slots."""
167  # pylint: disable=protected-access
168  if hasattr(var, "_distributed_container"):
169    distributed_container = var._distributed_container()
170    assert distributed_container is not None
171    if context.executing_eagerly():
172      return distributed_container._unique_id
173    return distributed_container._shared_name
174  if context.executing_eagerly():
175    return var._unique_id
176  return var.op.name
177
178
179def _resolve(value, name):
180  if callable(value):
181    value = value()
182  return ops.convert_to_tensor(value, name=name)
183
184
185def _is_dynamic(value):
186  """Returns true if __init__ arg `value` should be re-evaluated each step."""
187  if callable(value):
188    return True
189  # Don't need to do anything special in graph mode, since dynamic values
190  # will propagate correctly automatically.
191  # TODO(josh11b): Add per-device caching across steps using variables for
192  # truly static values once we add distributed support.
193  if context.executing_eagerly() and isinstance(
194      value, resource_variable_ops.ResourceVariable):
195    return True
196  return False
197
198
199class _OptimizerV2State(object):
200  """Holds per-graph and per-step optimizer state.
201
202  Use _init_with_static_hyper() to create the state for a graph, and then
203  _copy_with_dynamic_hyper() to convert that to state for a particular step.
204  The difference between the two is that the former only has hyper
205  parameter values that are static and the latter also has values that
206  can change every step (according to _is_dynamic()).
207  """
208
209  def __init__(self, op_name):
210    self._op_name = op_name
211
212  def _init_with_static_hyper(self, hyper):
213    """Initialize a fresh state object from hyper dict."""
214    # self._hyper contains a dict from name to a dict with the Tensor values.
215    # This dict starts with a single item with key "None" with the hyper
216    # parameter value converted to a Tensor. Other items have dtype keys
217    # with that Tensor cast to that dtype.
218    with ops.init_scope():
219      self._hyper = {
220          name: {
221              None: ops.convert_to_tensor(value, name=name)
222          } for name, (dynamic, value) in sorted(hyper.items()) if not dynamic
223      }
224    self._slots = {}
225    self._non_slot_dict = {}
226    # Extra state to help Optimizers implement Trackable. Holds information
227    # about variables which will be restored as soon as they're created.
228    self._deferred_dependencies = {}  # Non-slot variables
229    self._deferred_slot_restorations = {}  # Slot variables
230
231  def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
232    """Create a new state object for a particular step."""
233    ret = _OptimizerV2State(self._op_name)
234    # pylint: disable=protected-access
235    ret._slots = self._slots
236    ret._non_slot_dict = self._non_slot_dict
237    ret._deferred_dependencies = self._deferred_dependencies
238    ret._deferred_slot_restorations = self._deferred_slot_restorations
239    ret._hyper = {
240        name: {
241            None: _resolve(value, name)
242        } for name, (dynamic, value) in sorted(hyper.items()) if dynamic
243    }
244    ret._hyper.update(self._hyper)
245    ret._non_slot_devices = non_slot_devices
246    ret._distribution = distribution
247    return ret
248
249  def _variables(self):
250    """Returns a list of all variables held by self."""
251    optimizer_variables = list(self._non_slot_dict.values())
252    for variable_dict in self._slots.values():
253      for slot_for_variable in variable_dict.values():
254        optimizer_variables.append(slot_for_variable)
255    # Sort variables by name so that the return is deterministic.
256    return sorted(optimizer_variables, key=lambda v: v.name)
257
258  def _slot_dict(self, slot_name):
259    """Returns a dict for caching slots created under the given name.
260
261    Args:
262      slot_name: Name for the slot.
263
264    Returns:
265      A dict that maps primary `Variable` objects to the slot created
266      for that variable, under the given slot name.
267    """
268    named_slots = self._slots.get(slot_name, None)
269    if named_slots is None:
270      named_slots = {}
271      self._slots[slot_name] = named_slots
272    return named_slots
273
274  def create_slot(self, var, val, slot_name, optional_op_name=None):
275    """Find or create a slot for a variable.
276
277    Args:
278      var: A `Variable` object.
279      val: A `Tensor`.  The initial value of the slot.
280      slot_name: Name for the slot.
281      optional_op_name: Name to use when scoping the Variable that needs to be
282        created for the slot.
283
284    Returns:
285      A `Variable` object.
286    """
287    named_slots = self._slot_dict(slot_name)
288    var_key = _var_key_v2(var)
289    if var_key not in named_slots:
290      new_slot_variable = slot_creator.create_slot(
291          var, val, optional_op_name or self._op_name)
292      self._restore_slot_variable(
293          slot_name=slot_name, variable=var, slot_variable=new_slot_variable)
294      named_slots[var_key] = new_slot_variable
295    return named_slots[var_key]
296
297  def create_slot_with_initializer(self,
298                                   var,
299                                   initializer,
300                                   shape,
301                                   dtype,
302                                   slot_name,
303                                   optional_op_name=None):
304    """Find or create a slot for a variable, using an Initializer.
305
306    Args:
307      var: A `Variable` object.
308      initializer: An `Initializer`.  The initial value of the slot.
309      shape: Shape of the initial value of the slot.
310      dtype: Type of the value of the slot.
311      slot_name: Name for the slot.
312      optional_op_name: Name to use when scoping the Variable that needs to be
313        created for the slot.
314
315    Returns:
316      A `Variable` object.
317    """
318    named_slots = self._slot_dict(slot_name)
319    var_key = _var_key_v2(var)
320    if var_key not in named_slots:
321      new_slot_variable = slot_creator.create_slot_with_initializer(
322          var, initializer, shape, dtype, optional_op_name or self._op_name)
323      self._restore_slot_variable(
324          slot_name=slot_name, variable=var, slot_variable=new_slot_variable)
325      named_slots[var_key] = new_slot_variable
326    return named_slots[var_key]
327
328  def zeros_slot(self, var, slot_name, optional_op_name=None):
329    """Find or create a slot initialized with 0.0.
330
331    Args:
332      var: A `Variable` object.
333      slot_name: Name for the slot.
334      optional_op_name: Name to use when scoping the Variable that needs to be
335        created for the slot.
336
337    Returns:
338      A `Variable` object.
339    """
340    named_slots = self._slot_dict(slot_name)
341    var_key = _var_key_v2(var)
342    if var_key not in named_slots:
343      new_slot_variable = slot_creator.create_zeros_slot(
344          var, optional_op_name or self._op_name)
345      self._restore_slot_variable(
346          slot_name=slot_name, variable=var, slot_variable=new_slot_variable)
347      named_slots[var_key] = new_slot_variable
348    return named_slots[var_key]
349
350  def _create_or_restore_slot_variable(self,
351                                       slot_variable_position,
352                                       slot_name,
353                                       variable,
354                                       optional_op_name=None):
355    """Restore a slot variable's value, possibly creating it.
356
357    Called when a variable which has an associated slot variable is created or
358    restored. When executing eagerly, we create the slot variable with a
359    restoring initializer.
360
361    No new variables are created when graph building. Instead,
362    _restore_slot_variable catches these after normal creation and adds restore
363    ops to the graph. This method is nonetheless important when graph building
364    for the case when a slot variable has already been created but `variable`
365    has just been added to a dependency graph (causing us to realize that the
366    slot variable needs to be restored).
367
368    Args:
369      slot_variable_position: A `trackable._CheckpointPosition` object
370        indicating the slot variable `Trackable` object to be restored.
371      slot_name: The name of this `Optimizer`'s slot to restore into.
372      variable: The variable object this slot is being created for.
373      optional_op_name: Name to use when scoping the Variable that needs to be
374        created for the slot.
375    """
376    slot_variable = self.get_slot(var=variable, name=slot_name)
377    if (slot_variable is None and context.executing_eagerly() and
378        slot_variable_position.is_simple_variable()
379        # Defer slot variable creation if there is an active variable creator
380        # scope. Generally we'd like to eagerly create/restore slot variables
381        # when possible, but this may mean that scopes intended to catch
382        # `variable` also catch its eagerly created slot variable
383        # unintentionally (specifically make_template would add a dependency on
384        # a slot variable if not for this case). Deferring is mostly harmless
385        # (aside from double initialization), and makes variable creator scopes
386        # behave the same way they do when graph building.
387        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
388      initializer = trackable.CheckpointInitialValue(
389          checkpoint_position=slot_variable_position)
390      slot_variable = self.create_slot(
391          var=variable,
392          val=initializer,
393          slot_name=slot_name,
394          optional_op_name=optional_op_name)
395      # Optimizers do not have unconditional dependencies on their slot
396      # variables (nor do any other objects). They are only saved if the
397      # variables they were created for are also saved.
398    if slot_variable is not None:
399      # If we've either made this slot variable, or if we've pulled out an
400      # existing slot variable, we should restore it.
401      slot_variable_position.restore(slot_variable)
402    else:
403      # We didn't make the slot variable. Defer restoring until it gets created
404      # normally. We keep a list rather than the one with the highest restore
405      # UID in case slot variables have their own dependencies, in which case
406      # those could differ between restores.
407      variable_key = _var_key_v2(variable)
408      self._deferred_slot_restorations.setdefault(slot_name, {}).setdefault(
409          variable_key, []).append(slot_variable_position)
410
411  def get_slot(self, var, name):
412    """Return a slot named `name` created for `var` by the Optimizer.
413
414    Some `Optimizer` subclasses use additional variables.  For example
415    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
416    gives access to these `Variable` objects if for some reason you need them.
417
418    Use `get_slot_names()` to get the list of slot names created by the
419    `Optimizer`.
420
421    Args:
422      var: A variable passed to `minimize()` or `apply_gradients()`.
423      name: A string.
424
425    Returns:
426      The `Variable` for the slot if it was created, `None` otherwise.
427    """
428    named_slots = self._slots.get(name, None)
429    if not named_slots:
430      return None
431    return named_slots.get(_var_key_v2(var), None)
432
433  def get_slot_names(self):
434    """Return a list of the names of slots created by the `Optimizer`.
435
436    See `get_slot()`.
437
438    Returns:
439      A list of strings.
440    """
441    return sorted(self._slots.keys())
442
443  def create_non_slot(self, initial_value, name, colocate_with=None):
444    """Add an extra variable, not associated with a slot."""
445    v = self._non_slot_dict.get(name, None)
446    if v is None:
447      if colocate_with is None:
448        colocate_with = self._non_slot_devices
449      with self._distribution.extended.colocate_vars_with(colocate_with):
450        # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
451        v = variable_scope.variable(initial_value, name=name, trainable=False)
452      self._non_slot_dict[name] = v
453      deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
454      for checkpoint_position in sorted(
455          deferred_dependencies_list,
456          key=lambda restore: restore.checkpoint.restore_uid,
457          reverse=True):
458        checkpoint_position.restore(v)
459    return v
460
461  def _restore_slot_variable(self, slot_name, variable, slot_variable):
462    """Restore a newly created slot variable's value."""
463    variable_key = _var_key_v2(variable)
464    deferred_restorations = self._deferred_slot_restorations.get(
465        slot_name, {}).pop(variable_key, [])
466    # Iterate over restores, highest restore UID first to minimize the number
467    # of assignments.
468    deferred_restorations.sort(
469        key=lambda position: position.restore_uid, reverse=True)
470    for checkpoint_position in deferred_restorations:
471      checkpoint_position.restore(slot_variable)
472
473  def get_non_slot(self, name):
474    """Returns the non-slot variable identified by `name`."""
475    return self._non_slot_dict.get(name, None)
476
477  def get_hyper(self, name, dtype=None):
478    """Returns the `name` hyper parameter, optionally cast to `dtype`."""
479    dtype_dict = self._hyper[name]
480    # Do we have the value cast to dtype already cached? This should always
481    # succeed when dtype is None.
482    if dtype in dtype_dict:
483      return dtype_dict[dtype]
484    # Not cached, cast to dtype and save the result in the cache.
485    result = math_ops.cast(dtype_dict[None], dtype)
486    dtype_dict[dtype] = result
487    return result
488
489
490class OptimizerV2(optimizer_v1.Optimizer):
491  """Updated base class for optimizers.
492
493  This class defines the API to add Ops to train a model.  You never use this
494  class directly, but instead instantiate one of its subclasses such as
495  `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
496
497  ### Usage
498
499  ```python
500  # Create an optimizer with the desired parameters.
501  opt = GradientDescentOptimizer(learning_rate=0.1)
502  # Add Ops to the graph to minimize a cost by updating a list of variables.
503  # "cost" is a Tensor, and the list of variables contains tf.Variable
504  # objects.
505  opt_op = opt.minimize(cost, var_list=<list of variables>)
506  ```
507
508  In the training program you will just have to run the returned Op.
509
510  ```python
511  # Execute opt_op to do one step of training:
512  opt_op.run()
513  ```
514
515  ### Processing gradients before applying them.
516
517  Calling `minimize()` takes care of both computing the gradients and
518  applying them to the variables.  If you want to process the gradients
519  before applying them you can instead use the optimizer in three steps:
520
521  1.  Compute the gradients with `compute_gradients()`.
522  2.  Process the gradients as you wish.
523  3.  Apply the processed gradients with `apply_gradients()`.
524
525  Example:
526
527  ```python
528  # Create an optimizer.
529  opt = GradientDescentOptimizer(learning_rate=0.1)
530
531  # Compute the gradients for a list of variables.
532  grads_and_vars = opt.compute_gradients(loss, <list of variables>)
533
534  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
535  # need to the 'gradient' part, for example cap them, etc.
536  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
537
538  # Ask the optimizer to apply the capped gradients.
539  opt.apply_gradients(capped_grads_and_vars)
540  ```
541
542  ### Gating Gradients
543
544  Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
545  argument that controls the degree of parallelism during the application of
546  the gradients.
547
548  The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
549
550  <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
551  the maximum parallelism in execution, at the cost of some non-reproducibility
552  in the results.  For example the two gradients of `matmul` depend on the input
553  values: With `GATE_NONE` one of the gradients could be applied to one of the
554  inputs _before_ the other gradient is computed resulting in non-reproducible
555  results.
556
557  <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
558  they are used.  This prevents race conditions for Ops that generate gradients
559  for multiple inputs where the gradients depend on the inputs.
560
561  <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
562  before any one of them is used.  This provides the least parallelism but can
563  be useful if you want to process all gradients before applying any of them.
564
565  ### Slots
566
567  Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
568  allocate and manage additional variables associated with the variables to
569  train.  These are called <i>Slots</i>.  Slots have names and you can ask the
570  optimizer for the names of the slots that it uses.  Once you have a slot name
571  you can ask the optimizer for the variable it created to hold the slot value.
572
573  This can be useful if you want to log debug a training algorithm, report stats
574  about the slots, etc.
575
576  ### Non-slot variables
577
578  Some optimizer subclasses, such as `AdamOptimizer` have variables that
579  are not associated with the variables to train, just the step itself.
580
581  ### Hyper parameters
582
583  These are arguments passed to the optimizer subclass constructor
584  (the `__init__` method), and then passed to `self._set_hyper()`.
585  They can be either regular Python values (like 1.0), tensors, or
586  callables. If they are callable, the callable will be called during
587  `apply_gradients()` to get the value for the hyper parameter.
588
589  ### State
590
591  Internal methods are passed a `state` argument with the correct
592  values to use for the slot and non-slot variables, and the hyper
593  parameters.
594  """
595
596  # Values for gate_gradients.
597  GATE_NONE = 0
598  GATE_OP = 1
599  GATE_GRAPH = 2
600
601  def __init__(self, use_locking, name):
602    """Create a new Optimizer.
603
604    This must be called by the constructors of subclasses.
605    Note that Optimizer instances should not bind to a single graph,
606    and so shouldn't keep Tensors as member variables. Generally
607    you should be able to use the _set_hyper()/state.get_hyper()
608    facility instead.
609
610    Args:
611      use_locking: Bool. If True apply use locks to prevent concurrent updates
612        to variables.
613      name: A non-empty string.  The name to use for accumulators created
614        for the optimizer.
615
616    Raises:
617      ValueError: If name is malformed.
618      RuntimeError: If _create_slots has been overridden instead of
619          _create_vars.
620    """
621    # Note: We intentionally don't call parent __init__.
622
623    # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
624    if (self.__class__._create_slots.__code__ is not  # pylint: disable=protected-access
625        OptimizerV2._create_slots.__code__):
626      raise RuntimeError(
627          "Override _create_vars instead of _create_slots when "
628          "descending from OptimizerV2 (class %s)" % self.__class__.__name__)
629    if not name:
630      raise ValueError("Must specify the optimizer name")
631
632    self._use_locking = use_locking
633    self._name = name
634    # Map from graph_key to state for that graph. We use the graph_key
635    # since it works in both eager and graph mode, and gives the outer
636    # graph inside functions.
637    replica_context = distribute_ctx.get_replica_context()
638    if replica_context is None:
639      # In a cross-replica context for a DistributionStrategy, which means
640      # only one Optimizer will be created, not one per replica.
641      self._per_graph_state = {}
642    else:
643      # We use get_replica_context().merge_call() to get a single dict
644      # shared across all model replicas when running with a
645      # DistributionStrategy.
646      self._per_graph_state = replica_context.merge_call(lambda _: {})
647
648    # Hyper parameters, and whether they should be re-evaluated every step.
649    self._hyper = {}
650
651  def _set_hyper(self, name, value):
652    self._hyper[name] = (_is_dynamic(value), value)
653
654  def minimize(self,
655               loss,
656               global_step=None,
657               var_list=None,
658               gate_gradients=GATE_OP,
659               aggregation_method=None,
660               name=None,
661               grad_loss=None,
662               stop_gradients=None,
663               scale_loss_by_num_replicas=False):
664    """Add operations to minimize `loss` by updating `var_list`.
665
666    This method simply combines calls `compute_gradients()` and
667    `apply_gradients()`. If you want to process the gradient before applying
668    them call `compute_gradients()` and `apply_gradients()` explicitly instead
669    of using this function.
670
671    Args:
672      loss: A `Tensor` containing the value to minimize.
673      global_step: Optional `Variable` to increment by one after the variables
674        have been updated.
675      var_list: Optional list or tuple of `Variable` objects to update to
676        minimize `loss`.  Defaults to the list of variables collected in the
677        graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
678      gate_gradients: How to gate the computation of gradients.  Can be
679        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
680      aggregation_method: Specifies the method used to combine gradient terms.
681        Valid values are defined in the class `AggregationMethod`.
682      name: Optional name for the returned operation.
683      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
684      stop_gradients: Optional. A Tensor or list of tensors not to differentiate
685        through.
686      scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down
687        by the number of replicas. DEPRECATED and generally no longer needed.
688
689    Returns:
690      An Operation that updates the variables in `var_list`.  If `global_step`
691      was not `None`, that operation also increments `global_step`.
692
693    Raises:
694      ValueError: If some of the variables are not `Variable` objects.
695
696    @compatibility(eager)
697    When eager execution is enabled, `loss` should be a Python function that
698    takes elements of `var_list` as arguments and computes the value to be
699    minimized. If `var_list` is None, `loss` should take no arguments.
700    Minimization (and gradient computation) is done with respect to the
701    elements of `var_list` if not None, else with respect to any trainable
702    variables created during the execution of the `loss` function.
703    `gate_gradients`, `aggregation_method`, and `grad_loss` are ignored when
704    eager execution is enabled.
705    @end_compatibility
706    """
707    grads_and_vars = self.compute_gradients(
708        loss,
709        var_list=var_list,
710        gate_gradients=gate_gradients,
711        aggregation_method=aggregation_method,
712        grad_loss=grad_loss,
713        stop_gradients=stop_gradients,
714        scale_loss_by_num_replicas=scale_loss_by_num_replicas)
715
716    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
717    if not vars_with_grad:
718      raise ValueError(
719          "No gradients provided for any variable, check your graph for ops"
720          " that do not support gradients, between variables %s and loss %s." %
721          ([str(v) for _, v in grads_and_vars], loss))
722
723    return self.apply_gradients(
724        grads_and_vars, global_step=global_step, name=name)
725
726  def compute_gradients(self,
727                        loss,
728                        var_list=None,
729                        gate_gradients=GATE_OP,
730                        aggregation_method=None,
731                        grad_loss=None,
732                        stop_gradients=None,
733                        scale_loss_by_num_replicas=False):
734    """Compute gradients of `loss` for the variables in `var_list`.
735
736    This is the first part of `minimize()`.  It returns a list
737    of (gradient, variable) pairs where "gradient" is the gradient
738    for "variable".  Note that "gradient" can be a `Tensor`, an
739    `IndexedSlices`, or `None` if there is no gradient for the
740    given variable.
741
742    Args:
743      loss: A Tensor containing the value to minimize or a callable taking no
744        arguments which returns the value to minimize. When eager execution is
745        enabled it must be a callable.
746      var_list: Optional list or tuple of `tf.Variable` to update to minimize
747        `loss`.  Defaults to the list of variables collected in the graph under
748        the key `GraphKeys.TRAINABLE_VARIABLES`.
749      gate_gradients: How to gate the computation of gradients.  Can be
750        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
751      aggregation_method: Specifies the method used to combine gradient terms.
752        Valid values are defined in the class `AggregationMethod`.
753      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
754      stop_gradients: Optional. A Tensor or list of tensors not to differentiate
755        through.
756      scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down
757        by the number of replicas. DEPRECATED and generally no longer needed.
758
759    Returns:
760      A list of (gradient, variable) pairs. Variable is always present, but
761      gradient can be `None`.
762
763    Raises:
764      TypeError: If `var_list` contains anything else than `Variable` objects.
765      ValueError: If some arguments are invalid.
766      RuntimeError: If called with eager execution enabled and `loss` is
767        not callable.
768
769    @compatibility(eager)
770    When eager execution is enabled, `gate_gradients`, and `aggregation_method`
771    are ignored.
772    @end_compatibility
773    """
774    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
775    if callable(loss):
776      with backprop.GradientTape() as tape:
777        if var_list is not None:
778          tape.watch(var_list)
779        loss_value = loss()
780
781        # Scale loss for number of replicas (callable-loss case).
782        loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas)
783
784      if var_list is None:
785        var_list = tape.watched_variables()
786      grads = tape.gradient(loss_value, var_list, grad_loss)
787      return list(zip(grads, var_list))
788    if context.executing_eagerly():
789      raise RuntimeError("`loss` passed to Optimizer.compute_gradients should "
790                         "be a function when eager execution is enabled.")
791
792    # Scale loss for number of replicas (non-callable-loss case).
793    loss = self._scale_loss(loss, scale_loss_by_num_replicas)
794
795    if gate_gradients not in [
796        optimizer_v1.Optimizer.GATE_NONE, optimizer_v1.Optimizer.GATE_OP,
797        optimizer_v1.Optimizer.GATE_GRAPH
798    ]:
799      raise ValueError(
800          "gate_gradients must be one of: Optimizer.GATE_NONE, "
801          "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" % gate_gradients)
802    self._assert_valid_dtypes([loss])
803    if grad_loss is not None:
804      self._assert_valid_dtypes([grad_loss])
805    if var_list is None:
806      var_list = (
807          variables.trainable_variables() + ops.get_collection(
808              ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
809    else:
810      var_list = nest.flatten(var_list)
811    # pylint: disable=protected-access
812    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
813    # pylint: enable=protected-access
814    processors = [_get_processor(v) for v in var_list]
815    if not var_list:
816      raise ValueError("No variables to optimize.")
817    var_refs = [p.target() for p in processors]
818    grads = gradients.gradients(
819        loss,
820        var_refs,
821        grad_ys=grad_loss,
822        gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
823        aggregation_method=aggregation_method,
824        stop_gradients=stop_gradients)
825    if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
826      grads = control_flow_ops.tuple(grads)
827    grads_and_vars = list(zip(grads, var_list))
828    self._assert_valid_dtypes([
829        v for g, v in grads_and_vars
830        if g is not None and v.dtype != dtypes.resource
831    ])
832    return grads_and_vars
833
834  @staticmethod
835  def _scale_loss(loss_value, scale_loss_by_num_replicas):
836    """Scale loss for the number of replicas."""
837    if scale_loss_by_num_replicas:
838      num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
839      if num_replicas > 1:
840        loss_value *= 1. / num_replicas
841    return loss_value
842
843  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
844    """Apply gradients to variables.
845
846    This is the second part of `minimize()`. It returns an `Operation` that
847    applies gradients.
848
849    Args:
850      grads_and_vars: List of (gradient, variable) pairs as returned by
851        `compute_gradients()`.
852      global_step: Optional `Variable` to increment by one after the variables
853        have been updated.
854      name: Optional name for the returned operation.  Default to the name
855        passed to the `Optimizer` constructor.
856
857    Returns:
858      An `Operation` that applies the specified gradients. If `global_step`
859      was not None, that operation also increments `global_step`.
860
861    Raises:
862      TypeError: If `grads_and_vars` is malformed.
863      ValueError: If none of the variables have gradients.
864    """
865    # This is a default implementation of apply_gradients() that can be shared
866    # by most optimizers.  It relies on the subclass implementing the following
867    # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
868
869    # Filter out variables with gradients of `None`.
870    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
871    if not grads_and_vars:
872      raise ValueError("No variables provided.")
873    filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
874    if not filtered:
875      raise ValueError("No gradients provided for any variable: %s." %
876                       ([str(v) for _, v in grads_and_vars],))
877    return distribute_ctx.get_replica_context().merge_call(
878        self._distributed_apply, args=(filtered,),
879        kwargs={"global_step": global_step, "name": name})
880
881  def _get_or_create_state(self, var_list=None):
882    """Either looks up or creates `_OptimizerV2State`.
883
884    If any variables are available, they should be passed via the `var_list`
885    argument, and these will be used to determine the graph to create/retrieve
886    state for. Otherwise the returned state is for the current default graph.
887
888    Args:
889      var_list: A list of variables to extract a graph from.
890
891    Returns:
892      An `_OptimizerV2State` object.
893    """
894    # Determine the graph_key from the current graph.
895    eager_execution = context.executing_eagerly()
896    if eager_execution or var_list is None:
897      graph = ops.get_default_graph()
898    else:
899      graph = ops._get_graph_from_inputs(var_list)  # pylint: disable=protected-access
900    assert graph is not None
901    graph_key = graph._graph_key  # pylint: disable=protected-access
902
903    # Get the per graph state by looking up the graph_key.
904    if graph_key in self._per_graph_state:
905      per_graph_state = self._per_graph_state[graph_key]
906    else:
907      per_graph_state = _OptimizerV2State(self._name)
908      per_graph_state._init_with_static_hyper(self._hyper)  # pylint: disable=protected-access
909      self._per_graph_state[graph_key] = per_graph_state
910    return per_graph_state
911
912  def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
913    """`apply_gradients` for use with a `DistributionStrategy`."""
914    reduced_grads = distribution.extended.batch_reduce_to(
915        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
916    var_list = [v for _, v in grads_and_vars]
917    grads_and_vars = zip(reduced_grads, var_list)
918
919    unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
920    eager_execution = context.executing_eagerly()
921    if eager_execution:
922      # Give a clear error in this case instead of "name not supported
923      # for Eager Tensors" when we compute non_slot_devices.
924      for v in unwrapped_var_list:
925        if isinstance(v, ops.Tensor):
926          raise NotImplementedError("Trying to update a Tensor ", v)
927
928    with ops.name_scope(name, self._name) as name:
929      per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
930      # Include the current value of any dynamic hyper parameters in `state`.
931      non_slot_devices = distribution.extended.non_slot_devices(var_list)
932      state = per_graph_state._copy_with_dynamic_hyper(  # pylint: disable=protected-access
933          self._hyper, distribution, non_slot_devices)
934
935    # Create any slot and non-slot variables we need in `state`.
936    with ops.init_scope():
937      self._create_vars(var_list, state)
938
939    with ops.name_scope(name):  # Re-enter name_scope created above
940      # Give the child class a chance to do something before we start
941      # applying gradients.
942      self._prepare(state)
943
944      def update(v, g):
945        """Update variable `v` using gradient `g`."""
946        assert v is not None
947
948        # Convert the grad to Tensor or IndexedSlices if necessary, and
949        # look up a processor for each variable's type.
950        try:
951          g = ops.convert_to_tensor_or_indexed_slices(g)
952        except TypeError:
953          raise TypeError("Gradient must be convertible to a Tensor"
954                          " or IndexedSlices, or None: %s" % g)
955        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
956          raise TypeError(
957              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
958        processor = _get_processor(v)
959
960        # We colocate all ops created in _apply_dense or _apply_sparse
961        # on the same device as the variable.
962        # TODO(apassos): figure out how to get the variable name here.
963        scope_name = "" if eager_execution else v.op.name
964        # device_policy is set because non-mirrored tensors will be read in
965        # `update_op`.
966        # TODO(josh11b): Make different state objects for each device to
967        # avoid needing to set the device_policy.
968        device_policy = context.device_policy(
969            context.DEVICE_PLACEMENT_SILENT)
970        with ops.name_scope("update_" + scope_name), device_policy:
971          return processor.update_op(self, g, state)
972
973      # Use the processors to update the variables.
974      update_ops = []
975      for grad, var in grads_and_vars:
976        update_ops.extend(distribution.extended.update(
977            var, update, args=(grad,), group=False))
978
979      # Give the child class a chance to do something after applying
980      # gradients
981      def finish():
982        # TODO(josh11b): Make different state objects for each device to
983        # avoid needing to set the device_policy.
984        with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
985          return self._finish(state)
986
987      update_ops = control_flow_ops.group(update_ops)
988      with ops.control_dependencies([update_ops]):
989        finish_updates = distribution.extended.update_non_slot(
990            non_slot_devices, finish, group=False)
991      # We said group=False, which means finish_updates is always a tuple.
992      # It will be (None,) when finish() returns None.
993      if finish_updates == (None,):
994        finish_updates = (update_ops,)
995
996      # Update `global_step` (if any).
997      if global_step is None:
998        apply_updates = distribution.group(finish_updates, name=name)
999      else:
1000        with ops.control_dependencies(finish_updates):
1001
1002          def update_global_step(global_step, name):
1003            return global_step.assign_add(1, read_value=False, name=name)
1004
1005          apply_updates = distribution.extended.update(
1006              global_step, update_global_step, args=(name,))
1007
1008      # Add the training op to the TRAIN_OP graph collection in graph mode.
1009      if not eager_execution:
1010        if isinstance(apply_updates, ops.Tensor):
1011          apply_updates = apply_updates.op
1012        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
1013        if apply_updates not in train_op:
1014          train_op.append(apply_updates)
1015
1016      return apply_updates
1017
1018  def get_slot(self, var, name):
1019    """Return a slot named `name` created for `var` by the Optimizer.
1020
1021    Some `Optimizer` subclasses use additional variables.  For example
1022    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
1023    gives access to these `Variable` objects if for some reason you need them.
1024
1025    Use `get_slot_names()` to get the list of slot names created by the
1026    `Optimizer`.
1027
1028    Args:
1029      var: A variable passed to `minimize()` or `apply_gradients()`.
1030      name: A string.
1031
1032    Returns:
1033      The `Variable` for the slot if it was created, `None` otherwise.
1034    """
1035    state = self._get_state_for_var(var)
1036    return state.get_slot(var, name) if state is not None else None
1037
1038  def get_slot_names(self):
1039    """Return a list of the names of slots created by the `Optimizer`.
1040
1041    See `get_slot()`.
1042
1043    Returns:
1044      A list of strings.
1045    """
1046    state = self._get_per_graph_state()
1047    return state.get_slot_names() if state is not None else []
1048
1049  def variables(self):
1050    """A list of variables which encode the current state of `Optimizer`.
1051
1052    Includes slot variables and additional global variables created by the
1053    optimizer in the current default graph.
1054
1055    Returns:
1056      A list of variables.
1057    """
1058    state = self._get_per_graph_state()
1059    return state._variables() if state is not None else []  # pylint: disable=protected-access
1060
1061  # --------------
1062  # Methods to be implemented by subclasses if they want to use the
1063  # inherited implementation of apply_gradients() or compute_gradients().
1064  # --------------
1065  def _create_vars(self, var_list, state):
1066    """Create all slots needed by the variables and any non-slot variables.
1067
1068    Args:
1069      var_list: A list of `Variable` objects.
1070      state: An object with these methods: `create_slot(var, val, slot_name,
1071        optional_op_name)`, `create_slot_with_initializer(` `var, initializer,
1072        shape, dtype, slot_name, optional_op_name)`, `zeros_slot(var, slot_name,
1073        optional_op_name)`, `create_non_slot_variable(initial_value, name,
1074        colocate_with)`, `get_hyper(name)`
1075    """
1076    # No slots needed by default
1077    pass
1078
1079  def _prepare(self, state):
1080    """Code to execute before applying gradients.
1081
1082    Note that most uses of _prepare() in Optimizer have been subsumed
1083    by explicit support for hyper parameters in OptimizerV2
1084
1085    Args:
1086      state: An object with a `get_hyper(name)` method.
1087
1088    Returns:
1089      Return value will be ignored.
1090    """
1091    pass
1092
1093  def _apply_dense(self, grad, var, state):
1094    """Add ops to apply dense gradients to `var`.
1095
1096    Args:
1097      grad: A `Tensor`.
1098      var: A `Variable` object.
1099      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1100        and `get_hyper(name)` methods.
1101
1102    Returns:
1103      An `Operation`.
1104    """
1105    raise NotImplementedError()
1106
1107  def _resource_apply_dense(self, grad, handle, state):
1108    """Add ops to apply dense gradients to the variable `handle`.
1109
1110    Args:
1111      grad: a `Tensor` representing the gradient.
1112      handle: a `Tensor` of dtype `resource` which points to the variable to be
1113        updated.
1114      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1115        and `get_hyper(name)` methods.
1116
1117    Returns:
1118      An `Operation` which updates the value of the variable.
1119    """
1120    raise NotImplementedError()
1121
1122  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
1123                                               state):
1124    """Add ops to apply sparse gradients to `handle`, with repeated indices.
1125
1126    Optimizers which override this method must deal with repeated indices. See
1127    the docstring of `_apply_sparse_duplicate_indices` for details. By default
1128    the correct behavior, to sum non-unique indices and their associated
1129    gradients, is enforced by first pre-processing `grad` and `indices` and
1130    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1131    with duplicate indices may instead override this method to avoid the
1132    overhead of summing.
1133
1134    Args:
1135      grad: a `Tensor` representing the gradient for the affected indices.
1136      handle: a `Tensor` of dtype `resource` which points to the variable to be
1137        updated.
1138      indices: a `Tensor` of integral type representing the indices for which
1139        the gradient is nonzero. Indices may be repeated.
1140      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1141        and `get_hyper(name)` methods.
1142
1143    Returns:
1144      An `Operation` which updates the value of the variable.
1145    """
1146    # pylint: disable=protected-access
1147    summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
1148        values=grad, indices=indices)
1149    # pylint: enable=protected-access
1150    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
1151                                       state)
1152
1153  def _resource_apply_sparse(self, grad, handle, indices, state):
1154    """Add ops to apply sparse gradients to the variable `handle`.
1155
1156    Similar to `_apply_sparse`, the `indices` argument to this method has been
1157    de-duplicated. Optimizers which deal correctly with non-unique indices may
1158    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1159    overhead.
1160
1161    Args:
1162      grad: a `Tensor` representing the gradient for the affected indices.
1163      handle: a `Tensor` of dtype `resource` which points to the variable to be
1164        updated.
1165      indices: a `Tensor` of integral type representing the indices for which
1166        the gradient is nonzero. Indices are unique.
1167      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1168        and `get_hyper(name)` methods.
1169
1170    Returns:
1171      An `Operation` which updates the value of the variable.
1172    """
1173    raise NotImplementedError()
1174
1175  def _apply_sparse_duplicate_indices(self, grad, var, state):
1176    """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
1177
1178    Optimizers which override this method must deal with IndexedSlices objects
1179    such as the following:
1180
1181      IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
1182
1183    The correct interpretation is:
1184
1185      IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
1186
1187    Many optimizers deal incorrectly with repeated indices when updating based
1188    on sparse gradients (e.g. summing squares rather than squaring the sum, or
1189    applying momentum terms multiple times). Adding first is always the correct
1190    behavior, so this is enforced here by reconstructing the IndexedSlices to
1191    have only unique indices, then calling _apply_sparse.
1192
1193    Optimizers which deal correctly with repeated indices may instead override
1194    this method to avoid the overhead of summing indices.
1195
1196    Args:
1197      grad: `IndexedSlices`.
1198      var: A `Variable` object.
1199      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1200        and `get_hyper(name)` methods.
1201
1202    Returns:
1203      An `Operation`.
1204    """
1205    # pylint: disable=protected-access
1206    summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
1207        values=grad.values, indices=grad.indices)
1208    # pylint: enable=protected-access
1209    gradient_no_duplicate_indices = ops.IndexedSlices(
1210        indices=unique_indices,
1211        values=summed_values,
1212        dense_shape=grad.dense_shape)
1213    return self._apply_sparse(gradient_no_duplicate_indices, var, state)
1214
1215  def _apply_sparse(self, grad, var, state):
1216    """Add ops to apply sparse gradients to `var`.
1217
1218    The IndexedSlices object passed to `grad` in this function is by default
1219    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
1220    indices (see its docstring for details). Optimizers which can tolerate or
1221    have correct special cases for duplicate sparse indices may override
1222    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
1223    overhead.
1224
1225    Args:
1226      grad: `IndexedSlices`, with no repeated indices.
1227      var: A `Variable` object.
1228      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1229        and `get_hyper(name)` methods.
1230
1231    Returns:
1232      An `Operation`.
1233    """
1234    raise NotImplementedError()
1235
1236  def _finish(self, state):
1237    """Do what is needed to finish the update.
1238
1239    This is called inside a scope colocated with any non-slot variables.
1240
1241    Args:
1242      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
1243        and `get_hyper(name)` methods.
1244
1245    Returns:
1246      The operation to apply updates, or None if no updates.
1247    """
1248    return None
1249
1250  # --------------
1251  # Utility methods for subclasses.
1252  # --------------
1253  def _get_per_graph_state(self):
1254    # pylint: disable=protected-access
1255    return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
1256
1257  def _get_state_for_var(self, var):
1258    # pylint: disable=protected-access
1259    return self._per_graph_state.get(var._graph_key, None)
1260
1261  # --------------
1262  # Overridden methods from Trackable.
1263  # --------------
1264
1265  def _track_trackable(self, *args, **kwargs):
1266    """Optimizers may not track dependencies. Raises an error."""
1267    raise NotImplementedError(
1268        "Optimizers may not have dependencies. File a feature request if this "
1269        "limitation bothers you.")
1270
1271  @property
1272  def _checkpoint_dependencies(self):
1273    """From Trackable. Gather graph-specific non-slot variables to save."""
1274    current_graph_non_slot_variables = []
1275    state = self._get_per_graph_state()
1276    if state is not None:
1277      for name, variable_object in sorted(
1278          state._non_slot_dict.items(),  # pylint: disable=protected-access
1279          # Avoid comparing variables
1280          key=lambda item: item[0]):
1281        current_graph_non_slot_variables.append(
1282            trackable.TrackableReference(
1283                name=name, ref=variable_object))
1284    # Note: ignores super(); Optimizers may not have any dependencies outside of
1285    # state objects.
1286    return current_graph_non_slot_variables
1287
1288  def _lookup_dependency(self, name):
1289    """From Trackable. Find a non-slot variable in the current graph."""
1290    state = self._get_per_graph_state()
1291    if state is None:
1292      return None
1293    else:
1294      return state.get_non_slot(name)
1295
1296  @property
1297  def _deferred_dependencies(self):
1298    """Lets Trackable know where non-slot variables are created.
1299
1300    If necessary, creates a new state object for the current default graph.
1301    Trackable will then add entries to that state's deferred dependency
1302    dictionary. The state object will check that dictionary when creating
1303    non-slot variables, restoring their value if an entry is found.
1304
1305    Returns:
1306      A dictionary which holds deferred dependencies for the current default
1307      graph.
1308    """
1309    state = self._get_or_create_state()
1310    return state._deferred_dependencies  # pylint: disable=protected-access
1311
1312  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
1313                                       variable):
1314    """Trackable: Restore a slot variable's value, possibly creating it.
1315
1316    Called when a variable which has an associated slot variable is created or
1317    restored.
1318
1319    Args:
1320      slot_variable_position: A `trackable._CheckpointPosition` object
1321        indicating the slot variable `Trackable` object to be restored.
1322      slot_name: The name of this `Optimizer`'s slot to restore into.
1323      variable: The variable object this slot is being created for.
1324    """
1325    state = self._get_or_create_state(var_list=[variable])
1326    state._create_or_restore_slot_variable(  # pylint: disable=protected-access
1327        slot_variable_position=slot_variable_position,
1328        slot_name=slot_name,
1329        variable=variable,
1330        optional_op_name=self._name)
1331
1332  # --------------
1333  # Unsupported parent methods
1334  # --------------
1335  def _slot_dict(self, slot_name):
1336    raise NotImplementedError("_slot_dict() method unsupported in OptimizerV2")
1337
1338  def _get_or_make_slot(self, var, val, slot_name, op_name):
1339    raise NotImplementedError(
1340        "_get_or_make_slot() method unsupported in OptimizerV2")
1341
1342  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
1343                                         slot_name, op_name):
1344    raise NotImplementedError(
1345        "_get_or_make_slot_with_initializer() method unsupported in "
1346        "OptimizerV2")
1347
1348  def _create_non_slot_variable(self, initial_value, name, colocate_with):
1349    raise NotImplementedError(
1350        "_create_non_slot_variable() method unsupported in OptimizerV2")
1351
1352  def _get_non_slot_variable(self, name, graph=None):
1353    raise NotImplementedError(
1354        "_get_non_slot_variable() method unsupported in OptimizerV2")
1355
1356  def _non_slot_variables(self):
1357    raise NotImplementedError(
1358        "_non_slot_variables() method unsupported in OptimizerV2")
1359