• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains LossScale classes."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22import six
23
24from tensorflow.python.distribute import distribution_strategy_context
25from tensorflow.python.distribute import reduce_util
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variable_scope
32from tensorflow.python.ops import variables
33from tensorflow.python.training.tracking import base as trackable
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import nest
36from tensorflow.python.util.tf_export import tf_export
37
38
39@six.add_metaclass(abc.ABCMeta)
40@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale',
41                                  'train.experimental.LossScale')
42@tf_export(
43    'mixed_precision.experimental.LossScale',
44    'train.experimental.LossScale',
45    v1=[
46        'mixed_precision.LossScale',
47        'mixed_precision.experimental.LossScale',
48        'train.experimental.LossScale'
49    ])
50class LossScale(trackable.Trackable):
51  """Base class for all TF1 loss scales.
52
53  WARNING: This class is deprecated and will be unexposed from the TF 2
54  namespace in a future version of TensorFlow. Once this occurs, this class will
55  only be accessible as `tf.compat.v1.mixed_precision.LossScale`. All the
56  functionality in this class has been merged into
57  `tf.keras.mixed_precision.LossScaleOptimizer`, so this class is no longer
58  needed.
59
60  This is an abstract base class, so you cannot instantiate it directly.
61  Instead, use one of its concrete subclasses:
62    * `tf.compat.v1.mixed_precision.DynamicLossScale`
63    * `tf.compat.v1.mixed_precision.FixedLossScale`
64
65  Loss scaling is a process that multiplies the loss by a multiplier called the
66  loss scale, and divides each gradient by the same multiplier. The pseudocode
67  for this process is:
68
69  ```
70  loss = ...
71  loss *= loss_scale
72  grads = gradients(loss, vars)
73  grads /= loss_scale
74  ```
75
76  Mathematically, loss scaling has no effect, but can help avoid numerical
77  underflow in intermediate gradients when float16 tensors are used for mixed
78  precision training. By multiplying the loss, each intermediate gradient will
79  have the same multiplier applied.
80
81  Instances of this class represent a loss scale. Calling instances of this
82  class returns the loss scale as a scalar float32 tensor, while method
83  `update()` updates the loss scale depending on the values of the gradients.
84  Optimizers use instances of this class to scale loss and gradients.
85
86  In most functions that accept a LossScale, you can also pass an int (such as
87  8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic
88  loss scale.
89  """
90
91  def __init__(self):
92    """Initializes the loss scale class."""
93    self._weights = {}
94
95  @abc.abstractmethod
96  def __call__(self):
97    """Returns the current loss scale as a scalar `float32` tensor."""
98    pass
99
100  @abc.abstractmethod
101  def update(self, grads):
102    """Updates the value of the loss scale.
103
104    The loss scale will be potentially updated, based on the value of `grads`.
105    The tensor returned by calling this class is only updated when this function
106    is evaluated.
107
108    In eager mode, this directly updates the loss scale, so that calling
109    `__call__` will return the newly updated loss scale. In graph mode,
110    this returns an op that, when evaluated, updates the loss scale.
111
112    This function also returns a `should_apply_gradients` bool. If False,
113    gradients should not be applied to the variables that step, as nonfinite
114    gradients were found, and the loss scale has been be updated to reduce the
115    chance of finding nonfinite gradients in the next step. Some loss scale
116    classes will always return True, as they cannot adjust themselves in
117    response to nonfinite gradients.
118
119    When a DistributionStrategy is used, this function may only be called in a
120    cross-replica context.
121
122    Args:
123      grads: A nested structure of unscaled gradients, each which is the
124        gradient of the loss with respect to a weight. The gradients should have
125        already been divided by the loss scale being before passed to this
126        function. 'None' gradients are accepted, and are ignored.
127
128    Returns:
129      update_op: In eager mode, None. In graph mode, an op to update the loss
130        scale.
131      should_apply_gradients: Either a bool or a scalar boolean tensor. If
132        False, the caller should skip applying `grads` to the variables this
133        step.
134    """
135    pass
136
137  def _add_weight(self, name, initial_value, dtype=None):
138    """Adds a weight to this loss scale.
139
140    Args:
141      name: Variable name.
142      initial_value: The variable's initial value.
143      dtype: The type of the variable.
144
145    Returns:
146      A variable.
147
148    Raises:
149      RuntimeError: If a weight with `name` has already been added.
150    """
151    variable = variable_scope.variable(
152        initial_value=initial_value,
153        name=name,
154        dtype=dtype,
155        trainable=False,
156        use_resource=True,
157        synchronization=variables.VariableSynchronization.AUTO,
158        # Set aggregation to NONE, as loss scaling variables should never be
159        # aggregated.
160        aggregation=variables.VariableAggregation.NONE)
161    if context.executing_eagerly():
162      graph_key = None
163    else:
164      graph = ops.get_default_graph()
165      graph_key = graph._graph_key  # pylint: disable=protected-access
166
167    key = (name, graph_key)
168    if self._weights.get(key, None) is not None:
169      raise RuntimeError('Duplicate variables detected. {}'.format(key))
170    self._weights[key] = variable
171    self._handle_deferred_dependencies(name=name, trackable=variable)
172    return variable
173
174  @property
175  def _checkpoint_dependencies(self):
176    """From Trackable. Gather graph-specific weights to save."""
177    if context.executing_eagerly():
178      graph_key = None
179    else:
180      graph = ops.get_default_graph()
181      graph_key = graph._graph_key  # pylint: disable=protected-access
182    weights = []
183    for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
184      if g == graph_key:
185        weights.append(trackable.TrackableReference(name=name, ref=v))
186    return super(LossScale, self)._checkpoint_dependencies + weights
187
188  def _lookup_dependency(self, name):
189    """From Trackable. Find a weight in the current graph."""
190    unconditional = super(LossScale, self)._lookup_dependency(name)
191    if unconditional is not None:
192      return unconditional
193    if context.executing_eagerly():
194      graph_key = None
195    else:
196      graph = ops.get_default_graph()
197      graph_key = graph._graph_key  # pylint: disable=protected-access
198    return self._weights.get((name, graph_key), None)
199
200  @abc.abstractmethod
201  def get_config(self):
202    """Returns the config of this loss scale."""
203    pass
204
205  @classmethod
206  def from_config(cls, config):
207    """Creates the LossScale from its config."""
208    return cls(**config)
209
210
211@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale',
212                                  'train.experimental.FixedLossScale')
213@tf_export(
214    'mixed_precision.experimental.FixedLossScale',
215    'train.experimental.FixedLossScale',
216    v1=[
217        'mixed_precision.FixedLossScale',
218        'mixed_precision.experimental.FixedLossScale',
219        'train.experimental.FixedLossScale'
220    ])
221class FixedLossScale(LossScale):
222  """Loss scale with a fixed value.
223
224  WARNING: This class is deprecated and will be unexposed from the TF 2
225  namespace in a future version of TensorFlow. Once this occurs, this class will
226  only be accessible as `tf.compat.v1.mixed_precision.FixedLossScale`. All the
227  functionality in this class has been merged into
228  `tf.keras.mixed_precision.LossScaleOptimizer`, so this class is no longer
229  needed.
230
231  The loss scale is not updated for the lifetime of instances of this class.
232  A given instance of this class always returns the same number when called.
233  """
234
235  @deprecation.deprecated(
236      None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
237            'LossScaleOptimizer now has all the functionality of '
238            'FixedLossScale')
239  def __init__(self, loss_scale_value):
240    """Creates the fixed loss scale.
241
242    Args:
243      loss_scale_value: A Python float. Its ideal value varies depending on
244        models to run. Choosing a too small loss_scale might affect model
245        quality; a too big loss_scale might cause inf or nan. There is no single
246        right loss_scale to apply. There is no harm choosing a relatively big
247        number as long as no nan or inf is encountered in training.
248
249    Raises:
250      ValueError: If loss_scale_value is less than 1.
251    """
252    super(FixedLossScale, self).__init__()
253    if not isinstance(loss_scale_value, six.integer_types + (float,)):
254      raise ValueError('loss_scale_value must be a Python int or float.')
255    if loss_scale_value < 1:
256      raise ValueError('loss_scale_value must be at least 1.')
257    # It's important we do not create tensors in the constructor, as such
258    # tensors might be on a different device or tf.function vs when the tensor
259    # is used. This would hurt performance. Therefore, we do not create a tensor
260    # from loss_scale_value, but instead leave it as a Python float.
261    # TODO(reedwm): Also do not create tensors in the DynamicLossScale
262    # constructor.
263    self._loss_scale_value = float(loss_scale_value)
264
265  def __call__(self):
266    return ops.convert_to_tensor(self._loss_scale_value)
267
268  def update(self, grads):
269    del grads
270    return control_flow_ops.no_op(), True
271
272  def __repr__(self):
273    return 'FixedLossScale(%s)' % self._loss_scale_value
274
275  def get_config(self):
276    return {'loss_scale_value': self._loss_scale_value}
277
278
279def _is_all_finite(grads):
280  """Returns a scalar boolean tensor indicating if all gradients are finite."""
281  is_finite_per_grad = [
282      math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
283  ]
284  return math_ops.reduce_all(is_finite_per_grad)
285
286
287def _op_in_graph_mode(tensor):
288  """Returns the tensor's op in graph mode, or the tensor in eager mode.
289
290  This is useful because sometimes an op is needed in graph mode instead of a
291  tensor. In eager mode, there are no ops.
292
293  Args:
294    tensor: A tensor.
295
296  Returns:
297    The tensor's op in graph mode. The tensor in eager mode.
298  """
299  if context.executing_eagerly():
300    return tensor
301  return tensor.op
302
303
304def _assign_if_finite(var, value):
305  """Assigns a value to a variable if the value is finite."""
306  return control_flow_ops.cond(
307      math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
308      control_flow_ops.no_op)
309
310
311@deprecation.deprecated_endpoints(
312    'mixed_precision.experimental.DynamicLossScale',
313    'train.experimental.DynamicLossScale')
314@tf_export(
315    'mixed_precision.experimental.DynamicLossScale',
316    'train.experimental.DynamicLossScale',
317    v1=[
318        'mixed_precision.DynamicLossScale',
319        'mixed_precision.experimental.DynamicLossScale',
320        'train.experimental.DynamicLossScale'
321    ])
322class DynamicLossScale(LossScale):
323  """Loss scale that dynamically adjusts itself.
324
325  WARNING: This class is deprecated and will be unexposed from the TF 2
326  namespace in a future version of TensorFlow. Once this occurs, this class will
327  only be accessible as `tf.compat.v1.mixed_precision.DynamicLossScale`. All the
328  functionality in this class has been merged into
329  `tf.keras.mixed_precision.LossScaleOptimizer`, so this class is no longer
330  needed.
331
332  Dynamic loss scaling works by adjusting the loss scale as training progresses.
333  The goal is to keep the loss scale as high as possible without overflowing the
334  gradients. As long as the gradients do not overflow, raising the loss scale
335  never hurts.
336
337  The algorithm starts by setting the loss scale to an initial value. Every N
338  steps that the gradients are finite, the loss scale is increased by some
339  factor. However, if a NaN or Inf gradient is found, the gradients for that
340  step are not applied, and the loss scale is decreased by the factor. This
341  process tends to keep the loss scale as high as possible without gradients
342  overflowing.
343  """
344
345  @deprecation.deprecated(
346      None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
347            'LossScaleOptimizer now has all the functionality of '
348            'DynamicLossScale')
349  def __init__(self,
350               initial_loss_scale=2 ** 15,  # See docstring for why this is big.
351               increment_period=2000,
352               multiplier=2.):
353    """Creates the dynamic loss scale.
354
355    Args:
356      initial_loss_scale: A Python float.  The loss scale to use at the
357        beginning. It's better to start this at a very high number, because a
358        loss scale that is too high gets lowered far more quickly than a loss
359        scale that is too low gets raised. The default is 2 ** 15, which is
360        approximately half the maximum float16 value.
361      increment_period: Increases loss scale every `increment_period`
362        consecutive steps that finite gradients are encountered. If a nonfinite
363        gradient is encountered, the count is reset back to zero.
364      multiplier: The multiplier to use when increasing or decreasing the loss
365        scale.
366    """
367    super(DynamicLossScale, self).__init__()
368    self._initial_loss_scale = float(initial_loss_scale)
369    self._increment_period = int(increment_period)
370    self._multiplier = float(multiplier)
371
372    self._current_loss_scale = self._add_weight(
373        name='current_loss_scale',
374        dtype=dtypes.float32,
375        initial_value=self._initial_loss_scale)
376    # The number of consecutive steps with finite gradients since the last
377    # nonfinite gradient or change in loss scale.
378    self._num_good_steps = self._add_weight(
379        name='good_steps', dtype=dtypes.int64, initial_value=0)
380
381  @property
382  def initial_loss_scale(self):
383    return self._initial_loss_scale
384
385  @property
386  def increment_period(self):
387    return self._increment_period
388
389  @property
390  def multiplier(self):
391    return self._multiplier
392
393  def __call__(self):
394    return ops.convert_to_tensor(self._current_loss_scale)
395
396  def update(self, grads):
397    """Updates loss scale based on if gradients are finite in current step."""
398    grads = nest.flatten(grads)
399    if distribution_strategy_context.has_strategy():
400      distribution = distribution_strategy_context.get_cross_replica_context()
401
402      def get_is_finite(grads):
403        is_finite = _is_all_finite(grads)
404        # We cast to float, because we cannot reduce booleans with
405        # DistributionStrategy.
406        return math_ops.cast(is_finite, dtypes.float32)
407
408      is_finite_float = distribution.extended.call_for_each_replica(
409          get_is_finite, args=(grads,))
410      reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
411                                                    is_finite_float, axis=None)
412      is_finite = math_ops.equal(reduced_is_finite_float,
413                                 distribution.num_replicas_in_sync)
414    else:
415      is_finite = _is_all_finite(grads)
416
417    def update_if_finite_grads():
418      """Update assuming the gradients are finite."""
419
420      def incr_loss_scale():
421        new_loss_scale = self._current_loss_scale * self._multiplier
422        return control_flow_ops.group(
423            _assign_if_finite(self._current_loss_scale, new_loss_scale),
424            self._num_good_steps.assign(0))
425
426      return control_flow_ops.cond(
427          self._num_good_steps + 1 >= self._increment_period,
428          incr_loss_scale, lambda: _op_in_graph_mode(
429              self._num_good_steps.assign_add(1)))
430
431    def update_if_not_finite_grads():
432      """Update assuming the gradients are nonfinite."""
433
434      new_loss_scale = math_ops.maximum(
435          self._current_loss_scale / self._multiplier, 1)
436      return control_flow_ops.group(
437          self._num_good_steps.assign(0),
438          self._current_loss_scale.assign(new_loss_scale))
439
440    update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
441                                      update_if_not_finite_grads)
442    should_apply_gradients = is_finite
443    return update_op, should_apply_gradients
444
445  def __repr__(self):
446    if context.executing_eagerly():
447      return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
448              'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
449              (self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
450               self.initial_loss_scale, self.increment_period, self.multiplier))
451    else:
452      return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
453              'multiplier=%s)' %
454              (self.initial_loss_scale, self.increment_period, self.multiplier))
455
456  def get_config(self):
457    return {
458        'initial_loss_scale': self.initial_loss_scale,
459        'increment_period': self.increment_period,
460        'multiplier': self.multiplier,
461    }
462
463
464def get(identifier):
465  """Get a loss scale object."""
466  if isinstance(identifier, six.integer_types + (float,)):
467    return FixedLossScale(identifier)
468  if identifier == 'dynamic':
469    return DynamicLossScale()
470  if isinstance(identifier, LossScale):
471    return identifier
472  elif identifier is None:
473    return None
474  else:
475    raise ValueError('Could not interpret loss scale identifier: %s' %
476                     identifier)
477