• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the loss scaling optimizer class."""
16
17from tensorflow.python.distribute import collective_all_reduce_strategy
18from tensorflow.python.distribute import distribution_strategy_context
19from tensorflow.python.distribute import mirrored_strategy
20from tensorflow.python.distribute import one_device_strategy
21from tensorflow.python.distribute import tpu_strategy
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import smart_cond
27from tensorflow.python.keras import backend
28from tensorflow.python.keras import optimizers
29from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module
30from tensorflow.python.keras.optimizer_v2 import optimizer_v2
31from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import tf_logging
37from tensorflow.python.training.experimental import loss_scale as loss_scale_module
38from tensorflow.python.training.experimental import mixed_precision
39from tensorflow.python.training.tracking import base as trackable
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import keras_export
42
43
44class _UnwrapPreventer(object):
45  """Wrapper that DistributionStrategy will not unwrap.
46
47  Typically, DistributionStrategy will unwrap values when going from a cross-
48  replica context to a replica context via `call_for_each_replica`. This class
49  is a wrapper that DistributionStrategy will not unwrap, so it can be used to
50  prevent it from unwrapping a value.
51
52  TODO(reedwm): Find/implement a better way of preventing values from being
53  unwrapped by DistributionStrategy
54  """
55
56  __slots__ = ['value']
57
58  def __init__(self, value):
59    self.value = value
60
61
62class _DelegatingTrackableMixin(object):
63  """A mixin that delegates all Trackable methods to another trackable object.
64
65  This class must be used with multiple inheritance. A class that subclasses
66  Trackable can also subclass this class, which causes all Trackable methods to
67  be delegated to the trackable object passed in the constructor.
68
69  A subclass can use this mixin to appear as if it were the trackable passed to
70  the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
71  mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
72  the checkpoint format for a normal optimizer. This allows a model to be saved
73  with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
74  The only difference in checkpoint format is that the loss scale is also saved
75  with a LossScaleOptimizer.
76  """
77
78  def __init__(self, trackable_obj):
79    self._trackable = trackable_obj
80
81  # pylint: disable=protected-access
82  @property
83  def _setattr_tracking(self):
84    return self._trackable._setattr_tracking
85
86  @_setattr_tracking.setter
87  def _setattr_tracking(self, value):
88    self._trackable._setattr_tracking = value
89
90  @property
91  def _update_uid(self):
92    return self._trackable._update_uid
93
94  @_update_uid.setter
95  def _update_uid(self, value):
96    self._trackable._update_uid = value
97
98  @property
99  def _unconditional_checkpoint_dependencies(self):
100    return self._trackable._unconditional_checkpoint_dependencies
101
102  @property
103  def _unconditional_dependency_names(self):
104    return self._trackable._unconditional_dependency_names
105
106  @property
107  def _name_based_restores(self):
108    return self._trackable._name_based_restores
109
110  def _maybe_initialize_trackable(self):
111    return self._trackable._maybe_initialize_trackable()
112
113  @property
114  def _object_identifier(self):
115    return self._trackable._object_identifier
116
117  @property
118  def _tracking_metadata(self):
119    return self._trackable._tracking_metadata
120
121  def _no_dependency(self, value):
122    return self._trackable._no_dependency(value)
123
124  def _name_based_attribute_restore(self, checkpoint):
125    return self._trackable._name_based_attribute_restore(checkpoint)
126
127  @property
128  def _checkpoint_dependencies(self):
129    return self._trackable._checkpoint_dependencies
130
131  @property
132  def _deferred_dependencies(self):
133    return self._trackable._deferred_dependencies
134
135  def _lookup_dependency(self, name):
136    self._trackable._lookup_dependency(name)
137
138  def _add_variable_with_custom_getter(self,
139                                       name,
140                                       shape=None,
141                                       dtype=dtypes.float32,
142                                       initializer=None,
143                                       getter=None,
144                                       overwrite=False,
145                                       **kwargs_for_getter):
146    return self._trackable._add_variable_with_custom_getter(
147        name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
148
149  def _preload_simple_restoration(self, name):
150    return self._trackable._preload_simple_restoration(name)
151
152  def _track_trackable(self, trackable, name, overwrite=False):  # pylint: disable=redefined-outer-name
153    return self._trackable._track_trackable(trackable, name, overwrite)
154
155  def _handle_deferred_dependencies(self, name, trackable):  # pylint: disable=redefined-outer-name
156    return self._trackable._handle_deferred_dependencies(name, trackable)
157
158  def _restore_from_checkpoint_position(self, checkpoint_position):
159    return self._trackable._restore_from_checkpoint_position(
160        checkpoint_position)
161
162  def _single_restoration_from_checkpoint_position(self, checkpoint_position,
163                                                   visit_queue):
164    return self._trackable._single_restoration_from_checkpoint_position(
165        checkpoint_position, visit_queue)
166
167  def _gather_saveables_for_checkpoint(self):
168    return self._trackable._gather_saveables_for_checkpoint()
169
170  def _list_extra_dependencies_for_serialization(self, serialization_cache):
171    return self._trackable._list_extra_dependencies_for_serialization(
172        serialization_cache)
173
174  def _list_functions_for_serialization(self, serialization_cache):
175    return self._trackable._list_functions_for_serialization(
176        serialization_cache)
177  # pylint: enable=protected-access
178
179
180def _is_all_finite(grads):
181  """Returns a scalar boolean tensor indicating if all gradients are finite."""
182  is_finite_per_grad = [
183      math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
184  ]
185  return math_ops.reduce_all(is_finite_per_grad)
186
187
188def _op_in_graph_mode(tensor):
189  """Returns the tensor's op in graph mode, or the tensor in eager mode.
190
191  This is useful because sometimes an op is needed in graph mode instead of a
192  tensor. In eager mode, there are no ops.
193
194  Args:
195    tensor: A tensor.
196
197  Returns:
198    The tensor's op in graph mode. The tensor in eager mode.
199  """
200  if context.executing_eagerly():
201    return tensor
202  return tensor.op
203
204
205def _assign_if_finite(var, value):
206  """Assigns a value to a variable if the value is finite."""
207  return control_flow_ops.cond(
208      math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
209      control_flow_ops.no_op)
210
211
212class _DynamicLossScaleState(trackable.Trackable):
213  """The state of a dynamic loss scale."""
214
215  def __init__(self,
216               initial_loss_scale,
217               growth_steps,
218               multiplier):
219    """Creates the dynamic loss scale."""
220    super(_DynamicLossScaleState, self).__init__()
221    self._initial_loss_scale = float(initial_loss_scale)
222    self._growth_steps = int(growth_steps)
223    self._multiplier = float(multiplier)
224
225    self._weights = {}
226    self._current_loss_scale = self._add_weight(
227        name='current_loss_scale',
228        dtype=dtypes.float32,
229        initial_value=self._initial_loss_scale)
230    # The number of consecutive steps with finite gradients since the last
231    # nonfinite gradient or change in loss scale. The name is 'good_steps' for
232    # backwards compatibility with older checkpoints.
233    self._counter = self._add_weight(
234        name='good_steps', dtype=dtypes.int64, initial_value=0)
235
236  def _add_weight(self, name, initial_value, dtype=None):
237    """Adds a weight to this loss scale.
238
239    Args:
240      name: Variable name.
241      initial_value: The variable's initial value.
242      dtype: The type of the variable.
243
244    Returns:
245      A variable.
246
247    Raises:
248      RuntimeError: If a weight with `name` has already been added.
249    """
250    variable = variable_scope.variable(
251        initial_value=initial_value,
252        name=name,
253        dtype=dtype,
254        trainable=False,
255        use_resource=True,
256        synchronization=variables.VariableSynchronization.AUTO,
257        # Set aggregation to NONE, as loss scaling variables should never be
258        # aggregated.
259        aggregation=variables.VariableAggregation.NONE)
260    if context.executing_eagerly():
261      graph_key = None
262    else:
263      graph = ops.get_default_graph()
264      graph_key = graph._graph_key  # pylint: disable=protected-access
265
266    key = (name, graph_key)
267    self._weights[key] = variable
268    self._handle_deferred_dependencies(name=name, trackable=variable)
269    backend.track_variable(variable)
270    return variable
271
272  @property
273  def _checkpoint_dependencies(self):
274    """From Trackable. Gather graph-specific weights to save."""
275    if context.executing_eagerly():
276      graph_key = None
277    else:
278      graph = ops.get_default_graph()
279      graph_key = graph._graph_key  # pylint: disable=protected-access
280    weights = []
281    for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
282      if g == graph_key:
283        weights.append(trackable.TrackableReference(name=name, ref=v))
284    return (super(_DynamicLossScaleState, self)._checkpoint_dependencies +
285            weights)
286
287  def _lookup_dependency(self, name):
288    """From Trackable. Find a weight in the current graph."""
289    unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name)
290    if unconditional is not None:
291      return unconditional
292    if context.executing_eagerly():
293      graph_key = None
294    else:
295      graph = ops.get_default_graph()
296      graph_key = graph._graph_key  # pylint: disable=protected-access
297    return self._weights.get((name, graph_key), None)
298
299  @property
300  def initial_loss_scale(self):
301    return self._initial_loss_scale
302
303  @property
304  def growth_steps(self):
305    return self._growth_steps
306
307  @property
308  def multiplier(self):
309    return self._multiplier
310
311  @property
312  def current_loss_scale(self):
313    """Returns the current loss scale as a float32 `tf.Variable`."""
314    return self._current_loss_scale
315
316  @property
317  def counter(self):
318    """Returns the counter as a float32 `tf.Variable`."""
319    return self._counter
320
321  def __call__(self):
322    """Returns the current loss scale as a scalar `float32` tensor."""
323    return ops.convert_to_tensor_v2_with_dispatch(self._current_loss_scale)
324
325  def update(self, grads):
326    """Updates the value of the loss scale.
327
328    Args:
329      grads: A nested structure of unscaled gradients, each which is an
330        all-reduced gradient of the loss with respect to a weight.
331
332    Returns:
333      update_op: In eager mode, None. In graph mode, an op to update the loss
334        scale.
335      should_apply_gradients: Either a bool or a scalar boolean tensor. If
336        False, the caller should skip applying `grads` to the variables this
337        step.
338    """
339    grads = nest.flatten(grads)
340    if distribution_strategy_context.has_strategy(
341    ) and distribution_strategy_context.in_cross_replica_context():
342      distribution = distribution_strategy_context.get_strategy()
343      is_finite_per_replica = distribution.extended.call_for_each_replica(
344          _is_all_finite, args=(grads,))
345      # Each replica computed the same `is_finite` value, since `grads` is
346      # all-reduced across replicas. Arbitrarily take `is_finite` from the first
347      # replica.
348      is_finite = (
349          distribution.experimental_local_results(is_finite_per_replica)[0])
350    else:
351      is_finite = _is_all_finite(grads)
352
353    def update_if_finite_grads():
354      """Update assuming the gradients are finite."""
355
356      def incr_loss_scale():
357        new_loss_scale = self.current_loss_scale * self.multiplier
358        return control_flow_ops.group(
359            _assign_if_finite(self.current_loss_scale, new_loss_scale),
360            self.counter.assign(0))
361
362      return control_flow_ops.cond(
363          self.counter + 1 >= self.growth_steps,
364          incr_loss_scale,
365          lambda: _op_in_graph_mode(self.counter.assign_add(1)))
366
367    def update_if_not_finite_grads():
368      """Update assuming the gradients are nonfinite."""
369
370      new_loss_scale = math_ops.maximum(
371          self.current_loss_scale / self.multiplier, 1)
372      return control_flow_ops.group(
373          self.counter.assign(0),
374          self.current_loss_scale.assign(new_loss_scale))
375
376    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
377                                      update_if_not_finite_grads)
378    should_apply_gradients = is_finite
379    return update_op, should_apply_gradients
380
381
382# See LossScaleOptimizer docstring for why this is so big
383_DEFAULT_INITIAL_SCALE = 2 ** 15
384_DEFAULT_GROWTH_STEPS = 2000
385
386
387# pylint: disable=g-classes-have-attributes
388@keras_export('keras.mixed_precision.LossScaleOptimizer')
389class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
390  """An optimizer that applies loss scaling to prevent numeric underflow.
391
392  Loss scaling is a technique to prevent numeric underflow in intermediate
393  gradients when float16 is used. To prevent underflow, the loss is multiplied
394  (or "scaled") by a certain factor called the "loss scale", which causes
395  intermediate gradients to be scaled by the loss scale as well. The final
396  gradients are divided (or "unscaled") by the loss scale to bring them back to
397  their original value.
398
399  `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
400  By default, the loss scale is dynamically updated over time so you do not have
401  to choose the loss scale. The `minimize` method automatically scales the loss,
402  unscales the gradients, and updates the loss scale so all you have to do is
403  wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For
404  example:
405
406  >>> opt = tf.keras.optimizers.SGD(0.25)
407  >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
408  >>> var = tf.Variable(1.)
409  >>> loss_fn = lambda: var ** 2
410  >>> # 'minimize' applies loss scaling and updates the loss sale.
411  >>> opt.minimize(loss_fn, var_list=var)
412  >>> var.numpy()
413  0.5
414
415  If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you
416  must scale the loss and gradients manually. This can be done with the
417  `LossScaleOptimizer.get_scaled_loss` and
418  `LossScaleOptimizer.get_unscaled_gradients` methods. For example:
419
420  >>> with tf.GradientTape() as tape:
421  ...   loss = loss_fn()
422  ...   scaled_loss = opt.get_scaled_loss(loss)
423  >>> scaled_grad = tape.gradient(scaled_loss, var)
424  >>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
425  >>> opt.apply_gradients([(grad, var)])  # Loss scale is updated here
426  >>> var.numpy()
427  0.25
428
429  Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients`
430  (or both) when using a `tf.GradientTape`, the model will likely converge to a
431  worse quality. Please make sure you call each function exactly once.
432
433  When mixed precision with float16 is used, there is typically no risk of
434  underflow affecting model quality if loss scaling is properly used. See
435  [the mixed precision guide](
436  https://www.tensorflow.org/guide/keras/mixed_precision) for more information
437  on how to use mixed precision.
438
439  Args:
440    inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap.
441    dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
442      True. If True, the loss scale will be dynamically updated over time using
443      an algorithm that keeps the loss scale at approximately its optimal value.
444      If False, a single fixed loss scale is used and `initial_scale` must be
445      specified, which is used as the loss scale. Recommended to keep as True,
446      as choosing a fixed loss scale can be tricky. Currently, there is a small
447      performance overhead to dynamic loss scaling compared to fixed loss
448      scaling.
449    initial_scale: The initial loss scale. If `dynamic` is True, this defaults
450      to `2 ** 15`. If `dynamic` is False, this must be specified and acts as
451      the sole loss scale, as the loss scale does not change over time. When
452      dynamic loss scaling is used, is better for this to be a very high number,
453      because a loss scale that is too high gets lowered far more quickly than a
454      loss scale that is too low gets raised.
455    dynamic_growth_steps: With dynamic loss scaling, every
456      `dynamic_growth_steps` steps with finite gradients, the loss scale is
457      doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
458      count is reset back to zero, gradients are skipped that step, and the loss
459      scale is halved. The count can be queried with
460      `LossScaleOptimizer.dynamic_counter`. This argument can only be specified
461      if `dynamic` is True.
462
463  `LossScaleOptimizer` will occasionally skip applying gradients to the
464  variables, in which case the trainable variables will not change that step.
465  This is done because the dynamic loss scale will sometimes be raised too
466  high, causing overflow in the gradients. Typically, the first 2 to 15 steps of
467  the model are skipped as the initial loss scale is very high, but afterwards
468  steps will only be skipped on average 0.05% of the time (the fraction of steps
469  skipped is `1 / dynamic_growth_steps`).
470
471  `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner
472  optimizer. Additionally, in methods `minimize` and `get_gradients`, it scales
473  the loss and unscales the gradients. In methods `minimize` and
474  `apply_gradients`, it additionally updates the loss scale and skips applying
475  gradients if any gradient has a nonfinite value.
476
477  ### Hyperparameters
478
479  Hyperparameters can be accessed and set on the LossScaleOptimizer, which will
480  be delegated to the wrapped optimizer.
481
482  >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5)
483  >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
484  >>> opt.beta_1  # Equivalent to `opt.inner_optimizer.beta_1`
485  0.8
486  >>> opt.beta_1 = 0.7  # Equivalent to `opt.inner_optimizer.beta_1 = 0.7`
487  >>> opt.beta_1
488  0.7
489  >>> opt.inner_optimizer.beta_1
490  0.7
491
492  However, accessing or setting non-hyperparameters is not delegated to the
493  LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but
494  `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on
495  `beta_1`.
496
497  >>> opt.inner_optimizer.epsilon
498  1e-5
499  >>> opt.epsilon
500  Traceback (most recent call last):
501  ...
502  AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon'
503  >>> opt.epsilon = 1e-4  # This does NOT set epsilon on `opt.inner_optimizer`
504  >>> opt.inner_optimizer.epsilon
505  >>> 1e-5
506
507  In the above example, despite epsilon being set on the LossScaleOptimizer, the
508  old epsilon value will still be used when training as epsilon was not set on
509  the inner optimizer.
510  """
511
512  _HAS_AGGREGATE_GRAD = True
513
514  def __init__(self, inner_optimizer, dynamic=True, initial_scale=None,
515               dynamic_growth_steps=None):
516    if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2):
517      raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, '
518                      'but got: %s' % inner_optimizer)
519    if not isinstance(dynamic, bool):
520      # Catch errors if a user incorrectly passes a string or float to the
521      # second argument argument, as this is commonly done for
522      # LossScaleOptimizerV1.
523      raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must '
524                      'be a bool, but got: %r' % (dynamic,))
525    if isinstance(inner_optimizer, LossScaleOptimizer):
526      raise TypeError('LossScaleOptimizer cannot wrap another '
527                      'LossScaleOptimizer, but got: %s' % (inner_optimizer,))
528    self._raise_if_strategy_unsupported()
529    if getattr(inner_optimizer, '_is_wrapped_by_loss_scale_optimizer', False):
530      # TODO(reedwm): Maybe support this. The difficulty is that LSO has the
531      # same checkpoint format as the inner optimizer, so multiple LSOs wrapping
532      # the same optimizer causes the checkpointing logic to become confused.
533      raise ValueError('"inner_optimizer" is already wrapped by a '
534                       'LossScaleOptimizer. An optimizer can only be wrapped '
535                       'by a single LossScaleOptimizer')
536    self._optimizer = inner_optimizer
537    self._optimizer._is_wrapped_by_loss_scale_optimizer = True
538
539    # We don't call super().__init__, since we do not want to call OptimizerV2's
540    # constructor.
541    _DelegatingTrackableMixin.__init__(self, self._optimizer)
542
543    if dynamic:
544      if initial_scale is None:
545        initial_scale = _DEFAULT_INITIAL_SCALE
546      if dynamic_growth_steps is None:
547        dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
548      self._loss_scale = _DynamicLossScaleState(
549          initial_scale, dynamic_growth_steps, multiplier=2)
550      self._track_trackable(self._loss_scale, 'loss_scale')
551    else:
552      if initial_scale is None:
553        raise ValueError('"initial_scale" must be specified if "dynamic" is '
554                         'False')
555      self._loss_scale = float(initial_scale)
556      if dynamic_growth_steps is not None:
557        raise ValueError('"dynamic_growth_steps" must be None if "dynamic" '
558                         'is False, but got: %s' % (dynamic_growth_steps,))
559
560    # To support restoring TensorFlow 2.2 checkpoints.
561    self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
562                          'base_optimizer')
563
564  @property
565  def dynamic(self):
566    """Bool indicating whether dynamic loss scaling is used."""
567    return isinstance(self._loss_scale, _DynamicLossScaleState)
568
569  @property
570  def loss_scale(self):
571    """The current loss scale as a float32 scalar tensor."""
572    if isinstance(self._loss_scale, _DynamicLossScaleState):
573      return ops.convert_to_tensor_v2_with_dispatch(
574          self._loss_scale.current_loss_scale)
575    else:
576      return ops.convert_to_tensor_v2_with_dispatch(self._loss_scale)
577
578  @property
579  def dynamic_counter(self):
580    """The number of steps since the loss scale was last increased or decreased.
581
582    This is None if `LossScaleOptimizer.dynamic` is False.
583
584    The counter is incremented every step. Once it reaches
585    `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled
586    and the counter will be reset back to zero. If nonfinite gradients are
587    encountered, the loss scale will be halved and the counter will be reset
588    back to zero.
589    """
590    if isinstance(self._loss_scale, _DynamicLossScaleState):
591      return self._loss_scale.counter
592    else:
593      return None
594
595  @property
596  def initial_scale(self):
597    """The initial loss scale.
598
599    If `LossScaleOptimizer.dynamic` is False, this is the same number as
600    `LossScaleOptimizer.loss_scale`, as the loss scale never changes.
601    """
602    if isinstance(self._loss_scale, _DynamicLossScaleState):
603      return self._loss_scale.initial_loss_scale
604    else:
605      return self._loss_scale
606
607  @property
608  def dynamic_growth_steps(self):
609    """The number of steps it takes to increase the loss scale.
610
611    This is None if `LossScaleOptimizer.dynamic` is False.
612
613    Every `dynamic_growth_steps` consecutive steps with finite gradients, the
614    loss scale is increased.
615    """
616    if isinstance(self._loss_scale, _DynamicLossScaleState):
617      return self._loss_scale.growth_steps
618    else:
619      return None
620
621  @property
622  def inner_optimizer(self):
623    """The optimizer that this LossScaleOptimizer is wrapping."""
624    return self._optimizer
625
626  def get_scaled_loss(self, loss):
627    """Scales the loss by the loss scale.
628
629    This method is only needed if you compute gradients manually, e.g. with
630    `tf.GradientTape`. In that case, call this method to scale the loss before
631    passing the loss to `tf.GradientTape`. If you use
632    `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
633    scaling is automatically applied and this method is unneeded.
634
635    If this method is called, `get_unscaled_gradients` should also be called.
636    See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for
637    an example.
638
639    Args:
640      loss: The loss, which will be multiplied by the loss scale. Can either be
641        a tensor or a callable returning a tensor.
642
643    Returns:
644      `loss` multiplied by `LossScaleOptimizer.loss_scale`.
645    """
646    if callable(loss):
647      def new_loss():
648        loss_val = loss()
649        return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype)
650      return new_loss
651    else:
652      return loss * math_ops.cast(self.loss_scale, loss.dtype)
653
654  def get_unscaled_gradients(self, grads):
655    """Unscales the gradients by the loss scale.
656
657    This method is only needed if you compute gradients manually, e.g. with
658    `tf.GradientTape`. In that case, call this method to unscale the gradients
659    after computing them with `tf.GradientTape`. If you use
660    `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
661    scaling is automatically applied and this method is unneeded.
662
663    If this method is called, `get_scaled_loss` should also be called. See
664    the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an
665    example.
666
667    Args:
668      grads: A list of tensors, each which will be divided by the loss scale.
669        Can have None values, which are ignored.
670
671    Returns:
672      A new list the same size as `grads`, where every non-None value in `grads`
673      is divided by `LossScaleOptimizer.loss_scale`.
674    """
675    loss_scale_reciprocal = 1. / self.loss_scale
676    return [
677        _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None
678        for g in grads
679    ]
680
681  def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
682    tape = backprop.GradientTape() if tape is None else tape
683    with tape:
684      loss = self.get_scaled_loss(loss)
685    grads_and_vars = self._optimizer._compute_gradients(  # pylint: disable=protected-access
686        loss,
687        var_list,
688        grad_loss,
689        tape=tape)
690    grads = [g for g, _ in grads_and_vars]
691    weights = [v for _, v in grads_and_vars]
692    unscaled_grads = self.get_unscaled_gradients(grads)
693    return list(zip(unscaled_grads, weights))
694
695  def get_gradients(self, loss, params):
696    loss = self.get_scaled_loss(loss)
697    grads = self._optimizer.get_gradients(loss, params)
698    return self.get_unscaled_gradients(grads)
699
700  def _create_all_weights(self, var_list):
701    self._optimizer._create_all_weights(var_list)    # pylint: disable=protected-access
702
703  def apply_gradients(self,
704                      grads_and_vars,
705                      name=None,
706                      experimental_aggregate_gradients=True):
707    if distribution_strategy_context.in_cross_replica_context():
708      raise ValueError('apply_gradients() must be called in a replica context.')
709    # We check for the strategy here despite already checking in the constructor
710    # as frequently the optimizer is created outside the strategy's scope.
711    self._raise_if_strategy_unsupported()
712
713    grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
714    if experimental_aggregate_gradients:
715      # We must aggregate the gradients here instead of in
716      # self.optimizer.apply_gradients, so that any NaN or Inf gradients are
717      # propogated to each replica. If any replica has a NaN or Inf gradient,
718      # they must all have a NaN or Inf gradient so that they all skip the step.
719      # pylint: disable=protected-access
720      grads_and_vars = self._optimizer._transform_unaggregated_gradients(
721          grads_and_vars)
722      grads_and_vars = self._optimizer._aggregate_gradients(grads_and_vars)
723      # pylint: enable=protected-access
724
725    grads_and_vars = tuple(grads_and_vars)
726    grads = [g for g, _ in grads_and_vars]
727    # We do not want DistributionStrategy to unwrap any MirroredVariables in
728    # grads_and_vars, because even in a replica context, the wrapped
729    # optimizer expects mirrored variables. So we wrap the variables with an
730    # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
731    # MirroredVariables.
732    wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
733
734    def do_not_apply_fn():
735      # Normally self._optimizer.iterations is incremented in
736      # self._optimizer.apply_gradients(). Since that is not called in this
737      # branch, we increment it here instead.
738      return self._optimizer.iterations.assign_add(1, read_value=False)
739
740    def _if_should_apply_grads(grads):
741      if isinstance(self._loss_scale, _DynamicLossScaleState):
742        return self._loss_scale.update(grads)
743      else:
744        return (control_flow_ops.no_op(), True)
745
746    if optimizer_utils.strategy_supports_no_merge_call():
747      loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads)
748      def apply_fn():
749        return self._apply_gradients(grads, wrapped_vars, name)
750
751      maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
752                                             do_not_apply_fn)
753      return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
754
755    else:
756
757      def _apply_gradients_cross_replica(distribution, grads, wrapped_vars,
758                                         name):
759        loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads)
760
761        def apply_fn():
762          return distribution.extended.call_for_each_replica(
763              self._apply_gradients,
764              args=(grads, wrapped_vars, name))
765
766        # Note: We must call this cond() in a cross-replica context.
767        # DistributionStrategy does not support having a cond in a replica
768        # context with a branch that calls `merge_call`, and
769        # self._optimizer.apply_gradients calls `merge_call`.
770        maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
771                                               do_not_apply_fn)
772        return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
773      return distribution_strategy_context.get_replica_context().merge_call(
774          _apply_gradients_cross_replica,
775          args=(grads, wrapped_vars, name))
776
777  def _apply_gradients(self, grads, wrapped_vars, name):
778    # Pass experimental_aggregate_gradients=False since LossScaleOptimizer
779    # already aggregated the gradients.
780    # TODO(reedwm): This will raise a fairly cryptic error message if
781    # self._optimizer.apply_gradients does not take
782    # experimental_aggregate_gradients.
783    return self._optimizer.apply_gradients(
784        list(zip(grads, wrapped_vars.value)), name,
785        experimental_aggregate_gradients=False)
786
787  def get_config(self):
788    serialized_optimizer = optimizers.serialize(self._optimizer)
789    return {
790        'inner_optimizer': serialized_optimizer,
791        'dynamic': self.dynamic,
792        'initial_scale': self.initial_scale,
793        'dynamic_growth_steps': self.dynamic_growth_steps,
794    }
795
796  @classmethod
797  def from_config(cls, config, custom_objects=None):
798    config = config.copy()  # Make a copy, since we mutate config
799    if 'loss_scale' in config:
800      # If loss_scale is in config, we assume we are deserializing a
801      # LossScaleOptimizer from TF 2.3 or below. We convert the config so it
802      # can be deserialized in the current LossScaleOptimizer.
803      loss_scale = keras_loss_scale_module.deserialize(
804          config.pop('loss_scale'))
805      if isinstance(loss_scale, loss_scale_module.FixedLossScale):
806        config['dynamic'] = False
807        config['initial_scale'] = loss_scale._loss_scale_value  # pylint: disable=protected-access
808      elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
809        config['dynamic'] = True
810        config['initial_scale'] = loss_scale.initial_loss_scale
811        config['dynamic_growth_steps'] = loss_scale.increment_period
812        if loss_scale.multiplier != 2:
813          raise ValueError('Cannot deserialize LossScaleOptimizer with a '
814                           'DynamicLossScale whose multiplier is not 2. Got '
815                           'DynamicLossScale: %s' % (loss_scale,))
816      else:
817        raise ValueError(
818            'Serialized LossScaleOptimizers with a LossScale that is neither a '
819            'FixedLossScale nor a DynamicLossScale can no longer be '
820            'deserialized')
821      config['inner_optimizer'] = config.pop('optimizer')
822    config['inner_optimizer'] = optimizers.deserialize(
823        config['inner_optimizer'], custom_objects=custom_objects)
824    return cls(**config)
825
826  def _raise_if_strategy_unsupported(self):
827    if not strategy_supports_loss_scaling():
828      strategy = distribution_strategy_context.get_strategy()
829      if isinstance(strategy,
830                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
831                     tpu_strategy.TPUStrategyV2)):
832        raise ValueError(
833            'Loss scaling is not supported with TPUStrategy. Loss scaling is '
834            'unnecessary with TPUs, since they support bfloat16 instead of '
835            'float16 and bfloat16 does not require loss scaling. You should '
836            'remove the use of the LossScaleOptimizer when TPUs are used.')
837      else:
838        raise ValueError('Loss scaling is not supported with the '
839                         'tf.distribute.Strategy: %s. Try using a different '
840                         'Strategy, e.g. a MirroredStrategy' %
841                         strategy.__class__.__name__)
842
843  # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
844  # below.
845
846  @property
847  def iterations(self):
848    return self._optimizer.iterations
849
850  @iterations.setter
851  def iterations(self, variable):
852    self._optimizer.iterations = variable
853
854  def get_slot_names(self):
855    return self._optimizer.get_slot_names()
856
857  def variables(self):
858    return self._optimizer.variables()
859
860  @property
861  def weights(self):
862    return self._optimizer.weights
863
864  def get_weights(self):
865    return self._optimizer.get_weights()
866
867  def set_weights(self, weights):
868    return self._optimizer.set_weights(weights)
869
870  @property
871  def clipnorm(self):
872    return self._optimizer.clipnorm
873
874  @clipnorm.setter
875  def clipnorm(self, val):
876    self._optimizer.clipnorm = val
877
878  @property
879  def global_clipnorm(self):
880    return self._optimizer.global_clipnorm
881
882  @global_clipnorm.setter
883  def global_clipnorm(self, val):
884    self._optimizer.global_clipnorm = val
885
886  @property
887  def clipvalue(self):
888    return self._optimizer.clipvalue
889
890  @clipvalue.setter
891  def clipvalue(self, val):
892    self._optimizer.clipvalue = val
893
894  def _aggregate_gradients(self, grads_and_vars):
895    return self._optimizer._aggregate_gradients(grads_and_vars)  # pylint: disable=protected-access
896
897  def _restore_slot_variable(self, slot_name, variable, slot_variable):
898    return self._optimizer._restore_slot_variable(slot_name, variable,  # pylint: disable=protected-access
899                                                  slot_variable)
900
901  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
902                                       variable):
903    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
904        slot_variable_position, slot_name, variable)
905
906  def get_slot(self, var, slot_name):
907    return self._optimizer.get_slot(var, slot_name)
908
909  def add_slot(self, var, slot_name, initializer='zeros'):
910    return self._optimizer.add_slot(var, slot_name, initializer)
911
912  def __getattribute__(self, name):
913    try:
914      return object.__getattribute__(self, name)
915    except AttributeError as e:
916      if name == '_optimizer' or name == '_hyper':
917        # Avoid infinite recursion
918        raise e
919
920      # Delegate hyperparameter accesses to inner optimizer.
921      if name == 'lr':
922        name = 'learning_rate'
923      if name in self._optimizer._hyper:
924        return self._optimizer._get_hyper(name)
925      raise e
926
927  def __dir__(self):
928    result = set(super(LossScaleOptimizer, self).__dir__())
929    if '_optimizer' in result:
930      result |= self._optimizer._hyper.keys()
931      if 'learning_rate' in self._optimizer._hyper.keys():
932        result.add('lr')
933    return list(result)
934
935  def __setattr__(self, name, value):
936    if name == 'lr':
937      name = 'learning_rate'
938    # Delegate setting hyperparameter to inner optimizer if the attribute does
939    # not exist on the LossScaleOptimizer
940    try:
941      # We cannot check for the 'iterations' attribute as it cannot be set after
942      # it is accessed.
943      if name != 'iterations':
944        object.__getattribute__(self, name)
945      has_attribute = True
946    except AttributeError:
947      has_attribute = False
948    if (name != '_optimizer' and name in self._optimizer._hyper
949        and not has_attribute):
950      self._optimizer._set_hyper(name, value)
951    else:
952      super(LossScaleOptimizer, self).__setattr__(name, value)
953
954  # Explicitly delegate learning_rate. Normally hyperparameters are delegated in
955  # __getattribute__, but if a hyperparameter is not in self._optimizer._hyper
956  # (e.g. because self._optimizer itself wraps another optimizer), then it won't
957  # be delegated. Since learning_rate is a very commonly accessed
958  # hyperparameter, we delegate it here.
959  @property
960  def learning_rate(self):
961    return self._optimizer.learning_rate
962
963  @learning_rate.setter
964  def learning_rate(self, value):
965    self._optimizer.learning_rate = value
966
967  @property
968  def lr(self):
969    return self._optimizer.learning_rate
970
971  @lr.setter
972  def lr(self, value):
973    self._optimizer.lr = value
974
975  # We do not override some OptimizerV2 methods. For each, we describe why we do
976  # not delegate them to self._optimizer:
977  # * get_updates: get_updates() calls get_gradients(). Since we override
978  #   get_gradients(), we cannot delegate get_updates() to self._optimizer,
979  #   otherwise the overridden get_gradients() method would not be called.
980  #   Luckily, get_updates() does not access any OptimizerV2 fields, so
981  #   inheriting the OptimizerV2 version works fine.
982  # * minimize: We don't delegate for a similar as get_updates(): it calls
983  #   both self._compute_gradients() and self.apply_gradients(), and both need
984  #   to have the LossScaleOptimizer version called.
985
986  # TODO(reedwm): Maybe throw an error if mixed precision is used without this
987  # optimizer being used.
988
989
990@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
991class LossScaleOptimizerV1(LossScaleOptimizer):
992  """An deprecated optimizer that applies loss scaling.
993
994  Warning: This class is deprecated and will be removed in a future version of
995  TensorFlow. Please use the non-experimental class
996  `tf.keras.mixed_precision.LossScaleOptimizer` instead.
997
998  This class is identical to the non-experimental
999  `keras.mixed_precision.LossScaleOptimizer` except its constructor takes
1000  different arguments. For this class (the experimental version), the
1001  constructor takes a `loss_scale` argument.  For the non-experimental class,
1002  the constructor encodes the loss scaling information in multiple arguments.
1003  Note that unlike this class, the non-experimental class does not accept a
1004  `tf.compat.v1.mixed_precision.LossScale`, which is deprecated.
1005
1006  If you currently use this class, you should switch to the non-experimental
1007  `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several
1008  examples of converting the use of the experimental class to the equivalent
1009  non-experimental class.
1010
1011  >>> # In all of the the examples below, `opt1` and `opt2` are identical
1012  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
1013  ...     tf.keras.optimizers.SGD(), loss_scale='dynamic')
1014  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
1015  ...     tf.keras.optimizers.SGD())
1016  >>> assert opt1.get_config() == opt2.get_config()
1017
1018  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
1019  ...     tf.keras.optimizers.SGD(), loss_scale=123)
1020  >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123
1021  >>> # refers to the initial loss scale, which is the single fixed loss scale
1022  >>> # when dynamic=False.
1023  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
1024  ...     tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123)
1025  >>> assert opt1.get_config() == opt2.get_config()
1026
1027  >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale(
1028  ...     initial_loss_scale=2048, increment_period=500)
1029  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
1030  ...     tf.keras.optimizers.SGD(), loss_scale=loss_scale)
1031  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
1032  ...     tf.keras.optimizers.SGD(), initial_scale=2048,
1033  ...     dynamic_growth_steps=500)
1034  >>> assert opt1.get_config() == opt2.get_config()
1035
1036  Make sure to also switch from this class to the non-experimental class in
1037  isinstance checks, if you have any. If you do not do this, your model may run
1038  into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses
1039  the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to
1040  switch isinstance checks to the non-experimental `LossScaleOptimizer` even
1041  before using the non-experimental `LossScaleOptimizer`.
1042
1043  >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
1044  ...     tf.keras.optimizers.SGD(), loss_scale='dynamic')
1045  >>> # The experimental class subclasses the non-experimental class
1046  >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer)
1047  True
1048  >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
1049  ...     tf.keras.optimizers.SGD())
1050  >>> # The non-experimental class does NOT subclass the experimental class.
1051  >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
1052  False
1053
1054  Args:
1055    optimizer: The Optimizer instance to wrap.
1056    loss_scale: The loss scale to scale the loss and gradients. This can
1057      either be an int/float to use a fixed loss scale, the string "dynamic"
1058      to use dynamic loss scaling, or an instance of a LossScale. The string
1059      "dynamic" equivalent to passing `DynamicLossScale()`, and passing an
1060      int/float is equivalent to passing a FixedLossScale with the given loss
1061      scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must
1062      be 2 (the default).
1063  """
1064
1065  def __init__(self, optimizer, loss_scale):
1066    warn_msg_prefix = (
1067        'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
1068        'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
1069        'instead. ')
1070
1071    if isinstance(loss_scale, dict):
1072      loss_scale = keras_loss_scale_module.deserialize(loss_scale)
1073
1074    if isinstance(loss_scale, (int, float)):
1075      tf_logging.warning(
1076          warn_msg_prefix + 'For example:\n'
1077          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
1078          'opt, dynamic=False, initial_scale={})'.format(loss_scale))
1079      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
1080                                                 initial_scale=loss_scale)
1081    elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
1082      ls_val = loss_scale._loss_scale_value  # pylint: disable=protected-access
1083      tf_logging.warning(
1084          warn_msg_prefix + 'For example:\n'
1085          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
1086          'opt, dynamic=False, initial_scale={})'.format(ls_val))
1087      super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
1088                                                 initial_scale=ls_val)
1089    elif loss_scale == 'dynamic':
1090      tf_logging.warning(
1091          warn_msg_prefix + 'For example:\n'
1092          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
1093          'opt)')
1094      super(LossScaleOptimizerV1, self).__init__(optimizer)
1095    elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
1096      kwargs = {}
1097      extra_arguments = ''
1098      if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
1099        kwargs['initial_scale'] = loss_scale.initial_loss_scale
1100        extra_arguments += (', initial_scale=%s' %
1101                            loss_scale.initial_loss_scale)
1102      if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
1103        kwargs['dynamic_growth_steps'] = loss_scale.increment_period
1104        extra_arguments += (', dynamic_growth_steps=%s' %
1105                            loss_scale.increment_period)
1106      if loss_scale.multiplier != 2:
1107        raise ValueError('When passing a DynamicLossScale to "loss_scale", '
1108                         'DynamicLossScale.multiplier must be 2. Got: %s'
1109                         % (loss_scale,))
1110      tf_logging.warning(
1111          warn_msg_prefix +
1112          'Note that the non-experimental LossScaleOptimizer does not take a '
1113          'DynamicLossScale but instead takes the dynamic configuration '
1114          'directly in the constructor. For example:\n'
1115          '  opt = tf.keras.mixed_precision.LossScaleOptimizer('
1116          'opt{})\n'.format(extra_arguments))
1117      super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
1118    elif isinstance(loss_scale, loss_scale_module.LossScale):
1119      raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
1120                      'DynamicLossScale is no longer supported. Got: {}'
1121                      .format(loss_scale))
1122    else:
1123      raise ValueError('Invalid value passed to loss_scale. loss_scale '
1124                       'must be the string "dynamic" (recommended), an int, '
1125                       'a float, a FixedLossScale, or a DynamicLossScale. Got '
1126                       'value: {}'.format(loss_scale))
1127
1128  @classmethod
1129  def from_config(cls, config, custom_objects=None):
1130    config = config.copy()  # Make a copy, since we mutate config
1131
1132    # If loss_scale is in config, we assume we are deserializing a
1133    # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are
1134    # deserializing a LossScaleOptimizer from TF 2.4 or above.
1135    if 'loss_scale' in config:
1136      config['loss_scale'] = keras_loss_scale_module.deserialize(
1137          config['loss_scale'])
1138      if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale)
1139          and config['loss_scale'].multiplier != 2):
1140        raise ValueError('Cannot deserialize LossScaleOptimizer with a '
1141                         'DynamicLossScale whose multiplier is not 2. Got '
1142                         'DynamicLossScale: %s' % (config['loss_scale'],))
1143      config['optimizer'] = optimizers.deserialize(
1144          config['optimizer'], custom_objects=custom_objects)
1145      return cls(**config)
1146
1147    # We convert the config, as generated by LossScaleOptimizer.get_config, to a
1148    # version that can be passed to LossScaleOptimizerV1.__init__
1149    if config['dynamic']:
1150      config['loss_scale'] = loss_scale_module.DynamicLossScale(
1151          config['initial_scale'], config['dynamic_growth_steps'], multiplier=2)
1152    else:
1153      config['loss_scale'] = loss_scale_module.FixedLossScale(
1154          config['initial_scale'])
1155
1156    del config['dynamic']
1157    del config['initial_scale']
1158    del config['dynamic_growth_steps']
1159    config['optimizer'] = optimizers.deserialize(
1160        config.pop('inner_optimizer'), custom_objects=custom_objects)
1161    return cls(**config)
1162
1163
1164class FakeOptimizerForRestoration(trackable.Trackable):
1165  """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
1166
1167  The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
1168  exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
1169
1170  In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
1171  following in LossScaleOptimizer.__init__
1172
1173  ```
1174  self._track_trackable(self._optimizer, 'base_optimizer')
1175  ```
1176
1177  This means a dependency from the LossScaleOptimizer to the wrapped optimizer
1178  would be stored in the checkpoint. However now, the checkpoint format with a
1179  LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
1180  except the loss scale is also stored. This means there is no dependency from
1181  the LossScaleOptimizer to the wrapped optimizer. Instead, the
1182  LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
1183  perspective, by overriding all Trackable methods and delegating them to the
1184  wrapped optimizer.
1185
1186  To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
1187  on this class instead of the inner optimizer. When restored, this class will
1188  instead restore the slot variables of the inner optimizer. Since this class
1189  has no variables, it does not affect the checkpoint when saved.
1190  """
1191
1192  def __init__(self, optimizer):
1193    self._optimizer = optimizer
1194
1195  def get_slot_names(self):
1196    return self._optimizer.get_slot_names()
1197
1198  def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
1199                                       variable):
1200    return self._optimizer._create_or_restore_slot_variable(  # pylint: disable=protected-access
1201        slot_variable_position, slot_name, variable)
1202
1203
1204mixed_precision.register_loss_scale_wrapper(optimizer_v2.OptimizerV2,
1205                                            LossScaleOptimizerV1)
1206
1207
1208def _multiply_gradient(gradient, scale):
1209  """Multiply a (possibly sparse) gradient by the given scale factor."""
1210  scale = math_ops.cast(scale, gradient.dtype)
1211  if isinstance(gradient, ops.IndexedSlices):
1212    return ops.IndexedSlices(
1213        gradient.values * scale,
1214        gradient.indices,
1215        dense_shape=gradient.dense_shape)
1216  else:
1217    return gradient * scale
1218
1219
1220def strategy_supports_loss_scaling():
1221  """Returns True if the current Strategy supports loss scaling."""
1222  if not distribution_strategy_context.has_strategy():
1223    return True
1224  strategy = distribution_strategy_context.get_strategy()
1225  # Strategies are supported if either there is only one replica or if variables
1226  # are replicated per device. Otherwise, the current model.fit() implementation
1227  # and most custom training loops incorrectly unscale the gradients. Currently,
1228  # gradients are unscaled once per compute replica, but they should be unscaled
1229  # once per variable replica. When there is one variable replica for each
1230  # compute replica, this works fine, but otherwise issues will occur.
1231  # TODO(reedwm): Support all strategies.
1232  return isinstance(strategy, (
1233      collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1234      collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
1235      one_device_strategy.OneDeviceStrategy,
1236      one_device_strategy.OneDeviceStrategyV1,
1237      mirrored_strategy.MirroredStrategy,
1238      mirrored_strategy.MirroredStrategyV1,
1239  ))
1240