1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains LossScale classes.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22import six 23 24from tensorflow.python.distribute import distribution_strategy_context 25from tensorflow.python.distribute import reduce_util 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variable_scope 32from tensorflow.python.ops import variables 33from tensorflow.python.training.tracking import base as trackable 34from tensorflow.python.util import deprecation 35from tensorflow.python.util import nest 36from tensorflow.python.util.tf_export import tf_export 37 38 39@six.add_metaclass(abc.ABCMeta) 40@deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale', 41 'train.experimental.LossScale') 42@tf_export( 43 'mixed_precision.experimental.LossScale', 44 'train.experimental.LossScale', 45 v1=[ 46 'mixed_precision.LossScale', 47 'mixed_precision.experimental.LossScale', 48 'train.experimental.LossScale' 49 ]) 50class LossScale(trackable.Trackable): 51 """Base class for all TF1 loss scales. 52 53 WARNING: This class is deprecated and will be unexposed from the TF 2 54 namespace in a future version of TensorFlow. Once this occurs, this class will 55 only be accessible as `tf.compat.v1.mixed_precision.LossScale`. All the 56 functionality in this class has been merged into 57 `tf.keras.mixed_precision.LossScaleOptimizer`, so this class is no longer 58 needed. 59 60 This is an abstract base class, so you cannot instantiate it directly. 61 Instead, use one of its concrete subclasses: 62 * `tf.compat.v1.mixed_precision.DynamicLossScale` 63 * `tf.compat.v1.mixed_precision.FixedLossScale` 64 65 Loss scaling is a process that multiplies the loss by a multiplier called the 66 loss scale, and divides each gradient by the same multiplier. The pseudocode 67 for this process is: 68 69 ``` 70 loss = ... 71 loss *= loss_scale 72 grads = gradients(loss, vars) 73 grads /= loss_scale 74 ``` 75 76 Mathematically, loss scaling has no effect, but can help avoid numerical 77 underflow in intermediate gradients when float16 tensors are used for mixed 78 precision training. By multiplying the loss, each intermediate gradient will 79 have the same multiplier applied. 80 81 Instances of this class represent a loss scale. Calling instances of this 82 class returns the loss scale as a scalar float32 tensor, while method 83 `update()` updates the loss scale depending on the values of the gradients. 84 Optimizers use instances of this class to scale loss and gradients. 85 86 In most functions that accept a LossScale, you can also pass an int (such as 87 8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic 88 loss scale. 89 """ 90 91 def __init__(self): 92 """Initializes the loss scale class.""" 93 self._weights = {} 94 95 @abc.abstractmethod 96 def __call__(self): 97 """Returns the current loss scale as a scalar `float32` tensor.""" 98 pass 99 100 @abc.abstractmethod 101 def update(self, grads): 102 """Updates the value of the loss scale. 103 104 The loss scale will be potentially updated, based on the value of `grads`. 105 The tensor returned by calling this class is only updated when this function 106 is evaluated. 107 108 In eager mode, this directly updates the loss scale, so that calling 109 `__call__` will return the newly updated loss scale. In graph mode, 110 this returns an op that, when evaluated, updates the loss scale. 111 112 This function also returns a `should_apply_gradients` bool. If False, 113 gradients should not be applied to the variables that step, as nonfinite 114 gradients were found, and the loss scale has been be updated to reduce the 115 chance of finding nonfinite gradients in the next step. Some loss scale 116 classes will always return True, as they cannot adjust themselves in 117 response to nonfinite gradients. 118 119 When a DistributionStrategy is used, this function may only be called in a 120 cross-replica context. 121 122 Args: 123 grads: A nested structure of unscaled gradients, each which is the 124 gradient of the loss with respect to a weight. The gradients should have 125 already been divided by the loss scale being before passed to this 126 function. 'None' gradients are accepted, and are ignored. 127 128 Returns: 129 update_op: In eager mode, None. In graph mode, an op to update the loss 130 scale. 131 should_apply_gradients: Either a bool or a scalar boolean tensor. If 132 False, the caller should skip applying `grads` to the variables this 133 step. 134 """ 135 pass 136 137 def _add_weight(self, name, initial_value, dtype=None): 138 """Adds a weight to this loss scale. 139 140 Args: 141 name: Variable name. 142 initial_value: The variable's initial value. 143 dtype: The type of the variable. 144 145 Returns: 146 A variable. 147 148 Raises: 149 RuntimeError: If a weight with `name` has already been added. 150 """ 151 variable = variable_scope.variable( 152 initial_value=initial_value, 153 name=name, 154 dtype=dtype, 155 trainable=False, 156 use_resource=True, 157 synchronization=variables.VariableSynchronization.AUTO, 158 # Set aggregation to NONE, as loss scaling variables should never be 159 # aggregated. 160 aggregation=variables.VariableAggregation.NONE) 161 if context.executing_eagerly(): 162 graph_key = None 163 else: 164 graph = ops.get_default_graph() 165 graph_key = graph._graph_key # pylint: disable=protected-access 166 167 key = (name, graph_key) 168 if self._weights.get(key, None) is not None: 169 raise RuntimeError('Duplicate variables detected. {}'.format(key)) 170 self._weights[key] = variable 171 self._handle_deferred_dependencies(name=name, trackable=variable) 172 return variable 173 174 @property 175 def _checkpoint_dependencies(self): 176 """From Trackable. Gather graph-specific weights to save.""" 177 if context.executing_eagerly(): 178 graph_key = None 179 else: 180 graph = ops.get_default_graph() 181 graph_key = graph._graph_key # pylint: disable=protected-access 182 weights = [] 183 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 184 if g == graph_key: 185 weights.append(trackable.TrackableReference(name=name, ref=v)) 186 return super(LossScale, self)._checkpoint_dependencies + weights 187 188 def _lookup_dependency(self, name): 189 """From Trackable. Find a weight in the current graph.""" 190 unconditional = super(LossScale, self)._lookup_dependency(name) 191 if unconditional is not None: 192 return unconditional 193 if context.executing_eagerly(): 194 graph_key = None 195 else: 196 graph = ops.get_default_graph() 197 graph_key = graph._graph_key # pylint: disable=protected-access 198 return self._weights.get((name, graph_key), None) 199 200 @abc.abstractmethod 201 def get_config(self): 202 """Returns the config of this loss scale.""" 203 pass 204 205 @classmethod 206 def from_config(cls, config): 207 """Creates the LossScale from its config.""" 208 return cls(**config) 209 210 211@deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale', 212 'train.experimental.FixedLossScale') 213@tf_export( 214 'mixed_precision.experimental.FixedLossScale', 215 'train.experimental.FixedLossScale', 216 v1=[ 217 'mixed_precision.FixedLossScale', 218 'mixed_precision.experimental.FixedLossScale', 219 'train.experimental.FixedLossScale' 220 ]) 221class FixedLossScale(LossScale): 222 """Loss scale with a fixed value. 223 224 WARNING: This class is deprecated and will be unexposed from the TF 2 225 namespace in a future version of TensorFlow. Once this occurs, this class will 226 only be accessible as `tf.compat.v1.mixed_precision.FixedLossScale`. All the 227 functionality in this class has been merged into 228 `tf.keras.mixed_precision.LossScaleOptimizer`, so this class is no longer 229 needed. 230 231 The loss scale is not updated for the lifetime of instances of this class. 232 A given instance of this class always returns the same number when called. 233 """ 234 235 @deprecation.deprecated( 236 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' 237 'LossScaleOptimizer now has all the functionality of ' 238 'FixedLossScale') 239 def __init__(self, loss_scale_value): 240 """Creates the fixed loss scale. 241 242 Args: 243 loss_scale_value: A Python float. Its ideal value varies depending on 244 models to run. Choosing a too small loss_scale might affect model 245 quality; a too big loss_scale might cause inf or nan. There is no single 246 right loss_scale to apply. There is no harm choosing a relatively big 247 number as long as no nan or inf is encountered in training. 248 249 Raises: 250 ValueError: If loss_scale_value is less than 1. 251 """ 252 super(FixedLossScale, self).__init__() 253 if not isinstance(loss_scale_value, six.integer_types + (float,)): 254 raise ValueError('loss_scale_value must be a Python int or float.') 255 if loss_scale_value < 1: 256 raise ValueError('loss_scale_value must be at least 1.') 257 # It's important we do not create tensors in the constructor, as such 258 # tensors might be on a different device or tf.function vs when the tensor 259 # is used. This would hurt performance. Therefore, we do not create a tensor 260 # from loss_scale_value, but instead leave it as a Python float. 261 # TODO(reedwm): Also do not create tensors in the DynamicLossScale 262 # constructor. 263 self._loss_scale_value = float(loss_scale_value) 264 265 def __call__(self): 266 return ops.convert_to_tensor(self._loss_scale_value) 267 268 def update(self, grads): 269 del grads 270 return control_flow_ops.no_op(), True 271 272 def __repr__(self): 273 return 'FixedLossScale(%s)' % self._loss_scale_value 274 275 def get_config(self): 276 return {'loss_scale_value': self._loss_scale_value} 277 278 279def _is_all_finite(grads): 280 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 281 is_finite_per_grad = [ 282 math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None 283 ] 284 return math_ops.reduce_all(is_finite_per_grad) 285 286 287def _op_in_graph_mode(tensor): 288 """Returns the tensor's op in graph mode, or the tensor in eager mode. 289 290 This is useful because sometimes an op is needed in graph mode instead of a 291 tensor. In eager mode, there are no ops. 292 293 Args: 294 tensor: A tensor. 295 296 Returns: 297 The tensor's op in graph mode. The tensor in eager mode. 298 """ 299 if context.executing_eagerly(): 300 return tensor 301 return tensor.op 302 303 304def _assign_if_finite(var, value): 305 """Assigns a value to a variable if the value is finite.""" 306 return control_flow_ops.cond( 307 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 308 control_flow_ops.no_op) 309 310 311@deprecation.deprecated_endpoints( 312 'mixed_precision.experimental.DynamicLossScale', 313 'train.experimental.DynamicLossScale') 314@tf_export( 315 'mixed_precision.experimental.DynamicLossScale', 316 'train.experimental.DynamicLossScale', 317 v1=[ 318 'mixed_precision.DynamicLossScale', 319 'mixed_precision.experimental.DynamicLossScale', 320 'train.experimental.DynamicLossScale' 321 ]) 322class DynamicLossScale(LossScale): 323 """Loss scale that dynamically adjusts itself. 324 325 WARNING: This class is deprecated and will be unexposed from the TF 2 326 namespace in a future version of TensorFlow. Once this occurs, this class will 327 only be accessible as `tf.compat.v1.mixed_precision.DynamicLossScale`. All the 328 functionality in this class has been merged into 329 `tf.keras.mixed_precision.LossScaleOptimizer`, so this class is no longer 330 needed. 331 332 Dynamic loss scaling works by adjusting the loss scale as training progresses. 333 The goal is to keep the loss scale as high as possible without overflowing the 334 gradients. As long as the gradients do not overflow, raising the loss scale 335 never hurts. 336 337 The algorithm starts by setting the loss scale to an initial value. Every N 338 steps that the gradients are finite, the loss scale is increased by some 339 factor. However, if a NaN or Inf gradient is found, the gradients for that 340 step are not applied, and the loss scale is decreased by the factor. This 341 process tends to keep the loss scale as high as possible without gradients 342 overflowing. 343 """ 344 345 @deprecation.deprecated( 346 None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. ' 347 'LossScaleOptimizer now has all the functionality of ' 348 'DynamicLossScale') 349 def __init__(self, 350 initial_loss_scale=2 ** 15, # See docstring for why this is big. 351 increment_period=2000, 352 multiplier=2.): 353 """Creates the dynamic loss scale. 354 355 Args: 356 initial_loss_scale: A Python float. The loss scale to use at the 357 beginning. It's better to start this at a very high number, because a 358 loss scale that is too high gets lowered far more quickly than a loss 359 scale that is too low gets raised. The default is 2 ** 15, which is 360 approximately half the maximum float16 value. 361 increment_period: Increases loss scale every `increment_period` 362 consecutive steps that finite gradients are encountered. If a nonfinite 363 gradient is encountered, the count is reset back to zero. 364 multiplier: The multiplier to use when increasing or decreasing the loss 365 scale. 366 """ 367 super(DynamicLossScale, self).__init__() 368 self._initial_loss_scale = float(initial_loss_scale) 369 self._increment_period = int(increment_period) 370 self._multiplier = float(multiplier) 371 372 self._current_loss_scale = self._add_weight( 373 name='current_loss_scale', 374 dtype=dtypes.float32, 375 initial_value=self._initial_loss_scale) 376 # The number of consecutive steps with finite gradients since the last 377 # nonfinite gradient or change in loss scale. 378 self._num_good_steps = self._add_weight( 379 name='good_steps', dtype=dtypes.int64, initial_value=0) 380 381 @property 382 def initial_loss_scale(self): 383 return self._initial_loss_scale 384 385 @property 386 def increment_period(self): 387 return self._increment_period 388 389 @property 390 def multiplier(self): 391 return self._multiplier 392 393 def __call__(self): 394 return ops.convert_to_tensor(self._current_loss_scale) 395 396 def update(self, grads): 397 """Updates loss scale based on if gradients are finite in current step.""" 398 grads = nest.flatten(grads) 399 if distribution_strategy_context.has_strategy(): 400 distribution = distribution_strategy_context.get_cross_replica_context() 401 402 def get_is_finite(grads): 403 is_finite = _is_all_finite(grads) 404 # We cast to float, because we cannot reduce booleans with 405 # DistributionStrategy. 406 return math_ops.cast(is_finite, dtypes.float32) 407 408 is_finite_float = distribution.extended.call_for_each_replica( 409 get_is_finite, args=(grads,)) 410 reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, 411 is_finite_float, axis=None) 412 is_finite = math_ops.equal(reduced_is_finite_float, 413 distribution.num_replicas_in_sync) 414 else: 415 is_finite = _is_all_finite(grads) 416 417 def update_if_finite_grads(): 418 """Update assuming the gradients are finite.""" 419 420 def incr_loss_scale(): 421 new_loss_scale = self._current_loss_scale * self._multiplier 422 return control_flow_ops.group( 423 _assign_if_finite(self._current_loss_scale, new_loss_scale), 424 self._num_good_steps.assign(0)) 425 426 return control_flow_ops.cond( 427 self._num_good_steps + 1 >= self._increment_period, 428 incr_loss_scale, lambda: _op_in_graph_mode( 429 self._num_good_steps.assign_add(1))) 430 431 def update_if_not_finite_grads(): 432 """Update assuming the gradients are nonfinite.""" 433 434 new_loss_scale = math_ops.maximum( 435 self._current_loss_scale / self._multiplier, 1) 436 return control_flow_ops.group( 437 self._num_good_steps.assign(0), 438 self._current_loss_scale.assign(new_loss_scale)) 439 440 update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, 441 update_if_not_finite_grads) 442 should_apply_gradients = is_finite 443 return update_op, should_apply_gradients 444 445 def __repr__(self): 446 if context.executing_eagerly(): 447 return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, ' 448 'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' % 449 (self._current_loss_scale.numpy(), self._num_good_steps.numpy(), 450 self.initial_loss_scale, self.increment_period, self.multiplier)) 451 else: 452 return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, ' 453 'multiplier=%s)' % 454 (self.initial_loss_scale, self.increment_period, self.multiplier)) 455 456 def get_config(self): 457 return { 458 'initial_loss_scale': self.initial_loss_scale, 459 'increment_period': self.increment_period, 460 'multiplier': self.multiplier, 461 } 462 463 464def get(identifier): 465 """Get a loss scale object.""" 466 if isinstance(identifier, six.integer_types + (float,)): 467 return FixedLossScale(identifier) 468 if identifier == 'dynamic': 469 return DynamicLossScale() 470 if isinstance(identifier, LossScale): 471 return identifier 472 elif identifier is None: 473 return None 474 else: 475 raise ValueError('Could not interpret loss scale identifier: %s' % 476 identifier) 477