• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Version 2 of class Optimizer."""
16# pylint: disable=g-bad-name
17
18import abc
19import contextlib
20import functools
21import warnings
22
23from tensorflow.python.distribute import central_storage_strategy
24from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
25from tensorflow.python.distribute import parameter_server_strategy
26from tensorflow.python.distribute import parameter_server_strategy_v2
27from tensorflow.python.distribute import values as ds_values
28from tensorflow.python.eager import backprop
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.keras import backend
34from tensorflow.python.keras import initializers
35from tensorflow.python.keras.engine import base_layer_utils
36from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
37from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
38from tensorflow.python.keras.utils import generic_utils
39from tensorflow.python.keras.utils import layer_utils
40from tensorflow.python.keras.utils import tf_inspect
41from tensorflow.python.keras.utils import tf_utils
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gen_resource_variable_ops
45from tensorflow.python.ops import gradients
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import variables as tf_variables
48from tensorflow.python.saved_model import revived_types
49from tensorflow.python.training.tracking import base as trackable
50from tensorflow.python.util import nest
51from tensorflow.python.util.tf_export import keras_export
52
53
54_DEFAULT_VALID_DTYPES = frozenset([
55    dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
56    dtypes.complex64, dtypes.complex128
57])
58
59
60def _deduplicate_indexed_slices(values, indices):
61  """Sums `values` associated with any non-unique `indices`.
62
63  Args:
64    values: A `Tensor` with rank >= 1.
65    indices: A one-dimensional integer `Tensor`, indexing into the first
66      dimension of `values` (as in an IndexedSlices object).
67
68  Returns:
69    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
70    de-duplicated version of `indices` and `summed_values` contains the sum of
71    `values` slices associated with each unique index.
72  """
73  unique_indices, new_index_positions = array_ops.unique(indices)
74  summed_values = math_ops.unsorted_segment_sum(
75      values, new_index_positions,
76      array_ops.shape(unique_indices)[0])
77  return (summed_values, unique_indices)
78
79
80class NullContextmanager(object):
81
82  def __init__(self, *args, **kwargs):
83    pass
84
85  def __enter__(self):
86    pass
87
88  def __exit__(self, type_arg, value_arg, traceback_arg):
89    return False  # False values do not suppress exceptions
90
91
92def name_scope_only_in_function_or_graph(name):
93  """Internal-only entry point for `name_scope*`.
94
95  Enters a compat.v1.name_scope only when in a function or graph,
96  not when running fully eagerly.
97
98  Args:
99    name: The name argument that is passed to the op function.
100
101  Returns:
102    `name_scope*` context manager.
103  """
104  if not context.executing_eagerly():
105    return ops.name_scope_v1(name)
106  else:
107    return NullContextmanager()
108
109
110@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta)
111class OptimizerV2(trackable.Trackable):
112  """Base class for Keras optimizers.
113
114  You should not use this class directly, but instead instantiate one of its
115  subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc.
116
117  ### Usage
118
119  ```python
120  # Create an optimizer with the desired parameters.
121  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
122  # `loss` is a callable that takes no argument and returns the value
123  # to minimize.
124  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
125  # In graph mode, returns op that minimizes the loss by updating the listed
126  # variables.
127  opt_op = opt.minimize(loss, var_list=[var1, var2])
128  opt_op.run()
129  # In eager mode, simply call minimize to update the list of variables.
130  opt.minimize(loss, var_list=[var1, var2])
131  ```
132
133  ### Usage in custom training loops
134
135  In Keras models, sometimes variables are created when the model is first
136  called, instead of construction time. Examples include 1) sequential models
137  without input shape pre-defined, or 2) subclassed models. Pass var_list as
138  callable in these cases.
139
140  Example:
141
142  ```python
143  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
144  model = tf.keras.Sequential()
145  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
146  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
147  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
148  var_list_fn = lambda: model.trainable_weights
149  for input, output in data:
150    opt.minimize(loss_fn, var_list_fn)
151  ```
152
153  ### Processing gradients before applying them
154
155  Calling `minimize()` takes care of both computing the gradients and
156  applying them to the variables.  If you want to process the gradients
157  before applying them you can instead use the optimizer in three steps:
158
159  1.  Compute the gradients with `tf.GradientTape`.
160  2.  Process the gradients as you wish.
161  3.  Apply the processed gradients with `apply_gradients()`.
162
163  Example:
164
165  ```python
166  # Create an optimizer.
167  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
168
169  # Compute the gradients for a list of variables.
170  with tf.GradientTape() as tape:
171    loss = <call_loss_function>
172  vars = <list_of_variables>
173  grads = tape.gradient(loss, vars)
174
175  # Process the gradients, for example cap them, etc.
176  # capped_grads = [MyCapper(g) for g in grads]
177  processed_grads = [process_gradient(g) for g in grads]
178
179  # Ask the optimizer to apply the processed gradients.
180  opt.apply_gradients(zip(processed_grads, var_list))
181  ```
182
183  ### Use with `tf.distribute.Strategy`
184
185  This optimizer class is `tf.distribute.Strategy` aware, which means it
186  automatically sums gradients across all replicas. To average gradients,
187  you divide your loss by the global batch size, which is done
188  automatically if you use `tf.keras` built-in training or evaluation loops.
189  See the `reduction` argument of your loss which should be set to
190  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
191  `tf.keras.losses.Reduction.SUM` for not.
192
193  To aggregate gradients yourself, call `apply_gradients` with
194  `experimental_aggregate_gradients` set to False. This is useful if you need to
195  process aggregated gradients.
196
197  If you are not using these and you want to average gradients, you should use
198  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
199  global batch size. Note that when using `tf.distribute.Strategy`, the first
200  component of a tensor's shape is the *replica-local* batch size, which is off
201  by a factor equal to the number of replicas being used to compute a single
202  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
203  resulting in gradients that can be many times too big.
204
205  ### Variable Constraints
206
207  All Keras optimizers respect variable constraints. If constraint function is
208  passed to any variable, the constraint will be applied to the variable after
209  the gradient has been applied to the variable.
210  Important: If gradient is sparse tensor, variable constraint is not supported.
211
212  ### Thread Compatibility
213
214  The entire optimizer is currently thread compatible, not thread-safe. The user
215  needs to perform synchronization if necessary.
216
217  ### Slots
218
219  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
220  additional variables associated with the variables to train.  These are called
221  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
222  the slots that it uses.  Once you have a slot name you can ask the optimizer
223  for the variable it created to hold the slot value.
224
225  This can be useful if you want to log debug a training algorithm, report stats
226  about the slots, etc.
227
228  ### Hyperparameters
229
230  These are arguments passed to the optimizer subclass constructor
231  (the `__init__` method), and then passed to `self._set_hyper()`.
232  They can be either regular Python values (like 1.0), tensors, or
233  callables. If they are callable, the callable will be called during
234  `apply_gradients()` to get the value for the hyper parameter.
235
236  Hyperparameters can be overwritten through user code:
237
238  Example:
239
240  ```python
241  # Create an optimizer with the desired parameters.
242  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
243  # `loss` is a callable that takes no argument and returns the value
244  # to minimize.
245  loss = lambda: 3 * var1 + 2 * var2
246  # In eager mode, simply call minimize to update the list of variables.
247  opt.minimize(loss, var_list=[var1, var2])
248  # update learning rate
249  opt.learning_rate = 0.05
250  opt.minimize(loss, var_list=[var1, var2])
251  ```
252
253  ### Callable learning rate
254
255  Optimizer accepts a callable learning rate in two ways. The first way is
256  through built-in or customized
257  `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be
258  called on each iteration with `schedule(iteration)`, a `tf.Variable`
259  owned by the optimizer.
260
261  Example:
262
263  >>> var = tf.Variable(np.random.random(size=(1,)))
264  >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
265  ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
266  >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate)
267  >>> loss = lambda: 3 * var
268  >>> opt.minimize(loss, var_list=[var])
269  <tf.Variable...
270
271  The second way is through a callable function that
272  does not accept any arguments.
273
274  Example:
275
276  >>> var = tf.Variable(np.random.random(size=(1,)))
277  >>> def lr_callable():
278  ...   return .1
279  >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable)
280  >>> loss = lambda: 3 * var
281  >>> opt.minimize(loss, var_list=[var])
282  <tf.Variable...
283
284  ### Creating a custom optimizer
285
286  If you intend to create your own optimization algorithm, simply inherit from
287  this class and override the following methods:
288
289    - `_resource_apply_dense` (update variable given gradient tensor is a dense
290      `tf.Tensor`)
291    - `_resource_apply_sparse` (update variable given gradient tensor is a
292      sparse `tf.IndexedSlices`. The most common way for this to happen
293      is if you are taking the gradient through a `tf.gather`.)
294    - `_create_slots`
295      (if your optimizer algorithm requires additional variables)
296    - `get_config`
297      (serialization of the optimizer, include all hyper parameters)
298  """
299
300  # Subclasses should set this to True unless they override `apply_gradients`
301  # with a version that does not have the `experimental_aggregate_gradients`
302  # argument.  Older versions of Keras did not have this argument so custom
303  # optimizers may have overridden `apply_gradients` without the
304  # `experimental_aggregate_gradients` argument. Keras only passes
305  # `experimental_aggregate_gradients` if this attribute is True.
306  # Note: This attribute will likely be removed in an upcoming release.
307  _HAS_AGGREGATE_GRAD = False
308
309  def __init__(self,
310               name,
311               gradient_aggregator=None,
312               gradient_transformers=None,
313               **kwargs):
314    """Create a new Optimizer.
315
316    This must be called by the constructors of subclasses.
317    Note that Optimizer instances should not bind to a single graph,
318    and so shouldn't keep Tensors as member variables. Generally
319    you should be able to use the _set_hyper()/state.get_hyper()
320    facility instead.
321
322    This class is stateful and thread-compatible.
323
324    Example of custom gradient transformations:
325
326    ```python
327    def my_gradient_transformer(grads_and_vars):
328      # Simple example, double the gradients.
329      return [(2. * g, v) for g, v in grads_and_vars]
330
331    optimizer = tf.keras.optimizers.SGD(
332        1e-3, gradient_transformers=[my_gradient_transformer])
333    ```
334
335    Args:
336      name: String. The name to use for momentum accumulator weights created
337        by the optimizer.
338      gradient_aggregator: The function to use to aggregate gradients across
339        devices (when using `tf.distribute.Strategy`). If `None`, defaults to
340        summing the gradients across devices. The function should accept and
341        return a list of `(gradient, variable)` tuples.
342      gradient_transformers: Optional. List of functions to use to transform
343        gradients before applying updates to Variables. The functions are
344        applied after `gradient_aggregator`. The functions should accept and
345        return a list of `(gradient, variable)` tuples.
346      **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
347        `clipnorm`, `global_clipnorm`.
348        If `clipvalue` (float) is set, the gradient of each weight
349        is clipped to be no higher than this value.
350        If `clipnorm` (float) is set, the gradient of each weight
351        is individually clipped so that its norm is no higher than this value.
352        If `global_clipnorm` (float) is set the gradient of all weights is
353        clipped so that their global norm is no higher than this value.
354
355    Raises:
356      ValueError: in case of any invalid argument.
357    """
358    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"}
359    for k in kwargs:
360      if k not in allowed_kwargs:
361        raise TypeError("Unexpected keyword argument "
362                        "passed to optimizer: " + str(k))
363      # checks that all keyword arguments are non-negative.
364      if kwargs[k] is not None and kwargs[k] < 0:
365        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
366      if k == "lr":
367        warnings.warn(
368            "The `lr` argument is deprecated, use `learning_rate` instead.")
369
370    self._use_locking = True
371    self._init_set_name(name)
372    self._hyper = {}
373    # dict: {variable name : {slot name : variable}}
374    self._slots = {}
375    self._slot_names = []
376    self._weights = []
377    self._iterations = None
378
379    # For implementing Trackable. Stores information about how to restore
380    # slot variables which have not yet been created
381    # (trackable._CheckpointPosition objects).
382    #  {slot_name :
383    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
384    #   ... }
385    self._deferred_slot_restorations = {}
386
387    decay = kwargs.pop("decay", 0.0)
388    if decay < 0.:
389      raise ValueError("decay cannot be less than 0: {}".format(decay))
390    self._initial_decay = decay
391
392    self._hypers_created = False
393    # Store the distribution strategy object if the optimizer is created inside
394    # strategy scope, so it could be used to create variables later.
395    if distribute_ctx.has_strategy():
396      self._distribution_strategy = distribute_ctx.get_strategy()
397    else:
398      self._distribution_strategy = None
399
400    # Configure gradient transformations.
401    if gradient_aggregator is None:
402      gradient_aggregator = optimizer_utils.all_reduce_sum_gradients
403    self.gradient_aggregator = gradient_aggregator
404    if gradient_transformers is None:
405      gradient_transformers = []
406    self.gradient_transformers = gradient_transformers
407    self.clipnorm = kwargs.pop("clipnorm", None)
408    self.global_clipnorm = kwargs.pop("global_clipnorm", None)
409    if self.clipnorm is not None and self.global_clipnorm is not None:
410      raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, "
411                       "passed `clipnorm` {}, `global_clipnorm` {}".format(
412                           self.clipnorm, self.global_clipnorm))
413    self.clipvalue = kwargs.pop("clipvalue", None)
414
415  @property
416  def clipnorm(self):
417    """`float` or `None`. If set, clips gradients to a maximum norm."""
418    return self._clipnorm
419
420  @property
421  def global_clipnorm(self):
422    """`float` or `None`. If set, clips gradients to a maximum norm."""
423    return self._global_clipnorm
424
425  @clipnorm.setter
426  def clipnorm(self, val):
427    if val is not None and self.gradient_transformers:
428      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
429                       "is set. Instead, use the `gradient_transformers` to "
430                       "specify clipping and other transformations.")
431    self._clipnorm = val
432    self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
433        self._clipnorm)
434
435  @global_clipnorm.setter
436  def global_clipnorm(self, val):
437    if val is not None and self.gradient_transformers:
438      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
439                       "is set. Instead, use the `gradient_transformers` to "
440                       "specify clipping and other transformations.")
441    self._global_clipnorm = val
442    self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn(
443        self._global_clipnorm)
444
445  @property
446  def clipvalue(self):
447    """`float` or `None`. If set, clips gradients to a maximum value."""
448    return self._clipvalue
449
450  @clipvalue.setter
451  def clipvalue(self, val):
452    if val is not None and self.gradient_transformers:
453      raise ValueError("`clipvalue` cannot be set when `gradient_transformers` "
454                       "is set. Instead, use the `gradient_transformers` to "
455                       "specify clipping and other transformations.")
456    self._clipvalue = val
457    self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn(
458        self._clipvalue)
459
460  def _transform_loss(self, loss):
461    """Called in `.minimize` to transform loss before computing gradients."""
462    return loss
463
464  def _get_gradients(self, tape, loss, var_list, grad_loss=None):
465    """Called in `minimize` to compute gradients from loss."""
466    grads = tape.gradient(loss, var_list, grad_loss)
467    return list(zip(grads, var_list))
468
469  def _transform_unaggregated_gradients(self, grads_and_vars):
470    """Called in `apply_gradients` before gradient aggregation."""
471    return grads_and_vars
472
473  def _aggregate_gradients(self, grads_and_vars):
474    """Called in `apply_gradients` to aggregate gradients across devices.
475
476    Note that user subclasses may override this, so the interface should not be
477    changed.
478
479    Args:
480      grads_and_vars: List of (gradient, variable) pairs.
481
482    Returns:
483      A list of (aggregrated_gradient, variable) pairs. By default, this calls
484      `self.gradient_aggregator`.
485    """
486    return self.gradient_aggregator(grads_and_vars)
487
488  def _transform_gradients(self, grads_and_vars):
489    """Called in `apply_gradients` after aggregation."""
490    if self._clipvalue is not None:
491      grads_and_vars = self._clipvalue_fn(grads_and_vars)
492    if self._clipnorm is not None:
493      grads_and_vars = self._clipnorm_fn(grads_and_vars)
494    if self._global_clipnorm is not None:
495      grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
496
497    for fn in self.gradient_transformers:
498      grads_and_vars = fn(grads_and_vars)
499    return grads_and_vars
500
501  def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
502    """Minimize `loss` by updating `var_list`.
503
504    This method simply computes gradient using `tf.GradientTape` and calls
505    `apply_gradients()`. If you want to process the gradient before applying
506    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
507    of using this function.
508
509    Args:
510      loss: `Tensor` or callable. If a callable, `loss` should take no arguments
511        and return the value to minimize. If a `Tensor`, the `tape` argument
512        must be passed.
513      var_list: list or tuple of `Variable` objects to update to minimize
514        `loss`, or a callable returning the list or tuple of `Variable` objects.
515        Use callable when the variable list would otherwise be incomplete before
516        `minimize` since the variables are created at the first time `loss` is
517        called.
518      grad_loss: (Optional). A `Tensor` holding the gradient computed for
519        `loss`.
520      name: (Optional) str. Name for the returned operation.
521      tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
522        the tape that computed the `loss` must be provided.
523
524    Returns:
525      An `Operation` that updates the variables in `var_list`. The `iterations`
526      will be automatically increased by 1.
527
528    Raises:
529      ValueError: If some of the variables are not `Variable` objects.
530
531    """
532    grads_and_vars = self._compute_gradients(
533        loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
534    return self.apply_gradients(grads_and_vars, name=name)
535
536  def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
537    """Compute gradients of `loss` for the variables in `var_list`.
538
539    This is the first part of `minimize()`.  It returns a list
540    of (gradient, variable) pairs where "gradient" is the gradient
541    for "variable".  Note that "gradient" can be a `Tensor`, an
542    `IndexedSlices`, or `None` if there is no gradient for the
543    given variable.
544
545    Args:
546      loss: `Tensor` or callable. If a callable, `loss` should take no
547        arguments and return the value to minimize. If a `Tensor`, the `tape`
548        argument must be passed.
549      var_list: list or tuple of `Variable` objects to update to minimize
550        `loss`, or a callable returning the list or tuple of `Variable` objects.
551        Use callable when the variable list would otherwise be incomplete before
552        `minimize` and the variables are created at the first time when `loss`
553        is called.
554      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
555      tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
556        the tape that computed the `loss` must be provided.
557
558    Returns:
559      A list of (gradient, variable) pairs. Variable is always present, but
560      gradient can be `None`.
561
562    Raises:
563      TypeError: If `var_list` contains anything else than `Variable` objects.
564      ValueError: If some arguments are invalid, or var_list is None.
565    """
566    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
567    if not callable(loss) and tape is None:
568      raise ValueError("`tape` is required when a `Tensor` loss is passed.")
569    tape = tape if tape is not None else backprop.GradientTape()
570
571    if callable(loss):
572      with tape:
573        if not callable(var_list):
574          tape.watch(var_list)
575        loss = loss()
576        if callable(var_list):
577          var_list = var_list()
578
579    with tape:
580      loss = self._transform_loss(loss)
581
582    var_list = nest.flatten(var_list)
583    with ops.name_scope_v2(self._name + "/gradients"):
584      grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
585
586    self._assert_valid_dtypes([
587        v for g, v in grads_and_vars
588        if g is not None and v.dtype != dtypes.resource
589    ])
590
591    return grads_and_vars
592
593  def apply_gradients(self,
594                      grads_and_vars,
595                      name=None,
596                      experimental_aggregate_gradients=True):
597    """Apply gradients to variables.
598
599    This is the second part of `minimize()`. It returns an `Operation` that
600    applies gradients.
601
602    The method sums gradients from all replicas in the presence of
603    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
604    passing `experimental_aggregate_gradients=False`.
605
606    Example:
607
608    ```python
609    grads = tape.gradient(loss, vars)
610    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
611    # Processing aggregated gradients.
612    optimizer.apply_gradients(zip(grads, vars),
613        experimental_aggregate_gradients=False)
614
615    ```
616
617    Args:
618      grads_and_vars: List of (gradient, variable) pairs.
619      name: Optional name for the returned operation. Default to the name passed
620        to the `Optimizer` constructor.
621      experimental_aggregate_gradients: Whether to sum gradients from different
622        replicas in the presense of `tf.distribute.Strategy`. If False, it's
623        user responsibility to aggregate the gradients. Default to True.
624
625    Returns:
626      An `Operation` that applies the specified gradients. The `iterations`
627      will be automatically increased by 1.
628
629    Raises:
630      TypeError: If `grads_and_vars` is malformed.
631      ValueError: If none of the variables have gradients.
632      RuntimeError: If called in a cross-replica context.
633    """
634    grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
635    var_list = [v for (_, v) in grads_and_vars]
636
637    with ops.name_scope_v2(self._name):
638      # Create iteration if necessary.
639      with ops.init_scope():
640        self._create_all_weights(var_list)
641
642      if not grads_and_vars:
643        # Distribution strategy does not support reducing an empty list of
644        # gradients
645        return control_flow_ops.no_op()
646
647      if distribute_ctx.in_cross_replica_context():
648        raise RuntimeError(
649            "`apply_gradients() cannot be called in cross-replica context. "
650            "Use `tf.distribute.Strategy.run` to enter replica "
651            "context.")
652
653      strategy = distribute_ctx.get_strategy()
654      if (not experimental_aggregate_gradients and strategy and
655          isinstance(strategy,
656                     (parameter_server_strategy.ParameterServerStrategyV1,
657                      parameter_server_strategy_v2.ParameterServerStrategyV2,
658                      central_storage_strategy.CentralStorageStrategy,
659                      central_storage_strategy.CentralStorageStrategyV1))):
660        raise NotImplementedError(
661            "`experimental_aggregate_gradients=False is not supported for "
662            "ParameterServerStrategy and CentralStorageStrategy")
663
664      apply_state = self._prepare(var_list)
665      if experimental_aggregate_gradients:
666        grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars)
667        grads_and_vars = self._aggregate_gradients(grads_and_vars)
668      grads_and_vars = self._transform_gradients(grads_and_vars)
669
670      if optimizer_utils.strategy_supports_no_merge_call():
671        return self._distributed_apply(strategy, grads_and_vars, name,
672                                       apply_state)
673      else:
674        return distribute_ctx.get_replica_context().merge_call(
675            functools.partial(self._distributed_apply, apply_state=apply_state),
676            args=(grads_and_vars,),
677            kwargs={
678                "name": name,
679            })
680
681  def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
682    """`apply_gradients` using a `DistributionStrategy`."""
683
684    def apply_grad_to_update_var(var, grad):
685      """Apply gradient to variable."""
686      if isinstance(var, ops.Tensor):
687        raise NotImplementedError("Trying to update a Tensor ", var)
688
689      apply_kwargs = {}
690      if isinstance(grad, ops.IndexedSlices):
691        if var.constraint is not None:
692          raise RuntimeError(
693              "Cannot use a constraint function on a sparse variable.")
694        if "apply_state" in self._sparse_apply_args:
695          apply_kwargs["apply_state"] = apply_state
696        return self._resource_apply_sparse_duplicate_indices(
697            grad.values, var, grad.indices, **apply_kwargs)
698
699      if "apply_state" in self._dense_apply_args:
700        apply_kwargs["apply_state"] = apply_state
701      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
702      if var.constraint is not None:
703        with ops.control_dependencies([update_op]):
704          return var.assign(var.constraint(var))
705      else:
706        return update_op
707
708    eagerly_outside_functions = ops.executing_eagerly_outside_functions()
709    update_ops = []
710    with name_scope_only_in_function_or_graph(name or self._name):
711      for grad, var in grads_and_vars:
712        # Colocate the update with variables to avoid unnecessary communication
713        # delays. See b/136304694.
714        with distribution.extended.colocate_vars_with(var):
715          with name_scope_only_in_function_or_graph(
716              "update" if eagerly_outside_functions else "update_" +
717              var.op.name):
718            update_op = distribution.extended.update(
719                var, apply_grad_to_update_var, args=(grad,), group=False)
720            if distribute_ctx.in_cross_replica_context():
721              # In cross-replica context, extended.update returns a list of
722              # update ops from all replicas (group=False).
723              update_ops.extend(update_op)
724            else:
725              # In replica context, extended.update return the single update op
726              # of current replica.
727              update_ops.append(update_op)
728
729      any_symbolic = any(isinstance(i, ops.Operation) or
730                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
731      if not context.executing_eagerly() or any_symbolic:
732        # If the current context is graph mode or any of the update ops are
733        # symbolic then the step update should be carried out under a graph
734        # context. (eager updates execute immediately)
735        with backend._current_graph(update_ops).as_default():  # pylint: disable=protected-access
736          with ops.control_dependencies([control_flow_ops.group(update_ops)]):
737            return self._iterations.assign_add(1, read_value=False)
738
739      return self._iterations.assign_add(1)
740
741  def get_gradients(self, loss, params):
742    """Returns gradients of `loss` with respect to `params`.
743
744    Should be used only in legacy v1 graph mode.
745
746    Args:
747      loss: Loss tensor.
748      params: List of variables.
749
750    Returns:
751      List of gradient tensors.
752
753    Raises:
754      ValueError: In case any gradient cannot be computed (e.g. if gradient
755        function not implemented).
756    """
757    params = nest.flatten(params)
758    with backend.get_graph().as_default(), backend.name_scope(self._name +
759                                                              "/gradients"):
760      grads = gradients.gradients(loss, params)
761      for grad, param in zip(grads, params):
762        if grad is None:
763          raise ValueError("Variable {} has `None` for gradient. "
764                           "Please make sure that all of your ops have a "
765                           "gradient defined (i.e. are differentiable). "
766                           "Common ops without gradient: "
767                           "K.argmax, K.round, K.eval.".format(param))
768    return grads
769
770  def get_updates(self, loss, params):
771    grads = self.get_gradients(loss, params)
772    grads_and_vars = list(zip(grads, params))
773    self._assert_valid_dtypes([
774        v for g, v in grads_and_vars
775        if g is not None and v.dtype != dtypes.resource
776    ])
777    return [self.apply_gradients(grads_and_vars)]
778
779  def _set_hyper(self, name, value):
780    """set hyper `name` to value. value can be callable, tensor, numeric."""
781    if isinstance(value, trackable.Trackable):
782      self._track_trackable(value, name, overwrite=True)
783    if name not in self._hyper:
784      self._hyper[name] = value
785    else:
786      prev_value = self._hyper[name]
787      if (callable(prev_value)
788          or isinstance(prev_value,
789                        (ops.Tensor, int, float,
790                         learning_rate_schedule.LearningRateSchedule))
791          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
792        self._hyper[name] = value
793      else:
794        backend.set_value(self._hyper[name], value)
795
796  def _get_hyper(self, name, dtype=None):
797    if not self._hypers_created:
798      self._create_hypers()
799    value = self._hyper[name]
800    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
801      return value
802    if callable(value):
803      value = value()
804    if dtype:
805      return math_ops.cast(value, dtype)
806    else:
807      return value
808
809  def _create_slots(self, var_list):
810    pass
811
812  def _create_all_weights(self, var_list):
813    """Creates all weights, including iterations, hyperparameters and slot vars.
814
815    This will add newly created variables to `optimizer.weights`.
816
817    New variables are only created when this method is called the first time, or
818    when called with different variables in the var_list.
819
820    Args:
821      var_list: list or tuple of `Variable` objects that will be minimized
822        using this optimizer.
823    """
824
825    _ = self.iterations
826    self._create_hypers()
827    self._create_slots(var_list)
828
829  def __getattribute__(self, name):
830    """Overridden to support hyperparameter access."""
831    try:
832      return super(OptimizerV2, self).__getattribute__(name)
833    except AttributeError as e:
834      # Needed to avoid infinite recursion with __setattr__.
835      if name == "_hyper":
836        raise e
837      # Backwards compatibility with Keras optimizers.
838      if name == "lr":
839        name = "learning_rate"
840      if name in self._hyper:
841        return self._get_hyper(name)
842      raise e
843
844  def __dir__(self):
845    result = set(super(OptimizerV2, self).__dir__())
846    if "_hyper" in result:
847      result |= self._hyper.keys()
848      if "learning_rate" in self._hyper.keys():
849        result.add("lr")
850    return list(result)
851
852  def __setattr__(self, name, value):
853    """Override setattr to support dynamic hyperparameter setting."""
854    # Backwards compatibility with Keras optimizers.
855    if name == "lr":
856      name = "learning_rate"
857    if hasattr(self, "_hyper") and name in self._hyper:
858      self._set_hyper(name, value)
859    else:
860      super(OptimizerV2, self).__setattr__(name, value)
861
862  def get_slot_names(self):
863    """A list of names for this optimizer's slots."""
864    return self._slot_names
865
866  def add_slot(self, var, slot_name, initializer="zeros", shape=None):
867    """Add a new slot variable for `var`.
868
869    A slot variable is an additional variable associated with `var` to train.
870    It is allocated and managed by optimizers, e.g. `Adam`.
871
872    Args:
873      var: a `Variable` object.
874      slot_name: name of the slot variable.
875      initializer: initializer of the slot variable
876      shape: (Optional) shape of the slot variable. If not set, it will default
877      to the shape of `var`.
878
879    Returns:
880      A slot variable.
881    """
882    if slot_name not in self._slot_names:
883      self._slot_names.append(slot_name)
884    var_key = _var_key(var)
885    slot_dict = self._slots.setdefault(var_key, {})
886    weight = slot_dict.get(slot_name, None)
887    if weight is None:
888      if isinstance(initializer, str) or callable(initializer):
889        initializer = initializers.get(initializer)
890        if isinstance(
891            initializer,
892            trackable.CheckpointInitialValueCallable) or (shape is not None):
893          slot_shape = shape
894        else:
895          slot_shape = var.shape
896        initial_value = functools.partial(
897            initializer, shape=slot_shape, dtype=var.dtype)
898      else:
899        initial_value = initializer
900
901      with self._distribution_strategy_scope():
902        strategy = distribute_ctx.get_strategy()
903        if not strategy.extended.variable_created_in_scope(var):
904          raise ValueError(
905              "Trying to create optimizer slot variable under the scope for "
906              "tf.distribute.Strategy ({}), which is different from the scope "
907              "used for the original variable ({}). Make sure the slot "
908              "variables are created under the same strategy scope. This may "
909              "happen if you're restoring from a checkpoint outside the scope"
910              .format(strategy, var))
911
912        with strategy.extended.colocate_vars_with(var):
913          weight = tf_variables.Variable(
914              name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
915              dtype=var.dtype,
916              trainable=False,
917              initial_value=initial_value)
918      backend.track_variable(weight)
919      slot_dict[slot_name] = weight
920      self._restore_slot_variable(
921          slot_name=slot_name, variable=var,
922          slot_variable=weight)
923      self._weights.append(weight)
924    return weight
925
926  def get_slot(self, var, slot_name):
927    var_key = _var_key(var)
928    slot_dict = self._slots[var_key]
929    return slot_dict[slot_name]
930
931  def _prepare(self, var_list):
932    keys = set()
933    for var in var_list:
934      if isinstance(var, ds_values.DistributedValues):
935        var_devices = var._devices   # pylint: disable=protected-access
936      else:
937        var_devices = [var.device]
938      var_dtype = var.dtype.base_dtype
939      for var_device in var_devices:
940        keys.add((var_device, var_dtype))
941
942    apply_state = {}
943    for var_device, var_dtype in keys:
944      apply_state[(var_device, var_dtype)] = {}
945      with ops.device(var_device):
946        self._prepare_local(var_device, var_dtype, apply_state)
947
948    return apply_state
949
950  def _prepare_local(self, var_device, var_dtype, apply_state):
951    if "learning_rate" in self._hyper:
952      lr_t = array_ops.identity(self._decayed_lr(var_dtype))
953      apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
954
955  def _fallback_apply_state(self, var_device, var_dtype):
956    """Compatibility for subclasses that don't pass apply_state through."""
957    apply_state = {(var_device, var_dtype): {}}
958    self._prepare_local(var_device, var_dtype, apply_state)
959    return apply_state[(var_device, var_dtype)]
960
961  def _create_hypers(self):
962    if self._hypers_created:
963      return
964    with self._distribution_strategy_scope():
965      # Iterate hyper values deterministically.
966      for name, value in sorted(self._hyper.items()):
967        if isinstance(value,
968                      (ops.Tensor, tf_variables.Variable)) or callable(value):
969          # The check for `callable` covers the usage when `value` is a
970          # `LearningRateSchedule`, in which case it does not need to create a
971          # variable.
972          continue
973        else:
974          self._hyper[name] = self.add_weight(
975              name,
976              shape=[],
977              trainable=False,
978              initializer=value,
979              aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
980    self._hypers_created = True
981
982  @property
983  def iterations(self):
984    """Variable. The number of training steps this Optimizer has run."""
985    if self._iterations is None:
986      with self._distribution_strategy_scope():
987        self._iterations = self.add_weight(
988            "iter",
989            shape=[],
990            dtype=dtypes.int64,
991            trainable=False,
992            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
993      self._weights.append(self._iterations)
994    return self._iterations
995
996  @iterations.setter
997  def iterations(self, variable):
998    if self._iterations is not None:
999      raise RuntimeError("Cannot set `iterations` to a new Variable after "
1000                         "the Optimizer weights have been created")
1001    self._iterations = variable
1002    self._weights.append(self._iterations)
1003
1004  def _decayed_lr(self, var_dtype):
1005    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
1006    lr_t = self._get_hyper("learning_rate", var_dtype)
1007    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
1008      local_step = math_ops.cast(self.iterations, var_dtype)
1009      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
1010    if self._initial_decay > 0.:
1011      local_step = math_ops.cast(self.iterations, var_dtype)
1012      decay_t = math_ops.cast(self._initial_decay, var_dtype)
1013      lr_t = lr_t / (1. + decay_t * local_step)
1014    return lr_t
1015
1016  @abc.abstractmethod
1017  def get_config(self):
1018    """Returns the config of the optimizer.
1019
1020    An optimizer config is a Python dictionary (serializable)
1021    containing the configuration of an optimizer.
1022    The same optimizer can be reinstantiated later
1023    (without any saved state) from this configuration.
1024
1025    Returns:
1026        Python dictionary.
1027    """
1028    config = {"name": self._name}
1029    if self.clipnorm is not None:
1030      config["clipnorm"] = self.clipnorm
1031    if self.clipvalue is not None:
1032      config["clipvalue"] = self.clipvalue
1033    if self.global_clipnorm is not None:
1034      config["global_clipnorm"] = self.global_clipnorm
1035    return config
1036
1037  @classmethod
1038  def from_config(cls, config, custom_objects=None):
1039    """Creates an optimizer from its config.
1040
1041    This method is the reverse of `get_config`,
1042    capable of instantiating the same optimizer from the config
1043    dictionary.
1044
1045    Args:
1046        config: A Python dictionary, typically the output of get_config.
1047        custom_objects: A Python dictionary mapping names to additional Python
1048          objects used to create this optimizer, such as a function used for a
1049          hyperparameter.
1050
1051    Returns:
1052        An optimizer instance.
1053    """
1054    if "lr" in config:
1055      config["learning_rate"] = config.pop("lr")
1056    if "learning_rate" in config:
1057      if isinstance(config["learning_rate"], dict):
1058        config["learning_rate"] = learning_rate_schedule.deserialize(
1059            config["learning_rate"], custom_objects=custom_objects)
1060    return cls(**config)
1061
1062  def _serialize_hyperparameter(self, hyperparameter_name):
1063    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
1064    value = self._hyper[hyperparameter_name]
1065    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
1066      return learning_rate_schedule.serialize(value)
1067    if callable(value):
1068      return value()
1069    if tensor_util.is_tf_type(value):
1070      return backend.get_value(value)
1071    return value
1072
1073  def variables(self):
1074    """Returns variables of this Optimizer based on the order created."""
1075    return self._weights
1076
1077  @property
1078  def weights(self):
1079    """Returns variables of this Optimizer based on the order created."""
1080    return self._weights
1081
1082  def get_weights(self):
1083    """Returns the current weights of the optimizer.
1084
1085    The weights of an optimizer are its state (ie, variables).
1086    This function returns the weight values associated with this
1087    optimizer as a list of Numpy arrays. The first value is always the
1088    iterations count of the optimizer, followed by the optimizer's state
1089    variables in the order they were created. The returned list can in turn
1090    be used to load state into similarly parameterized optimizers.
1091
1092    For example, the RMSprop optimizer for this simple model returns a list of
1093    three values-- the iteration count, followed by the root-mean-square value
1094    of the kernel and bias of the single Dense layer:
1095
1096    >>> opt = tf.keras.optimizers.RMSprop()
1097    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1098    >>> m.compile(opt, loss='mse')
1099    >>> data = np.arange(100).reshape(5, 20)
1100    >>> labels = np.zeros(5)
1101    >>> print('Training'); results = m.fit(data, labels)
1102    Training ...
1103    >>> len(opt.get_weights())
1104    3
1105
1106    Returns:
1107        Weights values as a list of numpy arrays.
1108    """
1109    params = self.weights
1110    return backend.batch_get_value(params)
1111
1112  # TODO(tanzheny): Maybe share this logic with base_layer.
1113  def set_weights(self, weights):
1114    """Set the weights of the optimizer.
1115
1116    The weights of an optimizer are its state (ie, variables).
1117    This function takes the weight values associated with this
1118    optimizer as a list of Numpy arrays. The first value is always the
1119    iterations count of the optimizer, followed by the optimizer's state
1120    variables in the order they are created. The passed values are used to set
1121    the new state of the optimizer.
1122
1123    For example, the RMSprop optimizer for this simple model takes a list of
1124    three values-- the iteration count, followed by the root-mean-square value
1125    of the kernel and bias of the single Dense layer:
1126
1127    >>> opt = tf.keras.optimizers.RMSprop()
1128    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1129    >>> m.compile(opt, loss='mse')
1130    >>> data = np.arange(100).reshape(5, 20)
1131    >>> labels = np.zeros(5)
1132    >>> print('Training'); results = m.fit(data, labels)
1133    Training ...
1134    >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
1135    >>> opt.set_weights(new_weights)
1136    >>> opt.iterations
1137    <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
1138
1139    Args:
1140        weights: weight values as a list of numpy arrays.
1141    """
1142    params = self.weights
1143    if len(params) != len(weights):
1144      raise ValueError(
1145          "You called `set_weights(weights)` on optimizer " + self._name +
1146          " with a  weight list of length " + str(len(weights)) +
1147          ", but the optimizer was expecting " + str(len(params)) +
1148          " weights. Provided weights: " + str(weights)[:50] + "...")
1149    if not params:
1150      return
1151    weight_value_tuples = []
1152    param_values = backend.batch_get_value(params)
1153    for pv, p, w in zip(param_values, params, weights):
1154      if pv.shape != w.shape:
1155        raise ValueError("Optimizer weight shape " + str(pv.shape) +
1156                         " not compatible with "
1157                         "provided weight shape " + str(w.shape))
1158      weight_value_tuples.append((p, w))
1159    backend.batch_set_value(weight_value_tuples)
1160
1161  def add_weight(self,
1162                 name,
1163                 shape,
1164                 dtype=None,
1165                 initializer="zeros",
1166                 trainable=None,
1167                 synchronization=tf_variables.VariableSynchronization.AUTO,
1168                 aggregation=tf_variables.VariableAggregation.NONE):
1169
1170    if dtype is None:
1171      dtype = dtypes.float32
1172    if isinstance(initializer, str) or callable(initializer):
1173      initializer = initializers.get(initializer)
1174
1175    if synchronization == tf_variables.VariableSynchronization.ON_READ:
1176      if trainable:
1177        raise ValueError(
1178            "Synchronization value can be set to "
1179            "VariableSynchronization.ON_READ only for non-trainable variables. "
1180            "You have specified trainable=True and "
1181            "synchronization=VariableSynchronization.ON_READ.")
1182      else:
1183        # Set trainable to be false when variable is to be synced on read.
1184        trainable = False
1185    elif trainable is None:
1186      trainable = True
1187
1188    variable = self._add_variable_with_custom_getter(
1189        name=name,
1190        shape=shape,
1191        getter=base_layer_utils.make_variable,
1192        overwrite=True,
1193        initializer=initializer,
1194        dtype=dtype,
1195        trainable=trainable,
1196        use_resource=True,
1197        synchronization=synchronization,
1198        aggregation=aggregation)
1199    backend.track_variable(variable)
1200
1201    return variable
1202
1203  def _init_set_name(self, name, zero_based=True):
1204    if not name:
1205      self._name = backend.unique_object_name(
1206          generic_utils.to_snake_case(self.__class__.__name__),
1207          zero_based=zero_based)
1208    else:
1209      self._name = name
1210
1211  def _assert_valid_dtypes(self, tensors):
1212    """Asserts tensors are all valid types (see `_valid_dtypes`).
1213
1214    Args:
1215      tensors: Tensors to check.
1216
1217    Raises:
1218      ValueError: If any tensor is not a valid type.
1219    """
1220    valid_dtypes = self._valid_dtypes()
1221    for t in tensors:
1222      dtype = t.dtype.base_dtype
1223      if dtype not in valid_dtypes:
1224        raise ValueError("Invalid type %r for %s, expected: %s." %
1225                         (dtype, t.name, [v for v in valid_dtypes]))
1226
1227  def _valid_dtypes(self):
1228    """Valid types for loss, variables and gradients.
1229
1230    Subclasses should override to allow other float types.
1231
1232    Returns:
1233      Valid types for loss, variables and gradients.
1234    """
1235    return _DEFAULT_VALID_DTYPES
1236
1237  def _call_if_callable(self, param):
1238    """Call the function if param is callable."""
1239    return param() if callable(param) else param
1240
1241  def _resource_apply_dense(self, grad, handle, apply_state):
1242    """Add ops to apply dense gradients to the variable `handle`.
1243
1244    Args:
1245      grad: a `Tensor` representing the gradient.
1246      handle: a `Tensor` of dtype `resource` which points to the variable to be
1247        updated.
1248      apply_state: A dict which is used across multiple apply calls.
1249
1250    Returns:
1251      An `Operation` which updates the value of the variable.
1252    """
1253    raise NotImplementedError("Must be implemented in subclasses.")
1254
1255  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
1256                                               **kwargs):
1257    """Add ops to apply sparse gradients to `handle`, with repeated indices.
1258
1259    Optimizers which override this method must deal with repeated indices. See
1260    the docstring of `_apply_sparse_duplicate_indices` for details. By default
1261    the correct behavior, to sum non-unique indices and their associated
1262    gradients, is enforced by first pre-processing `grad` and `indices` and
1263    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1264    with duplicate indices may instead override this method to avoid the
1265    overhead of summing.
1266
1267    Args:
1268      grad: a `Tensor` representing the gradient for the affected indices.
1269      handle: a `Tensor` of dtype `resource` which points to the variable to be
1270        updated.
1271      indices: a `Tensor` of integral type representing the indices for which
1272        the gradient is nonzero. Indices may be repeated.
1273      **kwargs: May optionally contain `apply_state`
1274
1275    Returns:
1276      An `Operation` which updates the value of the variable.
1277    """
1278    summed_grad, unique_indices = _deduplicate_indexed_slices(
1279        values=grad, indices=indices)
1280    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
1281                                       **kwargs)
1282
1283  def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1284    """Add ops to apply sparse gradients to the variable `handle`.
1285
1286    Similar to `_apply_sparse`, the `indices` argument to this method has been
1287    de-duplicated. Optimizers which deal correctly with non-unique indices may
1288    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1289    overhead.
1290
1291    Args:
1292      grad: a `Tensor` representing the gradient for the affected indices.
1293      handle: a `Tensor` of dtype `resource` which points to the variable to be
1294        updated.
1295      indices: a `Tensor` of integral type representing the indices for which
1296        the gradient is nonzero. Indices are unique.
1297      apply_state: A dict which is used across multiple apply calls.
1298
1299    Returns:
1300      An `Operation` which updates the value of the variable.
1301    """
1302    raise NotImplementedError("Must be implemented in subclasses.")
1303
1304  def _resource_scatter_add(self, x, i, v):
1305    with ops.control_dependencies([
1306        gen_resource_variable_ops.ResourceScatterAdd(
1307            resource=x.handle, indices=i, updates=v)
1308    ]):
1309      return x.value()
1310
1311  def _resource_scatter_update(self, x, i, v):
1312    with ops.control_dependencies(
1313        [gen_resource_variable_ops.ResourceScatterUpdate(
1314            resource=x.handle, indices=i, updates=v)]):
1315      return x.value()
1316
1317  @property
1318  @layer_utils.cached_per_instance
1319  def _dense_apply_args(self):
1320    return tf_inspect.getfullargspec(self._resource_apply_dense).args
1321
1322  @property
1323  @layer_utils.cached_per_instance
1324  def _sparse_apply_args(self):
1325    return tf_inspect.getfullargspec(self._resource_apply_sparse).args
1326
1327  # ---------------
1328  # For implementing the trackable interface
1329  # ---------------
1330
1331  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1332    """Restore a newly created slot variable's value."""
1333    variable_key = _var_key(variable)
1334    deferred_restorations = self._deferred_slot_restorations.get(
1335        slot_name, {}).pop(variable_key, [])
1336    # Iterate over restores, highest restore UID first to minimize the number
1337    # of assignments.
1338    deferred_restorations.sort(key=lambda position: position.restore_uid,
1339                               reverse=True)
1340    for checkpoint_position in deferred_restorations:
1341      checkpoint_position.restore(slot_variable)
1342
1343  def _create_or_restore_slot_variable(
1344      self, slot_variable_position, slot_name, variable):
1345    """Restore a slot variable's value, possibly creating it.
1346
1347    Called when a variable which has an associated slot variable is created or
1348    restored. When executing eagerly, we create the slot variable with a
1349    restoring initializer.
1350
1351    No new variables are created when graph building. Instead,
1352    _restore_slot_variable catches these after normal creation and adds restore
1353    ops to the graph. This method is nonetheless important when graph building
1354    for the case when a slot variable has already been created but `variable`
1355    has just been added to a dependency graph (causing us to realize that the
1356    slot variable needs to be restored).
1357
1358    Args:
1359      slot_variable_position: A `trackable._CheckpointPosition` object
1360        indicating the slot variable `Trackable` object to be restored.
1361      slot_name: The name of this `Optimizer`'s slot to restore into.
1362      variable: The variable object this slot is being created for.
1363    """
1364    variable_key = _var_key(variable)
1365    slot_dict = self._slots.get(variable_key, {})
1366    slot_variable = slot_dict.get(slot_name, None)
1367    if (slot_variable is None and context.executing_eagerly() and
1368        slot_variable_position.is_simple_variable()
1369        # Defer slot variable creation if there is an active variable creator
1370        # scope. Generally we'd like to eagerly create/restore slot variables
1371        # when possible, but this may mean that scopes intended to catch
1372        # `variable` also catch its eagerly created slot variable
1373        # unintentionally (specifically make_template would add a dependency on
1374        # a slot variable if not for this case). Deferring is mostly harmless
1375        # (aside from double initialization), and makes variable creator scopes
1376        # behave the same way they do when graph building.
1377        #
1378        # One notable case is with distribution strategy, which uses variable
1379        # creator scope but always desires the `variable` and the slot to use
1380        # the same scope, thus we can safely eagerly create/restore slot
1381        # variables.
1382        and (not ops.get_default_graph()._variable_creator_stack or  # pylint: disable=protected-access
1383             self._distribution_strategy)):
1384      initializer = trackable.CheckpointInitialValueCallable(
1385          checkpoint_position=slot_variable_position)
1386      slot_variable = self.add_slot(
1387          var=variable,
1388          initializer=initializer,
1389          slot_name=slot_name,
1390          shape=slot_variable_position.value_shape())
1391      # Slot variables are not owned by any one object (because we don't want to
1392      # save the slot variable if the optimizer is saved without the non-slot
1393      # variable, or if the non-slot variable is saved without the optimizer;
1394      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1395      # variable, variable)). So we don't _track_ slot variables anywhere, and
1396      # instead special-case this dependency and otherwise pretend it's a normal
1397      # graph.
1398    if slot_variable is not None:
1399      # If we've either made this slot variable, or if we've pulled out an
1400      # existing slot variable, we should restore it.
1401      slot_variable_position.restore(slot_variable)
1402    else:
1403      # We didn't make the slot variable. Defer restoring until it gets created
1404      # normally. We keep a list rather than the one with the highest restore
1405      # UID in case slot variables have their own dependencies, in which case
1406      # those could differ between restores.
1407      self._deferred_slot_restorations.setdefault(
1408          slot_name, {}).setdefault(variable_key, []).append(
1409              slot_variable_position)
1410
1411  @contextlib.contextmanager
1412  def _distribution_strategy_scope(self):
1413    """Returns the `tf.distribute.Strategy` this optimizer was created under."""
1414    if self._distribution_strategy and not distribute_ctx.has_strategy():
1415      with self._distribution_strategy.scope():
1416        yield self._distribution_strategy.scope()
1417    else:
1418      yield
1419
1420
1421def _var_key(var):
1422  """Key for representing a primary variable, for looking up slots.
1423
1424  In graph mode the name is derived from the var shared name.
1425  In eager mode the name is derived from the var unique id.
1426  If distribution strategy exists, get the primary variable first.
1427
1428  Args:
1429    var: the variable.
1430
1431  Returns:
1432    the unique name of the variable.
1433  """
1434
1435  # pylint: disable=protected-access
1436  # Get the distributed variable if it exists.
1437  if hasattr(var, "_distributed_container"):
1438    var = var._distributed_container()
1439  if var._in_graph_mode:
1440    return var._shared_name
1441  return var._unique_id
1442
1443
1444def _get_slot_key_from_var(var, slot_name):
1445  """Get the slot key for the variable: var_name/slot_name."""
1446
1447  name = _var_key(var)
1448  return name + "/" + slot_name
1449
1450
1451class RestoredOptimizer(OptimizerV2):
1452  """A non-functional Optimizer implementation for checkpoint compatibility.
1453
1454  Holds slot variables and hyperparameters when an optimizer is restored from a
1455  SavedModel. These variables may be referenced in functions along with ops
1456  created by the original optimizer, but currently we do not support using the
1457  optimizer object iself (e.g. through `apply_gradients`).
1458  """
1459  # TODO(allenl): Make the restored optimizer functional by tracing its apply
1460  # methods.
1461
1462  def __init__(self):
1463    super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1464    self._hypers_created = True
1465
1466  def get_config(self):
1467    # TODO(allenl): Save and restore the Optimizer's config
1468    raise NotImplementedError(
1469        "Restoring functional Optimizers from SavedModels is not currently "
1470        "supported. Please file a feature request if this limitation bothers "
1471        "you.")
1472
1473revived_types.register_revived_type(
1474    "tf_deprecated_optimizer",
1475    lambda obj: isinstance(obj, OptimizerV2),
1476    versions=[revived_types.VersionedTypeRegistration(
1477        object_factory=lambda proto: RestoredOptimizer(),
1478        version=1,
1479        min_producer_version=1,
1480        min_consumer_version=1,
1481        setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1482    )])
1483