• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains Loss Scale Gradient Tape."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.eager import backprop
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
28from tensorflow.python.training.experimental import loss_scale as loss_scale_module
29from tensorflow.python.util import nest
30
31
32def _convert_to_per_replicas(distribution, values):
33  """Converts tensors and DistributedVariables to PerReplica values.
34
35  Args:
36    distribution: The distribution strategy in effect.
37    values: A list of tensors, variables, DistributedValues, or anything else
38      that can be converted to a PerReplcia value
39
40  Returns:
41    `values`, but each element has been converted to a PerReplica value.
42  """
43  return distribution.run(
44      lambda values: [array_ops.identity(v) for v in values],
45      args=(values,)
46  )
47
48
49# TODO(reedwm): Expose this after testing it on several models.
50class LossScaleGradientTape(backprop.GradientTape):
51  """A gradient tape that scales losses and unscales resulting gradients.
52
53  Operates as a normal gradient tape, but takes in a
54  `tf.mixed_precision.experimental.LossScale` object. Losses are scaled up by
55  some amount before the gradients are calculated and the resulting gradients
56  are scaled down by the same amount.
57
58  This has no net mathematical effect, but can be used to prevent vanishing
59  gradients, for example in the case of mixed precision training.
60
61  If a DynamicLossScale object is used and non-finite gradients are encountered,
62  the loss scale will be updated and the gradients recomputed until either
63  finite gradients are encountered or the loss scale becomes 1.
64
65  This class should *not* be used with a LossScaleOptimizer, as both classes
66  update the LossScale object. Use a non-loss scaling optimizer instead.
67
68  Usage:
69  ```
70  opt = tf.keras.optimizers.SGD(1.0)
71  model_loss_scale = tf.mixed_precision.experimental.DynamicLossScale()
72
73  for step in training_steps:
74    with LossScaleGradientTape(model_loss_scale) as tape:
75      logits = ...  # Run model and get logits
76      loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
77                                                     labels=labels)
78      loss = tf.reduce_mean(loss)
79    vars = tape.watched_variables()
80    grads = tape.gradient(loss, vars)
81    opt.apply_gradients(zip(grads, vars))
82  ```
83
84  WARNING: Computing second-order (or higher) gradients with a
85  `LossScaleGradientTape` does not yet work properly when a
86  `tf.distribute.Strategy` is used. Computing second-order gradients will return
87  None instead of the gradient tensors. This only occurs when you nest multiple
88  gradient tapes under each other; if you do not nest them, this issue will not
89  occur.
90  """
91
92  def __init__(self,
93               loss_scale,
94               persistent=False,
95               watch_accessed_variables=True):
96    """Creates a new LossScaleGradientTape.
97
98    Args:
99      loss_scale: `tf.mixed_precision.experimental.LossScale` object that
100        manages what quantity to scale by. This is typically either a
101        FixedLossScale object with a constant scalar or a
102        `tf.mixed_precision.experimental.DynamicLossScale` object that will
103        adjust the scalar appropriately if any non-finite gradients are
104        encountered.
105      persistent: Boolean controlling whether a persistent gradient tape is
106        created. False by default, which means at most one call can be made to
107        the gradient() method on this object.
108      watch_accessed_variables: Boolean controlling whether the tape will
109        automatically `watch` any (trainable) variables accessed while the tape
110        is active. Defaults to True meaning gradients can be requested from any
111        result computed in the tape derived from reading a trainable `Variable`.
112        If False users must explicitly `watch` any `Variable`s they want to
113        request gradients from.
114    """
115    if not isinstance(loss_scale, loss_scale_module.LossScale):
116      raise ValueError("`loss_scale` must be an instance of LossScale, "
117                       "but got: %s" % (loss_scale,))
118    if not ops.executing_eagerly_outside_functions():
119      raise ValueError("LossScaleGradientTape is only supported in Eager mode.")
120
121    # always make a persistent tape to loop over loss scaling
122    super(LossScaleGradientTape, self).__init__(True,
123                                                watch_accessed_variables)
124    self._outer_persistent = persistent
125    self._loss_scale = loss_scale
126
127  def gradient(self,
128               target,
129               sources,
130               output_gradients=None,
131               unconnected_gradients=UnconnectedGradients.NONE):
132    """Computes the gradient using operations recorded in context of this tape.
133
134    Uses the `LossScale` object provided in the constructor to scale `target`
135    and then to unscale the resulting gradients.
136
137    Args:
138      target: a list or nested structure of Tensors or Variables to be
139        differentiated.
140      sources: a list or nested structure of Tensors or Variables. `target` will
141        be differentiated against elements in `sources`.
142      output_gradients: a list of gradients, one for each element of target.
143        Defaults to None.
144      unconnected_gradients: a value which can either hold 'none' or 'zero' and
145        alters the value which will be returned if the target and sources are
146        unconnected. The possible values and effects are detailed in
147        'UnconnectedGradients' and it defaults to 'none'.
148
149    Returns:
150      a list or nested structure of Tensors (or IndexedSlices, or None),
151      one for each element in `sources`. Returned structure is the same as
152      the structure of `sources`. If non-finite gradients are encountered
153      after dynamic scaling, the loss scale will be updated and the gradients
154      recomputed until either finite gradients are encountered or the loss scale
155      becomes 1.
156
157    Raises:
158      RuntimeError: if called inside the context of the tape, or if called more
159       than once on a non-persistent tape.
160      ValueError: if the target is a variable or if unconnected gradients is
161       called with an unknown value.
162    """
163    if self._tape is None:  # pylint: disable=access-member-before-definition
164      raise RuntimeError("GradientTape.gradient can only be called once on "
165                         "non-persistent tapes.")
166    if distribution_strategy_context.in_cross_replica_context():
167      raise ValueError("LossScaleGradientTape.gradient() must be called in a "
168                       "replica context.")
169
170    # Note: DistributionStrategy does not support running a while loop in a
171    # replica context. So, we call `_compute_gradients_until_finite` in a cross-
172    # replica context.
173    replica_context = distribution_strategy_context.get_replica_context()
174    grads = replica_context.merge_call(
175        _compute_gradients_until_finite,
176        args=(self, self._loss_scale, target, sources, output_gradients,
177              unconnected_gradients))
178
179    if not self._outer_persistent:
180      self._tape = None  # free up resources if a persistent tape was not needed
181    return grads
182
183  def jacobian(self,
184               target,
185               sources,
186               unconnected_gradients=UnconnectedGradients.NONE,
187               parallel_iterations=None,
188               experimental_use_pfor=True):
189    # TODO(reedwm): Implement this
190    raise NotImplementedError("LossScaleGradientTape.jacobian is not "
191                              "yet implemented")
192
193  def batch_jacobian(self,
194                     target,
195                     source,
196                     unconnected_gradients=UnconnectedGradients.NONE,
197                     parallel_iterations=None,
198                     experimental_use_pfor=True):
199    # TODO(reedwm): Implement this
200    raise NotImplementedError("LossScaleGradientTape.batch_jacobian is not "
201                              "yet implemented")
202
203
204def _compute_gradients_until_finite(
205    distribution, loss_scale_gradient_tapes, loss_scale, target, sources,
206    output_gradients, unconnected_gradients):
207  """Compute gradients and update the loss scale until the gradients are finite.
208
209  This must be called in a cross-replica context.
210
211  This is a function instead of a method of LossScaleGradientTape, as the `self`
212  parameter would be meaningless. There is one LossScaleGradientTape per
213  replica, but this function is called once total (not per replica), so there
214  cannot be a singular `self` parameter.
215
216  Args:
217    distribution: The distribution strategy in effect.
218    loss_scale_gradient_tapes: A PerReplica value of LossScaleGradientTapes.
219      Contains the LossScaleGradientTape of each replica.
220    loss_scale: The loss scale to use to scale the loss and unscale the
221      gradient.
222    target: a list or nested structure of Tensors or Variables to be
223      differentiated.
224    sources: a list or nested structure of Tensors or Variables. `target` will
225      be differentiated against elements in `sources`.
226    output_gradients: Passed to GradientTape.gradient
227    unconnected_gradients: Pass to GradientTape.gradient.
228
229  Returns:
230    The gradients of `target` with respect to `sources`.
231  """
232  # Autograph cannot convert this function, so we must use an explicit
233  # tf.while_loop.
234  # TODO(b/143572314): Fix Autograph so that it can convert this function, then
235  # replace the tf.while_loop with a Python while loop.
236
237  # For convenience, we only deal with flattened sources
238  flattened_sources = nest.flatten(sources)
239
240  # Define the initial loop variables of the while loop.
241
242  # Dummy value for initial_grads. The first iteration of the loop will
243  # overwrite `grads` to the actual gradients.
244  initial_grads = flattened_sources
245  if distribution_strategy_context.has_strategy():
246    # A while_loop requires the initial values to have the same types as the
247    # return values from the body. However, 'initial_grads' may have type
248    # 'DistributionVariable', while body returns a 'PerReplica'. While both
249    # types subclass 'DistributedValues', while_loop will still throw an error.
250    # So we convert 'initial_grads' to be PerReplica values.
251    # TODO(b/146084534): Once the bug is fixed, remove this special case.
252    initial_grads = _convert_to_per_replicas(distribution, initial_grads)
253  initial_ready_to_update = False
254  initial_is_first_iteration = True
255
256  def cond(grads, ready_to_update, is_first_iteration):
257    """The condition of the while loop."""
258    del grads
259    # Equivalent to:
260    # `is_first_iteration or (not ready_to_update and loss_scale() > 1)`
261    return math_ops.logical_or(
262        is_first_iteration,
263        math_ops.logical_and(
264            math_ops.logical_not(ready_to_update),
265            math_ops.greater(loss_scale(), 1)))
266
267  # Boolean list specifying whether each gradient is None or not. Set by body().
268  is_nones = []
269
270  def body(grads, ready_to_update, is_first_iteration):
271    """The body of the while loop."""
272    del grads, ready_to_update, is_first_iteration
273    def replica_fn(gradient_tape, target, flattened_sources, output_gradients,
274                   initial_grads):
275      """Scales the loss, computes the gradients, and unscales the gradients."""
276      loss_scale_val = loss_scale()
277      with gradient_tape:  # re-enter gradient tape so it sees the loss scaling
278        scaled_target = nest.map_structure(
279            lambda t: t * math_ops.cast(loss_scale_val, t.dtype), target)
280      scaled_grads = super(LossScaleGradientTape, gradient_tape).gradient(
281          scaled_target, flattened_sources, output_gradients,
282          unconnected_gradients)
283
284      is_nones[:] = [g is None for g in scaled_grads]
285      inv_loss_scale = 1.0 / loss_scale_val
286      grads = []  # The unscaled gradients
287      for g, initial_grad in zip(scaled_grads, initial_grads):
288        if g is not None:
289          # We call ensure_shape as shape information can be lost for certain
290          # ops, such as tf.transpose, if the op is called in a tf.function and
291          # has inputs created outside the tf.function.
292          # TODO(b/132092188): Remove ensure_shape call after this has been
293          # fixed.
294          g = array_ops.ensure_shape(g, initial_grad.shape)
295          grads.append(g * math_ops.cast(inv_loss_scale, g.dtype))
296        else:
297          # We cannot return None from a tf.while_loop, so we pass a dummy
298          # tensor instead. We use initial_grad as a dummy tensor as it has the
299          # correct shape and dtype. We replace it with None outside the while
300          # loop.
301          grads.append(initial_grad)
302      return grads
303
304    # Switch to a replica-context to compute gradients once per replica.
305    grads = distribution.run(
306        replica_fn,
307        args=(loss_scale_gradient_tapes, target, flattened_sources,
308              output_gradients, initial_grads))
309    # Check for non-finite gradients possibly resulting from scaling
310    _, ready_to_update = loss_scale.update(grads)
311    is_first_iteration = False
312    return grads, ready_to_update, is_first_iteration
313
314  grads, _, _ = control_flow_ops.while_loop(
315      cond, body, [initial_grads, initial_ready_to_update,
316                   initial_is_first_iteration],
317      )
318  grads = [None if is_none else g for g, is_none in zip(grads, is_nones)]
319  grads = nest.pack_sequence_as(sources, grads)
320  return grads
321