1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the loss scaling optimizer class.""" 16 17from tensorflow.python.distribute import collective_all_reduce_strategy 18from tensorflow.python.distribute import distribution_strategy_context 19from tensorflow.python.distribute import mirrored_strategy 20from tensorflow.python.distribute import one_device_strategy 21from tensorflow.python.distribute import tpu_strategy 22from tensorflow.python.eager import backprop 23from tensorflow.python.eager import context 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import smart_cond 27from tensorflow.python.keras import backend 28from tensorflow.python.keras import optimizers 29from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module 30from tensorflow.python.keras.optimizer_v2 import optimizer_v2 31from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import tf_logging 37from tensorflow.python.training.experimental import loss_scale as loss_scale_module 38from tensorflow.python.training.experimental import mixed_precision 39from tensorflow.python.training.tracking import base as trackable 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import keras_export 42 43 44class _UnwrapPreventer(object): 45 """Wrapper that DistributionStrategy will not unwrap. 46 47 Typically, DistributionStrategy will unwrap values when going from a cross- 48 replica context to a replica context via `call_for_each_replica`. This class 49 is a wrapper that DistributionStrategy will not unwrap, so it can be used to 50 prevent it from unwrapping a value. 51 52 TODO(reedwm): Find/implement a better way of preventing values from being 53 unwrapped by DistributionStrategy 54 """ 55 56 __slots__ = ['value'] 57 58 def __init__(self, value): 59 self.value = value 60 61 62class _DelegatingTrackableMixin(object): 63 """A mixin that delegates all Trackable methods to another trackable object. 64 65 This class must be used with multiple inheritance. A class that subclasses 66 Trackable can also subclass this class, which causes all Trackable methods to 67 be delegated to the trackable object passed in the constructor. 68 69 A subclass can use this mixin to appear as if it were the trackable passed to 70 the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this 71 mixin, so that the checkpoint format for a LossScaleOptimizer is identical to 72 the checkpoint format for a normal optimizer. This allows a model to be saved 73 with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa. 74 The only difference in checkpoint format is that the loss scale is also saved 75 with a LossScaleOptimizer. 76 """ 77 78 def __init__(self, trackable_obj): 79 self._trackable = trackable_obj 80 81 # pylint: disable=protected-access 82 @property 83 def _setattr_tracking(self): 84 return self._trackable._setattr_tracking 85 86 @_setattr_tracking.setter 87 def _setattr_tracking(self, value): 88 self._trackable._setattr_tracking = value 89 90 @property 91 def _update_uid(self): 92 return self._trackable._update_uid 93 94 @_update_uid.setter 95 def _update_uid(self, value): 96 self._trackable._update_uid = value 97 98 @property 99 def _unconditional_checkpoint_dependencies(self): 100 return self._trackable._unconditional_checkpoint_dependencies 101 102 @property 103 def _unconditional_dependency_names(self): 104 return self._trackable._unconditional_dependency_names 105 106 @property 107 def _name_based_restores(self): 108 return self._trackable._name_based_restores 109 110 def _maybe_initialize_trackable(self): 111 return self._trackable._maybe_initialize_trackable() 112 113 @property 114 def _object_identifier(self): 115 return self._trackable._object_identifier 116 117 @property 118 def _tracking_metadata(self): 119 return self._trackable._tracking_metadata 120 121 def _no_dependency(self, value): 122 return self._trackable._no_dependency(value) 123 124 def _name_based_attribute_restore(self, checkpoint): 125 return self._trackable._name_based_attribute_restore(checkpoint) 126 127 @property 128 def _checkpoint_dependencies(self): 129 return self._trackable._checkpoint_dependencies 130 131 @property 132 def _deferred_dependencies(self): 133 return self._trackable._deferred_dependencies 134 135 def _lookup_dependency(self, name): 136 self._trackable._lookup_dependency(name) 137 138 def _add_variable_with_custom_getter(self, 139 name, 140 shape=None, 141 dtype=dtypes.float32, 142 initializer=None, 143 getter=None, 144 overwrite=False, 145 **kwargs_for_getter): 146 return self._trackable._add_variable_with_custom_getter( 147 name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter) 148 149 def _preload_simple_restoration(self, name): 150 return self._trackable._preload_simple_restoration(name) 151 152 def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name 153 return self._trackable._track_trackable(trackable, name, overwrite) 154 155 def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name 156 return self._trackable._handle_deferred_dependencies(name, trackable) 157 158 def _restore_from_checkpoint_position(self, checkpoint_position): 159 return self._trackable._restore_from_checkpoint_position( 160 checkpoint_position) 161 162 def _single_restoration_from_checkpoint_position(self, checkpoint_position, 163 visit_queue): 164 return self._trackable._single_restoration_from_checkpoint_position( 165 checkpoint_position, visit_queue) 166 167 def _gather_saveables_for_checkpoint(self): 168 return self._trackable._gather_saveables_for_checkpoint() 169 170 def _list_extra_dependencies_for_serialization(self, serialization_cache): 171 return self._trackable._list_extra_dependencies_for_serialization( 172 serialization_cache) 173 174 def _list_functions_for_serialization(self, serialization_cache): 175 return self._trackable._list_functions_for_serialization( 176 serialization_cache) 177 # pylint: enable=protected-access 178 179 180def _is_all_finite(grads): 181 """Returns a scalar boolean tensor indicating if all gradients are finite.""" 182 is_finite_per_grad = [ 183 math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None 184 ] 185 return math_ops.reduce_all(is_finite_per_grad) 186 187 188def _op_in_graph_mode(tensor): 189 """Returns the tensor's op in graph mode, or the tensor in eager mode. 190 191 This is useful because sometimes an op is needed in graph mode instead of a 192 tensor. In eager mode, there are no ops. 193 194 Args: 195 tensor: A tensor. 196 197 Returns: 198 The tensor's op in graph mode. The tensor in eager mode. 199 """ 200 if context.executing_eagerly(): 201 return tensor 202 return tensor.op 203 204 205def _assign_if_finite(var, value): 206 """Assigns a value to a variable if the value is finite.""" 207 return control_flow_ops.cond( 208 math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), 209 control_flow_ops.no_op) 210 211 212class _DynamicLossScaleState(trackable.Trackable): 213 """The state of a dynamic loss scale.""" 214 215 def __init__(self, 216 initial_loss_scale, 217 growth_steps, 218 multiplier): 219 """Creates the dynamic loss scale.""" 220 super(_DynamicLossScaleState, self).__init__() 221 self._initial_loss_scale = float(initial_loss_scale) 222 self._growth_steps = int(growth_steps) 223 self._multiplier = float(multiplier) 224 225 self._weights = {} 226 self._current_loss_scale = self._add_weight( 227 name='current_loss_scale', 228 dtype=dtypes.float32, 229 initial_value=self._initial_loss_scale) 230 # The number of consecutive steps with finite gradients since the last 231 # nonfinite gradient or change in loss scale. The name is 'good_steps' for 232 # backwards compatibility with older checkpoints. 233 self._counter = self._add_weight( 234 name='good_steps', dtype=dtypes.int64, initial_value=0) 235 236 def _add_weight(self, name, initial_value, dtype=None): 237 """Adds a weight to this loss scale. 238 239 Args: 240 name: Variable name. 241 initial_value: The variable's initial value. 242 dtype: The type of the variable. 243 244 Returns: 245 A variable. 246 247 Raises: 248 RuntimeError: If a weight with `name` has already been added. 249 """ 250 variable = variable_scope.variable( 251 initial_value=initial_value, 252 name=name, 253 dtype=dtype, 254 trainable=False, 255 use_resource=True, 256 synchronization=variables.VariableSynchronization.AUTO, 257 # Set aggregation to NONE, as loss scaling variables should never be 258 # aggregated. 259 aggregation=variables.VariableAggregation.NONE) 260 if context.executing_eagerly(): 261 graph_key = None 262 else: 263 graph = ops.get_default_graph() 264 graph_key = graph._graph_key # pylint: disable=protected-access 265 266 key = (name, graph_key) 267 self._weights[key] = variable 268 self._handle_deferred_dependencies(name=name, trackable=variable) 269 backend.track_variable(variable) 270 return variable 271 272 @property 273 def _checkpoint_dependencies(self): 274 """From Trackable. Gather graph-specific weights to save.""" 275 if context.executing_eagerly(): 276 graph_key = None 277 else: 278 graph = ops.get_default_graph() 279 graph_key = graph._graph_key # pylint: disable=protected-access 280 weights = [] 281 for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): 282 if g == graph_key: 283 weights.append(trackable.TrackableReference(name=name, ref=v)) 284 return (super(_DynamicLossScaleState, self)._checkpoint_dependencies + 285 weights) 286 287 def _lookup_dependency(self, name): 288 """From Trackable. Find a weight in the current graph.""" 289 unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name) 290 if unconditional is not None: 291 return unconditional 292 if context.executing_eagerly(): 293 graph_key = None 294 else: 295 graph = ops.get_default_graph() 296 graph_key = graph._graph_key # pylint: disable=protected-access 297 return self._weights.get((name, graph_key), None) 298 299 @property 300 def initial_loss_scale(self): 301 return self._initial_loss_scale 302 303 @property 304 def growth_steps(self): 305 return self._growth_steps 306 307 @property 308 def multiplier(self): 309 return self._multiplier 310 311 @property 312 def current_loss_scale(self): 313 """Returns the current loss scale as a float32 `tf.Variable`.""" 314 return self._current_loss_scale 315 316 @property 317 def counter(self): 318 """Returns the counter as a float32 `tf.Variable`.""" 319 return self._counter 320 321 def __call__(self): 322 """Returns the current loss scale as a scalar `float32` tensor.""" 323 return ops.convert_to_tensor_v2_with_dispatch(self._current_loss_scale) 324 325 def update(self, grads): 326 """Updates the value of the loss scale. 327 328 Args: 329 grads: A nested structure of unscaled gradients, each which is an 330 all-reduced gradient of the loss with respect to a weight. 331 332 Returns: 333 update_op: In eager mode, None. In graph mode, an op to update the loss 334 scale. 335 should_apply_gradients: Either a bool or a scalar boolean tensor. If 336 False, the caller should skip applying `grads` to the variables this 337 step. 338 """ 339 grads = nest.flatten(grads) 340 if distribution_strategy_context.has_strategy( 341 ) and distribution_strategy_context.in_cross_replica_context(): 342 distribution = distribution_strategy_context.get_strategy() 343 is_finite_per_replica = distribution.extended.call_for_each_replica( 344 _is_all_finite, args=(grads,)) 345 # Each replica computed the same `is_finite` value, since `grads` is 346 # all-reduced across replicas. Arbitrarily take `is_finite` from the first 347 # replica. 348 is_finite = ( 349 distribution.experimental_local_results(is_finite_per_replica)[0]) 350 else: 351 is_finite = _is_all_finite(grads) 352 353 def update_if_finite_grads(): 354 """Update assuming the gradients are finite.""" 355 356 def incr_loss_scale(): 357 new_loss_scale = self.current_loss_scale * self.multiplier 358 return control_flow_ops.group( 359 _assign_if_finite(self.current_loss_scale, new_loss_scale), 360 self.counter.assign(0)) 361 362 return control_flow_ops.cond( 363 self.counter + 1 >= self.growth_steps, 364 incr_loss_scale, 365 lambda: _op_in_graph_mode(self.counter.assign_add(1))) 366 367 def update_if_not_finite_grads(): 368 """Update assuming the gradients are nonfinite.""" 369 370 new_loss_scale = math_ops.maximum( 371 self.current_loss_scale / self.multiplier, 1) 372 return control_flow_ops.group( 373 self.counter.assign(0), 374 self.current_loss_scale.assign(new_loss_scale)) 375 376 update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, 377 update_if_not_finite_grads) 378 should_apply_gradients = is_finite 379 return update_op, should_apply_gradients 380 381 382# See LossScaleOptimizer docstring for why this is so big 383_DEFAULT_INITIAL_SCALE = 2 ** 15 384_DEFAULT_GROWTH_STEPS = 2000 385 386 387# pylint: disable=g-classes-have-attributes 388@keras_export('keras.mixed_precision.LossScaleOptimizer') 389class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): 390 """An optimizer that applies loss scaling to prevent numeric underflow. 391 392 Loss scaling is a technique to prevent numeric underflow in intermediate 393 gradients when float16 is used. To prevent underflow, the loss is multiplied 394 (or "scaled") by a certain factor called the "loss scale", which causes 395 intermediate gradients to be scaled by the loss scale as well. The final 396 gradients are divided (or "unscaled") by the loss scale to bring them back to 397 their original value. 398 399 `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. 400 By default, the loss scale is dynamically updated over time so you do not have 401 to choose the loss scale. The `minimize` method automatically scales the loss, 402 unscales the gradients, and updates the loss scale so all you have to do is 403 wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For 404 example: 405 406 >>> opt = tf.keras.optimizers.SGD(0.25) 407 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 408 >>> var = tf.Variable(1.) 409 >>> loss_fn = lambda: var ** 2 410 >>> # 'minimize' applies loss scaling and updates the loss sale. 411 >>> opt.minimize(loss_fn, var_list=var) 412 >>> var.numpy() 413 0.5 414 415 If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you 416 must scale the loss and gradients manually. This can be done with the 417 `LossScaleOptimizer.get_scaled_loss` and 418 `LossScaleOptimizer.get_unscaled_gradients` methods. For example: 419 420 >>> with tf.GradientTape() as tape: 421 ... loss = loss_fn() 422 ... scaled_loss = opt.get_scaled_loss(loss) 423 >>> scaled_grad = tape.gradient(scaled_loss, var) 424 >>> (grad,) = opt.get_unscaled_gradients([scaled_grad]) 425 >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here 426 >>> var.numpy() 427 0.25 428 429 Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients` 430 (or both) when using a `tf.GradientTape`, the model will likely converge to a 431 worse quality. Please make sure you call each function exactly once. 432 433 When mixed precision with float16 is used, there is typically no risk of 434 underflow affecting model quality if loss scaling is properly used. See 435 [the mixed precision guide]( 436 https://www.tensorflow.org/guide/keras/mixed_precision) for more information 437 on how to use mixed precision. 438 439 Args: 440 inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap. 441 dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to 442 True. If True, the loss scale will be dynamically updated over time using 443 an algorithm that keeps the loss scale at approximately its optimal value. 444 If False, a single fixed loss scale is used and `initial_scale` must be 445 specified, which is used as the loss scale. Recommended to keep as True, 446 as choosing a fixed loss scale can be tricky. Currently, there is a small 447 performance overhead to dynamic loss scaling compared to fixed loss 448 scaling. 449 initial_scale: The initial loss scale. If `dynamic` is True, this defaults 450 to `2 ** 15`. If `dynamic` is False, this must be specified and acts as 451 the sole loss scale, as the loss scale does not change over time. When 452 dynamic loss scaling is used, is better for this to be a very high number, 453 because a loss scale that is too high gets lowered far more quickly than a 454 loss scale that is too low gets raised. 455 dynamic_growth_steps: With dynamic loss scaling, every 456 `dynamic_growth_steps` steps with finite gradients, the loss scale is 457 doubled. Defaults to 2000. If a nonfinite gradient is encountered, the 458 count is reset back to zero, gradients are skipped that step, and the loss 459 scale is halved. The count can be queried with 460 `LossScaleOptimizer.dynamic_counter`. This argument can only be specified 461 if `dynamic` is True. 462 463 `LossScaleOptimizer` will occasionally skip applying gradients to the 464 variables, in which case the trainable variables will not change that step. 465 This is done because the dynamic loss scale will sometimes be raised too 466 high, causing overflow in the gradients. Typically, the first 2 to 15 steps of 467 the model are skipped as the initial loss scale is very high, but afterwards 468 steps will only be skipped on average 0.05% of the time (the fraction of steps 469 skipped is `1 / dynamic_growth_steps`). 470 471 `LossScaleOptimizer` delegates all public `Optimizer` methods to the inner 472 optimizer. Additionally, in methods `minimize` and `get_gradients`, it scales 473 the loss and unscales the gradients. In methods `minimize` and 474 `apply_gradients`, it additionally updates the loss scale and skips applying 475 gradients if any gradient has a nonfinite value. 476 477 ### Hyperparameters 478 479 Hyperparameters can be accessed and set on the LossScaleOptimizer, which will 480 be delegated to the wrapped optimizer. 481 482 >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) 483 >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) 484 >>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1` 485 0.8 486 >>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7` 487 >>> opt.beta_1 488 0.7 489 >>> opt.inner_optimizer.beta_1 490 0.7 491 492 However, accessing or setting non-hyperparameters is not delegated to the 493 LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but 494 `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on 495 `beta_1`. 496 497 >>> opt.inner_optimizer.epsilon 498 1e-5 499 >>> opt.epsilon 500 Traceback (most recent call last): 501 ... 502 AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' 503 >>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer` 504 >>> opt.inner_optimizer.epsilon 505 >>> 1e-5 506 507 In the above example, despite epsilon being set on the LossScaleOptimizer, the 508 old epsilon value will still be used when training as epsilon was not set on 509 the inner optimizer. 510 """ 511 512 _HAS_AGGREGATE_GRAD = True 513 514 def __init__(self, inner_optimizer, dynamic=True, initial_scale=None, 515 dynamic_growth_steps=None): 516 if not isinstance(inner_optimizer, optimizer_v2.OptimizerV2): 517 raise TypeError('"inner_optimizer" must be an instance of OptimizerV2, ' 518 'but got: %s' % inner_optimizer) 519 if not isinstance(dynamic, bool): 520 # Catch errors if a user incorrectly passes a string or float to the 521 # second argument argument, as this is commonly done for 522 # LossScaleOptimizerV1. 523 raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must ' 524 'be a bool, but got: %r' % (dynamic,)) 525 if isinstance(inner_optimizer, LossScaleOptimizer): 526 raise TypeError('LossScaleOptimizer cannot wrap another ' 527 'LossScaleOptimizer, but got: %s' % (inner_optimizer,)) 528 self._raise_if_strategy_unsupported() 529 if getattr(inner_optimizer, '_is_wrapped_by_loss_scale_optimizer', False): 530 # TODO(reedwm): Maybe support this. The difficulty is that LSO has the 531 # same checkpoint format as the inner optimizer, so multiple LSOs wrapping 532 # the same optimizer causes the checkpointing logic to become confused. 533 raise ValueError('"inner_optimizer" is already wrapped by a ' 534 'LossScaleOptimizer. An optimizer can only be wrapped ' 535 'by a single LossScaleOptimizer') 536 self._optimizer = inner_optimizer 537 self._optimizer._is_wrapped_by_loss_scale_optimizer = True 538 539 # We don't call super().__init__, since we do not want to call OptimizerV2's 540 # constructor. 541 _DelegatingTrackableMixin.__init__(self, self._optimizer) 542 543 if dynamic: 544 if initial_scale is None: 545 initial_scale = _DEFAULT_INITIAL_SCALE 546 if dynamic_growth_steps is None: 547 dynamic_growth_steps = _DEFAULT_GROWTH_STEPS 548 self._loss_scale = _DynamicLossScaleState( 549 initial_scale, dynamic_growth_steps, multiplier=2) 550 self._track_trackable(self._loss_scale, 'loss_scale') 551 else: 552 if initial_scale is None: 553 raise ValueError('"initial_scale" must be specified if "dynamic" is ' 554 'False') 555 self._loss_scale = float(initial_scale) 556 if dynamic_growth_steps is not None: 557 raise ValueError('"dynamic_growth_steps" must be None if "dynamic" ' 558 'is False, but got: %s' % (dynamic_growth_steps,)) 559 560 # To support restoring TensorFlow 2.2 checkpoints. 561 self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 562 'base_optimizer') 563 564 @property 565 def dynamic(self): 566 """Bool indicating whether dynamic loss scaling is used.""" 567 return isinstance(self._loss_scale, _DynamicLossScaleState) 568 569 @property 570 def loss_scale(self): 571 """The current loss scale as a float32 scalar tensor.""" 572 if isinstance(self._loss_scale, _DynamicLossScaleState): 573 return ops.convert_to_tensor_v2_with_dispatch( 574 self._loss_scale.current_loss_scale) 575 else: 576 return ops.convert_to_tensor_v2_with_dispatch(self._loss_scale) 577 578 @property 579 def dynamic_counter(self): 580 """The number of steps since the loss scale was last increased or decreased. 581 582 This is None if `LossScaleOptimizer.dynamic` is False. 583 584 The counter is incremented every step. Once it reaches 585 `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled 586 and the counter will be reset back to zero. If nonfinite gradients are 587 encountered, the loss scale will be halved and the counter will be reset 588 back to zero. 589 """ 590 if isinstance(self._loss_scale, _DynamicLossScaleState): 591 return self._loss_scale.counter 592 else: 593 return None 594 595 @property 596 def initial_scale(self): 597 """The initial loss scale. 598 599 If `LossScaleOptimizer.dynamic` is False, this is the same number as 600 `LossScaleOptimizer.loss_scale`, as the loss scale never changes. 601 """ 602 if isinstance(self._loss_scale, _DynamicLossScaleState): 603 return self._loss_scale.initial_loss_scale 604 else: 605 return self._loss_scale 606 607 @property 608 def dynamic_growth_steps(self): 609 """The number of steps it takes to increase the loss scale. 610 611 This is None if `LossScaleOptimizer.dynamic` is False. 612 613 Every `dynamic_growth_steps` consecutive steps with finite gradients, the 614 loss scale is increased. 615 """ 616 if isinstance(self._loss_scale, _DynamicLossScaleState): 617 return self._loss_scale.growth_steps 618 else: 619 return None 620 621 @property 622 def inner_optimizer(self): 623 """The optimizer that this LossScaleOptimizer is wrapping.""" 624 return self._optimizer 625 626 def get_scaled_loss(self, loss): 627 """Scales the loss by the loss scale. 628 629 This method is only needed if you compute gradients manually, e.g. with 630 `tf.GradientTape`. In that case, call this method to scale the loss before 631 passing the loss to `tf.GradientTape`. If you use 632 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 633 scaling is automatically applied and this method is unneeded. 634 635 If this method is called, `get_unscaled_gradients` should also be called. 636 See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for 637 an example. 638 639 Args: 640 loss: The loss, which will be multiplied by the loss scale. Can either be 641 a tensor or a callable returning a tensor. 642 643 Returns: 644 `loss` multiplied by `LossScaleOptimizer.loss_scale`. 645 """ 646 if callable(loss): 647 def new_loss(): 648 loss_val = loss() 649 return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype) 650 return new_loss 651 else: 652 return loss * math_ops.cast(self.loss_scale, loss.dtype) 653 654 def get_unscaled_gradients(self, grads): 655 """Unscales the gradients by the loss scale. 656 657 This method is only needed if you compute gradients manually, e.g. with 658 `tf.GradientTape`. In that case, call this method to unscale the gradients 659 after computing them with `tf.GradientTape`. If you use 660 `LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss 661 scaling is automatically applied and this method is unneeded. 662 663 If this method is called, `get_scaled_loss` should also be called. See 664 the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an 665 example. 666 667 Args: 668 grads: A list of tensors, each which will be divided by the loss scale. 669 Can have None values, which are ignored. 670 671 Returns: 672 A new list the same size as `grads`, where every non-None value in `grads` 673 is divided by `LossScaleOptimizer.loss_scale`. 674 """ 675 loss_scale_reciprocal = 1. / self.loss_scale 676 return [ 677 _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None 678 for g in grads 679 ] 680 681 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 682 tape = backprop.GradientTape() if tape is None else tape 683 with tape: 684 loss = self.get_scaled_loss(loss) 685 grads_and_vars = self._optimizer._compute_gradients( # pylint: disable=protected-access 686 loss, 687 var_list, 688 grad_loss, 689 tape=tape) 690 grads = [g for g, _ in grads_and_vars] 691 weights = [v for _, v in grads_and_vars] 692 unscaled_grads = self.get_unscaled_gradients(grads) 693 return list(zip(unscaled_grads, weights)) 694 695 def get_gradients(self, loss, params): 696 loss = self.get_scaled_loss(loss) 697 grads = self._optimizer.get_gradients(loss, params) 698 return self.get_unscaled_gradients(grads) 699 700 def _create_all_weights(self, var_list): 701 self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access 702 703 def apply_gradients(self, 704 grads_and_vars, 705 name=None, 706 experimental_aggregate_gradients=True): 707 if distribution_strategy_context.in_cross_replica_context(): 708 raise ValueError('apply_gradients() must be called in a replica context.') 709 # We check for the strategy here despite already checking in the constructor 710 # as frequently the optimizer is created outside the strategy's scope. 711 self._raise_if_strategy_unsupported() 712 713 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 714 if experimental_aggregate_gradients: 715 # We must aggregate the gradients here instead of in 716 # self.optimizer.apply_gradients, so that any NaN or Inf gradients are 717 # propogated to each replica. If any replica has a NaN or Inf gradient, 718 # they must all have a NaN or Inf gradient so that they all skip the step. 719 # pylint: disable=protected-access 720 grads_and_vars = self._optimizer._transform_unaggregated_gradients( 721 grads_and_vars) 722 grads_and_vars = self._optimizer._aggregate_gradients(grads_and_vars) 723 # pylint: enable=protected-access 724 725 grads_and_vars = tuple(grads_and_vars) 726 grads = [g for g, _ in grads_and_vars] 727 # We do not want DistributionStrategy to unwrap any MirroredVariables in 728 # grads_and_vars, because even in a replica context, the wrapped 729 # optimizer expects mirrored variables. So we wrap the variables with an 730 # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the 731 # MirroredVariables. 732 wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) 733 734 def do_not_apply_fn(): 735 # Normally self._optimizer.iterations is incremented in 736 # self._optimizer.apply_gradients(). Since that is not called in this 737 # branch, we increment it here instead. 738 return self._optimizer.iterations.assign_add(1, read_value=False) 739 740 def _if_should_apply_grads(grads): 741 if isinstance(self._loss_scale, _DynamicLossScaleState): 742 return self._loss_scale.update(grads) 743 else: 744 return (control_flow_ops.no_op(), True) 745 746 if optimizer_utils.strategy_supports_no_merge_call(): 747 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads) 748 def apply_fn(): 749 return self._apply_gradients(grads, wrapped_vars, name) 750 751 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 752 do_not_apply_fn) 753 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 754 755 else: 756 757 def _apply_gradients_cross_replica(distribution, grads, wrapped_vars, 758 name): 759 loss_scale_update_op, should_apply_grads = _if_should_apply_grads(grads) 760 761 def apply_fn(): 762 return distribution.extended.call_for_each_replica( 763 self._apply_gradients, 764 args=(grads, wrapped_vars, name)) 765 766 # Note: We must call this cond() in a cross-replica context. 767 # DistributionStrategy does not support having a cond in a replica 768 # context with a branch that calls `merge_call`, and 769 # self._optimizer.apply_gradients calls `merge_call`. 770 maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, 771 do_not_apply_fn) 772 return control_flow_ops.group(maybe_apply_op, loss_scale_update_op) 773 return distribution_strategy_context.get_replica_context().merge_call( 774 _apply_gradients_cross_replica, 775 args=(grads, wrapped_vars, name)) 776 777 def _apply_gradients(self, grads, wrapped_vars, name): 778 # Pass experimental_aggregate_gradients=False since LossScaleOptimizer 779 # already aggregated the gradients. 780 # TODO(reedwm): This will raise a fairly cryptic error message if 781 # self._optimizer.apply_gradients does not take 782 # experimental_aggregate_gradients. 783 return self._optimizer.apply_gradients( 784 list(zip(grads, wrapped_vars.value)), name, 785 experimental_aggregate_gradients=False) 786 787 def get_config(self): 788 serialized_optimizer = optimizers.serialize(self._optimizer) 789 return { 790 'inner_optimizer': serialized_optimizer, 791 'dynamic': self.dynamic, 792 'initial_scale': self.initial_scale, 793 'dynamic_growth_steps': self.dynamic_growth_steps, 794 } 795 796 @classmethod 797 def from_config(cls, config, custom_objects=None): 798 config = config.copy() # Make a copy, since we mutate config 799 if 'loss_scale' in config: 800 # If loss_scale is in config, we assume we are deserializing a 801 # LossScaleOptimizer from TF 2.3 or below. We convert the config so it 802 # can be deserialized in the current LossScaleOptimizer. 803 loss_scale = keras_loss_scale_module.deserialize( 804 config.pop('loss_scale')) 805 if isinstance(loss_scale, loss_scale_module.FixedLossScale): 806 config['dynamic'] = False 807 config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access 808 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 809 config['dynamic'] = True 810 config['initial_scale'] = loss_scale.initial_loss_scale 811 config['dynamic_growth_steps'] = loss_scale.increment_period 812 if loss_scale.multiplier != 2: 813 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 814 'DynamicLossScale whose multiplier is not 2. Got ' 815 'DynamicLossScale: %s' % (loss_scale,)) 816 else: 817 raise ValueError( 818 'Serialized LossScaleOptimizers with a LossScale that is neither a ' 819 'FixedLossScale nor a DynamicLossScale can no longer be ' 820 'deserialized') 821 config['inner_optimizer'] = config.pop('optimizer') 822 config['inner_optimizer'] = optimizers.deserialize( 823 config['inner_optimizer'], custom_objects=custom_objects) 824 return cls(**config) 825 826 def _raise_if_strategy_unsupported(self): 827 if not strategy_supports_loss_scaling(): 828 strategy = distribution_strategy_context.get_strategy() 829 if isinstance(strategy, 830 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 831 tpu_strategy.TPUStrategyV2)): 832 raise ValueError( 833 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 834 'unnecessary with TPUs, since they support bfloat16 instead of ' 835 'float16 and bfloat16 does not require loss scaling. You should ' 836 'remove the use of the LossScaleOptimizer when TPUs are used.') 837 else: 838 raise ValueError('Loss scaling is not supported with the ' 839 'tf.distribute.Strategy: %s. Try using a different ' 840 'Strategy, e.g. a MirroredStrategy' % 841 strategy.__class__.__name__) 842 843 # Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer 844 # below. 845 846 @property 847 def iterations(self): 848 return self._optimizer.iterations 849 850 @iterations.setter 851 def iterations(self, variable): 852 self._optimizer.iterations = variable 853 854 def get_slot_names(self): 855 return self._optimizer.get_slot_names() 856 857 def variables(self): 858 return self._optimizer.variables() 859 860 @property 861 def weights(self): 862 return self._optimizer.weights 863 864 def get_weights(self): 865 return self._optimizer.get_weights() 866 867 def set_weights(self, weights): 868 return self._optimizer.set_weights(weights) 869 870 @property 871 def clipnorm(self): 872 return self._optimizer.clipnorm 873 874 @clipnorm.setter 875 def clipnorm(self, val): 876 self._optimizer.clipnorm = val 877 878 @property 879 def global_clipnorm(self): 880 return self._optimizer.global_clipnorm 881 882 @global_clipnorm.setter 883 def global_clipnorm(self, val): 884 self._optimizer.global_clipnorm = val 885 886 @property 887 def clipvalue(self): 888 return self._optimizer.clipvalue 889 890 @clipvalue.setter 891 def clipvalue(self, val): 892 self._optimizer.clipvalue = val 893 894 def _aggregate_gradients(self, grads_and_vars): 895 return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access 896 897 def _restore_slot_variable(self, slot_name, variable, slot_variable): 898 return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access 899 slot_variable) 900 901 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 902 variable): 903 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 904 slot_variable_position, slot_name, variable) 905 906 def get_slot(self, var, slot_name): 907 return self._optimizer.get_slot(var, slot_name) 908 909 def add_slot(self, var, slot_name, initializer='zeros'): 910 return self._optimizer.add_slot(var, slot_name, initializer) 911 912 def __getattribute__(self, name): 913 try: 914 return object.__getattribute__(self, name) 915 except AttributeError as e: 916 if name == '_optimizer' or name == '_hyper': 917 # Avoid infinite recursion 918 raise e 919 920 # Delegate hyperparameter accesses to inner optimizer. 921 if name == 'lr': 922 name = 'learning_rate' 923 if name in self._optimizer._hyper: 924 return self._optimizer._get_hyper(name) 925 raise e 926 927 def __dir__(self): 928 result = set(super(LossScaleOptimizer, self).__dir__()) 929 if '_optimizer' in result: 930 result |= self._optimizer._hyper.keys() 931 if 'learning_rate' in self._optimizer._hyper.keys(): 932 result.add('lr') 933 return list(result) 934 935 def __setattr__(self, name, value): 936 if name == 'lr': 937 name = 'learning_rate' 938 # Delegate setting hyperparameter to inner optimizer if the attribute does 939 # not exist on the LossScaleOptimizer 940 try: 941 # We cannot check for the 'iterations' attribute as it cannot be set after 942 # it is accessed. 943 if name != 'iterations': 944 object.__getattribute__(self, name) 945 has_attribute = True 946 except AttributeError: 947 has_attribute = False 948 if (name != '_optimizer' and name in self._optimizer._hyper 949 and not has_attribute): 950 self._optimizer._set_hyper(name, value) 951 else: 952 super(LossScaleOptimizer, self).__setattr__(name, value) 953 954 # Explicitly delegate learning_rate. Normally hyperparameters are delegated in 955 # __getattribute__, but if a hyperparameter is not in self._optimizer._hyper 956 # (e.g. because self._optimizer itself wraps another optimizer), then it won't 957 # be delegated. Since learning_rate is a very commonly accessed 958 # hyperparameter, we delegate it here. 959 @property 960 def learning_rate(self): 961 return self._optimizer.learning_rate 962 963 @learning_rate.setter 964 def learning_rate(self, value): 965 self._optimizer.learning_rate = value 966 967 @property 968 def lr(self): 969 return self._optimizer.learning_rate 970 971 @lr.setter 972 def lr(self, value): 973 self._optimizer.lr = value 974 975 # We do not override some OptimizerV2 methods. For each, we describe why we do 976 # not delegate them to self._optimizer: 977 # * get_updates: get_updates() calls get_gradients(). Since we override 978 # get_gradients(), we cannot delegate get_updates() to self._optimizer, 979 # otherwise the overridden get_gradients() method would not be called. 980 # Luckily, get_updates() does not access any OptimizerV2 fields, so 981 # inheriting the OptimizerV2 version works fine. 982 # * minimize: We don't delegate for a similar as get_updates(): it calls 983 # both self._compute_gradients() and self.apply_gradients(), and both need 984 # to have the LossScaleOptimizer version called. 985 986 # TODO(reedwm): Maybe throw an error if mixed precision is used without this 987 # optimizer being used. 988 989 990@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') 991class LossScaleOptimizerV1(LossScaleOptimizer): 992 """An deprecated optimizer that applies loss scaling. 993 994 Warning: This class is deprecated and will be removed in a future version of 995 TensorFlow. Please use the non-experimental class 996 `tf.keras.mixed_precision.LossScaleOptimizer` instead. 997 998 This class is identical to the non-experimental 999 `keras.mixed_precision.LossScaleOptimizer` except its constructor takes 1000 different arguments. For this class (the experimental version), the 1001 constructor takes a `loss_scale` argument. For the non-experimental class, 1002 the constructor encodes the loss scaling information in multiple arguments. 1003 Note that unlike this class, the non-experimental class does not accept a 1004 `tf.compat.v1.mixed_precision.LossScale`, which is deprecated. 1005 1006 If you currently use this class, you should switch to the non-experimental 1007 `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several 1008 examples of converting the use of the experimental class to the equivalent 1009 non-experimental class. 1010 1011 >>> # In all of the the examples below, `opt1` and `opt2` are identical 1012 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 1013 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 1014 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 1015 ... tf.keras.optimizers.SGD()) 1016 >>> assert opt1.get_config() == opt2.get_config() 1017 1018 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 1019 ... tf.keras.optimizers.SGD(), loss_scale=123) 1020 >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123 1021 >>> # refers to the initial loss scale, which is the single fixed loss scale 1022 >>> # when dynamic=False. 1023 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 1024 ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123) 1025 >>> assert opt1.get_config() == opt2.get_config() 1026 1027 >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale( 1028 ... initial_loss_scale=2048, increment_period=500) 1029 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 1030 ... tf.keras.optimizers.SGD(), loss_scale=loss_scale) 1031 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 1032 ... tf.keras.optimizers.SGD(), initial_scale=2048, 1033 ... dynamic_growth_steps=500) 1034 >>> assert opt1.get_config() == opt2.get_config() 1035 1036 Make sure to also switch from this class to the non-experimental class in 1037 isinstance checks, if you have any. If you do not do this, your model may run 1038 into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses 1039 the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to 1040 switch isinstance checks to the non-experimental `LossScaleOptimizer` even 1041 before using the non-experimental `LossScaleOptimizer`. 1042 1043 >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( 1044 ... tf.keras.optimizers.SGD(), loss_scale='dynamic') 1045 >>> # The experimental class subclasses the non-experimental class 1046 >>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer) 1047 True 1048 >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( 1049 ... tf.keras.optimizers.SGD()) 1050 >>> # The non-experimental class does NOT subclass the experimental class. 1051 >>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer) 1052 False 1053 1054 Args: 1055 optimizer: The Optimizer instance to wrap. 1056 loss_scale: The loss scale to scale the loss and gradients. This can 1057 either be an int/float to use a fixed loss scale, the string "dynamic" 1058 to use dynamic loss scaling, or an instance of a LossScale. The string 1059 "dynamic" equivalent to passing `DynamicLossScale()`, and passing an 1060 int/float is equivalent to passing a FixedLossScale with the given loss 1061 scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must 1062 be 2 (the default). 1063 """ 1064 1065 def __init__(self, optimizer, loss_scale): 1066 warn_msg_prefix = ( 1067 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' 1068 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' 1069 'instead. ') 1070 1071 if isinstance(loss_scale, dict): 1072 loss_scale = keras_loss_scale_module.deserialize(loss_scale) 1073 1074 if isinstance(loss_scale, (int, float)): 1075 tf_logging.warning( 1076 warn_msg_prefix + 'For example:\n' 1077 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 1078 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) 1079 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 1080 initial_scale=loss_scale) 1081 elif isinstance(loss_scale, loss_scale_module.FixedLossScale): 1082 ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access 1083 tf_logging.warning( 1084 warn_msg_prefix + 'For example:\n' 1085 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 1086 'opt, dynamic=False, initial_scale={})'.format(ls_val)) 1087 super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, 1088 initial_scale=ls_val) 1089 elif loss_scale == 'dynamic': 1090 tf_logging.warning( 1091 warn_msg_prefix + 'For example:\n' 1092 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 1093 'opt)') 1094 super(LossScaleOptimizerV1, self).__init__(optimizer) 1095 elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): 1096 kwargs = {} 1097 extra_arguments = '' 1098 if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: 1099 kwargs['initial_scale'] = loss_scale.initial_loss_scale 1100 extra_arguments += (', initial_scale=%s' % 1101 loss_scale.initial_loss_scale) 1102 if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: 1103 kwargs['dynamic_growth_steps'] = loss_scale.increment_period 1104 extra_arguments += (', dynamic_growth_steps=%s' % 1105 loss_scale.increment_period) 1106 if loss_scale.multiplier != 2: 1107 raise ValueError('When passing a DynamicLossScale to "loss_scale", ' 1108 'DynamicLossScale.multiplier must be 2. Got: %s' 1109 % (loss_scale,)) 1110 tf_logging.warning( 1111 warn_msg_prefix + 1112 'Note that the non-experimental LossScaleOptimizer does not take a ' 1113 'DynamicLossScale but instead takes the dynamic configuration ' 1114 'directly in the constructor. For example:\n' 1115 ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 1116 'opt{})\n'.format(extra_arguments)) 1117 super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) 1118 elif isinstance(loss_scale, loss_scale_module.LossScale): 1119 raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' 1120 'DynamicLossScale is no longer supported. Got: {}' 1121 .format(loss_scale)) 1122 else: 1123 raise ValueError('Invalid value passed to loss_scale. loss_scale ' 1124 'must be the string "dynamic" (recommended), an int, ' 1125 'a float, a FixedLossScale, or a DynamicLossScale. Got ' 1126 'value: {}'.format(loss_scale)) 1127 1128 @classmethod 1129 def from_config(cls, config, custom_objects=None): 1130 config = config.copy() # Make a copy, since we mutate config 1131 1132 # If loss_scale is in config, we assume we are deserializing a 1133 # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are 1134 # deserializing a LossScaleOptimizer from TF 2.4 or above. 1135 if 'loss_scale' in config: 1136 config['loss_scale'] = keras_loss_scale_module.deserialize( 1137 config['loss_scale']) 1138 if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) 1139 and config['loss_scale'].multiplier != 2): 1140 raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 1141 'DynamicLossScale whose multiplier is not 2. Got ' 1142 'DynamicLossScale: %s' % (config['loss_scale'],)) 1143 config['optimizer'] = optimizers.deserialize( 1144 config['optimizer'], custom_objects=custom_objects) 1145 return cls(**config) 1146 1147 # We convert the config, as generated by LossScaleOptimizer.get_config, to a 1148 # version that can be passed to LossScaleOptimizerV1.__init__ 1149 if config['dynamic']: 1150 config['loss_scale'] = loss_scale_module.DynamicLossScale( 1151 config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) 1152 else: 1153 config['loss_scale'] = loss_scale_module.FixedLossScale( 1154 config['initial_scale']) 1155 1156 del config['dynamic'] 1157 del config['initial_scale'] 1158 del config['dynamic_growth_steps'] 1159 config['optimizer'] = optimizers.deserialize( 1160 config.pop('inner_optimizer'), custom_objects=custom_objects) 1161 return cls(**config) 1162 1163 1164class FakeOptimizerForRestoration(trackable.Trackable): 1165 """A fake optimizer used to support restoring TensorFlow 2.2 checkpoints. 1166 1167 The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class 1168 exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow. 1169 1170 In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the 1171 following in LossScaleOptimizer.__init__ 1172 1173 ``` 1174 self._track_trackable(self._optimizer, 'base_optimizer') 1175 ``` 1176 1177 This means a dependency from the LossScaleOptimizer to the wrapped optimizer 1178 would be stored in the checkpoint. However now, the checkpoint format with a 1179 LossScaleOptimizer is the same as the format without a LossScaleOptimizer, 1180 except the loss scale is also stored. This means there is no dependency from 1181 the LossScaleOptimizer to the wrapped optimizer. Instead, the 1182 LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's 1183 perspective, by overriding all Trackable methods and delegating them to the 1184 wrapped optimizer. 1185 1186 To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency 1187 on this class instead of the inner optimizer. When restored, this class will 1188 instead restore the slot variables of the inner optimizer. Since this class 1189 has no variables, it does not affect the checkpoint when saved. 1190 """ 1191 1192 def __init__(self, optimizer): 1193 self._optimizer = optimizer 1194 1195 def get_slot_names(self): 1196 return self._optimizer.get_slot_names() 1197 1198 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 1199 variable): 1200 return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access 1201 slot_variable_position, slot_name, variable) 1202 1203 1204mixed_precision.register_loss_scale_wrapper(optimizer_v2.OptimizerV2, 1205 LossScaleOptimizerV1) 1206 1207 1208def _multiply_gradient(gradient, scale): 1209 """Multiply a (possibly sparse) gradient by the given scale factor.""" 1210 scale = math_ops.cast(scale, gradient.dtype) 1211 if isinstance(gradient, ops.IndexedSlices): 1212 return ops.IndexedSlices( 1213 gradient.values * scale, 1214 gradient.indices, 1215 dense_shape=gradient.dense_shape) 1216 else: 1217 return gradient * scale 1218 1219 1220def strategy_supports_loss_scaling(): 1221 """Returns True if the current Strategy supports loss scaling.""" 1222 if not distribution_strategy_context.has_strategy(): 1223 return True 1224 strategy = distribution_strategy_context.get_strategy() 1225 # Strategies are supported if either there is only one replica or if variables 1226 # are replicated per device. Otherwise, the current model.fit() implementation 1227 # and most custom training loops incorrectly unscale the gradients. Currently, 1228 # gradients are unscaled once per compute replica, but they should be unscaled 1229 # once per variable replica. When there is one variable replica for each 1230 # compute replica, this works fine, but otherwise issues will occur. 1231 # TODO(reedwm): Support all strategies. 1232 return isinstance(strategy, ( 1233 collective_all_reduce_strategy.CollectiveAllReduceStrategy, 1234 collective_all_reduce_strategy.CollectiveAllReduceStrategyV1, 1235 one_device_strategy.OneDeviceStrategy, 1236 one_device_strategy.OneDeviceStrategyV1, 1237 mirrored_strategy.MirroredStrategy, 1238 mirrored_strategy.MirroredStrategyV1, 1239 )) 1240