1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains Loss Scale Gradient Tape.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.eager import backprop 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients 28from tensorflow.python.training.experimental import loss_scale as loss_scale_module 29from tensorflow.python.util import nest 30 31 32def _convert_to_per_replicas(distribution, values): 33 """Converts tensors and DistributedVariables to PerReplica values. 34 35 Args: 36 distribution: The distribution strategy in effect. 37 values: A list of tensors, variables, DistributedValues, or anything else 38 that can be converted to a PerReplcia value 39 40 Returns: 41 `values`, but each element has been converted to a PerReplica value. 42 """ 43 return distribution.run( 44 lambda values: [array_ops.identity(v) for v in values], 45 args=(values,) 46 ) 47 48 49# TODO(reedwm): Expose this after testing it on several models. 50class LossScaleGradientTape(backprop.GradientTape): 51 """A gradient tape that scales losses and unscales resulting gradients. 52 53 Operates as a normal gradient tape, but takes in a 54 `tf.mixed_precision.experimental.LossScale` object. Losses are scaled up by 55 some amount before the gradients are calculated and the resulting gradients 56 are scaled down by the same amount. 57 58 This has no net mathematical effect, but can be used to prevent vanishing 59 gradients, for example in the case of mixed precision training. 60 61 If a DynamicLossScale object is used and non-finite gradients are encountered, 62 the loss scale will be updated and the gradients recomputed until either 63 finite gradients are encountered or the loss scale becomes 1. 64 65 This class should *not* be used with a LossScaleOptimizer, as both classes 66 update the LossScale object. Use a non-loss scaling optimizer instead. 67 68 Usage: 69 ``` 70 opt = tf.keras.optimizers.SGD(1.0) 71 model_loss_scale = tf.mixed_precision.experimental.DynamicLossScale() 72 73 for step in training_steps: 74 with LossScaleGradientTape(model_loss_scale) as tape: 75 logits = ... # Run model and get logits 76 loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, 77 labels=labels) 78 loss = tf.reduce_mean(loss) 79 vars = tape.watched_variables() 80 grads = tape.gradient(loss, vars) 81 opt.apply_gradients(zip(grads, vars)) 82 ``` 83 84 WARNING: Computing second-order (or higher) gradients with a 85 `LossScaleGradientTape` does not yet work properly when a 86 `tf.distribute.Strategy` is used. Computing second-order gradients will return 87 None instead of the gradient tensors. This only occurs when you nest multiple 88 gradient tapes under each other; if you do not nest them, this issue will not 89 occur. 90 """ 91 92 def __init__(self, 93 loss_scale, 94 persistent=False, 95 watch_accessed_variables=True): 96 """Creates a new LossScaleGradientTape. 97 98 Args: 99 loss_scale: `tf.mixed_precision.experimental.LossScale` object that 100 manages what quantity to scale by. This is typically either a 101 FixedLossScale object with a constant scalar or a 102 `tf.mixed_precision.experimental.DynamicLossScale` object that will 103 adjust the scalar appropriately if any non-finite gradients are 104 encountered. 105 persistent: Boolean controlling whether a persistent gradient tape is 106 created. False by default, which means at most one call can be made to 107 the gradient() method on this object. 108 watch_accessed_variables: Boolean controlling whether the tape will 109 automatically `watch` any (trainable) variables accessed while the tape 110 is active. Defaults to True meaning gradients can be requested from any 111 result computed in the tape derived from reading a trainable `Variable`. 112 If False users must explicitly `watch` any `Variable`s they want to 113 request gradients from. 114 """ 115 if not isinstance(loss_scale, loss_scale_module.LossScale): 116 raise ValueError("`loss_scale` must be an instance of LossScale, " 117 "but got: %s" % (loss_scale,)) 118 if not ops.executing_eagerly_outside_functions(): 119 raise ValueError("LossScaleGradientTape is only supported in Eager mode.") 120 121 # always make a persistent tape to loop over loss scaling 122 super(LossScaleGradientTape, self).__init__(True, 123 watch_accessed_variables) 124 self._outer_persistent = persistent 125 self._loss_scale = loss_scale 126 127 def gradient(self, 128 target, 129 sources, 130 output_gradients=None, 131 unconnected_gradients=UnconnectedGradients.NONE): 132 """Computes the gradient using operations recorded in context of this tape. 133 134 Uses the `LossScale` object provided in the constructor to scale `target` 135 and then to unscale the resulting gradients. 136 137 Args: 138 target: a list or nested structure of Tensors or Variables to be 139 differentiated. 140 sources: a list or nested structure of Tensors or Variables. `target` will 141 be differentiated against elements in `sources`. 142 output_gradients: a list of gradients, one for each element of target. 143 Defaults to None. 144 unconnected_gradients: a value which can either hold 'none' or 'zero' and 145 alters the value which will be returned if the target and sources are 146 unconnected. The possible values and effects are detailed in 147 'UnconnectedGradients' and it defaults to 'none'. 148 149 Returns: 150 a list or nested structure of Tensors (or IndexedSlices, or None), 151 one for each element in `sources`. Returned structure is the same as 152 the structure of `sources`. If non-finite gradients are encountered 153 after dynamic scaling, the loss scale will be updated and the gradients 154 recomputed until either finite gradients are encountered or the loss scale 155 becomes 1. 156 157 Raises: 158 RuntimeError: if called inside the context of the tape, or if called more 159 than once on a non-persistent tape. 160 ValueError: if the target is a variable or if unconnected gradients is 161 called with an unknown value. 162 """ 163 if self._tape is None: # pylint: disable=access-member-before-definition 164 raise RuntimeError("GradientTape.gradient can only be called once on " 165 "non-persistent tapes.") 166 if distribution_strategy_context.in_cross_replica_context(): 167 raise ValueError("LossScaleGradientTape.gradient() must be called in a " 168 "replica context.") 169 170 # Note: DistributionStrategy does not support running a while loop in a 171 # replica context. So, we call `_compute_gradients_until_finite` in a cross- 172 # replica context. 173 replica_context = distribution_strategy_context.get_replica_context() 174 grads = replica_context.merge_call( 175 _compute_gradients_until_finite, 176 args=(self, self._loss_scale, target, sources, output_gradients, 177 unconnected_gradients)) 178 179 if not self._outer_persistent: 180 self._tape = None # free up resources if a persistent tape was not needed 181 return grads 182 183 def jacobian(self, 184 target, 185 sources, 186 unconnected_gradients=UnconnectedGradients.NONE, 187 parallel_iterations=None, 188 experimental_use_pfor=True): 189 # TODO(reedwm): Implement this 190 raise NotImplementedError("LossScaleGradientTape.jacobian is not " 191 "yet implemented") 192 193 def batch_jacobian(self, 194 target, 195 source, 196 unconnected_gradients=UnconnectedGradients.NONE, 197 parallel_iterations=None, 198 experimental_use_pfor=True): 199 # TODO(reedwm): Implement this 200 raise NotImplementedError("LossScaleGradientTape.batch_jacobian is not " 201 "yet implemented") 202 203 204def _compute_gradients_until_finite( 205 distribution, loss_scale_gradient_tapes, loss_scale, target, sources, 206 output_gradients, unconnected_gradients): 207 """Compute gradients and update the loss scale until the gradients are finite. 208 209 This must be called in a cross-replica context. 210 211 This is a function instead of a method of LossScaleGradientTape, as the `self` 212 parameter would be meaningless. There is one LossScaleGradientTape per 213 replica, but this function is called once total (not per replica), so there 214 cannot be a singular `self` parameter. 215 216 Args: 217 distribution: The distribution strategy in effect. 218 loss_scale_gradient_tapes: A PerReplica value of LossScaleGradientTapes. 219 Contains the LossScaleGradientTape of each replica. 220 loss_scale: The loss scale to use to scale the loss and unscale the 221 gradient. 222 target: a list or nested structure of Tensors or Variables to be 223 differentiated. 224 sources: a list or nested structure of Tensors or Variables. `target` will 225 be differentiated against elements in `sources`. 226 output_gradients: Passed to GradientTape.gradient 227 unconnected_gradients: Pass to GradientTape.gradient. 228 229 Returns: 230 The gradients of `target` with respect to `sources`. 231 """ 232 # Autograph cannot convert this function, so we must use an explicit 233 # tf.while_loop. 234 # TODO(b/143572314): Fix Autograph so that it can convert this function, then 235 # replace the tf.while_loop with a Python while loop. 236 237 # For convenience, we only deal with flattened sources 238 flattened_sources = nest.flatten(sources) 239 240 # Define the initial loop variables of the while loop. 241 242 # Dummy value for initial_grads. The first iteration of the loop will 243 # overwrite `grads` to the actual gradients. 244 initial_grads = flattened_sources 245 if distribution_strategy_context.has_strategy(): 246 # A while_loop requires the initial values to have the same types as the 247 # return values from the body. However, 'initial_grads' may have type 248 # 'DistributionVariable', while body returns a 'PerReplica'. While both 249 # types subclass 'DistributedValues', while_loop will still throw an error. 250 # So we convert 'initial_grads' to be PerReplica values. 251 # TODO(b/146084534): Once the bug is fixed, remove this special case. 252 initial_grads = _convert_to_per_replicas(distribution, initial_grads) 253 initial_ready_to_update = False 254 initial_is_first_iteration = True 255 256 def cond(grads, ready_to_update, is_first_iteration): 257 """The condition of the while loop.""" 258 del grads 259 # Equivalent to: 260 # `is_first_iteration or (not ready_to_update and loss_scale() > 1)` 261 return math_ops.logical_or( 262 is_first_iteration, 263 math_ops.logical_and( 264 math_ops.logical_not(ready_to_update), 265 math_ops.greater(loss_scale(), 1))) 266 267 # Boolean list specifying whether each gradient is None or not. Set by body(). 268 is_nones = [] 269 270 def body(grads, ready_to_update, is_first_iteration): 271 """The body of the while loop.""" 272 del grads, ready_to_update, is_first_iteration 273 def replica_fn(gradient_tape, target, flattened_sources, output_gradients, 274 initial_grads): 275 """Scales the loss, computes the gradients, and unscales the gradients.""" 276 loss_scale_val = loss_scale() 277 with gradient_tape: # re-enter gradient tape so it sees the loss scaling 278 scaled_target = nest.map_structure( 279 lambda t: t * math_ops.cast(loss_scale_val, t.dtype), target) 280 scaled_grads = super(LossScaleGradientTape, gradient_tape).gradient( 281 scaled_target, flattened_sources, output_gradients, 282 unconnected_gradients) 283 284 is_nones[:] = [g is None for g in scaled_grads] 285 inv_loss_scale = 1.0 / loss_scale_val 286 grads = [] # The unscaled gradients 287 for g, initial_grad in zip(scaled_grads, initial_grads): 288 if g is not None: 289 # We call ensure_shape as shape information can be lost for certain 290 # ops, such as tf.transpose, if the op is called in a tf.function and 291 # has inputs created outside the tf.function. 292 # TODO(b/132092188): Remove ensure_shape call after this has been 293 # fixed. 294 g = array_ops.ensure_shape(g, initial_grad.shape) 295 grads.append(g * math_ops.cast(inv_loss_scale, g.dtype)) 296 else: 297 # We cannot return None from a tf.while_loop, so we pass a dummy 298 # tensor instead. We use initial_grad as a dummy tensor as it has the 299 # correct shape and dtype. We replace it with None outside the while 300 # loop. 301 grads.append(initial_grad) 302 return grads 303 304 # Switch to a replica-context to compute gradients once per replica. 305 grads = distribution.run( 306 replica_fn, 307 args=(loss_scale_gradient_tapes, target, flattened_sources, 308 output_gradients, initial_grads)) 309 # Check for non-finite gradients possibly resulting from scaling 310 _, ready_to_update = loss_scale.update(grads) 311 is_first_iteration = False 312 return grads, ready_to_update, is_first_iteration 313 314 grads, _, _ = control_flow_ops.while_loop( 315 cond, body, [initial_grads, initial_ready_to_update, 316 initial_is_first_iteration], 317 ) 318 grads = [None if is_none else g for g, is_none in zip(grads, is_nones)] 319 grads = nest.pack_sequence_as(sources, grads) 320 return grads 321