1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Version 2 of class Optimizer.""" 16# pylint: disable=g-bad-name 17 18import abc 19import contextlib 20import functools 21import warnings 22 23from tensorflow.python.distribute import central_storage_strategy 24from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 25from tensorflow.python.distribute import parameter_server_strategy 26from tensorflow.python.distribute import parameter_server_strategy_v2 27from tensorflow.python.distribute import values as ds_values 28from tensorflow.python.eager import backprop 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.keras import backend 34from tensorflow.python.keras import initializers 35from tensorflow.python.keras.engine import base_layer_utils 36from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 37from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 38from tensorflow.python.keras.utils import generic_utils 39from tensorflow.python.keras.utils import layer_utils 40from tensorflow.python.keras.utils import tf_inspect 41from tensorflow.python.keras.utils import tf_utils 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gen_resource_variable_ops 45from tensorflow.python.ops import gradients 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import variables as tf_variables 48from tensorflow.python.saved_model import revived_types 49from tensorflow.python.training.tracking import base as trackable 50from tensorflow.python.util import nest 51from tensorflow.python.util.tf_export import keras_export 52 53 54_DEFAULT_VALID_DTYPES = frozenset([ 55 dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, 56 dtypes.complex64, dtypes.complex128 57]) 58 59 60def _deduplicate_indexed_slices(values, indices): 61 """Sums `values` associated with any non-unique `indices`. 62 63 Args: 64 values: A `Tensor` with rank >= 1. 65 indices: A one-dimensional integer `Tensor`, indexing into the first 66 dimension of `values` (as in an IndexedSlices object). 67 68 Returns: 69 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 70 de-duplicated version of `indices` and `summed_values` contains the sum of 71 `values` slices associated with each unique index. 72 """ 73 unique_indices, new_index_positions = array_ops.unique(indices) 74 summed_values = math_ops.unsorted_segment_sum( 75 values, new_index_positions, 76 array_ops.shape(unique_indices)[0]) 77 return (summed_values, unique_indices) 78 79 80class NullContextmanager(object): 81 82 def __init__(self, *args, **kwargs): 83 pass 84 85 def __enter__(self): 86 pass 87 88 def __exit__(self, type_arg, value_arg, traceback_arg): 89 return False # False values do not suppress exceptions 90 91 92def name_scope_only_in_function_or_graph(name): 93 """Internal-only entry point for `name_scope*`. 94 95 Enters a compat.v1.name_scope only when in a function or graph, 96 not when running fully eagerly. 97 98 Args: 99 name: The name argument that is passed to the op function. 100 101 Returns: 102 `name_scope*` context manager. 103 """ 104 if not context.executing_eagerly(): 105 return ops.name_scope_v1(name) 106 else: 107 return NullContextmanager() 108 109 110@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta) 111class OptimizerV2(trackable.Trackable): 112 """Base class for Keras optimizers. 113 114 You should not use this class directly, but instead instantiate one of its 115 subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc. 116 117 ### Usage 118 119 ```python 120 # Create an optimizer with the desired parameters. 121 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 122 # `loss` is a callable that takes no argument and returns the value 123 # to minimize. 124 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 125 # In graph mode, returns op that minimizes the loss by updating the listed 126 # variables. 127 opt_op = opt.minimize(loss, var_list=[var1, var2]) 128 opt_op.run() 129 # In eager mode, simply call minimize to update the list of variables. 130 opt.minimize(loss, var_list=[var1, var2]) 131 ``` 132 133 ### Usage in custom training loops 134 135 In Keras models, sometimes variables are created when the model is first 136 called, instead of construction time. Examples include 1) sequential models 137 without input shape pre-defined, or 2) subclassed models. Pass var_list as 138 callable in these cases. 139 140 Example: 141 142 ```python 143 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 144 model = tf.keras.Sequential() 145 model.add(tf.keras.layers.Dense(num_hidden, activation='relu')) 146 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')) 147 loss_fn = lambda: tf.keras.losses.mse(model(input), output) 148 var_list_fn = lambda: model.trainable_weights 149 for input, output in data: 150 opt.minimize(loss_fn, var_list_fn) 151 ``` 152 153 ### Processing gradients before applying them 154 155 Calling `minimize()` takes care of both computing the gradients and 156 applying them to the variables. If you want to process the gradients 157 before applying them you can instead use the optimizer in three steps: 158 159 1. Compute the gradients with `tf.GradientTape`. 160 2. Process the gradients as you wish. 161 3. Apply the processed gradients with `apply_gradients()`. 162 163 Example: 164 165 ```python 166 # Create an optimizer. 167 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 168 169 # Compute the gradients for a list of variables. 170 with tf.GradientTape() as tape: 171 loss = <call_loss_function> 172 vars = <list_of_variables> 173 grads = tape.gradient(loss, vars) 174 175 # Process the gradients, for example cap them, etc. 176 # capped_grads = [MyCapper(g) for g in grads] 177 processed_grads = [process_gradient(g) for g in grads] 178 179 # Ask the optimizer to apply the processed gradients. 180 opt.apply_gradients(zip(processed_grads, var_list)) 181 ``` 182 183 ### Use with `tf.distribute.Strategy` 184 185 This optimizer class is `tf.distribute.Strategy` aware, which means it 186 automatically sums gradients across all replicas. To average gradients, 187 you divide your loss by the global batch size, which is done 188 automatically if you use `tf.keras` built-in training or evaluation loops. 189 See the `reduction` argument of your loss which should be set to 190 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 191 `tf.keras.losses.Reduction.SUM` for not. 192 193 To aggregate gradients yourself, call `apply_gradients` with 194 `experimental_aggregate_gradients` set to False. This is useful if you need to 195 process aggregated gradients. 196 197 If you are not using these and you want to average gradients, you should use 198 `tf.math.reduce_sum` to add up your per-example losses and then divide by the 199 global batch size. Note that when using `tf.distribute.Strategy`, the first 200 component of a tensor's shape is the *replica-local* batch size, which is off 201 by a factor equal to the number of replicas being used to compute a single 202 step. As a result, using `tf.math.reduce_mean` will give the wrong answer, 203 resulting in gradients that can be many times too big. 204 205 ### Variable Constraints 206 207 All Keras optimizers respect variable constraints. If constraint function is 208 passed to any variable, the constraint will be applied to the variable after 209 the gradient has been applied to the variable. 210 Important: If gradient is sparse tensor, variable constraint is not supported. 211 212 ### Thread Compatibility 213 214 The entire optimizer is currently thread compatible, not thread-safe. The user 215 needs to perform synchronization if necessary. 216 217 ### Slots 218 219 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 220 additional variables associated with the variables to train. These are called 221 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of 222 the slots that it uses. Once you have a slot name you can ask the optimizer 223 for the variable it created to hold the slot value. 224 225 This can be useful if you want to log debug a training algorithm, report stats 226 about the slots, etc. 227 228 ### Hyperparameters 229 230 These are arguments passed to the optimizer subclass constructor 231 (the `__init__` method), and then passed to `self._set_hyper()`. 232 They can be either regular Python values (like 1.0), tensors, or 233 callables. If they are callable, the callable will be called during 234 `apply_gradients()` to get the value for the hyper parameter. 235 236 Hyperparameters can be overwritten through user code: 237 238 Example: 239 240 ```python 241 # Create an optimizer with the desired parameters. 242 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 243 # `loss` is a callable that takes no argument and returns the value 244 # to minimize. 245 loss = lambda: 3 * var1 + 2 * var2 246 # In eager mode, simply call minimize to update the list of variables. 247 opt.minimize(loss, var_list=[var1, var2]) 248 # update learning rate 249 opt.learning_rate = 0.05 250 opt.minimize(loss, var_list=[var1, var2]) 251 ``` 252 253 ### Callable learning rate 254 255 Optimizer accepts a callable learning rate in two ways. The first way is 256 through built-in or customized 257 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be 258 called on each iteration with `schedule(iteration)`, a `tf.Variable` 259 owned by the optimizer. 260 261 Example: 262 263 >>> var = tf.Variable(np.random.random(size=(1,))) 264 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( 265 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) 266 >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate) 267 >>> loss = lambda: 3 * var 268 >>> opt.minimize(loss, var_list=[var]) 269 <tf.Variable... 270 271 The second way is through a callable function that 272 does not accept any arguments. 273 274 Example: 275 276 >>> var = tf.Variable(np.random.random(size=(1,))) 277 >>> def lr_callable(): 278 ... return .1 279 >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable) 280 >>> loss = lambda: 3 * var 281 >>> opt.minimize(loss, var_list=[var]) 282 <tf.Variable... 283 284 ### Creating a custom optimizer 285 286 If you intend to create your own optimization algorithm, simply inherit from 287 this class and override the following methods: 288 289 - `_resource_apply_dense` (update variable given gradient tensor is a dense 290 `tf.Tensor`) 291 - `_resource_apply_sparse` (update variable given gradient tensor is a 292 sparse `tf.IndexedSlices`. The most common way for this to happen 293 is if you are taking the gradient through a `tf.gather`.) 294 - `_create_slots` 295 (if your optimizer algorithm requires additional variables) 296 - `get_config` 297 (serialization of the optimizer, include all hyper parameters) 298 """ 299 300 # Subclasses should set this to True unless they override `apply_gradients` 301 # with a version that does not have the `experimental_aggregate_gradients` 302 # argument. Older versions of Keras did not have this argument so custom 303 # optimizers may have overridden `apply_gradients` without the 304 # `experimental_aggregate_gradients` argument. Keras only passes 305 # `experimental_aggregate_gradients` if this attribute is True. 306 # Note: This attribute will likely be removed in an upcoming release. 307 _HAS_AGGREGATE_GRAD = False 308 309 def __init__(self, 310 name, 311 gradient_aggregator=None, 312 gradient_transformers=None, 313 **kwargs): 314 """Create a new Optimizer. 315 316 This must be called by the constructors of subclasses. 317 Note that Optimizer instances should not bind to a single graph, 318 and so shouldn't keep Tensors as member variables. Generally 319 you should be able to use the _set_hyper()/state.get_hyper() 320 facility instead. 321 322 This class is stateful and thread-compatible. 323 324 Example of custom gradient transformations: 325 326 ```python 327 def my_gradient_transformer(grads_and_vars): 328 # Simple example, double the gradients. 329 return [(2. * g, v) for g, v in grads_and_vars] 330 331 optimizer = tf.keras.optimizers.SGD( 332 1e-3, gradient_transformers=[my_gradient_transformer]) 333 ``` 334 335 Args: 336 name: String. The name to use for momentum accumulator weights created 337 by the optimizer. 338 gradient_aggregator: The function to use to aggregate gradients across 339 devices (when using `tf.distribute.Strategy`). If `None`, defaults to 340 summing the gradients across devices. The function should accept and 341 return a list of `(gradient, variable)` tuples. 342 gradient_transformers: Optional. List of functions to use to transform 343 gradients before applying updates to Variables. The functions are 344 applied after `gradient_aggregator`. The functions should accept and 345 return a list of `(gradient, variable)` tuples. 346 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 347 `clipnorm`, `global_clipnorm`. 348 If `clipvalue` (float) is set, the gradient of each weight 349 is clipped to be no higher than this value. 350 If `clipnorm` (float) is set, the gradient of each weight 351 is individually clipped so that its norm is no higher than this value. 352 If `global_clipnorm` (float) is set the gradient of all weights is 353 clipped so that their global norm is no higher than this value. 354 355 Raises: 356 ValueError: in case of any invalid argument. 357 """ 358 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"} 359 for k in kwargs: 360 if k not in allowed_kwargs: 361 raise TypeError("Unexpected keyword argument " 362 "passed to optimizer: " + str(k)) 363 # checks that all keyword arguments are non-negative. 364 if kwargs[k] is not None and kwargs[k] < 0: 365 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) 366 if k == "lr": 367 warnings.warn( 368 "The `lr` argument is deprecated, use `learning_rate` instead.") 369 370 self._use_locking = True 371 self._init_set_name(name) 372 self._hyper = {} 373 # dict: {variable name : {slot name : variable}} 374 self._slots = {} 375 self._slot_names = [] 376 self._weights = [] 377 self._iterations = None 378 379 # For implementing Trackable. Stores information about how to restore 380 # slot variables which have not yet been created 381 # (trackable._CheckpointPosition objects). 382 # {slot_name : 383 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 384 # ... } 385 self._deferred_slot_restorations = {} 386 387 decay = kwargs.pop("decay", 0.0) 388 if decay < 0.: 389 raise ValueError("decay cannot be less than 0: {}".format(decay)) 390 self._initial_decay = decay 391 392 self._hypers_created = False 393 # Store the distribution strategy object if the optimizer is created inside 394 # strategy scope, so it could be used to create variables later. 395 if distribute_ctx.has_strategy(): 396 self._distribution_strategy = distribute_ctx.get_strategy() 397 else: 398 self._distribution_strategy = None 399 400 # Configure gradient transformations. 401 if gradient_aggregator is None: 402 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients 403 self.gradient_aggregator = gradient_aggregator 404 if gradient_transformers is None: 405 gradient_transformers = [] 406 self.gradient_transformers = gradient_transformers 407 self.clipnorm = kwargs.pop("clipnorm", None) 408 self.global_clipnorm = kwargs.pop("global_clipnorm", None) 409 if self.clipnorm is not None and self.global_clipnorm is not None: 410 raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, " 411 "passed `clipnorm` {}, `global_clipnorm` {}".format( 412 self.clipnorm, self.global_clipnorm)) 413 self.clipvalue = kwargs.pop("clipvalue", None) 414 415 @property 416 def clipnorm(self): 417 """`float` or `None`. If set, clips gradients to a maximum norm.""" 418 return self._clipnorm 419 420 @property 421 def global_clipnorm(self): 422 """`float` or `None`. If set, clips gradients to a maximum norm.""" 423 return self._global_clipnorm 424 425 @clipnorm.setter 426 def clipnorm(self, val): 427 if val is not None and self.gradient_transformers: 428 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 429 "is set. Instead, use the `gradient_transformers` to " 430 "specify clipping and other transformations.") 431 self._clipnorm = val 432 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn( 433 self._clipnorm) 434 435 @global_clipnorm.setter 436 def global_clipnorm(self, val): 437 if val is not None and self.gradient_transformers: 438 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 439 "is set. Instead, use the `gradient_transformers` to " 440 "specify clipping and other transformations.") 441 self._global_clipnorm = val 442 self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn( 443 self._global_clipnorm) 444 445 @property 446 def clipvalue(self): 447 """`float` or `None`. If set, clips gradients to a maximum value.""" 448 return self._clipvalue 449 450 @clipvalue.setter 451 def clipvalue(self, val): 452 if val is not None and self.gradient_transformers: 453 raise ValueError("`clipvalue` cannot be set when `gradient_transformers` " 454 "is set. Instead, use the `gradient_transformers` to " 455 "specify clipping and other transformations.") 456 self._clipvalue = val 457 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn( 458 self._clipvalue) 459 460 def _transform_loss(self, loss): 461 """Called in `.minimize` to transform loss before computing gradients.""" 462 return loss 463 464 def _get_gradients(self, tape, loss, var_list, grad_loss=None): 465 """Called in `minimize` to compute gradients from loss.""" 466 grads = tape.gradient(loss, var_list, grad_loss) 467 return list(zip(grads, var_list)) 468 469 def _transform_unaggregated_gradients(self, grads_and_vars): 470 """Called in `apply_gradients` before gradient aggregation.""" 471 return grads_and_vars 472 473 def _aggregate_gradients(self, grads_and_vars): 474 """Called in `apply_gradients` to aggregate gradients across devices. 475 476 Note that user subclasses may override this, so the interface should not be 477 changed. 478 479 Args: 480 grads_and_vars: List of (gradient, variable) pairs. 481 482 Returns: 483 A list of (aggregrated_gradient, variable) pairs. By default, this calls 484 `self.gradient_aggregator`. 485 """ 486 return self.gradient_aggregator(grads_and_vars) 487 488 def _transform_gradients(self, grads_and_vars): 489 """Called in `apply_gradients` after aggregation.""" 490 if self._clipvalue is not None: 491 grads_and_vars = self._clipvalue_fn(grads_and_vars) 492 if self._clipnorm is not None: 493 grads_and_vars = self._clipnorm_fn(grads_and_vars) 494 if self._global_clipnorm is not None: 495 grads_and_vars = self._global_clipnorm_fn(grads_and_vars) 496 497 for fn in self.gradient_transformers: 498 grads_and_vars = fn(grads_and_vars) 499 return grads_and_vars 500 501 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None): 502 """Minimize `loss` by updating `var_list`. 503 504 This method simply computes gradient using `tf.GradientTape` and calls 505 `apply_gradients()`. If you want to process the gradient before applying 506 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 507 of using this function. 508 509 Args: 510 loss: `Tensor` or callable. If a callable, `loss` should take no arguments 511 and return the value to minimize. If a `Tensor`, the `tape` argument 512 must be passed. 513 var_list: list or tuple of `Variable` objects to update to minimize 514 `loss`, or a callable returning the list or tuple of `Variable` objects. 515 Use callable when the variable list would otherwise be incomplete before 516 `minimize` since the variables are created at the first time `loss` is 517 called. 518 grad_loss: (Optional). A `Tensor` holding the gradient computed for 519 `loss`. 520 name: (Optional) str. Name for the returned operation. 521 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 522 the tape that computed the `loss` must be provided. 523 524 Returns: 525 An `Operation` that updates the variables in `var_list`. The `iterations` 526 will be automatically increased by 1. 527 528 Raises: 529 ValueError: If some of the variables are not `Variable` objects. 530 531 """ 532 grads_and_vars = self._compute_gradients( 533 loss, var_list=var_list, grad_loss=grad_loss, tape=tape) 534 return self.apply_gradients(grads_and_vars, name=name) 535 536 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 537 """Compute gradients of `loss` for the variables in `var_list`. 538 539 This is the first part of `minimize()`. It returns a list 540 of (gradient, variable) pairs where "gradient" is the gradient 541 for "variable". Note that "gradient" can be a `Tensor`, an 542 `IndexedSlices`, or `None` if there is no gradient for the 543 given variable. 544 545 Args: 546 loss: `Tensor` or callable. If a callable, `loss` should take no 547 arguments and return the value to minimize. If a `Tensor`, the `tape` 548 argument must be passed. 549 var_list: list or tuple of `Variable` objects to update to minimize 550 `loss`, or a callable returning the list or tuple of `Variable` objects. 551 Use callable when the variable list would otherwise be incomplete before 552 `minimize` and the variables are created at the first time when `loss` 553 is called. 554 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 555 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 556 the tape that computed the `loss` must be provided. 557 558 Returns: 559 A list of (gradient, variable) pairs. Variable is always present, but 560 gradient can be `None`. 561 562 Raises: 563 TypeError: If `var_list` contains anything else than `Variable` objects. 564 ValueError: If some arguments are invalid, or var_list is None. 565 """ 566 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 567 if not callable(loss) and tape is None: 568 raise ValueError("`tape` is required when a `Tensor` loss is passed.") 569 tape = tape if tape is not None else backprop.GradientTape() 570 571 if callable(loss): 572 with tape: 573 if not callable(var_list): 574 tape.watch(var_list) 575 loss = loss() 576 if callable(var_list): 577 var_list = var_list() 578 579 with tape: 580 loss = self._transform_loss(loss) 581 582 var_list = nest.flatten(var_list) 583 with ops.name_scope_v2(self._name + "/gradients"): 584 grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss) 585 586 self._assert_valid_dtypes([ 587 v for g, v in grads_and_vars 588 if g is not None and v.dtype != dtypes.resource 589 ]) 590 591 return grads_and_vars 592 593 def apply_gradients(self, 594 grads_and_vars, 595 name=None, 596 experimental_aggregate_gradients=True): 597 """Apply gradients to variables. 598 599 This is the second part of `minimize()`. It returns an `Operation` that 600 applies gradients. 601 602 The method sums gradients from all replicas in the presence of 603 `tf.distribute.Strategy` by default. You can aggregate gradients yourself by 604 passing `experimental_aggregate_gradients=False`. 605 606 Example: 607 608 ```python 609 grads = tape.gradient(loss, vars) 610 grads = tf.distribute.get_replica_context().all_reduce('sum', grads) 611 # Processing aggregated gradients. 612 optimizer.apply_gradients(zip(grads, vars), 613 experimental_aggregate_gradients=False) 614 615 ``` 616 617 Args: 618 grads_and_vars: List of (gradient, variable) pairs. 619 name: Optional name for the returned operation. Default to the name passed 620 to the `Optimizer` constructor. 621 experimental_aggregate_gradients: Whether to sum gradients from different 622 replicas in the presense of `tf.distribute.Strategy`. If False, it's 623 user responsibility to aggregate the gradients. Default to True. 624 625 Returns: 626 An `Operation` that applies the specified gradients. The `iterations` 627 will be automatically increased by 1. 628 629 Raises: 630 TypeError: If `grads_and_vars` is malformed. 631 ValueError: If none of the variables have gradients. 632 RuntimeError: If called in a cross-replica context. 633 """ 634 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 635 var_list = [v for (_, v) in grads_and_vars] 636 637 with ops.name_scope_v2(self._name): 638 # Create iteration if necessary. 639 with ops.init_scope(): 640 self._create_all_weights(var_list) 641 642 if not grads_and_vars: 643 # Distribution strategy does not support reducing an empty list of 644 # gradients 645 return control_flow_ops.no_op() 646 647 if distribute_ctx.in_cross_replica_context(): 648 raise RuntimeError( 649 "`apply_gradients() cannot be called in cross-replica context. " 650 "Use `tf.distribute.Strategy.run` to enter replica " 651 "context.") 652 653 strategy = distribute_ctx.get_strategy() 654 if (not experimental_aggregate_gradients and strategy and 655 isinstance(strategy, 656 (parameter_server_strategy.ParameterServerStrategyV1, 657 parameter_server_strategy_v2.ParameterServerStrategyV2, 658 central_storage_strategy.CentralStorageStrategy, 659 central_storage_strategy.CentralStorageStrategyV1))): 660 raise NotImplementedError( 661 "`experimental_aggregate_gradients=False is not supported for " 662 "ParameterServerStrategy and CentralStorageStrategy") 663 664 apply_state = self._prepare(var_list) 665 if experimental_aggregate_gradients: 666 grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars) 667 grads_and_vars = self._aggregate_gradients(grads_and_vars) 668 grads_and_vars = self._transform_gradients(grads_and_vars) 669 670 if optimizer_utils.strategy_supports_no_merge_call(): 671 return self._distributed_apply(strategy, grads_and_vars, name, 672 apply_state) 673 else: 674 return distribute_ctx.get_replica_context().merge_call( 675 functools.partial(self._distributed_apply, apply_state=apply_state), 676 args=(grads_and_vars,), 677 kwargs={ 678 "name": name, 679 }) 680 681 def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): 682 """`apply_gradients` using a `DistributionStrategy`.""" 683 684 def apply_grad_to_update_var(var, grad): 685 """Apply gradient to variable.""" 686 if isinstance(var, ops.Tensor): 687 raise NotImplementedError("Trying to update a Tensor ", var) 688 689 apply_kwargs = {} 690 if isinstance(grad, ops.IndexedSlices): 691 if var.constraint is not None: 692 raise RuntimeError( 693 "Cannot use a constraint function on a sparse variable.") 694 if "apply_state" in self._sparse_apply_args: 695 apply_kwargs["apply_state"] = apply_state 696 return self._resource_apply_sparse_duplicate_indices( 697 grad.values, var, grad.indices, **apply_kwargs) 698 699 if "apply_state" in self._dense_apply_args: 700 apply_kwargs["apply_state"] = apply_state 701 update_op = self._resource_apply_dense(grad, var, **apply_kwargs) 702 if var.constraint is not None: 703 with ops.control_dependencies([update_op]): 704 return var.assign(var.constraint(var)) 705 else: 706 return update_op 707 708 eagerly_outside_functions = ops.executing_eagerly_outside_functions() 709 update_ops = [] 710 with name_scope_only_in_function_or_graph(name or self._name): 711 for grad, var in grads_and_vars: 712 # Colocate the update with variables to avoid unnecessary communication 713 # delays. See b/136304694. 714 with distribution.extended.colocate_vars_with(var): 715 with name_scope_only_in_function_or_graph( 716 "update" if eagerly_outside_functions else "update_" + 717 var.op.name): 718 update_op = distribution.extended.update( 719 var, apply_grad_to_update_var, args=(grad,), group=False) 720 if distribute_ctx.in_cross_replica_context(): 721 # In cross-replica context, extended.update returns a list of 722 # update ops from all replicas (group=False). 723 update_ops.extend(update_op) 724 else: 725 # In replica context, extended.update return the single update op 726 # of current replica. 727 update_ops.append(update_op) 728 729 any_symbolic = any(isinstance(i, ops.Operation) or 730 tf_utils.is_symbolic_tensor(i) for i in update_ops) 731 if not context.executing_eagerly() or any_symbolic: 732 # If the current context is graph mode or any of the update ops are 733 # symbolic then the step update should be carried out under a graph 734 # context. (eager updates execute immediately) 735 with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access 736 with ops.control_dependencies([control_flow_ops.group(update_ops)]): 737 return self._iterations.assign_add(1, read_value=False) 738 739 return self._iterations.assign_add(1) 740 741 def get_gradients(self, loss, params): 742 """Returns gradients of `loss` with respect to `params`. 743 744 Should be used only in legacy v1 graph mode. 745 746 Args: 747 loss: Loss tensor. 748 params: List of variables. 749 750 Returns: 751 List of gradient tensors. 752 753 Raises: 754 ValueError: In case any gradient cannot be computed (e.g. if gradient 755 function not implemented). 756 """ 757 params = nest.flatten(params) 758 with backend.get_graph().as_default(), backend.name_scope(self._name + 759 "/gradients"): 760 grads = gradients.gradients(loss, params) 761 for grad, param in zip(grads, params): 762 if grad is None: 763 raise ValueError("Variable {} has `None` for gradient. " 764 "Please make sure that all of your ops have a " 765 "gradient defined (i.e. are differentiable). " 766 "Common ops without gradient: " 767 "K.argmax, K.round, K.eval.".format(param)) 768 return grads 769 770 def get_updates(self, loss, params): 771 grads = self.get_gradients(loss, params) 772 grads_and_vars = list(zip(grads, params)) 773 self._assert_valid_dtypes([ 774 v for g, v in grads_and_vars 775 if g is not None and v.dtype != dtypes.resource 776 ]) 777 return [self.apply_gradients(grads_and_vars)] 778 779 def _set_hyper(self, name, value): 780 """set hyper `name` to value. value can be callable, tensor, numeric.""" 781 if isinstance(value, trackable.Trackable): 782 self._track_trackable(value, name, overwrite=True) 783 if name not in self._hyper: 784 self._hyper[name] = value 785 else: 786 prev_value = self._hyper[name] 787 if (callable(prev_value) 788 or isinstance(prev_value, 789 (ops.Tensor, int, float, 790 learning_rate_schedule.LearningRateSchedule)) 791 or isinstance(value, learning_rate_schedule.LearningRateSchedule)): 792 self._hyper[name] = value 793 else: 794 backend.set_value(self._hyper[name], value) 795 796 def _get_hyper(self, name, dtype=None): 797 if not self._hypers_created: 798 self._create_hypers() 799 value = self._hyper[name] 800 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 801 return value 802 if callable(value): 803 value = value() 804 if dtype: 805 return math_ops.cast(value, dtype) 806 else: 807 return value 808 809 def _create_slots(self, var_list): 810 pass 811 812 def _create_all_weights(self, var_list): 813 """Creates all weights, including iterations, hyperparameters and slot vars. 814 815 This will add newly created variables to `optimizer.weights`. 816 817 New variables are only created when this method is called the first time, or 818 when called with different variables in the var_list. 819 820 Args: 821 var_list: list or tuple of `Variable` objects that will be minimized 822 using this optimizer. 823 """ 824 825 _ = self.iterations 826 self._create_hypers() 827 self._create_slots(var_list) 828 829 def __getattribute__(self, name): 830 """Overridden to support hyperparameter access.""" 831 try: 832 return super(OptimizerV2, self).__getattribute__(name) 833 except AttributeError as e: 834 # Needed to avoid infinite recursion with __setattr__. 835 if name == "_hyper": 836 raise e 837 # Backwards compatibility with Keras optimizers. 838 if name == "lr": 839 name = "learning_rate" 840 if name in self._hyper: 841 return self._get_hyper(name) 842 raise e 843 844 def __dir__(self): 845 result = set(super(OptimizerV2, self).__dir__()) 846 if "_hyper" in result: 847 result |= self._hyper.keys() 848 if "learning_rate" in self._hyper.keys(): 849 result.add("lr") 850 return list(result) 851 852 def __setattr__(self, name, value): 853 """Override setattr to support dynamic hyperparameter setting.""" 854 # Backwards compatibility with Keras optimizers. 855 if name == "lr": 856 name = "learning_rate" 857 if hasattr(self, "_hyper") and name in self._hyper: 858 self._set_hyper(name, value) 859 else: 860 super(OptimizerV2, self).__setattr__(name, value) 861 862 def get_slot_names(self): 863 """A list of names for this optimizer's slots.""" 864 return self._slot_names 865 866 def add_slot(self, var, slot_name, initializer="zeros", shape=None): 867 """Add a new slot variable for `var`. 868 869 A slot variable is an additional variable associated with `var` to train. 870 It is allocated and managed by optimizers, e.g. `Adam`. 871 872 Args: 873 var: a `Variable` object. 874 slot_name: name of the slot variable. 875 initializer: initializer of the slot variable 876 shape: (Optional) shape of the slot variable. If not set, it will default 877 to the shape of `var`. 878 879 Returns: 880 A slot variable. 881 """ 882 if slot_name not in self._slot_names: 883 self._slot_names.append(slot_name) 884 var_key = _var_key(var) 885 slot_dict = self._slots.setdefault(var_key, {}) 886 weight = slot_dict.get(slot_name, None) 887 if weight is None: 888 if isinstance(initializer, str) or callable(initializer): 889 initializer = initializers.get(initializer) 890 if isinstance( 891 initializer, 892 trackable.CheckpointInitialValueCallable) or (shape is not None): 893 slot_shape = shape 894 else: 895 slot_shape = var.shape 896 initial_value = functools.partial( 897 initializer, shape=slot_shape, dtype=var.dtype) 898 else: 899 initial_value = initializer 900 901 with self._distribution_strategy_scope(): 902 strategy = distribute_ctx.get_strategy() 903 if not strategy.extended.variable_created_in_scope(var): 904 raise ValueError( 905 "Trying to create optimizer slot variable under the scope for " 906 "tf.distribute.Strategy ({}), which is different from the scope " 907 "used for the original variable ({}). Make sure the slot " 908 "variables are created under the same strategy scope. This may " 909 "happen if you're restoring from a checkpoint outside the scope" 910 .format(strategy, var)) 911 912 with strategy.extended.colocate_vars_with(var): 913 weight = tf_variables.Variable( 914 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access 915 dtype=var.dtype, 916 trainable=False, 917 initial_value=initial_value) 918 backend.track_variable(weight) 919 slot_dict[slot_name] = weight 920 self._restore_slot_variable( 921 slot_name=slot_name, variable=var, 922 slot_variable=weight) 923 self._weights.append(weight) 924 return weight 925 926 def get_slot(self, var, slot_name): 927 var_key = _var_key(var) 928 slot_dict = self._slots[var_key] 929 return slot_dict[slot_name] 930 931 def _prepare(self, var_list): 932 keys = set() 933 for var in var_list: 934 if isinstance(var, ds_values.DistributedValues): 935 var_devices = var._devices # pylint: disable=protected-access 936 else: 937 var_devices = [var.device] 938 var_dtype = var.dtype.base_dtype 939 for var_device in var_devices: 940 keys.add((var_device, var_dtype)) 941 942 apply_state = {} 943 for var_device, var_dtype in keys: 944 apply_state[(var_device, var_dtype)] = {} 945 with ops.device(var_device): 946 self._prepare_local(var_device, var_dtype, apply_state) 947 948 return apply_state 949 950 def _prepare_local(self, var_device, var_dtype, apply_state): 951 if "learning_rate" in self._hyper: 952 lr_t = array_ops.identity(self._decayed_lr(var_dtype)) 953 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t 954 955 def _fallback_apply_state(self, var_device, var_dtype): 956 """Compatibility for subclasses that don't pass apply_state through.""" 957 apply_state = {(var_device, var_dtype): {}} 958 self._prepare_local(var_device, var_dtype, apply_state) 959 return apply_state[(var_device, var_dtype)] 960 961 def _create_hypers(self): 962 if self._hypers_created: 963 return 964 with self._distribution_strategy_scope(): 965 # Iterate hyper values deterministically. 966 for name, value in sorted(self._hyper.items()): 967 if isinstance(value, 968 (ops.Tensor, tf_variables.Variable)) or callable(value): 969 # The check for `callable` covers the usage when `value` is a 970 # `LearningRateSchedule`, in which case it does not need to create a 971 # variable. 972 continue 973 else: 974 self._hyper[name] = self.add_weight( 975 name, 976 shape=[], 977 trainable=False, 978 initializer=value, 979 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 980 self._hypers_created = True 981 982 @property 983 def iterations(self): 984 """Variable. The number of training steps this Optimizer has run.""" 985 if self._iterations is None: 986 with self._distribution_strategy_scope(): 987 self._iterations = self.add_weight( 988 "iter", 989 shape=[], 990 dtype=dtypes.int64, 991 trainable=False, 992 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 993 self._weights.append(self._iterations) 994 return self._iterations 995 996 @iterations.setter 997 def iterations(self, variable): 998 if self._iterations is not None: 999 raise RuntimeError("Cannot set `iterations` to a new Variable after " 1000 "the Optimizer weights have been created") 1001 self._iterations = variable 1002 self._weights.append(self._iterations) 1003 1004 def _decayed_lr(self, var_dtype): 1005 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 1006 lr_t = self._get_hyper("learning_rate", var_dtype) 1007 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 1008 local_step = math_ops.cast(self.iterations, var_dtype) 1009 lr_t = math_ops.cast(lr_t(local_step), var_dtype) 1010 if self._initial_decay > 0.: 1011 local_step = math_ops.cast(self.iterations, var_dtype) 1012 decay_t = math_ops.cast(self._initial_decay, var_dtype) 1013 lr_t = lr_t / (1. + decay_t * local_step) 1014 return lr_t 1015 1016 @abc.abstractmethod 1017 def get_config(self): 1018 """Returns the config of the optimizer. 1019 1020 An optimizer config is a Python dictionary (serializable) 1021 containing the configuration of an optimizer. 1022 The same optimizer can be reinstantiated later 1023 (without any saved state) from this configuration. 1024 1025 Returns: 1026 Python dictionary. 1027 """ 1028 config = {"name": self._name} 1029 if self.clipnorm is not None: 1030 config["clipnorm"] = self.clipnorm 1031 if self.clipvalue is not None: 1032 config["clipvalue"] = self.clipvalue 1033 if self.global_clipnorm is not None: 1034 config["global_clipnorm"] = self.global_clipnorm 1035 return config 1036 1037 @classmethod 1038 def from_config(cls, config, custom_objects=None): 1039 """Creates an optimizer from its config. 1040 1041 This method is the reverse of `get_config`, 1042 capable of instantiating the same optimizer from the config 1043 dictionary. 1044 1045 Args: 1046 config: A Python dictionary, typically the output of get_config. 1047 custom_objects: A Python dictionary mapping names to additional Python 1048 objects used to create this optimizer, such as a function used for a 1049 hyperparameter. 1050 1051 Returns: 1052 An optimizer instance. 1053 """ 1054 if "lr" in config: 1055 config["learning_rate"] = config.pop("lr") 1056 if "learning_rate" in config: 1057 if isinstance(config["learning_rate"], dict): 1058 config["learning_rate"] = learning_rate_schedule.deserialize( 1059 config["learning_rate"], custom_objects=custom_objects) 1060 return cls(**config) 1061 1062 def _serialize_hyperparameter(self, hyperparameter_name): 1063 """Serialize a hyperparameter that can be a float, callable, or Tensor.""" 1064 value = self._hyper[hyperparameter_name] 1065 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 1066 return learning_rate_schedule.serialize(value) 1067 if callable(value): 1068 return value() 1069 if tensor_util.is_tf_type(value): 1070 return backend.get_value(value) 1071 return value 1072 1073 def variables(self): 1074 """Returns variables of this Optimizer based on the order created.""" 1075 return self._weights 1076 1077 @property 1078 def weights(self): 1079 """Returns variables of this Optimizer based on the order created.""" 1080 return self._weights 1081 1082 def get_weights(self): 1083 """Returns the current weights of the optimizer. 1084 1085 The weights of an optimizer are its state (ie, variables). 1086 This function returns the weight values associated with this 1087 optimizer as a list of Numpy arrays. The first value is always the 1088 iterations count of the optimizer, followed by the optimizer's state 1089 variables in the order they were created. The returned list can in turn 1090 be used to load state into similarly parameterized optimizers. 1091 1092 For example, the RMSprop optimizer for this simple model returns a list of 1093 three values-- the iteration count, followed by the root-mean-square value 1094 of the kernel and bias of the single Dense layer: 1095 1096 >>> opt = tf.keras.optimizers.RMSprop() 1097 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1098 >>> m.compile(opt, loss='mse') 1099 >>> data = np.arange(100).reshape(5, 20) 1100 >>> labels = np.zeros(5) 1101 >>> print('Training'); results = m.fit(data, labels) 1102 Training ... 1103 >>> len(opt.get_weights()) 1104 3 1105 1106 Returns: 1107 Weights values as a list of numpy arrays. 1108 """ 1109 params = self.weights 1110 return backend.batch_get_value(params) 1111 1112 # TODO(tanzheny): Maybe share this logic with base_layer. 1113 def set_weights(self, weights): 1114 """Set the weights of the optimizer. 1115 1116 The weights of an optimizer are its state (ie, variables). 1117 This function takes the weight values associated with this 1118 optimizer as a list of Numpy arrays. The first value is always the 1119 iterations count of the optimizer, followed by the optimizer's state 1120 variables in the order they are created. The passed values are used to set 1121 the new state of the optimizer. 1122 1123 For example, the RMSprop optimizer for this simple model takes a list of 1124 three values-- the iteration count, followed by the root-mean-square value 1125 of the kernel and bias of the single Dense layer: 1126 1127 >>> opt = tf.keras.optimizers.RMSprop() 1128 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1129 >>> m.compile(opt, loss='mse') 1130 >>> data = np.arange(100).reshape(5, 20) 1131 >>> labels = np.zeros(5) 1132 >>> print('Training'); results = m.fit(data, labels) 1133 Training ... 1134 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])] 1135 >>> opt.set_weights(new_weights) 1136 >>> opt.iterations 1137 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10> 1138 1139 Args: 1140 weights: weight values as a list of numpy arrays. 1141 """ 1142 params = self.weights 1143 if len(params) != len(weights): 1144 raise ValueError( 1145 "You called `set_weights(weights)` on optimizer " + self._name + 1146 " with a weight list of length " + str(len(weights)) + 1147 ", but the optimizer was expecting " + str(len(params)) + 1148 " weights. Provided weights: " + str(weights)[:50] + "...") 1149 if not params: 1150 return 1151 weight_value_tuples = [] 1152 param_values = backend.batch_get_value(params) 1153 for pv, p, w in zip(param_values, params, weights): 1154 if pv.shape != w.shape: 1155 raise ValueError("Optimizer weight shape " + str(pv.shape) + 1156 " not compatible with " 1157 "provided weight shape " + str(w.shape)) 1158 weight_value_tuples.append((p, w)) 1159 backend.batch_set_value(weight_value_tuples) 1160 1161 def add_weight(self, 1162 name, 1163 shape, 1164 dtype=None, 1165 initializer="zeros", 1166 trainable=None, 1167 synchronization=tf_variables.VariableSynchronization.AUTO, 1168 aggregation=tf_variables.VariableAggregation.NONE): 1169 1170 if dtype is None: 1171 dtype = dtypes.float32 1172 if isinstance(initializer, str) or callable(initializer): 1173 initializer = initializers.get(initializer) 1174 1175 if synchronization == tf_variables.VariableSynchronization.ON_READ: 1176 if trainable: 1177 raise ValueError( 1178 "Synchronization value can be set to " 1179 "VariableSynchronization.ON_READ only for non-trainable variables. " 1180 "You have specified trainable=True and " 1181 "synchronization=VariableSynchronization.ON_READ.") 1182 else: 1183 # Set trainable to be false when variable is to be synced on read. 1184 trainable = False 1185 elif trainable is None: 1186 trainable = True 1187 1188 variable = self._add_variable_with_custom_getter( 1189 name=name, 1190 shape=shape, 1191 getter=base_layer_utils.make_variable, 1192 overwrite=True, 1193 initializer=initializer, 1194 dtype=dtype, 1195 trainable=trainable, 1196 use_resource=True, 1197 synchronization=synchronization, 1198 aggregation=aggregation) 1199 backend.track_variable(variable) 1200 1201 return variable 1202 1203 def _init_set_name(self, name, zero_based=True): 1204 if not name: 1205 self._name = backend.unique_object_name( 1206 generic_utils.to_snake_case(self.__class__.__name__), 1207 zero_based=zero_based) 1208 else: 1209 self._name = name 1210 1211 def _assert_valid_dtypes(self, tensors): 1212 """Asserts tensors are all valid types (see `_valid_dtypes`). 1213 1214 Args: 1215 tensors: Tensors to check. 1216 1217 Raises: 1218 ValueError: If any tensor is not a valid type. 1219 """ 1220 valid_dtypes = self._valid_dtypes() 1221 for t in tensors: 1222 dtype = t.dtype.base_dtype 1223 if dtype not in valid_dtypes: 1224 raise ValueError("Invalid type %r for %s, expected: %s." % 1225 (dtype, t.name, [v for v in valid_dtypes])) 1226 1227 def _valid_dtypes(self): 1228 """Valid types for loss, variables and gradients. 1229 1230 Subclasses should override to allow other float types. 1231 1232 Returns: 1233 Valid types for loss, variables and gradients. 1234 """ 1235 return _DEFAULT_VALID_DTYPES 1236 1237 def _call_if_callable(self, param): 1238 """Call the function if param is callable.""" 1239 return param() if callable(param) else param 1240 1241 def _resource_apply_dense(self, grad, handle, apply_state): 1242 """Add ops to apply dense gradients to the variable `handle`. 1243 1244 Args: 1245 grad: a `Tensor` representing the gradient. 1246 handle: a `Tensor` of dtype `resource` which points to the variable to be 1247 updated. 1248 apply_state: A dict which is used across multiple apply calls. 1249 1250 Returns: 1251 An `Operation` which updates the value of the variable. 1252 """ 1253 raise NotImplementedError("Must be implemented in subclasses.") 1254 1255 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices, 1256 **kwargs): 1257 """Add ops to apply sparse gradients to `handle`, with repeated indices. 1258 1259 Optimizers which override this method must deal with repeated indices. See 1260 the docstring of `_apply_sparse_duplicate_indices` for details. By default 1261 the correct behavior, to sum non-unique indices and their associated 1262 gradients, is enforced by first pre-processing `grad` and `indices` and 1263 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 1264 with duplicate indices may instead override this method to avoid the 1265 overhead of summing. 1266 1267 Args: 1268 grad: a `Tensor` representing the gradient for the affected indices. 1269 handle: a `Tensor` of dtype `resource` which points to the variable to be 1270 updated. 1271 indices: a `Tensor` of integral type representing the indices for which 1272 the gradient is nonzero. Indices may be repeated. 1273 **kwargs: May optionally contain `apply_state` 1274 1275 Returns: 1276 An `Operation` which updates the value of the variable. 1277 """ 1278 summed_grad, unique_indices = _deduplicate_indexed_slices( 1279 values=grad, indices=indices) 1280 return self._resource_apply_sparse(summed_grad, handle, unique_indices, 1281 **kwargs) 1282 1283 def _resource_apply_sparse(self, grad, handle, indices, apply_state): 1284 """Add ops to apply sparse gradients to the variable `handle`. 1285 1286 Similar to `_apply_sparse`, the `indices` argument to this method has been 1287 de-duplicated. Optimizers which deal correctly with non-unique indices may 1288 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 1289 overhead. 1290 1291 Args: 1292 grad: a `Tensor` representing the gradient for the affected indices. 1293 handle: a `Tensor` of dtype `resource` which points to the variable to be 1294 updated. 1295 indices: a `Tensor` of integral type representing the indices for which 1296 the gradient is nonzero. Indices are unique. 1297 apply_state: A dict which is used across multiple apply calls. 1298 1299 Returns: 1300 An `Operation` which updates the value of the variable. 1301 """ 1302 raise NotImplementedError("Must be implemented in subclasses.") 1303 1304 def _resource_scatter_add(self, x, i, v): 1305 with ops.control_dependencies([ 1306 gen_resource_variable_ops.ResourceScatterAdd( 1307 resource=x.handle, indices=i, updates=v) 1308 ]): 1309 return x.value() 1310 1311 def _resource_scatter_update(self, x, i, v): 1312 with ops.control_dependencies( 1313 [gen_resource_variable_ops.ResourceScatterUpdate( 1314 resource=x.handle, indices=i, updates=v)]): 1315 return x.value() 1316 1317 @property 1318 @layer_utils.cached_per_instance 1319 def _dense_apply_args(self): 1320 return tf_inspect.getfullargspec(self._resource_apply_dense).args 1321 1322 @property 1323 @layer_utils.cached_per_instance 1324 def _sparse_apply_args(self): 1325 return tf_inspect.getfullargspec(self._resource_apply_sparse).args 1326 1327 # --------------- 1328 # For implementing the trackable interface 1329 # --------------- 1330 1331 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1332 """Restore a newly created slot variable's value.""" 1333 variable_key = _var_key(variable) 1334 deferred_restorations = self._deferred_slot_restorations.get( 1335 slot_name, {}).pop(variable_key, []) 1336 # Iterate over restores, highest restore UID first to minimize the number 1337 # of assignments. 1338 deferred_restorations.sort(key=lambda position: position.restore_uid, 1339 reverse=True) 1340 for checkpoint_position in deferred_restorations: 1341 checkpoint_position.restore(slot_variable) 1342 1343 def _create_or_restore_slot_variable( 1344 self, slot_variable_position, slot_name, variable): 1345 """Restore a slot variable's value, possibly creating it. 1346 1347 Called when a variable which has an associated slot variable is created or 1348 restored. When executing eagerly, we create the slot variable with a 1349 restoring initializer. 1350 1351 No new variables are created when graph building. Instead, 1352 _restore_slot_variable catches these after normal creation and adds restore 1353 ops to the graph. This method is nonetheless important when graph building 1354 for the case when a slot variable has already been created but `variable` 1355 has just been added to a dependency graph (causing us to realize that the 1356 slot variable needs to be restored). 1357 1358 Args: 1359 slot_variable_position: A `trackable._CheckpointPosition` object 1360 indicating the slot variable `Trackable` object to be restored. 1361 slot_name: The name of this `Optimizer`'s slot to restore into. 1362 variable: The variable object this slot is being created for. 1363 """ 1364 variable_key = _var_key(variable) 1365 slot_dict = self._slots.get(variable_key, {}) 1366 slot_variable = slot_dict.get(slot_name, None) 1367 if (slot_variable is None and context.executing_eagerly() and 1368 slot_variable_position.is_simple_variable() 1369 # Defer slot variable creation if there is an active variable creator 1370 # scope. Generally we'd like to eagerly create/restore slot variables 1371 # when possible, but this may mean that scopes intended to catch 1372 # `variable` also catch its eagerly created slot variable 1373 # unintentionally (specifically make_template would add a dependency on 1374 # a slot variable if not for this case). Deferring is mostly harmless 1375 # (aside from double initialization), and makes variable creator scopes 1376 # behave the same way they do when graph building. 1377 # 1378 # One notable case is with distribution strategy, which uses variable 1379 # creator scope but always desires the `variable` and the slot to use 1380 # the same scope, thus we can safely eagerly create/restore slot 1381 # variables. 1382 and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access 1383 self._distribution_strategy)): 1384 initializer = trackable.CheckpointInitialValueCallable( 1385 checkpoint_position=slot_variable_position) 1386 slot_variable = self.add_slot( 1387 var=variable, 1388 initializer=initializer, 1389 slot_name=slot_name, 1390 shape=slot_variable_position.value_shape()) 1391 # Slot variables are not owned by any one object (because we don't want to 1392 # save the slot variable if the optimizer is saved without the non-slot 1393 # variable, or if the non-slot variable is saved without the optimizer; 1394 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1395 # variable, variable)). So we don't _track_ slot variables anywhere, and 1396 # instead special-case this dependency and otherwise pretend it's a normal 1397 # graph. 1398 if slot_variable is not None: 1399 # If we've either made this slot variable, or if we've pulled out an 1400 # existing slot variable, we should restore it. 1401 slot_variable_position.restore(slot_variable) 1402 else: 1403 # We didn't make the slot variable. Defer restoring until it gets created 1404 # normally. We keep a list rather than the one with the highest restore 1405 # UID in case slot variables have their own dependencies, in which case 1406 # those could differ between restores. 1407 self._deferred_slot_restorations.setdefault( 1408 slot_name, {}).setdefault(variable_key, []).append( 1409 slot_variable_position) 1410 1411 @contextlib.contextmanager 1412 def _distribution_strategy_scope(self): 1413 """Returns the `tf.distribute.Strategy` this optimizer was created under.""" 1414 if self._distribution_strategy and not distribute_ctx.has_strategy(): 1415 with self._distribution_strategy.scope(): 1416 yield self._distribution_strategy.scope() 1417 else: 1418 yield 1419 1420 1421def _var_key(var): 1422 """Key for representing a primary variable, for looking up slots. 1423 1424 In graph mode the name is derived from the var shared name. 1425 In eager mode the name is derived from the var unique id. 1426 If distribution strategy exists, get the primary variable first. 1427 1428 Args: 1429 var: the variable. 1430 1431 Returns: 1432 the unique name of the variable. 1433 """ 1434 1435 # pylint: disable=protected-access 1436 # Get the distributed variable if it exists. 1437 if hasattr(var, "_distributed_container"): 1438 var = var._distributed_container() 1439 if var._in_graph_mode: 1440 return var._shared_name 1441 return var._unique_id 1442 1443 1444def _get_slot_key_from_var(var, slot_name): 1445 """Get the slot key for the variable: var_name/slot_name.""" 1446 1447 name = _var_key(var) 1448 return name + "/" + slot_name 1449 1450 1451class RestoredOptimizer(OptimizerV2): 1452 """A non-functional Optimizer implementation for checkpoint compatibility. 1453 1454 Holds slot variables and hyperparameters when an optimizer is restored from a 1455 SavedModel. These variables may be referenced in functions along with ops 1456 created by the original optimizer, but currently we do not support using the 1457 optimizer object iself (e.g. through `apply_gradients`). 1458 """ 1459 # TODO(allenl): Make the restored optimizer functional by tracing its apply 1460 # methods. 1461 1462 def __init__(self): 1463 super(RestoredOptimizer, self).__init__("RestoredOptimizer") 1464 self._hypers_created = True 1465 1466 def get_config(self): 1467 # TODO(allenl): Save and restore the Optimizer's config 1468 raise NotImplementedError( 1469 "Restoring functional Optimizers from SavedModels is not currently " 1470 "supported. Please file a feature request if this limitation bothers " 1471 "you.") 1472 1473revived_types.register_revived_type( 1474 "tf_deprecated_optimizer", 1475 lambda obj: isinstance(obj, OptimizerV2), 1476 versions=[revived_types.VersionedTypeRegistration( 1477 object_factory=lambda proto: RestoredOptimizer(), 1478 version=1, 1479 min_producer_version=1, 1480 min_consumer_version=1, 1481 setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access 1482 )]) 1483