• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Version 2 of class Optimizer."""
17# pylint: disable=g-bad-name
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24import functools
25
26import six
27
28from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
29from tensorflow.python.distribute import reduce_util as ds_reduce_util
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import context
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.keras import backend
36from tensorflow.python.keras import initializers
37from tensorflow.python.keras.engine import base_layer_utils
38from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
39from tensorflow.python.keras.utils import generic_utils
40from tensorflow.python.keras.utils import tf_utils
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import clip_ops
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import gradients
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.ops import variables as tf_variables
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.saved_model import revived_types
50from tensorflow.python.training.tracking import base as trackable
51from tensorflow.python.training.tracking import tracking
52from tensorflow.python.util import nest
53from tensorflow.python.util import tf_inspect
54from tensorflow.python.util.tf_export import keras_export
55
56
57def _deduplicate_indexed_slices(values, indices):
58  """Sums `values` associated with any non-unique `indices`.
59
60  Args:
61    values: A `Tensor` with rank >= 1.
62    indices: A one-dimensional integer `Tensor`, indexing into the first
63      dimension of `values` (as in an IndexedSlices object).
64
65  Returns:
66    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
67    de-duplicated version of `indices` and `summed_values` contains the sum of
68    `values` slices associated with each unique index.
69  """
70  unique_indices, new_index_positions = array_ops.unique(indices)
71  summed_values = math_ops.unsorted_segment_sum(
72      values, new_index_positions,
73      array_ops.shape(unique_indices)[0])
74  return (summed_values, unique_indices)
75
76
77@six.add_metaclass(abc.ABCMeta)
78@keras_export("keras.optimizers.Optimizer")
79class OptimizerV2(trackable.Trackable):
80  """Updated base class for optimizers.
81
82  This class defines the API to add Ops to train a model.  You never use this
83  class directly, but instead instantiate one of its subclasses such as
84  `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`.
85
86  ### Usage
87
88  ```python
89  # Create an optimizer with the desired parameters.
90  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
91  # `loss` is a callable that takes no argument and returns the value
92  # to minimize.
93  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
94  # In graph mode, returns op that minimizes the loss by updating the listed
95  # variables.
96  opt_op = opt.minimize(loss, var_list=[var1, var2])
97  opt_op.run()
98  # In eager mode, simply call minimize to update the list of variables.
99  opt.minimize(loss, var_list=[var1, var2])
100  ```
101
102  ### Custom training loop with Keras models
103
104  In Keras models, sometimes variables are created when the model is first
105  called, instead of construction time. Examples include 1) sequential models
106  without input shape pre-defined, or 2) subclassed models. Pass var_list as
107  callable in these cases.
108
109  Example:
110  ```python
111  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
112  model = tf.keras.Sequential()
113  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
114  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
115  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
116  var_list_fn = lambda: model.trainable_weights
117  for input, output in data:
118    opt.minimize(loss_fn, var_list_fn)
119  ```
120
121  ### Processing gradients before applying them.
122
123  Calling `minimize()` takes care of both computing the gradients and
124  applying them to the variables.  If you want to process the gradients
125  before applying them you can instead use the optimizer in three steps:
126
127  1.  Compute the gradients with `tf.GradientTape`.
128  2.  Process the gradients as you wish.
129  3.  Apply the processed gradients with `apply_gradients()`.
130
131  Example:
132
133  ```python
134  # Create an optimizer.
135  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
136
137  # Compute the gradients for a list of variables.
138  with tf.GradientTape() as tape:
139    loss = <call_loss_function>
140  vars = <list_of_variables>
141  grads = tape.gradient(loss, vars)
142
143  # Process the gradients, for example cap them, etc.
144  # capped_grads = [MyCapper(g) for g in grads]
145  processed_grads = [process_gradient(g) for g in grads]
146
147  # Ask the optimizer to apply the processed gradients.
148  opt.apply_gradients(zip(processed_grads, var_list))
149  ```
150
151  ### Use with `tf.distribute.Strategy`.
152
153  This optimizer class is `tf.distribute.Strategy` aware, which means it
154  automatically sums gradients across all replicas. To average gradients,
155  you divide your loss by the global batch size, which is done
156  automatically if you use `tf.keras` built-in training or evaluation loops.
157  See the `reduction` argument of your loss which should be set to
158  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
159  `tf.keras.losses.Reduction.SUM` for not.
160
161  If you are not using these and you want to average gradients, you should use
162  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
163  global batch size. Note that when using `tf.distribute.Strategy`, the first
164  component of a tensor's shape is the *replica-local* batch size, which is off
165  by a factor equal to the number of replicas being used to compute a single
166  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
167  resulting in gradients that can be many times too big.
168
169  ### Variable Constraint
170
171  All Keras optimizers respect variable constraints. If constraint function is
172  passed to any variable, the constraint will be applied to the variable after
173  the gradient has been applied to the variable.
174  Important: If gradient is sparse tensor, variable constraint is not supported.
175
176  ### Thread Compatibility
177
178  The entire optimizer is currently thread compatible, not thread-safe. The user
179  needs to perform synchronization if necessary.
180
181  ### Slots
182
183  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
184  additional variables associated with the variables to train.  These are called
185  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
186  the slots that it uses.  Once you have a slot name you can ask the optimizer
187  for the variable it created to hold the slot value.
188
189  This can be useful if you want to log debug a training algorithm, report stats
190  about the slots, etc.
191
192  ### Hyper parameters
193
194  These are arguments passed to the optimizer subclass constructor
195  (the `__init__` method), and then passed to `self._set_hyper()`.
196  They can be either regular Python values (like 1.0), tensors, or
197  callables. If they are callable, the callable will be called during
198  `apply_gradients()` to get the value for the hyper parameter.
199
200  Hyper parameters can be overwritten through user code:
201
202  Example:
203
204  ```python
205  # Create an optimizer with the desired parameters.
206  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
207  # `loss` is a callable that takes no argument and returns the value
208  # to minimize.
209  loss = lambda: 3 * var1 + 2 * var2
210  # In eager mode, simply call minimize to update the list of variables.
211  opt.minimize(loss, var_list=[var1, var2])
212  # update learning rate
213  opt.learning_rate = 0.05
214  opt.minimize(loss, var_list=[var1, var2])
215  ```
216
217  ### Write a customized optimizer.
218  If you intend to create your own optimization algorithm, simply inherit from
219  this class and override the following methods:
220
221    - resource_apply_dense (update variable given gradient tensor is dense)
222    - resource_apply_sparse (update variable given gradient tensor is sparse)
223    - create_slots (if your optimizer algorithm requires additional variables)
224    - get_config (serialization of the optimizer, include all hyper parameters)
225  """
226
227  def __init__(self, name, **kwargs):
228    """Create a new Optimizer.
229
230    This must be called by the constructors of subclasses.
231    Note that Optimizer instances should not bind to a single graph,
232    and so shouldn't keep Tensors as member variables. Generally
233    you should be able to use the _set_hyper()/state.get_hyper()
234    facility instead.
235
236    This class in stateful and thread-compatible.
237
238    Args:
239      name: A non-empty string.  The name to use for accumulators created
240        for the optimizer.
241      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
242        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
243        gradients by value, `decay` is included for backward compatibility to
244        allow time inverse decay of learning rate. `lr` is included for backward
245        compatibility, recommended to use `learning_rate` instead.
246
247    Raises:
248      ValueError: If name is malformed.
249      RuntimeError: If _create_slots has been overridden instead of
250          _create_vars.
251    """
252    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"}
253    for k in kwargs:
254      if k not in allowed_kwargs:
255        raise TypeError("Unexpected keyword argument "
256                        "passed to optimizer: " + str(k))
257      # checks that all keyword arguments are non-negative.
258      if kwargs[k] < 0:
259        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
260
261    self._use_locking = True
262    self._init_set_name(name)
263    self._hyper = {}
264    # dict: {variable name : {slot name : variable}}
265    self._slots = {}
266    self._slot_names = []
267    self._weights = []
268    self._iterations = None
269
270    # For implementing Trackable. Stores information about how to restore
271    # slot variables which have not yet been created
272    # (trackable._CheckpointPosition objects).
273    #  {slot_name :
274    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
275    #   ... }
276    self._deferred_slot_restorations = {}
277
278    decay = kwargs.pop("decay", 0.0)
279    if decay < 0.:
280      raise ValueError("decay cannot be less than 0: {}".format(decay))
281    self._initial_decay = decay
282    if "clipnorm" in kwargs:
283      self.clipnorm = kwargs.pop("clipnorm")
284    if "clipvalue" in kwargs:
285      self.clipvalue = kwargs.pop("clipvalue")
286
287    self._hypers_created = False
288
289  def minimize(self, loss, var_list, grad_loss=None, name=None):
290    """Minimize `loss` by updating `var_list`.
291
292    This method simply computes gradient using `tf.GradientTape` and calls
293    `apply_gradients()`. If you want to process the gradient before applying
294    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
295    of using this function.
296
297    Args:
298      loss: A callable taking no arguments which returns the value to minimize.
299      var_list: list or tuple of `Variable` objects to update to minimize
300        `loss`, or a callable returning the list or tuple of `Variable` objects.
301        Use callable when the variable list would otherwise be incomplete before
302        `minimize` since the variables are created at the first time `loss` is
303        called.
304      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
305      name: Optional name for the returned operation.
306
307    Returns:
308      An `Operation` that updates the variables in `var_list`. The `iterations`
309      will be automatically increased by 1.
310
311    Raises:
312      ValueError: If some of the variables are not `Variable` objects.
313
314    """
315    grads_and_vars = self._compute_gradients(
316        loss, var_list=var_list, grad_loss=grad_loss)
317
318    return self.apply_gradients(grads_and_vars, name=name)
319
320  def _compute_gradients(self, loss, var_list, grad_loss=None):
321    """Compute gradients of `loss` for the variables in `var_list`.
322
323    This is the first part of `minimize()`.  It returns a list
324    of (gradient, variable) pairs where "gradient" is the gradient
325    for "variable".  Note that "gradient" can be a `Tensor`, an
326    `IndexedSlices`, or `None` if there is no gradient for the
327    given variable.
328
329    Args:
330      loss: A callable taking no arguments which returns the value to minimize.
331      var_list: list or tuple of `Variable` objects to update to minimize
332        `loss`, or a callable returning the list or tuple of `Variable` objects.
333        Use callable when the variable list would otherwise be incomplete before
334        `minimize` and the variables are created at the first time when `loss`
335        is called.
336      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
337
338    Returns:
339      A list of (gradient, variable) pairs. Variable is always present, but
340      gradient can be `None`.
341
342    Raises:
343      TypeError: If `var_list` contains anything else than `Variable` objects.
344      ValueError: If some arguments are invalid, or var_list is None.
345    """
346    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
347    with backprop.GradientTape() as tape:
348      if not callable(var_list):
349        tape.watch(var_list)
350      loss_value = loss()
351    if callable(var_list):
352      var_list = var_list()
353    var_list = nest.flatten(var_list)
354    with backend.name_scope(self._name + "/gradients"):
355      grads = tape.gradient(loss_value, var_list, grad_loss)
356
357      if hasattr(self, "clipnorm"):
358        grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
359      if hasattr(self, "clipvalue"):
360        grads = [
361            clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
362            for g in grads
363        ]
364
365    grads_and_vars = list(zip(grads, var_list))
366    self._assert_valid_dtypes([
367        v for g, v in grads_and_vars
368        if g is not None and v.dtype != dtypes.resource
369    ])
370
371    return grads_and_vars
372
373  def get_gradients(self, loss, params):
374    """Returns gradients of `loss` with respect to `params`.
375
376    Arguments:
377      loss: Loss tensor.
378      params: List of variables.
379
380    Returns:
381      List of gradient tensors.
382
383    Raises:
384      ValueError: In case any gradient cannot be computed (e.g. if gradient
385        function not implemented).
386    """
387    params = nest.flatten(params)
388    with backend.get_graph().as_default(), backend.name_scope(self._name +
389                                                              "/gradients"):
390      grads = gradients.gradients(loss, params)
391      for grad, param in zip(grads, params):
392        if grad is None:
393          raise ValueError("Variable {} has `None` for gradient. "
394                           "Please make sure that all of your ops have a "
395                           "gradient defined (i.e. are differentiable). "
396                           "Common ops without gradient: "
397                           "K.argmax, K.round, K.eval.".format(param))
398      if hasattr(self, "clipnorm"):
399        grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
400      if hasattr(self, "clipvalue"):
401        grads = [
402            clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
403            for g in grads
404        ]
405    return grads
406
407  def apply_gradients(self, grads_and_vars, name=None):
408    """Apply gradients to variables.
409
410    This is the second part of `minimize()`. It returns an `Operation` that
411    applies gradients.
412
413    Args:
414      grads_and_vars: List of (gradient, variable) pairs.
415      name: Optional name for the returned operation.  Default to the name
416        passed to the `Optimizer` constructor.
417
418    Returns:
419      An `Operation` that applies the specified gradients. The `iterations`
420      will be automatically increased by 1.
421
422    Raises:
423      TypeError: If `grads_and_vars` is malformed.
424      ValueError: If none of the variables have gradients.
425    """
426    grads_and_vars = _filter_grads(grads_and_vars)
427    var_list = [v for (_, v) in grads_and_vars]
428
429    with backend.name_scope(self._name):
430      # Create iteration if necessary.
431      with ops.init_scope():
432        _ = self.iterations
433        self._create_hypers()
434        self._create_slots(var_list)
435
436      if not grads_and_vars:
437        # Distribution strategy does not support reducing an empty list of
438        # gradients
439        return control_flow_ops.no_op()
440      apply_state = self._prepare(var_list)
441      return distribute_ctx.get_replica_context().merge_call(
442          functools.partial(self._distributed_apply, apply_state=apply_state),
443          args=(grads_and_vars,),
444          kwargs={"name": name})
445
446  def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
447    """`apply_gradients` using a `DistributionStrategy`."""
448    reduced_grads = distribution.extended.batch_reduce_to(
449        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
450    var_list = [v for _, v in grads_and_vars]
451    grads_and_vars = zip(reduced_grads, var_list)
452
453    def apply_grad_to_update_var(var, grad):
454      """Apply gradient to variable."""
455      if isinstance(var, ops.Tensor):
456        raise NotImplementedError("Trying to update a Tensor ", var)
457
458      apply_kwargs = {}
459      if isinstance(grad, ops.IndexedSlices):
460        if var.constraint is not None:
461          raise RuntimeError(
462              "Cannot use a constraint function on a sparse variable.")
463        if "apply_state" in self._sparse_apply_args:
464          apply_kwargs["apply_state"] = apply_state
465        return self._resource_apply_sparse_duplicate_indices(
466            grad.values, var, grad.indices, **apply_kwargs)
467
468      if "apply_state" in self._dense_apply_args:
469        apply_kwargs["apply_state"] = apply_state
470      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
471      if var.constraint is not None:
472        with ops.control_dependencies([update_op]):
473          return var.assign(var.constraint(var))
474      else:
475        return update_op
476
477    eagerly_outside_functions = ops.executing_eagerly_outside_functions()
478    update_ops = []
479    with ops.name_scope(name or self._name, skip_on_eager=True):
480      for grad, var in grads_and_vars:
481        # Colocate the update with variables to avoid unnecessary communication
482        # delays. See b/136304694.
483        with distribution.extended.colocate_vars_with(var):
484          with ops.name_scope("update" if eagerly_outside_functions else
485                              "update_" + var.op.name, skip_on_eager=True):
486            update_ops.extend(distribution.extended.update(
487                var, apply_grad_to_update_var, args=(grad,), group=False))
488
489      any_symbolic = any(isinstance(i, ops.Operation) or
490                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
491      if not context.executing_eagerly() or any_symbolic:
492        # If the current context is graph mode or any of the update ops are
493        # symbolic then the step update should be carried out under a graph
494        # context. (eager updates execute immediately)
495        with ops._get_graph_from_inputs(update_ops).as_default():  # pylint: disable=protected-access
496          with ops.control_dependencies(update_ops):
497            return self._iterations.assign_add(1).op
498
499      return self._iterations.assign_add(1)
500
501  def get_updates(self, loss, params):
502    grads = self.get_gradients(loss, params)
503    grads_and_vars = list(zip(grads, params))
504    self._assert_valid_dtypes([
505        v for g, v in grads_and_vars
506        if g is not None and v.dtype != dtypes.resource
507    ])
508    return [self.apply_gradients(grads_and_vars)]
509
510  def _set_hyper(self, name, value):
511    """set hyper `name` to value. value can be callable, tensor, numeric."""
512    if isinstance(value, trackable.Trackable):
513      self._track_trackable(value, name, overwrite=True)
514    if name not in self._hyper:
515      self._hyper[name] = value
516    else:
517      prev_value = self._hyper[name]
518      if (callable(prev_value)
519          or isinstance(prev_value,
520                        (ops.Tensor, int, float,
521                         learning_rate_schedule.LearningRateSchedule))
522          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
523        self._hyper[name] = value
524      else:
525        backend.set_value(self._hyper[name], value)
526
527  def _get_hyper(self, name, dtype=None):
528    if not self._hypers_created:
529      self._create_hypers()
530    value = self._hyper[name]
531    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
532      return value
533    if callable(value):
534      value = value()
535    if dtype:
536      return math_ops.cast(value, dtype)
537    else:
538      return value
539
540  def __getattribute__(self, name):
541    """Overridden to support hyperparameter access."""
542    try:
543      return super(OptimizerV2, self).__getattribute__(name)
544    except AttributeError as e:
545      # Needed to avoid infinite recursion with __setattr__.
546      if name == "_hyper":
547        raise e
548      # Backwards compatibility with Keras optimizers.
549      if name == "lr":
550        name = "learning_rate"
551      if name in self._hyper:
552        return self._get_hyper(name)
553      raise e
554
555  def __setattr__(self, name, value):
556    """Override setattr to support dynamic hyperparameter setting."""
557    # Backwards compatibility with Keras optimizers.
558    if name == "lr":
559      name = "learning_rate"
560    if hasattr(self, "_hyper") and name in self._hyper:
561      self._set_hyper(name, value)
562    else:
563      super(OptimizerV2, self).__setattr__(name, value)
564
565  def get_slot_names(self):
566    """A list of names for this optimizer's slots."""
567    return self._slot_names
568
569  def add_slot(self, var, slot_name, initializer="zeros"):
570    """Add a new slot variable for `var`."""
571    if slot_name not in self._slot_names:
572      self._slot_names.append(slot_name)
573    var_key = _var_key(var)
574    slot_dict = self._slots.setdefault(var_key, {})
575    weight = slot_dict.get(slot_name, None)
576    if weight is None:
577      if isinstance(initializer, six.string_types) or callable(initializer):
578        initializer = initializers.get(initializer)
579        initial_value = functools.partial(
580            initializer, shape=var.shape, dtype=var.dtype)
581      else:
582        initial_value = initializer
583      strategy = distribute_ctx.get_strategy()
584      if not strategy.extended.variable_created_in_scope(var):
585        raise ValueError(
586            "Trying to create optimizer slot variable under the scope for "
587            "tf.distribute.Strategy ({}), which is different from the scope "
588            "used for the original variable ({}). Make sure the slot "
589            "variables are created under the same strategy scope. This may "
590            "happen if you're restoring from a checkpoint outside the scope"
591            .format(strategy, var))
592
593      with strategy.extended.colocate_vars_with(var):
594        weight = tf_variables.Variable(
595            name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
596            dtype=var.dtype,
597            trainable=False,
598            initial_value=initial_value)
599      backend.track_variable(weight)
600      slot_dict[slot_name] = weight
601      self._restore_slot_variable(
602          slot_name=slot_name, variable=var,
603          slot_variable=weight)
604      self._weights.append(weight)
605    return weight
606
607  def get_slot(self, var, slot_name):
608    var_key = _var_key(var)
609    slot_dict = self._slots[var_key]
610    return slot_dict[slot_name]
611
612  def _prepare(self, var_list):
613    keys = set()
614    for var in var_list:
615      var_devices = (getattr(var, "devices", None) or  # Distributed
616                     [var.device])                     # Regular
617      var_dtype = var.dtype.base_dtype
618      for var_device in var_devices:
619        keys.add((var_device, var_dtype))
620
621    apply_state = {}
622    for var_device, var_dtype in keys:
623      apply_state[(var_device, var_dtype)] = {}
624      with ops.device(var_device):
625        self._prepare_local(var_device, var_dtype, apply_state)
626
627    return apply_state
628
629  def _prepare_local(self, var_device, var_dtype, apply_state):
630    if "learning_rate" in self._hyper:
631      lr_t = array_ops.identity(self._decayed_lr(var_dtype))
632      apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
633
634  def _fallback_apply_state(self, var_device, var_dtype):
635    """Compatibility for subclasses that don't pass apply_state through."""
636    apply_state = {(var_device, var_dtype): {}}
637    self._prepare_local(var_device, var_dtype, apply_state)
638    return apply_state[(var_device, var_dtype)]
639
640  def _create_hypers(self):
641    if self._hypers_created:
642      return
643    # Iterate hyper values deterministically.
644    for name, value in sorted(self._hyper.items()):
645      if isinstance(
646          value, (ops.Tensor, tf_variables.Variable)) or callable(value):
647        continue
648      else:
649        self._hyper[name] = self.add_weight(
650            name,
651            shape=[],
652            trainable=False,
653            initializer=value,
654            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
655    self._hypers_created = True
656
657  @property
658  def iterations(self):
659    """Variable. The number of training steps this Optimizer has run."""
660    if self._iterations is None:
661      self._iterations = self.add_weight(
662          "iter",
663          shape=[],
664          dtype=dtypes.int64,
665          trainable=False,
666          aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
667      self._weights.append(self._iterations)
668    return self._iterations
669
670  @iterations.setter
671  def iterations(self, variable):
672    if self._iterations is not None:
673      raise RuntimeError("Cannot set `iterations` to a new Variable after "
674                         "the Optimizer weights have been created")
675    self._iterations = variable
676    self._weights.append(self._iterations)
677
678  def _decayed_lr(self, var_dtype):
679    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
680    lr_t = self._get_hyper("learning_rate", var_dtype)
681    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
682      local_step = math_ops.cast(self.iterations, var_dtype)
683      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
684    if self._initial_decay > 0.:
685      local_step = math_ops.cast(self.iterations, var_dtype)
686      decay_t = self._get_hyper("decay", var_dtype)
687      lr_t = lr_t / (1. + decay_t * local_step)
688    return lr_t
689
690  @abc.abstractmethod
691  def get_config(self):
692    """Returns the config of the optimimizer.
693
694    An optimizer config is a Python dictionary (serializable)
695    containing the configuration of an optimizer.
696    The same optimizer can be reinstantiated later
697    (without any saved state) from this configuration.
698
699    Returns:
700        Python dictionary.
701    """
702    config = {"name": self._name}
703    if hasattr(self, "clipnorm"):
704      config["clipnorm"] = self.clipnorm
705    if hasattr(self, "clipvalue"):
706      config["clipvalue"] = self.clipvalue
707    return config
708
709  @classmethod
710  def from_config(cls, config, custom_objects=None):
711    """Creates an optimizer from its config.
712
713    This method is the reverse of `get_config`,
714    capable of instantiating the same optimizer from the config
715    dictionary.
716
717    Arguments:
718        config: A Python dictionary, typically the output of get_config.
719        custom_objects: A Python dictionary mapping names to additional Python
720          objects used to create this optimizer, such as a function used for a
721          hyperparameter.
722
723    Returns:
724        An optimizer instance.
725    """
726    if "lr" in config:
727      config["learning_rate"] = config.pop("lr")
728    if "learning_rate" in config:
729      if isinstance(config["learning_rate"], dict):
730        config["learning_rate"] = learning_rate_schedule.deserialize(
731            config["learning_rate"], custom_objects=custom_objects)
732    return cls(**config)
733
734  def _serialize_hyperparameter(self, hyperparameter_name):
735    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
736    value = self._hyper[hyperparameter_name]
737    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
738      return learning_rate_schedule.serialize(value)
739    if callable(value):
740      return value()
741    if tensor_util.is_tensor(value):
742      return backend.get_value(value)
743    return value
744
745  def variables(self):
746    """Returns variables of this Optimizer based on the order created."""
747    return self._weights
748
749  @property
750  def weights(self):
751    """Returns variables of this Optimizer based on the order created."""
752    return self._weights
753
754  def get_weights(self):
755    """Returns the current weights of the optimizer.
756
757    The weights of an optimizer are its state (ie, variables).
758    This function returns the weight values associated with this
759    optimizer as a list of Numpy arrays. The first value is always the
760    iterations count of the optimizer, followed by the optimizer's state
761    variables in the order they were created. The returned list can in turn
762    be used to load state into similarly parameterized optimizers.
763
764    For example, the RMSprop optimizer for this simple model returns a list of
765    three values-- the iteration count, followed by the root-mean-square value
766    of the kernel and bias of the single Dense layer:
767
768    >>> opt = tf.keras.optimizers.RMSprop()
769    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
770    >>> m.compile(opt, loss='mse')
771    >>> data = np.arange(100).reshape(5, 20)
772    >>> labels = np.zeros(5)
773    >>> print('Training'); results = m.fit(data, labels)
774    Training ...
775    >>> len(opt.get_weights())
776    3
777
778    Returns:
779        Weights values as a list of numpy arrays.
780    """
781    params = self.weights
782    return backend.batch_get_value(params)
783
784  # TODO(tanzheny): Maybe share this logic with base_layer.
785  def set_weights(self, weights):
786    """Set the weights of the optimizer.
787
788    The weights of an optimizer are its state (ie, variables).
789    This function takes the weight values associated with this
790    optimizer as a list of Numpy arrays. The first value is always the
791    iterations count of the optimizer, followed by the optimizer's state
792    variables in the order they are created. The passed values are used to set
793    the new state of the optimizer.
794
795    For example, the RMSprop optimizer for this simple model takes a list of
796    three values-- the iteration count, followed by the root-mean-square value
797    of the kernel and bias of the single Dense layer:
798
799    >>> opt = tf.keras.optimizers.RMSprop()
800    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
801    >>> m.compile(opt, loss='mse')
802    >>> data = np.arange(100).reshape(5, 20)
803    >>> labels = np.zeros(5)
804    >>> print('Training'); results = m.fit(data, labels)
805    Training ...
806    >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
807    >>> opt.set_weights(new_weights)
808    >>> opt.iterations
809    <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
810
811    Arguments:
812        weights: weight values as a list of numpy arrays.
813    """
814    params = self.weights
815    if len(params) != len(weights):
816      raise ValueError(
817          "You called `set_weights(weights)` on optimizer " + self._name +
818          " with a  weight list of length " + str(len(weights)) +
819          ", but the optimizer was expecting " + str(len(params)) +
820          " weights. Provided weights: " + str(weights)[:50] + "...")
821    if not params:
822      return
823    weight_value_tuples = []
824    param_values = backend.batch_get_value(params)
825    for pv, p, w in zip(param_values, params, weights):
826      if pv.shape != w.shape:
827        raise ValueError("Optimizer weight shape " + str(pv.shape) +
828                         " not compatible with "
829                         "provided weight shape " + str(w.shape))
830      weight_value_tuples.append((p, w))
831    backend.batch_set_value(weight_value_tuples)
832
833  def add_weight(self,
834                 name,
835                 shape,
836                 dtype=None,
837                 initializer="zeros",
838                 trainable=None,
839                 synchronization=tf_variables.VariableSynchronization.AUTO,
840                 aggregation=tf_variables.VariableAggregation.NONE):
841
842    if dtype is None:
843      dtype = dtypes.float32
844    if isinstance(initializer, six.string_types) or callable(initializer):
845      initializer = initializers.get(initializer)
846
847    if synchronization == tf_variables.VariableSynchronization.ON_READ:
848      if trainable:
849        raise ValueError(
850            "Synchronization value can be set to "
851            "VariableSynchronization.ON_READ only for non-trainable variables. "
852            "You have specified trainable=True and "
853            "synchronization=VariableSynchronization.ON_READ.")
854      else:
855        # Set trainable to be false when variable is to be synced on read.
856        trainable = False
857    elif trainable is None:
858      trainable = True
859
860    variable = self._add_variable_with_custom_getter(
861        name=name,
862        shape=shape,
863        getter=base_layer_utils.make_variable,
864        overwrite=True,
865        initializer=initializer,
866        dtype=dtype,
867        trainable=trainable,
868        use_resource=True,
869        synchronization=synchronization,
870        aggregation=aggregation)
871    backend.track_variable(variable)
872
873    return variable
874
875  def _init_set_name(self, name, zero_based=True):
876    if not name:
877      self._name = backend.unique_object_name(
878          generic_utils.to_snake_case(self.__class__.__name__),
879          zero_based=zero_based)
880    else:
881      self._name = name
882
883  def _assert_valid_dtypes(self, tensors):
884    """Asserts tensors are all valid types (see `_valid_dtypes`).
885
886    Args:
887      tensors: Tensors to check.
888
889    Raises:
890      ValueError: If any tensor is not a valid type.
891    """
892    valid_dtypes = self._valid_dtypes()
893    for t in tensors:
894      dtype = t.dtype.base_dtype
895      if dtype not in valid_dtypes:
896        raise ValueError("Invalid type %r for %s, expected: %s." %
897                         (dtype, t.name, [v for v in valid_dtypes]))
898
899  def _valid_dtypes(self):
900    """Valid types for loss, variables and gradients.
901
902    Subclasses should override to allow other float types.
903
904    Returns:
905      Valid types for loss, variables and gradients.
906    """
907    return set([
908        dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
909        dtypes.complex64, dtypes.complex128
910    ])
911
912  def _call_if_callable(self, param):
913    """Call the function if param is callable."""
914    return param() if callable(param) else param
915
916  def _resource_apply_dense(self, grad, handle, apply_state):
917    """Add ops to apply dense gradients to the variable `handle`.
918
919    Args:
920      grad: a `Tensor` representing the gradient.
921      handle: a `Tensor` of dtype `resource` which points to the variable to be
922        updated.
923      apply_state: A dict which is used across multiple apply calls.
924
925    Returns:
926      An `Operation` which updates the value of the variable.
927    """
928    raise NotImplementedError()
929
930  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
931                                               **kwargs):
932    """Add ops to apply sparse gradients to `handle`, with repeated indices.
933
934    Optimizers which override this method must deal with repeated indices. See
935    the docstring of `_apply_sparse_duplicate_indices` for details. By default
936    the correct behavior, to sum non-unique indices and their associated
937    gradients, is enforced by first pre-processing `grad` and `indices` and
938    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
939    with duplicate indices may instead override this method to avoid the
940    overhead of summing.
941
942    Args:
943      grad: a `Tensor` representing the gradient for the affected indices.
944      handle: a `Tensor` of dtype `resource` which points to the variable to be
945        updated.
946      indices: a `Tensor` of integral type representing the indices for which
947        the gradient is nonzero. Indices may be repeated.
948      **kwargs: May optionally contain `apply_state`
949
950    Returns:
951      An `Operation` which updates the value of the variable.
952    """
953    summed_grad, unique_indices = _deduplicate_indexed_slices(
954        values=grad, indices=indices)
955    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
956                                       **kwargs)
957
958  def _resource_apply_sparse(self, grad, handle, indices, apply_state):
959    """Add ops to apply sparse gradients to the variable `handle`.
960
961    Similar to `_apply_sparse`, the `indices` argument to this method has been
962    de-duplicated. Optimizers which deal correctly with non-unique indices may
963    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
964    overhead.
965
966    Args:
967      grad: a `Tensor` representing the gradient for the affected indices.
968      handle: a `Tensor` of dtype `resource` which points to the variable to be
969        updated.
970      indices: a `Tensor` of integral type representing the indices for which
971        the gradient is nonzero. Indices are unique.
972      apply_state: A dict which is used across multiple apply calls.
973
974    Returns:
975      An `Operation` which updates the value of the variable.
976    """
977    raise NotImplementedError()
978
979  def _resource_scatter_add(self, x, i, v):
980    with ops.control_dependencies(
981        [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
982      return x.value()
983
984  def _resource_scatter_update(self, x, i, v):
985    with ops.control_dependencies(
986        [resource_variable_ops.resource_scatter_update(x.handle, i, v)]):
987      return x.value()
988
989  @property
990  @tracking.cached_per_instance
991  def _dense_apply_args(self):
992    return tf_inspect.getfullargspec(self._resource_apply_dense).args
993
994  @property
995  @tracking.cached_per_instance
996  def _sparse_apply_args(self):
997    return tf_inspect.getfullargspec(self._resource_apply_sparse).args
998
999  # ---------------
1000  # For implementing the trackable interface
1001  # ---------------
1002
1003  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1004    """Restore a newly created slot variable's value."""
1005    variable_key = _var_key(variable)
1006    deferred_restorations = self._deferred_slot_restorations.get(
1007        slot_name, {}).pop(variable_key, [])
1008    # Iterate over restores, highest restore UID first to minimize the number
1009    # of assignments.
1010    deferred_restorations.sort(key=lambda position: position.restore_uid,
1011                               reverse=True)
1012    for checkpoint_position in deferred_restorations:
1013      checkpoint_position.restore(slot_variable)
1014
1015  def _create_or_restore_slot_variable(
1016      self, slot_variable_position, slot_name, variable):
1017    """Restore a slot variable's value, possibly creating it.
1018
1019    Called when a variable which has an associated slot variable is created or
1020    restored. When executing eagerly, we create the slot variable with a
1021    restoring initializer.
1022
1023    No new variables are created when graph building. Instead,
1024    _restore_slot_variable catches these after normal creation and adds restore
1025    ops to the graph. This method is nonetheless important when graph building
1026    for the case when a slot variable has already been created but `variable`
1027    has just been added to a dependency graph (causing us to realize that the
1028    slot variable needs to be restored).
1029
1030    Args:
1031      slot_variable_position: A `trackable._CheckpointPosition` object
1032        indicating the slot variable `Trackable` object to be restored.
1033      slot_name: The name of this `Optimizer`'s slot to restore into.
1034      variable: The variable object this slot is being created for.
1035    """
1036    variable_key = _var_key(variable)
1037    slot_dict = self._slots.get(variable_key, {})
1038    slot_variable = slot_dict.get(slot_name, None)
1039    if (slot_variable is None and context.executing_eagerly() and
1040        slot_variable_position.is_simple_variable()
1041        # Defer slot variable creation if there is an active variable creator
1042        # scope. Generally we'd like to eagerly create/restore slot variables
1043        # when possible, but this may mean that scopes intended to catch
1044        # `variable` also catch its eagerly created slot variable
1045        # unintentionally (specifically make_template would add a dependency on
1046        # a slot variable if not for this case). Deferring is mostly harmless
1047        # (aside from double initialization), and makes variable creator scopes
1048        # behave the same way they do when graph building.
1049        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
1050      initializer = trackable.CheckpointInitialValue(
1051          checkpoint_position=slot_variable_position)
1052      slot_variable = self.add_slot(
1053          var=variable,
1054          initializer=initializer,
1055          slot_name=slot_name)
1056      # Slot variables are not owned by any one object (because we don't want to
1057      # save the slot variable if the optimizer is saved without the non-slot
1058      # variable, or if the non-slot variable is saved without the optimizer;
1059      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1060      # variable, variable)). So we don't _track_ slot variables anywhere, and
1061      # instead special-case this dependency and otherwise pretend it's a normal
1062      # graph.
1063    if slot_variable is not None:
1064      # If we've either made this slot variable, or if we've pulled out an
1065      # existing slot variable, we should restore it.
1066      slot_variable_position.restore(slot_variable)
1067    else:
1068      # We didn't make the slot variable. Defer restoring until it gets created
1069      # normally. We keep a list rather than the one with the highest restore
1070      # UID in case slot variables have their own dependencies, in which case
1071      # those could differ between restores.
1072      self._deferred_slot_restorations.setdefault(
1073          slot_name, {}).setdefault(variable_key, []).append(
1074              slot_variable_position)
1075
1076
1077def _filter_grads(grads_and_vars):
1078  """Filter out iterable with grad equal to None."""
1079  grads_and_vars = tuple(grads_and_vars)
1080  if not grads_and_vars:
1081    return grads_and_vars
1082  filtered = []
1083  vars_with_empty_grads = []
1084  for grad, var in grads_and_vars:
1085    if grad is None:
1086      vars_with_empty_grads.append(var)
1087    else:
1088      filtered.append((grad, var))
1089  filtered = tuple(filtered)
1090  if not filtered:
1091    raise ValueError("No gradients provided for any variable: %s." %
1092                     ([v.name for _, v in grads_and_vars],))
1093  if vars_with_empty_grads:
1094    logging.warning(
1095        ("Gradients do not exist for variables %s when minimizing the loss."),
1096        ([v.name for v in vars_with_empty_grads]))
1097  return filtered
1098
1099
1100def _var_key(var):
1101  """Key for representing a primary variable, for looking up slots.
1102
1103  In graph mode the name is derived from the var shared name.
1104  In eager mode the name is derived from the var unique id.
1105  If distribution strategy exists, get the primary variable first.
1106
1107  Args:
1108    var: the variable.
1109
1110  Returns:
1111    the unique name of the variable.
1112  """
1113
1114  # pylint: disable=protected-access
1115  # Get the distributed variable if it exists.
1116  if hasattr(var, "_distributed_container"):
1117    var = var._distributed_container()
1118  if var._in_graph_mode:
1119    return var._shared_name
1120  return var._unique_id
1121
1122
1123def _get_slot_key_from_var(var, slot_name):
1124  """Get the slot key for the variable: var_name/slot_name."""
1125
1126  name = _var_key(var)
1127  return name + "/" + slot_name
1128
1129
1130class RestoredOptimizer(OptimizerV2):
1131  """A non-functional Optimizer implementation for checkpoint compatibility.
1132
1133  Holds slot variables and hyperparameters when an optimizer is restored from a
1134  SavedModel. These variables may be referenced in functions along with ops
1135  created by the original optimizer, but currently we do not support using the
1136  optimizer object iself (e.g. through `apply_gradients`).
1137  """
1138  # TODO(allenl): Make the restored optimizer functional by tracing its apply
1139  # methods.
1140
1141  def __init__(self):
1142    super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1143    self._hypers_created = True
1144
1145  def get_config(self):
1146    # TODO(allenl): Save and restore the Optimizer's config
1147    raise NotImplementedError(
1148        "Restoring functional Optimzers from SavedModels is not currently "
1149        "supported. Please file a feature request if this limitation bothers "
1150        "you.")
1151
1152revived_types.register_revived_type(
1153    "optimizer",
1154    lambda obj: isinstance(obj, OptimizerV2),
1155    versions=[revived_types.VersionedTypeRegistration(
1156        object_factory=lambda proto: RestoredOptimizer(),
1157        version=1,
1158        min_producer_version=1,
1159        min_consumer_version=1,
1160        setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1161    )])
1162