• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Version 2 of class Optimizer."""
17# pylint: disable=g-bad-name
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24import functools
25
26import six
27
28from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
29from tensorflow.python.distribute import reduce_util as ds_reduce_util
30from tensorflow.python.distribute import values as distributed_values
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.keras import backend
36from tensorflow.python.keras import initializers
37from tensorflow.python.keras.engine import base_layer_utils
38from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
39from tensorflow.python.keras.utils import tf_utils
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import clip_ops
42from tensorflow.python.ops import gradients
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import resource_variable_ops
45from tensorflow.python.ops import variables as tf_variables
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.saved_model import revived_types
48from tensorflow.python.training.tracking import base as trackable
49from tensorflow.python.util import nest
50from tensorflow.python.util.tf_export import keras_export
51
52
53def _deduplicate_indexed_slices(values, indices):
54  """Sums `values` associated with any non-unique `indices`.
55
56  Args:
57    values: A `Tensor` with rank >= 1.
58    indices: A one-dimensional integer `Tensor`, indexing into the first
59      dimension of `values` (as in an IndexedSlices object).
60
61  Returns:
62    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
63    de-duplicated version of `indices` and `summed_values` contains the sum of
64    `values` slices associated with each unique index.
65  """
66  unique_indices, new_index_positions = array_ops.unique(indices)
67  summed_values = math_ops.unsorted_segment_sum(
68      values, new_index_positions,
69      array_ops.shape(unique_indices)[0])
70  return (summed_values, unique_indices)
71
72
73@six.add_metaclass(abc.ABCMeta)
74@keras_export("keras.optimizers.Optimizer")
75class OptimizerV2(trackable.Trackable):
76  """Updated base class for optimizers.
77
78  This class defines the API to add Ops to train a model.  You never use this
79  class directly, but instead instantiate one of its subclasses such as
80  `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`.
81
82  ### Usage
83
84  ```python
85  # Create an optimizer with the desired parameters.
86  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
87  # `loss` is a callable that takes no argument and returns the value
88  # to minimize.
89  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
90  # In graph mode, returns op that minimizes the loss by updating the listed
91  # variables.
92  opt_op = opt.minimize(loss, var_list=[var1, var2])
93  opt_op.run()
94  # In eager mode, simply call minimize to update the list of variables.
95  opt.minimize(loss, var_list=[var1, var2])
96  ```
97
98  ### Processing gradients before applying them.
99
100  Calling `minimize()` takes care of both computing the gradients and
101  applying them to the variables.  If you want to process the gradients
102  before applying them you can instead use the optimizer in three steps:
103
104  1.  Compute the gradients with `tf.GradientTape`.
105  2.  Process the gradients as you wish.
106  3.  Apply the processed gradients with `apply_gradients()`.
107
108  Example:
109
110  ```python
111  # Create an optimizer.
112  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
113
114  # Compute the gradients for a list of variables.
115  with tf.GradientTape() as tape:
116    loss = <call_loss_function>
117  vars = <list_of_variables>
118  grads = tape.gradient(loss, vars)
119  processed_grads = [process_gradient(g) for g in grads]
120  grads_and_vars = zip(processed_grads, var_list)
121
122  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
123  # need to the 'gradient' part, for example cap them, etc.
124  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
125
126  # Ask the optimizer to apply the capped gradients.
127  opt.apply_gradients(capped_grads_and_vars)
128  ```
129
130  ### Use with `tf.distribute.Strategy`.
131
132  This optimizer class is `tf.distribute.Strategy` aware, which means it
133  automatically sums gradients across all replicas. To average gradients,
134  you divide your loss by the global batch size, which is done automatically
135  if you use a member of `tf.keras.losses` or `tf.losses`. See the
136  `reduction` argument of your loss which should be set to
137  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
138  `tf.keras.losses.Reduction.SUM` for not.
139
140  If you are not using these and you want to average gradients, you should use
141  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
142  global batch size. Note that when using `tf.distribute.Strategy`, the first
143  component of a tensor's shape is the *replica-local* batch size, which is off
144  by a factor equal to the number of replicas being used to compute a single
145  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
146  resulting in gradients that can be many times too big.
147
148  ### Variable Constraint
149
150  All Keras optimizers respect variable constraints. If constraint function is
151  passed to any variable, the constraint will be applied to the variable after
152  the gradient has been applied to the variable.
153  Important: If gradient is sparse tensor, variable constraint is not supported.
154
155  ### Thread Compatibility
156
157  The entire optimizer is currently thread compatible, not thread-safe. The user
158  needs to perform synchronization if necessary.
159
160  ### Slots
161
162  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
163  additional variables associated with the variables to train.  These are called
164  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
165  the slots that it uses.  Once you have a slot name you can ask the optimizer
166  for the variable it created to hold the slot value.
167
168  This can be useful if you want to log debug a training algorithm, report stats
169  about the slots, etc.
170
171  ### Hyper parameters
172
173  These are arguments passed to the optimizer subclass constructor
174  (the `__init__` method), and then passed to `self._set_hyper()`.
175  They can be either regular Python values (like 1.0), tensors, or
176  callables. If they are callable, the callable will be called during
177  `apply_gradients()` to get the value for the hyper parameter.
178
179  Hyper parameters can be overwritten through user code:
180
181  Example:
182
183  ```python
184  # Create an optimizer with the desired parameters.
185  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
186  # `loss` is a callable that takes no argument and returns the value
187  # to minimize.
188  loss = lambda: 3 * var1 + 2 * var2
189  # In eager mode, simply call minimize to update the list of variables.
190  opt.minimize(loss, var_list=[var1, var2])
191  # update learning rate
192  opt.learning_rate = 0.05
193  opt.minimize(loss, var_list=[var1, var2])
194  ```
195
196  ### Write a customized optimizer.
197  If you intend to create your own optimization algorithm, simply inherit from
198  this class and override the following methods:
199
200    - resource_apply_dense (update variable given gradient tensor is dense)
201    - resource_apply_sparse (update variable given gradient tensor is sparse)
202    - create_slots (if your optimizer algorithm requires additional variables)
203    - get_config (serialization of the optimizer, include all hyper parameters)
204  """
205
206  def __init__(self, name, **kwargs):
207    """Create a new Optimizer.
208
209    This must be called by the constructors of subclasses.
210    Note that Optimizer instances should not bind to a single graph,
211    and so shouldn't keep Tensors as member variables. Generally
212    you should be able to use the _set_hyper()/state.get_hyper()
213    facility instead.
214
215    This class in stateful and thread-compatible.
216
217    Args:
218      name: A non-empty string.  The name to use for accumulators created
219        for the optimizer.
220      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
221        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
222        gradients by value, `decay` is included for backward compatibility to
223        allow time inverse decay of learning rate. `lr` is included for backward
224        compatibility, recommended to use `learning_rate` instead.
225
226    Raises:
227      ValueError: If name is malformed.
228      RuntimeError: If _create_slots has been overridden instead of
229          _create_vars.
230    """
231    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"}
232    for k in kwargs:
233      if k not in allowed_kwargs:
234        raise TypeError("Unexpected keyword argument "
235                        "passed to optimizer: " + str(k))
236      # checks that all keyword arguments are non-negative.
237      if kwargs[k] < 0:
238        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
239
240    self._use_locking = True
241    self._name = name
242    self._hyper = {}
243    # dict: {variable name : {slot name : variable}}
244    self._slots = {}
245    self._slot_names = []
246    self._weights = []
247    self._iterations = None
248
249    # For implementing Trackable. Stores information about how to restore
250    # slot variables which have not yet been created
251    # (trackable._CheckpointPosition objects).
252    #  {slot_name :
253    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
254    #   ... }
255    self._deferred_slot_restorations = {}
256
257    decay = kwargs.pop("decay", 0.0)
258    if decay < 0.:
259      raise ValueError("decay cannot be less than 0: {}".format(decay))
260    self._initial_decay = decay
261    if "clipnorm" in kwargs:
262      self.clipnorm = kwargs.pop("clipnorm")
263    if "clipvalue" in kwargs:
264      self.clipvalue = kwargs.pop("clipvalue")
265
266    self._hypers_created = False
267
268  def minimize(self, loss, var_list, grad_loss=None, name=None):
269    """Add operations to minimize `loss` by updating `var_list`.
270
271    This method simply computes gradient using `tf.GradientTape` and calls
272    `apply_gradients()`. If you want to process the gradient before applying
273    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
274    of using this function.
275
276    Args:
277      loss: A callable taking no arguments which returns the value to minimize.
278      var_list: list or tuple of `Variable` objects to update to minimize
279        `loss`.
280      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
281      name: Optional name for the returned operation.
282
283    Returns:
284      An Operation that updates the variables in `var_list`.  If `global_step`
285      was not `None`, that operation also increments `global_step`.
286
287    Raises:
288      ValueError: If some of the variables are not `Variable` objects.
289
290    @compatibility(eager)
291    When eager execution is enabled, `loss` should be a Python function that
292    takes no arguments and computes the value to be minimized. Minimization (and
293    gradient computation) is done with respect to the elements of `var_list`.
294    `grad_loss` is ignored when eager execution is enabled.
295    @end_compatibility
296    """
297    grads_and_vars = self._compute_gradients(
298        loss, var_list=var_list, grad_loss=grad_loss)
299
300    return self.apply_gradients(grads_and_vars, name=name)
301
302  def _compute_gradients(self, loss, var_list, grad_loss=None):
303    """Compute gradients of `loss` for the variables in `var_list`.
304
305    This is the first part of `minimize()`.  It returns a list
306    of (gradient, variable) pairs where "gradient" is the gradient
307    for "variable".  Note that "gradient" can be a `Tensor`, an
308    `IndexedSlices`, or `None` if there is no gradient for the
309    given variable.
310
311    Args:
312      loss: A callable taking no arguments which returns the value to minimize.
313      var_list: List or tuple of `tf.Variable` to update to minimize
314        `loss`.  Defaults to the list of variables collected in the graph under
315        the key `GraphKeys.TRAINABLE_VARIABLES`.
316      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
317
318    Returns:
319      A list of (gradient, variable) pairs. Variable is always present, but
320      gradient can be `None`.
321
322    Raises:
323      TypeError: If `var_list` contains anything else than `Variable` objects.
324      ValueError: If some arguments are invalid, or var_list is None.
325    """
326    var_list = nest.flatten(var_list)
327    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
328    with backprop.GradientTape() as tape:
329      tape.watch(var_list)
330      loss_value = loss()
331    grads = tape.gradient(loss_value, var_list, grad_loss)
332
333    if hasattr(self, "clipnorm"):
334      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
335    if hasattr(self, "clipvalue"):
336      grads = [
337          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
338          for g in grads
339      ]
340
341    grads_and_vars = list(zip(grads, var_list))
342    self._assert_valid_dtypes([
343        v for g, v in grads_and_vars
344        if g is not None and v.dtype != dtypes.resource
345    ])
346
347    return grads_and_vars
348
349  def get_gradients(self, loss, params):
350    """Returns gradients of `loss` with respect to `params`.
351
352    Arguments:
353      loss: Loss tensor.
354      params: List of variables.
355
356    Returns:
357      List of gradient tensors.
358
359    Raises:
360      ValueError: In case any gradient cannot be computed (e.g. if gradient
361        function not implemented).
362    """
363    with backend.get_graph().as_default():
364      grads = gradients.gradients(loss, params)
365    if None in grads:
366      raise ValueError("An operation has `None` for gradient. "
367                       "Please make sure that all of your ops have a "
368                       "gradient defined (i.e. are differentiable). "
369                       "Common ops without gradient: "
370                       "K.argmax, K.round, K.eval.")
371    if hasattr(self, "clipnorm"):
372      grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
373    if hasattr(self, "clipvalue"):
374      grads = [
375          clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
376          for g in grads
377      ]
378    return grads
379
380  def apply_gradients(self, grads_and_vars, name=None):
381    """Apply gradients to variables.
382
383    This is the second part of `minimize()`. It returns an `Operation` that
384    applies gradients.
385
386    Args:
387      grads_and_vars: List of (gradient, variable) pairs.
388      name: Optional name for the returned operation.  Default to the name
389        passed to the `Optimizer` constructor.
390
391    Returns:
392      An `Operation` that applies the specified gradients. If `global_step`
393      was not None, that operation also increments `global_step`.
394
395    Raises:
396      TypeError: If `grads_and_vars` is malformed.
397      ValueError: If none of the variables have gradients.
398    """
399    grads_and_vars = _filter_grads(grads_and_vars)
400    var_list = [v for (_, v) in grads_and_vars]
401
402    # Create iteration if necessary.
403    _ = self.iterations
404    self._create_hypers()
405    with ops.init_scope():
406      self._create_slots(var_list)
407
408    self._prepare(var_list)
409
410    return distribute_ctx.get_replica_context().merge_call(
411        self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
412
413  def _distributed_apply(self, distribution, grads_and_vars, name):
414    """`apply_gradients` using a `DistributionStrategy`."""
415    reduced_grads = distribution.extended.batch_reduce_to(
416        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
417    var_list = [v for _, v in grads_and_vars]
418    grads_and_vars = zip(reduced_grads, var_list)
419
420    def apply_grad_to_update_var(var, grad):
421      """Apply gradient to variable."""
422      if isinstance(var, ops.Tensor):
423        raise NotImplementedError("Trying to update a Tensor ", var)
424      if isinstance(grad, ops.IndexedSlices):
425        if var.constraint is not None:
426          raise RuntimeError(
427              "Cannot use a constraint function on a sparse variable.")
428        return self._resource_apply_sparse_duplicate_indices(
429            grad.values, var, grad.indices)
430      update_op = self._resource_apply_dense(grad, var)
431      if var.constraint is not None:
432        with ops.control_dependencies([update_op]):
433          return var.assign(var.constraint(var))
434      else:
435        return update_op
436
437    update_ops = []
438    with ops.name_scope(name, self._name) as name:
439      for grad, var in grads_and_vars:
440        scope_name = ("" if ops.executing_eagerly_outside_functions() else
441                      "_" + var.op.name)
442        with ops.name_scope("update" + scope_name):
443          update_ops.extend(
444              distribution.extended.update(
445                  var, apply_grad_to_update_var, args=(grad,), group=False))
446
447      any_symbolic = any(isinstance(i, ops.Operation) or
448                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
449      if not context.executing_eagerly() or any_symbolic:
450        # If the current context is graph mode or any of the update ops are
451        # symbolic then the step update should be carried out under a graph
452        # context. (eager updates execute immediately)
453        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
454          with ops.control_dependencies(update_ops):
455            return self._iterations.assign_add(1).op
456
457      return self._iterations.assign_add(1)
458
459  def get_updates(self, loss, params):
460    grads = self.get_gradients(loss, params)
461    grads_and_vars = list(zip(grads, params))
462    self._assert_valid_dtypes([
463        v for g, v in grads_and_vars
464        if g is not None and v.dtype != dtypes.resource
465    ])
466    return [self.apply_gradients(grads_and_vars)]
467
468  def _set_hyper(self, name, value):
469    """set hyper `name` to value. value can be callable, tensor, numeric."""
470    if isinstance(value, trackable.Trackable):
471      self._track_trackable(value, name, overwrite=True)
472    if name not in self._hyper:
473      self._hyper[name] = value
474    else:
475      prev_value = self._hyper[name]
476      if (callable(prev_value)
477          or isinstance(prev_value,
478                        (ops.Tensor, int, float,
479                         learning_rate_schedule.LearningRateSchedule))
480          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
481        self._hyper[name] = value
482      else:
483        backend.set_value(self._hyper[name], value)
484
485  def _get_hyper(self, name, dtype=None):
486    if not self._hypers_created:
487      self._create_hypers()
488    value = self._hyper[name]
489    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
490      return value
491    if callable(value):
492      value = value()
493    if dtype:
494      return math_ops.cast(value, dtype)
495    else:
496      return value
497
498  def __getattribute__(self, name):
499    """Overridden to support hyperparameter access."""
500    try:
501      return super(OptimizerV2, self).__getattribute__(name)
502    except AttributeError as e:
503      # Needed to avoid infinite recursion with __setattr__.
504      if name == "_hyper":
505        raise e
506      # Backwards compatibility with Keras optimizers.
507      if name == "lr":
508        name = "learning_rate"
509      if name in self._hyper:
510        return self._get_hyper(name)
511      raise e
512
513  def __setattr__(self, name, value):
514    """Override setattr to support dynamic hyperparameter setting."""
515    # Backwards compatibility with Keras optimizers.
516    if name == "lr":
517      name = "learning_rate"
518    if hasattr(self, "_hyper") and name in self._hyper:
519      self._set_hyper(name, value)
520    else:
521      super(OptimizerV2, self).__setattr__(name, value)
522
523  def get_slot_names(self):
524    """A list of names for this optimizer's slots."""
525    return self._slot_names
526
527  def add_slot(self, var, slot_name, initializer="zeros"):
528    """Add a new slot variable for `var`."""
529    if slot_name not in self._slot_names:
530      self._slot_names.append(slot_name)
531    var_key = _var_key(var)
532    slot_dict = self._slots.setdefault(var_key, {})
533    weight = slot_dict.get(slot_name, None)
534    if weight is None:
535      if isinstance(initializer, six.string_types) or callable(initializer):
536        initializer = initializers.get(initializer)
537        initial_value = functools.partial(
538            initializer, shape=var.shape, dtype=var.dtype)
539      else:
540        initial_value = initializer
541      strategy = distribute_ctx.get_strategy()
542      with strategy.colocate_vars_with(var):
543        weight = tf_variables.Variable(
544            name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
545            dtype=var.dtype,
546            trainable=False,
547            initial_value=initial_value)
548      backend.track_variable(weight)
549      slot_dict[slot_name] = weight
550      self._restore_slot_variable(
551          slot_name=slot_name, variable=var,
552          slot_variable=weight)
553      self._weights.append(weight)
554    return weight
555
556  def get_slot(self, var, slot_name):
557    var_key = _var_key(var)
558    slot_dict = self._slots[var_key]
559    return slot_dict[slot_name]
560
561  def _prepare(self, var_list):
562    pass
563
564  def _create_hypers(self):
565    if self._hypers_created:
566      return
567    # Iterate hyper values deterministically.
568    for name, value in sorted(self._hyper.items()):
569      if isinstance(value, ops.Tensor) or callable(value):
570        continue
571      else:
572        self._hyper[name] = self.add_weight(
573            name,
574            shape=[],
575            trainable=False,
576            initializer=value,
577            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
578    self._hypers_created = True
579
580  @property
581  def iterations(self):
582    """Variable. The number of training steps this Optimizer has run."""
583    if self._iterations is None:
584      self._iterations = self.add_weight(
585          "iter",
586          shape=[],
587          dtype=dtypes.int64,
588          trainable=False,
589          aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
590      self._weights.append(self._iterations)
591    return self._iterations
592
593  @iterations.setter
594  def iterations(self, variable):
595    if self._iterations is not None:
596      raise RuntimeError("Cannot set `iterations` to a new Variable after"
597                         "the Optimizer weights have been created")
598    self._iterations = variable
599    self._weights.append(self._iterations)
600
601  def _decayed_lr(self, var_dtype):
602    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
603    lr_t = self._get_hyper("learning_rate", var_dtype)
604    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
605      local_step = math_ops.cast(self.iterations, var_dtype)
606      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
607    if self._initial_decay > 0.:
608      local_step = math_ops.cast(self.iterations, var_dtype)
609      decay_t = self._get_hyper("decay", var_dtype)
610      lr_t = lr_t / (1. + decay_t * local_step)
611    return lr_t
612
613  @abc.abstractmethod
614  def get_config(self):
615    """Returns the config of the optimimizer.
616
617    An optimizer config is a Python dictionary (serializable)
618    containing the configuration of an optimizer.
619    The same optimizer can be reinstantiated later
620    (without any saved state) from this configuration.
621
622    Returns:
623        Python dictionary.
624    """
625    config = {"name": self._name}
626    if hasattr(self, "clipnorm"):
627      config["clipnorm"] = self.clipnorm
628    if hasattr(self, "clipvalue"):
629      config["clipvalue"] = self.clipvalue
630    return config
631
632  @classmethod
633  def from_config(cls, config, custom_objects=None):
634    """Creates an optimizer from its config.
635
636    This method is the reverse of `get_config`,
637    capable of instantiating the same optimizer from the config
638    dictionary.
639
640    Arguments:
641        config: A Python dictionary, typically the output of get_config.
642        custom_objects: A Python dictionary mapping names to additional Python
643          objects used to create this optimizer, such as a function used for a
644          hyperparameter.
645
646    Returns:
647        An optimizer instance.
648    """
649    if "lr" in config:
650      config["learning_rate"] = config.pop("lr")
651    if "learning_rate" in config:
652      if isinstance(config["learning_rate"], dict):
653        config["learning_rate"] = learning_rate_schedule.deserialize(
654            config["learning_rate"])
655    return cls(**config)
656
657  def _serialize_hyperparameter(self, hyperparameter_name):
658    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
659    value = self._hyper[hyperparameter_name]
660    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
661      return learning_rate_schedule.serialize(value)
662    if callable(value):
663      return value()
664    if isinstance(value, (ops.Tensor, tf_variables.Variable,
665                          distributed_values.TPUMirroredVariable,
666                          distributed_values.DistributedVariable)):
667      return backend.get_value(value)
668    return value
669
670  def variables(self):
671    """Returns variables of this Optimizer based on the order created."""
672    return self._weights
673
674  @property
675  def weights(self):
676    """Returns variables of this Optimizer based on the order created."""
677    return self._weights
678
679  def get_weights(self):
680    params = self.weights
681    return backend.batch_get_value(params)
682
683  # TODO(tanzheny): Maybe share this logic with base_layer.
684  def set_weights(self, weights):
685    params = self.weights
686    if len(params) != len(weights):
687      raise ValueError(
688          "You called `set_weights(weights)` on optimizer " + self._name +
689          " with a  weight list of length " + str(len(weights)) +
690          ", but the optimizer was expecting " + str(len(params)) +
691          " weights. Provided weights: " + str(weights)[:50] + "...")
692    if not params:
693      return
694    weight_value_tuples = []
695    param_values = backend.batch_get_value(params)
696    for pv, p, w in zip(param_values, params, weights):
697      if pv.shape != w.shape:
698        raise ValueError("Optimizer weight shape " + str(pv.shape) +
699                         " not compatible with "
700                         "provided weight shape " + str(w.shape))
701      weight_value_tuples.append((p, w))
702    backend.batch_set_value(weight_value_tuples)
703
704  def add_weight(self,
705                 name,
706                 shape,
707                 dtype=None,
708                 initializer="zeros",
709                 trainable=None,
710                 synchronization=tf_variables.VariableSynchronization.AUTO,
711                 aggregation=tf_variables.VariableAggregation.NONE):
712
713    if dtype is None:
714      dtype = dtypes.float32
715    if isinstance(initializer, six.string_types) or callable(initializer):
716      initializer = initializers.get(initializer)
717
718    if synchronization == tf_variables.VariableSynchronization.ON_READ:
719      if trainable:
720        raise ValueError(
721            "Synchronization value can be set to "
722            "VariableSynchronization.ON_READ only for non-trainable variables. "
723            "You have specified trainable=True and "
724            "synchronization=VariableSynchronization.ON_READ.")
725      else:
726        # Set trainable to be false when variable is to be synced on read.
727        trainable = False
728    elif trainable is None:
729      trainable = True
730
731    variable = self._add_variable_with_custom_getter(
732        name=name,
733        shape=shape,
734        getter=base_layer_utils.make_variable,
735        overwrite=True,
736        initializer=initializer,
737        dtype=dtype,
738        trainable=trainable,
739        use_resource=True,
740        synchronization=synchronization,
741        aggregation=aggregation)
742    backend.track_variable(variable)
743
744    return variable
745
746  def _assert_valid_dtypes(self, tensors):
747    """Asserts tensors are all valid types (see `_valid_dtypes`).
748
749    Args:
750      tensors: Tensors to check.
751
752    Raises:
753      ValueError: If any tensor is not a valid type.
754    """
755    valid_dtypes = self._valid_dtypes()
756    for t in tensors:
757      dtype = t.dtype.base_dtype
758      if dtype not in valid_dtypes:
759        raise ValueError("Invalid type %r for %s, expected: %s." %
760                         (dtype, t.name, [v for v in valid_dtypes]))
761
762  def _valid_dtypes(self):
763    """Valid types for loss, variables and gradients.
764
765    Subclasses should override to allow other float types.
766
767    Returns:
768      Valid types for loss, variables and gradients.
769    """
770    return set(
771        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
772
773  def _call_if_callable(self, param):
774    """Call the function if param is callable."""
775    return param() if callable(param) else param
776
777  def _resource_apply_dense(self, grad, handle):
778    """Add ops to apply dense gradients to the variable `handle`.
779
780    Args:
781      grad: a `Tensor` representing the gradient.
782      handle: a `Tensor` of dtype `resource` which points to the variable to be
783        updated.
784
785    Returns:
786      An `Operation` which updates the value of the variable.
787    """
788    raise NotImplementedError()
789
790  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
791    """Add ops to apply sparse gradients to `handle`, with repeated indices.
792
793    Optimizers which override this method must deal with repeated indices. See
794    the docstring of `_apply_sparse_duplicate_indices` for details. By default
795    the correct behavior, to sum non-unique indices and their associated
796    gradients, is enforced by first pre-processing `grad` and `indices` and
797    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
798    with duplicate indices may instead override this method to avoid the
799    overhead of summing.
800
801    Args:
802      grad: a `Tensor` representing the gradient for the affected indices.
803      handle: a `Tensor` of dtype `resource` which points to the variable to be
804        updated.
805      indices: a `Tensor` of integral type representing the indices for which
806        the gradient is nonzero. Indices may be repeated.
807
808    Returns:
809      An `Operation` which updates the value of the variable.
810    """
811    summed_grad, unique_indices = _deduplicate_indexed_slices(
812        values=grad, indices=indices)
813    return self._resource_apply_sparse(summed_grad, handle, unique_indices)
814
815  def _resource_apply_sparse(self, grad, handle, indices):
816    """Add ops to apply sparse gradients to the variable `handle`.
817
818    Similar to `_apply_sparse`, the `indices` argument to this method has been
819    de-duplicated. Optimizers which deal correctly with non-unique indices may
820    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
821    overhead.
822
823    Args:
824      grad: a `Tensor` representing the gradient for the affected indices.
825      handle: a `Tensor` of dtype `resource` which points to the variable to be
826        updated.
827      indices: a `Tensor` of integral type representing the indices for which
828        the gradient is nonzero. Indices are unique.
829
830    Returns:
831      An `Operation` which updates the value of the variable.
832    """
833    raise NotImplementedError()
834
835  def _resource_scatter_add(self, x, i, v):
836    with ops.control_dependencies(
837        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
838      return x.value()
839
840  def _resource_scatter_update(self, x, i, v):
841    with ops.control_dependencies(
842        [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
843      return x.value()
844
845  # ---------------
846  # For implementing the trackable interface
847  # ---------------
848
849  def _restore_slot_variable(self, slot_name, variable, slot_variable):
850    """Restore a newly created slot variable's value."""
851    variable_key = _var_key(variable)
852    deferred_restorations = self._deferred_slot_restorations.get(
853        slot_name, {}).pop(variable_key, [])
854    # Iterate over restores, highest restore UID first to minimize the number
855    # of assignments.
856    deferred_restorations.sort(key=lambda position: position.restore_uid,
857                               reverse=True)
858    for checkpoint_position in deferred_restorations:
859      checkpoint_position.restore(slot_variable)
860
861  def _create_or_restore_slot_variable(
862      self, slot_variable_position, slot_name, variable):
863    """Restore a slot variable's value, possibly creating it.
864
865    Called when a variable which has an associated slot variable is created or
866    restored. When executing eagerly, we create the slot variable with a
867    restoring initializer.
868
869    No new variables are created when graph building. Instead,
870    _restore_slot_variable catches these after normal creation and adds restore
871    ops to the graph. This method is nonetheless important when graph building
872    for the case when a slot variable has already been created but `variable`
873    has just been added to a dependency graph (causing us to realize that the
874    slot variable needs to be restored).
875
876    Args:
877      slot_variable_position: A `trackable._CheckpointPosition` object
878        indicating the slot variable `Trackable` object to be restored.
879      slot_name: The name of this `Optimizer`'s slot to restore into.
880      variable: The variable object this slot is being created for.
881    """
882    variable_key = _var_key(variable)
883    slot_dict = self._slots.get(variable_key, {})
884    slot_variable = slot_dict.get(slot_name, None)
885    if (slot_variable is None and context.executing_eagerly() and
886        slot_variable_position.is_simple_variable()
887        # Defer slot variable creation if there is an active variable creator
888        # scope. Generally we'd like to eagerly create/restore slot variables
889        # when possible, but this may mean that scopes intended to catch
890        # `variable` also catch its eagerly created slot variable
891        # unintentionally (specifically make_template would add a dependency on
892        # a slot variable if not for this case). Deferring is mostly harmless
893        # (aside from double initialization), and makes variable creator scopes
894        # behave the same way they do when graph building.
895        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
896      initializer = trackable.CheckpointInitialValue(
897          checkpoint_position=slot_variable_position)
898      slot_variable = self.add_slot(
899          var=variable,
900          initializer=initializer,
901          slot_name=slot_name)
902      # Slot variables are not owned by any one object (because we don't want to
903      # save the slot variable if the optimizer is saved without the non-slot
904      # variable, or if the non-slot variable is saved without the optimizer;
905      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
906      # variable, variable)). So we don't _track_ slot variables anywhere, and
907      # instead special-case this dependency and otherwise pretend it's a normal
908      # graph.
909    if slot_variable is not None:
910      # If we've either made this slot variable, or if we've pulled out an
911      # existing slot variable, we should restore it.
912      slot_variable_position.restore(slot_variable)
913    else:
914      # We didn't make the slot variable. Defer restoring until it gets created
915      # normally. We keep a list rather than the one with the highest restore
916      # UID in case slot variables have their own dependencies, in which case
917      # those could differ between restores.
918      self._deferred_slot_restorations.setdefault(
919          slot_name, {}).setdefault(variable_key, []).append(
920              slot_variable_position)
921
922
923def _filter_grads(grads_and_vars):
924  """Filter out iterable with grad equal to None."""
925  grads_and_vars = tuple(grads_and_vars)
926  if not grads_and_vars:
927    return grads_and_vars
928  filtered = []
929  vars_with_empty_grads = []
930  for grad, var in grads_and_vars:
931    if grad is None:
932      vars_with_empty_grads.append(var)
933    else:
934      filtered.append((grad, var))
935  filtered = tuple(filtered)
936  if not filtered:
937    raise ValueError("No gradients provided for any variable: %s." %
938                     ([v.name for _, v in grads_and_vars],))
939  if vars_with_empty_grads:
940    logging.warning(
941        ("Gradients does not exist for variables %s when minimizing the loss."),
942        ([v.name for v in vars_with_empty_grads]))
943  return filtered
944
945
946def _var_key(var):
947  """Key for representing a primary variable, for looking up slots.
948
949  In graph mode the name is derived from the var shared name.
950  In eager mode the name is derived from the var unique id.
951  If distribution strategy exists, get the primary variable first.
952
953  Args:
954    var: the variable.
955
956  Returns:
957    the unique name of the variable.
958  """
959
960  # pylint: disable=protected-access
961  # Get the distributed variable if it exists.
962  if getattr(var, "_distributed_container", None) is not None:
963    var = var._distributed_container()
964  if var._in_graph_mode:
965    return var._shared_name
966  return var._unique_id
967
968
969def _get_slot_key_from_var(var, slot_name):
970  """Get the slot key for the variable: var_name/slot_name."""
971
972  name = _var_key(var)
973  return name + "/" + slot_name
974
975
976class _RestoredOptimizer(OptimizerV2):
977  """A non-functional Optimizer implementation for checkpoint compatibility.
978
979  Holds slot variables and hyperparameters when an optimizer is restored from a
980  SavedModel. These variables may be referenced in functions along with ops
981  created by the original optimizer, but currently we do not support using the
982  optimizer object iself (e.g. through `apply_gradients`).
983  """
984  # TODO(allenl): Make the restored optimizer functional by tracing its apply
985  # methods.
986
987  def __init__(self):
988    super(_RestoredOptimizer, self).__init__("_RestoredOptimizer")
989    self._hypers_created = True
990
991  def get_config(self):
992    # TODO(allenl): Save and restore the Optimizer's config
993    raise NotImplementedError(
994        "Restoring functional Optimzers from SavedModels is not currently "
995        "supported. Please file a feature request if this limitation bothers "
996        "you.")
997
998revived_types.register_revived_type(
999    "optimizer",
1000    lambda obj: isinstance(obj, OptimizerV2),
1001    versions=[revived_types.VersionedTypeRegistration(
1002        object_factory=lambda proto: _RestoredOptimizer(),
1003        version=1,
1004        min_producer_version=1,
1005        min_consumer_version=1,
1006        setter=_RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1007    )])
1008