1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Version 2 of class Optimizer.""" 16# pylint: disable=g-bad-name 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import abc 23import contextlib 24import functools 25 26import six 27 28from tensorflow.python.distribute import central_storage_strategy 29from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 30from tensorflow.python.distribute import parameter_server_strategy 31from tensorflow.python.distribute import parameter_server_strategy_v2 32from tensorflow.python.distribute import values as ds_values 33from tensorflow.python.eager import backprop 34from tensorflow.python.eager import context 35from tensorflow.python.eager import monitoring 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.keras import backend 40from tensorflow.python.keras import initializers 41from tensorflow.python.keras.engine import base_layer_utils 42from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 43from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils 44from tensorflow.python.keras.utils import generic_utils 45from tensorflow.python.keras.utils import layer_utils 46from tensorflow.python.keras.utils import tf_inspect 47from tensorflow.python.keras.utils import tf_utils 48from tensorflow.python.ops import array_ops 49from tensorflow.python.ops import control_flow_ops 50from tensorflow.python.ops import gen_resource_variable_ops 51from tensorflow.python.ops import gradients 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import variables as tf_variables 54from tensorflow.python.saved_model import revived_types 55from tensorflow.python.training.tracking import base as trackable 56from tensorflow.python.util import nest 57from tensorflow.python.util.tf_export import keras_export 58 59 60keras_optimizers_gauge = monitoring.BoolGauge( 61 "/tensorflow/api/keras/optimizers", "keras optimizer usage", "method") 62 63_DEFAULT_VALID_DTYPES = frozenset([ 64 dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, 65 dtypes.complex64, dtypes.complex128 66]) 67 68 69def _deduplicate_indexed_slices(values, indices): 70 """Sums `values` associated with any non-unique `indices`. 71 72 Args: 73 values: A `Tensor` with rank >= 1. 74 indices: A one-dimensional integer `Tensor`, indexing into the first 75 dimension of `values` (as in an IndexedSlices object). 76 77 Returns: 78 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 79 de-duplicated version of `indices` and `summed_values` contains the sum of 80 `values` slices associated with each unique index. 81 """ 82 unique_indices, new_index_positions = array_ops.unique(indices) 83 summed_values = math_ops.unsorted_segment_sum( 84 values, new_index_positions, 85 array_ops.shape(unique_indices)[0]) 86 return (summed_values, unique_indices) 87 88 89class NullContextmanager(object): 90 91 def __init__(self, *args, **kwargs): 92 pass 93 94 def __enter__(self): 95 pass 96 97 def __exit__(self, type_arg, value_arg, traceback_arg): 98 return False # False values do not suppress exceptions 99 100 101def name_scope_only_in_function_or_graph(name): 102 """Internal-only entry point for `name_scope*`. 103 104 Enters a compat.v1.name_scope only when in a function or graph, 105 not when running fully eagerly. 106 107 Args: 108 name: The name argument that is passed to the op function. 109 110 Returns: 111 `name_scope*` context manager. 112 """ 113 if not context.executing_eagerly(): 114 return ops.name_scope_v1(name) 115 else: 116 return NullContextmanager() 117 118 119@six.add_metaclass(abc.ABCMeta) 120@keras_export("keras.optimizers.Optimizer") 121class OptimizerV2(trackable.Trackable): 122 """Base class for Keras optimizers. 123 124 You should not use this class directly, but instead instantiate one of its 125 subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc. 126 127 ### Usage 128 129 ```python 130 # Create an optimizer with the desired parameters. 131 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 132 # `loss` is a callable that takes no argument and returns the value 133 # to minimize. 134 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 135 # In graph mode, returns op that minimizes the loss by updating the listed 136 # variables. 137 opt_op = opt.minimize(loss, var_list=[var1, var2]) 138 opt_op.run() 139 # In eager mode, simply call minimize to update the list of variables. 140 opt.minimize(loss, var_list=[var1, var2]) 141 ``` 142 143 ### Usage in custom training loops 144 145 In Keras models, sometimes variables are created when the model is first 146 called, instead of construction time. Examples include 1) sequential models 147 without input shape pre-defined, or 2) subclassed models. Pass var_list as 148 callable in these cases. 149 150 Example: 151 152 ```python 153 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 154 model = tf.keras.Sequential() 155 model.add(tf.keras.layers.Dense(num_hidden, activation='relu')) 156 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')) 157 loss_fn = lambda: tf.keras.losses.mse(model(input), output) 158 var_list_fn = lambda: model.trainable_weights 159 for input, output in data: 160 opt.minimize(loss_fn, var_list_fn) 161 ``` 162 163 ### Processing gradients before applying them 164 165 Calling `minimize()` takes care of both computing the gradients and 166 applying them to the variables. If you want to process the gradients 167 before applying them you can instead use the optimizer in three steps: 168 169 1. Compute the gradients with `tf.GradientTape`. 170 2. Process the gradients as you wish. 171 3. Apply the processed gradients with `apply_gradients()`. 172 173 Example: 174 175 ```python 176 # Create an optimizer. 177 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 178 179 # Compute the gradients for a list of variables. 180 with tf.GradientTape() as tape: 181 loss = <call_loss_function> 182 vars = <list_of_variables> 183 grads = tape.gradient(loss, vars) 184 185 # Process the gradients, for example cap them, etc. 186 # capped_grads = [MyCapper(g) for g in grads] 187 processed_grads = [process_gradient(g) for g in grads] 188 189 # Ask the optimizer to apply the processed gradients. 190 opt.apply_gradients(zip(processed_grads, var_list)) 191 ``` 192 193 ### Use with `tf.distribute.Strategy` 194 195 This optimizer class is `tf.distribute.Strategy` aware, which means it 196 automatically sums gradients across all replicas. To average gradients, 197 you divide your loss by the global batch size, which is done 198 automatically if you use `tf.keras` built-in training or evaluation loops. 199 See the `reduction` argument of your loss which should be set to 200 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 201 `tf.keras.losses.Reduction.SUM` for not. 202 203 To aggregate gradients yourself, call `apply_gradients` with 204 `experimental_aggregate_gradients` set to False. This is useful if you need to 205 process aggregated gradients. 206 207 If you are not using these and you want to average gradients, you should use 208 `tf.math.reduce_sum` to add up your per-example losses and then divide by the 209 global batch size. Note that when using `tf.distribute.Strategy`, the first 210 component of a tensor's shape is the *replica-local* batch size, which is off 211 by a factor equal to the number of replicas being used to compute a single 212 step. As a result, using `tf.math.reduce_mean` will give the wrong answer, 213 resulting in gradients that can be many times too big. 214 215 ### Variable Constraints 216 217 All Keras optimizers respect variable constraints. If constraint function is 218 passed to any variable, the constraint will be applied to the variable after 219 the gradient has been applied to the variable. 220 Important: If gradient is sparse tensor, variable constraint is not supported. 221 222 ### Thread Compatibility 223 224 The entire optimizer is currently thread compatible, not thread-safe. The user 225 needs to perform synchronization if necessary. 226 227 ### Slots 228 229 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 230 additional variables associated with the variables to train. These are called 231 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of 232 the slots that it uses. Once you have a slot name you can ask the optimizer 233 for the variable it created to hold the slot value. 234 235 This can be useful if you want to log debug a training algorithm, report stats 236 about the slots, etc. 237 238 ### Hyperparameters 239 240 These are arguments passed to the optimizer subclass constructor 241 (the `__init__` method), and then passed to `self._set_hyper()`. 242 They can be either regular Python values (like 1.0), tensors, or 243 callables. If they are callable, the callable will be called during 244 `apply_gradients()` to get the value for the hyper parameter. 245 246 Hyperparameters can be overwritten through user code: 247 248 Example: 249 250 ```python 251 # Create an optimizer with the desired parameters. 252 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 253 # `loss` is a callable that takes no argument and returns the value 254 # to minimize. 255 loss = lambda: 3 * var1 + 2 * var2 256 # In eager mode, simply call minimize to update the list of variables. 257 opt.minimize(loss, var_list=[var1, var2]) 258 # update learning rate 259 opt.learning_rate = 0.05 260 opt.minimize(loss, var_list=[var1, var2]) 261 ``` 262 263 ### Callable learning rate 264 265 Optimizer accepts a callable learning rate in two ways. The first way is 266 through built-in or customized 267 `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be 268 called on each iteration with `schedule(iteration)`, a `tf.Variable` 269 owned by the optimizer. 270 271 Example: 272 273 >>> var = tf.Variable(np.random.random(size=(1,))) 274 >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay( 275 ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1) 276 >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate) 277 >>> loss = lambda: 3 * var 278 >>> opt.minimize(loss, var_list=[var]) 279 <tf.Variable... 280 281 The second way is through a callable function that 282 does not accept any arguments. 283 284 Example: 285 286 >>> var = tf.Variable(np.random.random(size=(1,))) 287 >>> def lr_callable(): 288 ... return .1 289 >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable) 290 >>> loss = lambda: 3 * var 291 >>> opt.minimize(loss, var_list=[var]) 292 <tf.Variable... 293 294 ### Creating a custom optimizer 295 296 If you intend to create your own optimization algorithm, simply inherit from 297 this class and override the following methods: 298 299 - `_resource_apply_dense` (update variable given gradient tensor is a dense 300 `tf.Tensor`) 301 - `_resource_apply_sparse` (update variable given gradient tensor is a 302 sparse `tf.IndexedSlices`. The most common way for this to happen 303 is if you are taking the gradient through a `tf.gather`.) 304 - `_create_slots` 305 (if your optimizer algorithm requires additional variables) 306 - `get_config` 307 (serialization of the optimizer, include all hyper parameters) 308 """ 309 310 # Subclasses should set this to True unless they override `apply_gradients` 311 # with a version that does not have the `experimental_aggregate_gradients` 312 # argument. Older versions of Keras did not have this argument so custom 313 # optimizers may have overridden `apply_gradients` without the 314 # `experimental_aggregate_gradients` argument. Keras only passes 315 # `experimental_aggregate_gradients` if this attribute is True. 316 # Note: This attribute will likely be removed in an upcoming release. 317 _HAS_AGGREGATE_GRAD = False 318 319 def __init__(self, 320 name, 321 gradient_aggregator=None, 322 gradient_transformers=None, 323 **kwargs): 324 """Create a new Optimizer. 325 326 This must be called by the constructors of subclasses. 327 Note that Optimizer instances should not bind to a single graph, 328 and so shouldn't keep Tensors as member variables. Generally 329 you should be able to use the _set_hyper()/state.get_hyper() 330 facility instead. 331 332 This class is stateful and thread-compatible. 333 334 Example of custom gradient transformations: 335 336 ```python 337 def my_gradient_transformer(grads_and_vars): 338 # Simple example, double the gradients. 339 return [(2. * g, v) for g, v in grads_and_vars] 340 341 optimizer = tf.keras.optimizers.SGD( 342 1e-3, gradient_transformers=[my_gradient_transformer]) 343 ``` 344 345 Args: 346 name: String. The name to use for momentum accumulator weights created 347 by the optimizer. 348 gradient_aggregator: The function to use to aggregate gradients across 349 devices (when using `tf.distribute.Strategy`). If `None`, defaults to 350 summing the gradients across devices. The function should accept and 351 return a list of `(gradient, variable)` tuples. 352 gradient_transformers: Optional. List of functions to use to transform 353 gradients before applying updates to Variables. The functions are 354 applied after `gradient_aggregator`. The functions should accept and 355 return a list of `(gradient, variable)` tuples. 356 **kwargs: keyword arguments. Allowed arguments are `clipvalue`, 357 `clipnorm`, `global_clipnorm`. 358 If `clipvalue` (float) is set, the gradient of each weight 359 is clipped to be no higher than this value. 360 If `clipnorm` (float) is set, the gradient of each weight 361 is individually clipped so that its norm is no higher than this value. 362 If `global_clipnorm` (float) is set the gradient of all weights is 363 clipped so that their global norm is no higher than this value. 364 365 Raises: 366 ValueError: in case of any invalid argument. 367 """ 368 # Instrument optimizer usages 369 keras_optimizers_gauge.get_cell(self.__class__.__name__).set(True) 370 371 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"} 372 for k in kwargs: 373 if k not in allowed_kwargs: 374 raise TypeError("Unexpected keyword argument " 375 "passed to optimizer: " + str(k)) 376 # checks that all keyword arguments are non-negative. 377 if kwargs[k] is not None and kwargs[k] < 0: 378 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) 379 380 self._use_locking = True 381 self._init_set_name(name) 382 self._hyper = {} 383 # dict: {variable name : {slot name : variable}} 384 self._slots = {} 385 self._slot_names = [] 386 self._weights = [] 387 self._iterations = None 388 389 # For implementing Trackable. Stores information about how to restore 390 # slot variables which have not yet been created 391 # (trackable._CheckpointPosition objects). 392 # {slot_name : 393 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 394 # ... } 395 self._deferred_slot_restorations = {} 396 397 decay = kwargs.pop("decay", 0.0) 398 if decay < 0.: 399 raise ValueError("decay cannot be less than 0: {}".format(decay)) 400 self._initial_decay = decay 401 402 self._hypers_created = False 403 # Store the distribution strategy object if the optimizer is created inside 404 # strategy scope, so it could be used to create variables later. 405 if distribute_ctx.has_strategy(): 406 self._distribution_strategy = distribute_ctx.get_strategy() 407 else: 408 self._distribution_strategy = None 409 410 # Configure gradient transformations. 411 if gradient_aggregator is None: 412 gradient_aggregator = optimizer_utils.all_reduce_sum_gradients 413 self.gradient_aggregator = gradient_aggregator 414 if gradient_transformers is None: 415 gradient_transformers = [] 416 self.gradient_transformers = gradient_transformers 417 self.clipnorm = kwargs.pop("clipnorm", None) 418 self.global_clipnorm = kwargs.pop("global_clipnorm", None) 419 if self.clipnorm is not None and self.global_clipnorm is not None: 420 raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, " 421 "passed `clipnorm` {}, `global_clipnorm` {}".format( 422 self.clipnorm, self.global_clipnorm)) 423 self.clipvalue = kwargs.pop("clipvalue", None) 424 425 @property 426 def clipnorm(self): 427 """`float` or `None`. If set, clips gradients to a maximum norm.""" 428 return self._clipnorm 429 430 @property 431 def global_clipnorm(self): 432 """`float` or `None`. If set, clips gradients to a maximum norm.""" 433 return self._global_clipnorm 434 435 @clipnorm.setter 436 def clipnorm(self, val): 437 if val is not None and self.gradient_transformers: 438 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 439 "is set. Instead, use the `gradient_transformers` to " 440 "specify clipping and other transformations.") 441 self._clipnorm = val 442 self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn( 443 self._clipnorm) 444 445 @global_clipnorm.setter 446 def global_clipnorm(self, val): 447 if val is not None and self.gradient_transformers: 448 raise ValueError("`clipnorm` cannot be set when `gradient_transformers` " 449 "is set. Instead, use the `gradient_transformers` to " 450 "specify clipping and other transformations.") 451 self._global_clipnorm = val 452 self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn( 453 self._global_clipnorm) 454 455 @property 456 def clipvalue(self): 457 """`float` or `None`. If set, clips gradients to a maximum value.""" 458 return self._clipvalue 459 460 @clipvalue.setter 461 def clipvalue(self, val): 462 if val is not None and self.gradient_transformers: 463 raise ValueError("`clipvalue` cannot be set when `gradient_transformers` " 464 "is set. Instead, use the `gradient_transformers` to " 465 "specify clipping and other transformations.") 466 self._clipvalue = val 467 self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn( 468 self._clipvalue) 469 470 def _transform_loss(self, loss): 471 """Called in `.minimize` to transform loss before computing gradients.""" 472 return loss 473 474 def _get_gradients(self, tape, loss, var_list, grad_loss=None): 475 """Called in `minimize` to compute gradients from loss.""" 476 grads = tape.gradient(loss, var_list, grad_loss) 477 return list(zip(grads, var_list)) 478 479 def _transform_unaggregated_gradients(self, grads_and_vars): 480 """Called in `apply_gradients` before gradient aggregation.""" 481 return grads_and_vars 482 483 def _aggregate_gradients(self, grads_and_vars): 484 """Called in `apply_gradients` to aggregate gradients across devices.""" 485 return self.gradient_aggregator(grads_and_vars) 486 487 def _transform_gradients(self, grads_and_vars): 488 """Called in `apply_gradients` after aggregation.""" 489 if self._clipvalue is not None: 490 grads_and_vars = self._clipvalue_fn(grads_and_vars) 491 if self._clipnorm is not None: 492 grads_and_vars = self._clipnorm_fn(grads_and_vars) 493 if self._global_clipnorm is not None: 494 grads_and_vars = self._global_clipnorm_fn(grads_and_vars) 495 496 for fn in self.gradient_transformers: 497 grads_and_vars = fn(grads_and_vars) 498 return grads_and_vars 499 500 def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None): 501 """Minimize `loss` by updating `var_list`. 502 503 This method simply computes gradient using `tf.GradientTape` and calls 504 `apply_gradients()`. If you want to process the gradient before applying 505 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 506 of using this function. 507 508 Args: 509 loss: `Tensor` or callable. If a callable, `loss` should take no arguments 510 and return the value to minimize. If a `Tensor`, the `tape` argument 511 must be passed. 512 var_list: list or tuple of `Variable` objects to update to minimize 513 `loss`, or a callable returning the list or tuple of `Variable` objects. 514 Use callable when the variable list would otherwise be incomplete before 515 `minimize` since the variables are created at the first time `loss` is 516 called. 517 grad_loss: (Optional). A `Tensor` holding the gradient computed for 518 `loss`. 519 name: (Optional) str. Name for the returned operation. 520 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 521 the tape that computed the `loss` must be provided. 522 523 Returns: 524 An `Operation` that updates the variables in `var_list`. The `iterations` 525 will be automatically increased by 1. 526 527 Raises: 528 ValueError: If some of the variables are not `Variable` objects. 529 530 """ 531 grads_and_vars = self._compute_gradients( 532 loss, var_list=var_list, grad_loss=grad_loss, tape=tape) 533 return self.apply_gradients(grads_and_vars, name=name) 534 535 def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): 536 """Compute gradients of `loss` for the variables in `var_list`. 537 538 This is the first part of `minimize()`. It returns a list 539 of (gradient, variable) pairs where "gradient" is the gradient 540 for "variable". Note that "gradient" can be a `Tensor`, an 541 `IndexedSlices`, or `None` if there is no gradient for the 542 given variable. 543 544 Args: 545 loss: `Tensor` or callable. If a callable, `loss` should take no 546 arguments and return the value to minimize. If a `Tensor`, the `tape` 547 argument must be passed. 548 var_list: list or tuple of `Variable` objects to update to minimize 549 `loss`, or a callable returning the list or tuple of `Variable` objects. 550 Use callable when the variable list would otherwise be incomplete before 551 `minimize` and the variables are created at the first time when `loss` 552 is called. 553 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 554 tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, 555 the tape that computed the `loss` must be provided. 556 557 Returns: 558 A list of (gradient, variable) pairs. Variable is always present, but 559 gradient can be `None`. 560 561 Raises: 562 TypeError: If `var_list` contains anything else than `Variable` objects. 563 ValueError: If some arguments are invalid, or var_list is None. 564 """ 565 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 566 if not callable(loss) and tape is None: 567 raise ValueError("`tape` is required when a `Tensor` loss is passed.") 568 tape = tape if tape is not None else backprop.GradientTape() 569 570 if callable(loss): 571 with tape: 572 if not callable(var_list): 573 tape.watch(var_list) 574 loss = loss() 575 if callable(var_list): 576 var_list = var_list() 577 578 with tape: 579 loss = self._transform_loss(loss) 580 581 var_list = nest.flatten(var_list) 582 with ops.name_scope_v2(self._name + "/gradients"): 583 grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss) 584 585 self._assert_valid_dtypes([ 586 v for g, v in grads_and_vars 587 if g is not None and v.dtype != dtypes.resource 588 ]) 589 590 return grads_and_vars 591 592 def apply_gradients(self, 593 grads_and_vars, 594 name=None, 595 experimental_aggregate_gradients=True): 596 """Apply gradients to variables. 597 598 This is the second part of `minimize()`. It returns an `Operation` that 599 applies gradients. 600 601 The method sums gradients from all replicas in the presence of 602 `tf.distribute.Strategy` by default. You can aggregate gradients yourself by 603 passing `experimental_aggregate_gradients=False`. 604 605 Example: 606 607 ```python 608 grads = tape.gradient(loss, vars) 609 grads = tf.distribute.get_replica_context().all_reduce('sum', grads) 610 # Processing aggregated gradients. 611 optimizer.apply_gradients(zip(grads, vars), 612 experimental_aggregate_gradients=False) 613 614 ``` 615 616 Args: 617 grads_and_vars: List of (gradient, variable) pairs. 618 name: Optional name for the returned operation. Default to the name passed 619 to the `Optimizer` constructor. 620 experimental_aggregate_gradients: Whether to sum gradients from different 621 replicas in the presense of `tf.distribute.Strategy`. If False, it's 622 user responsibility to aggregate the gradients. Default to True. 623 624 Returns: 625 An `Operation` that applies the specified gradients. The `iterations` 626 will be automatically increased by 1. 627 628 Raises: 629 TypeError: If `grads_and_vars` is malformed. 630 ValueError: If none of the variables have gradients. 631 RuntimeError: If called in a cross-replica context. 632 """ 633 grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars) 634 var_list = [v for (_, v) in grads_and_vars] 635 636 with ops.name_scope_v2(self._name): 637 # Create iteration if necessary. 638 with ops.init_scope(): 639 self._create_all_weights(var_list) 640 641 if not grads_and_vars: 642 # Distribution strategy does not support reducing an empty list of 643 # gradients 644 return control_flow_ops.no_op() 645 646 if distribute_ctx.in_cross_replica_context(): 647 raise RuntimeError( 648 "`apply_gradients() cannot be called in cross-replica context. " 649 "Use `tf.distribute.Strategy.run` to enter replica " 650 "context.") 651 652 strategy = distribute_ctx.get_strategy() 653 if (not experimental_aggregate_gradients and strategy and 654 isinstance(strategy, 655 (parameter_server_strategy.ParameterServerStrategyV1, 656 parameter_server_strategy_v2.ParameterServerStrategyV2, 657 central_storage_strategy.CentralStorageStrategy, 658 central_storage_strategy.CentralStorageStrategyV1))): 659 raise NotImplementedError( 660 "`experimental_aggregate_gradients=False is not supported for " 661 "ParameterServerStrategy and CentralStorageStrategy") 662 663 apply_state = self._prepare(var_list) 664 if experimental_aggregate_gradients: 665 grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars) 666 grads_and_vars = self._aggregate_gradients(grads_and_vars) 667 grads_and_vars = self._transform_gradients(grads_and_vars) 668 669 return distribute_ctx.get_replica_context().merge_call( 670 functools.partial(self._distributed_apply, apply_state=apply_state), 671 args=(grads_and_vars,), 672 kwargs={ 673 "name": name, 674 }) 675 676 def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): 677 """`apply_gradients` using a `DistributionStrategy`.""" 678 679 def apply_grad_to_update_var(var, grad): 680 """Apply gradient to variable.""" 681 if isinstance(var, ops.Tensor): 682 raise NotImplementedError("Trying to update a Tensor ", var) 683 684 apply_kwargs = {} 685 if isinstance(grad, ops.IndexedSlices): 686 if var.constraint is not None: 687 raise RuntimeError( 688 "Cannot use a constraint function on a sparse variable.") 689 if "apply_state" in self._sparse_apply_args: 690 apply_kwargs["apply_state"] = apply_state 691 return self._resource_apply_sparse_duplicate_indices( 692 grad.values, var, grad.indices, **apply_kwargs) 693 694 if "apply_state" in self._dense_apply_args: 695 apply_kwargs["apply_state"] = apply_state 696 update_op = self._resource_apply_dense(grad, var, **apply_kwargs) 697 if var.constraint is not None: 698 with ops.control_dependencies([update_op]): 699 return var.assign(var.constraint(var)) 700 else: 701 return update_op 702 703 eagerly_outside_functions = ops.executing_eagerly_outside_functions() 704 update_ops = [] 705 with name_scope_only_in_function_or_graph(name or self._name): 706 for grad, var in grads_and_vars: 707 # TODO(crccw): It's not allowed to assign PerReplica value to 708 # MirroredVariable. Remove this after we relax this restriction. 709 def _assume_mirrored(grad): 710 if isinstance(grad, ds_values.PerReplica): 711 return ds_values.Mirrored(grad.values) 712 return grad 713 714 grad = nest.map_structure(_assume_mirrored, grad) 715 # Colocate the update with variables to avoid unnecessary communication 716 # delays. See b/136304694. 717 with distribution.extended.colocate_vars_with(var): 718 with name_scope_only_in_function_or_graph( 719 "update" if eagerly_outside_functions else "update_" + 720 var.op.name): 721 update_ops.extend(distribution.extended.update( 722 var, apply_grad_to_update_var, args=(grad,), group=False)) 723 724 any_symbolic = any(isinstance(i, ops.Operation) or 725 tf_utils.is_symbolic_tensor(i) for i in update_ops) 726 if not context.executing_eagerly() or any_symbolic: 727 # If the current context is graph mode or any of the update ops are 728 # symbolic then the step update should be carried out under a graph 729 # context. (eager updates execute immediately) 730 with backend._current_graph(update_ops).as_default(): # pylint: disable=protected-access 731 with ops.control_dependencies([control_flow_ops.group(update_ops)]): 732 return self._iterations.assign_add(1, read_value=False) 733 734 return self._iterations.assign_add(1) 735 736 def get_gradients(self, loss, params): 737 """Returns gradients of `loss` with respect to `params`. 738 739 Should be used only in legacy v1 graph mode. 740 741 Args: 742 loss: Loss tensor. 743 params: List of variables. 744 745 Returns: 746 List of gradient tensors. 747 748 Raises: 749 ValueError: In case any gradient cannot be computed (e.g. if gradient 750 function not implemented). 751 """ 752 params = nest.flatten(params) 753 with backend.get_graph().as_default(), backend.name_scope(self._name + 754 "/gradients"): 755 grads = gradients.gradients(loss, params) 756 for grad, param in zip(grads, params): 757 if grad is None: 758 raise ValueError("Variable {} has `None` for gradient. " 759 "Please make sure that all of your ops have a " 760 "gradient defined (i.e. are differentiable). " 761 "Common ops without gradient: " 762 "K.argmax, K.round, K.eval.".format(param)) 763 return grads 764 765 def get_updates(self, loss, params): 766 grads = self.get_gradients(loss, params) 767 grads_and_vars = list(zip(grads, params)) 768 self._assert_valid_dtypes([ 769 v for g, v in grads_and_vars 770 if g is not None and v.dtype != dtypes.resource 771 ]) 772 return [self.apply_gradients(grads_and_vars)] 773 774 def _set_hyper(self, name, value): 775 """set hyper `name` to value. value can be callable, tensor, numeric.""" 776 if isinstance(value, trackable.Trackable): 777 self._track_trackable(value, name, overwrite=True) 778 if name not in self._hyper: 779 self._hyper[name] = value 780 else: 781 prev_value = self._hyper[name] 782 if (callable(prev_value) 783 or isinstance(prev_value, 784 (ops.Tensor, int, float, 785 learning_rate_schedule.LearningRateSchedule)) 786 or isinstance(value, learning_rate_schedule.LearningRateSchedule)): 787 self._hyper[name] = value 788 else: 789 backend.set_value(self._hyper[name], value) 790 791 def _get_hyper(self, name, dtype=None): 792 if not self._hypers_created: 793 self._create_hypers() 794 value = self._hyper[name] 795 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 796 return value 797 if callable(value): 798 value = value() 799 if dtype: 800 return math_ops.cast(value, dtype) 801 else: 802 return value 803 804 def _create_slots(self, var_list): 805 pass 806 807 def _create_all_weights(self, var_list): 808 """Creates all weights, including iterations, hyperparameters and slot vars. 809 810 This will add newly created variables to `optimizer.weights`. 811 812 New variables are only created when this method is called the first time, or 813 when called with different variables in the var_list. 814 815 Args: 816 var_list: list or tuple of `Variable` objects that will be minimized 817 using this optimizer. 818 """ 819 820 _ = self.iterations 821 self._create_hypers() 822 self._create_slots(var_list) 823 824 def __getattribute__(self, name): 825 """Overridden to support hyperparameter access.""" 826 try: 827 return super(OptimizerV2, self).__getattribute__(name) 828 except AttributeError as e: 829 # Needed to avoid infinite recursion with __setattr__. 830 if name == "_hyper": 831 raise e 832 # Backwards compatibility with Keras optimizers. 833 if name == "lr": 834 name = "learning_rate" 835 if name in self._hyper: 836 return self._get_hyper(name) 837 raise e 838 839 def __dir__(self): 840 result = set(super(OptimizerV2, self).__dir__()) 841 if "_hyper" in result: 842 result |= self._hyper.keys() 843 if "learning_rate" in self._hyper.keys(): 844 result.add("lr") 845 return list(result) 846 847 def __setattr__(self, name, value): 848 """Override setattr to support dynamic hyperparameter setting.""" 849 # Backwards compatibility with Keras optimizers. 850 if name == "lr": 851 name = "learning_rate" 852 if hasattr(self, "_hyper") and name in self._hyper: 853 self._set_hyper(name, value) 854 else: 855 super(OptimizerV2, self).__setattr__(name, value) 856 857 def get_slot_names(self): 858 """A list of names for this optimizer's slots.""" 859 return self._slot_names 860 861 def add_slot(self, var, slot_name, initializer="zeros", shape=None): 862 """Add a new slot variable for `var`. 863 864 A slot variable is an additional variable associated with `var` to train. 865 It is allocated and managed by optimizers, e.g. `Adam`. 866 867 Args: 868 var: a `Variable` object. 869 slot_name: name of the slot variable. 870 initializer: initializer of the slot variable 871 shape: (Optional) shape of the slot variable. If not set, it will default 872 to the shape of `var`. 873 874 Returns: 875 A slot variable. 876 """ 877 if slot_name not in self._slot_names: 878 self._slot_names.append(slot_name) 879 var_key = _var_key(var) 880 slot_dict = self._slots.setdefault(var_key, {}) 881 weight = slot_dict.get(slot_name, None) 882 if weight is None: 883 if isinstance(initializer, six.string_types) or callable(initializer): 884 initializer = initializers.get(initializer) 885 if isinstance( 886 initializer, 887 trackable.CheckpointInitialValueCallable) or (shape is not None): 888 slot_shape = shape 889 else: 890 slot_shape = var.shape 891 initial_value = functools.partial( 892 initializer, shape=slot_shape, dtype=var.dtype) 893 else: 894 initial_value = initializer 895 896 with self._distribution_strategy_scope(): 897 strategy = distribute_ctx.get_strategy() 898 if not strategy.extended.variable_created_in_scope(var): 899 raise ValueError( 900 "Trying to create optimizer slot variable under the scope for " 901 "tf.distribute.Strategy ({}), which is different from the scope " 902 "used for the original variable ({}). Make sure the slot " 903 "variables are created under the same strategy scope. This may " 904 "happen if you're restoring from a checkpoint outside the scope" 905 .format(strategy, var)) 906 907 with strategy.extended.colocate_vars_with(var): 908 weight = tf_variables.Variable( 909 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access 910 dtype=var.dtype, 911 trainable=False, 912 initial_value=initial_value) 913 backend.track_variable(weight) 914 slot_dict[slot_name] = weight 915 self._restore_slot_variable( 916 slot_name=slot_name, variable=var, 917 slot_variable=weight) 918 self._weights.append(weight) 919 return weight 920 921 def get_slot(self, var, slot_name): 922 var_key = _var_key(var) 923 slot_dict = self._slots[var_key] 924 return slot_dict[slot_name] 925 926 def _prepare(self, var_list): 927 keys = set() 928 for var in var_list: 929 if isinstance(var, ds_values.DistributedValues): 930 var_devices = var._devices # pylint: disable=protected-access 931 else: 932 var_devices = [var.device] 933 var_dtype = var.dtype.base_dtype 934 for var_device in var_devices: 935 keys.add((var_device, var_dtype)) 936 937 apply_state = {} 938 for var_device, var_dtype in keys: 939 apply_state[(var_device, var_dtype)] = {} 940 with ops.device(var_device): 941 self._prepare_local(var_device, var_dtype, apply_state) 942 943 return apply_state 944 945 def _prepare_local(self, var_device, var_dtype, apply_state): 946 if "learning_rate" in self._hyper: 947 lr_t = array_ops.identity(self._decayed_lr(var_dtype)) 948 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t 949 950 def _fallback_apply_state(self, var_device, var_dtype): 951 """Compatibility for subclasses that don't pass apply_state through.""" 952 apply_state = {(var_device, var_dtype): {}} 953 self._prepare_local(var_device, var_dtype, apply_state) 954 return apply_state[(var_device, var_dtype)] 955 956 def _create_hypers(self): 957 if self._hypers_created: 958 return 959 with self._distribution_strategy_scope(): 960 # Iterate hyper values deterministically. 961 for name, value in sorted(self._hyper.items()): 962 if isinstance(value, 963 (ops.Tensor, tf_variables.Variable)) or callable(value): 964 # The check for `callable` covers the usage when `value` is a 965 # `LearningRateSchedule`, in which case it does not need to create a 966 # variable. 967 continue 968 else: 969 self._hyper[name] = self.add_weight( 970 name, 971 shape=[], 972 trainable=False, 973 initializer=value, 974 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 975 self._hypers_created = True 976 977 @property 978 def iterations(self): 979 """Variable. The number of training steps this Optimizer has run.""" 980 if self._iterations is None: 981 with self._distribution_strategy_scope(): 982 self._iterations = self.add_weight( 983 "iter", 984 shape=[], 985 dtype=dtypes.int64, 986 trainable=False, 987 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 988 self._weights.append(self._iterations) 989 return self._iterations 990 991 @iterations.setter 992 def iterations(self, variable): 993 if self._iterations is not None: 994 raise RuntimeError("Cannot set `iterations` to a new Variable after " 995 "the Optimizer weights have been created") 996 self._iterations = variable 997 self._weights.append(self._iterations) 998 999 def _decayed_lr(self, var_dtype): 1000 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 1001 lr_t = self._get_hyper("learning_rate", var_dtype) 1002 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 1003 local_step = math_ops.cast(self.iterations, var_dtype) 1004 lr_t = math_ops.cast(lr_t(local_step), var_dtype) 1005 if self._initial_decay > 0.: 1006 local_step = math_ops.cast(self.iterations, var_dtype) 1007 decay_t = math_ops.cast(self._initial_decay, var_dtype) 1008 lr_t = lr_t / (1. + decay_t * local_step) 1009 return lr_t 1010 1011 @abc.abstractmethod 1012 def get_config(self): 1013 """Returns the config of the optimizer. 1014 1015 An optimizer config is a Python dictionary (serializable) 1016 containing the configuration of an optimizer. 1017 The same optimizer can be reinstantiated later 1018 (without any saved state) from this configuration. 1019 1020 Returns: 1021 Python dictionary. 1022 """ 1023 config = {"name": self._name} 1024 if self.clipnorm is not None: 1025 config["clipnorm"] = self.clipnorm 1026 if self.clipvalue is not None: 1027 config["clipvalue"] = self.clipvalue 1028 if self.global_clipnorm is not None: 1029 config["global_clipnorm"] = self.global_clipnorm 1030 return config 1031 1032 @classmethod 1033 def from_config(cls, config, custom_objects=None): 1034 """Creates an optimizer from its config. 1035 1036 This method is the reverse of `get_config`, 1037 capable of instantiating the same optimizer from the config 1038 dictionary. 1039 1040 Args: 1041 config: A Python dictionary, typically the output of get_config. 1042 custom_objects: A Python dictionary mapping names to additional Python 1043 objects used to create this optimizer, such as a function used for a 1044 hyperparameter. 1045 1046 Returns: 1047 An optimizer instance. 1048 """ 1049 if "lr" in config: 1050 config["learning_rate"] = config.pop("lr") 1051 if "learning_rate" in config: 1052 if isinstance(config["learning_rate"], dict): 1053 config["learning_rate"] = learning_rate_schedule.deserialize( 1054 config["learning_rate"], custom_objects=custom_objects) 1055 return cls(**config) 1056 1057 def _serialize_hyperparameter(self, hyperparameter_name): 1058 """Serialize a hyperparameter that can be a float, callable, or Tensor.""" 1059 value = self._hyper[hyperparameter_name] 1060 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 1061 return learning_rate_schedule.serialize(value) 1062 if callable(value): 1063 return value() 1064 if tensor_util.is_tf_type(value): 1065 return backend.get_value(value) 1066 return value 1067 1068 def variables(self): 1069 """Returns variables of this Optimizer based on the order created.""" 1070 return self._weights 1071 1072 @property 1073 def weights(self): 1074 """Returns variables of this Optimizer based on the order created.""" 1075 return self._weights 1076 1077 def get_weights(self): 1078 """Returns the current weights of the optimizer. 1079 1080 The weights of an optimizer are its state (ie, variables). 1081 This function returns the weight values associated with this 1082 optimizer as a list of Numpy arrays. The first value is always the 1083 iterations count of the optimizer, followed by the optimizer's state 1084 variables in the order they were created. The returned list can in turn 1085 be used to load state into similarly parameterized optimizers. 1086 1087 For example, the RMSprop optimizer for this simple model returns a list of 1088 three values-- the iteration count, followed by the root-mean-square value 1089 of the kernel and bias of the single Dense layer: 1090 1091 >>> opt = tf.keras.optimizers.RMSprop() 1092 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1093 >>> m.compile(opt, loss='mse') 1094 >>> data = np.arange(100).reshape(5, 20) 1095 >>> labels = np.zeros(5) 1096 >>> print('Training'); results = m.fit(data, labels) 1097 Training ... 1098 >>> len(opt.get_weights()) 1099 3 1100 1101 Returns: 1102 Weights values as a list of numpy arrays. 1103 """ 1104 params = self.weights 1105 return backend.batch_get_value(params) 1106 1107 # TODO(tanzheny): Maybe share this logic with base_layer. 1108 def set_weights(self, weights): 1109 """Set the weights of the optimizer. 1110 1111 The weights of an optimizer are its state (ie, variables). 1112 This function takes the weight values associated with this 1113 optimizer as a list of Numpy arrays. The first value is always the 1114 iterations count of the optimizer, followed by the optimizer's state 1115 variables in the order they are created. The passed values are used to set 1116 the new state of the optimizer. 1117 1118 For example, the RMSprop optimizer for this simple model takes a list of 1119 three values-- the iteration count, followed by the root-mean-square value 1120 of the kernel and bias of the single Dense layer: 1121 1122 >>> opt = tf.keras.optimizers.RMSprop() 1123 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1124 >>> m.compile(opt, loss='mse') 1125 >>> data = np.arange(100).reshape(5, 20) 1126 >>> labels = np.zeros(5) 1127 >>> print('Training'); results = m.fit(data, labels) 1128 Training ... 1129 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])] 1130 >>> opt.set_weights(new_weights) 1131 >>> opt.iterations 1132 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10> 1133 1134 Args: 1135 weights: weight values as a list of numpy arrays. 1136 """ 1137 params = self.weights 1138 if len(params) != len(weights): 1139 raise ValueError( 1140 "You called `set_weights(weights)` on optimizer " + self._name + 1141 " with a weight list of length " + str(len(weights)) + 1142 ", but the optimizer was expecting " + str(len(params)) + 1143 " weights. Provided weights: " + str(weights)[:50] + "...") 1144 if not params: 1145 return 1146 weight_value_tuples = [] 1147 param_values = backend.batch_get_value(params) 1148 for pv, p, w in zip(param_values, params, weights): 1149 if pv.shape != w.shape: 1150 raise ValueError("Optimizer weight shape " + str(pv.shape) + 1151 " not compatible with " 1152 "provided weight shape " + str(w.shape)) 1153 weight_value_tuples.append((p, w)) 1154 backend.batch_set_value(weight_value_tuples) 1155 1156 def add_weight(self, 1157 name, 1158 shape, 1159 dtype=None, 1160 initializer="zeros", 1161 trainable=None, 1162 synchronization=tf_variables.VariableSynchronization.AUTO, 1163 aggregation=tf_variables.VariableAggregation.NONE): 1164 1165 if dtype is None: 1166 dtype = dtypes.float32 1167 if isinstance(initializer, six.string_types) or callable(initializer): 1168 initializer = initializers.get(initializer) 1169 1170 if synchronization == tf_variables.VariableSynchronization.ON_READ: 1171 if trainable: 1172 raise ValueError( 1173 "Synchronization value can be set to " 1174 "VariableSynchronization.ON_READ only for non-trainable variables. " 1175 "You have specified trainable=True and " 1176 "synchronization=VariableSynchronization.ON_READ.") 1177 else: 1178 # Set trainable to be false when variable is to be synced on read. 1179 trainable = False 1180 elif trainable is None: 1181 trainable = True 1182 1183 variable = self._add_variable_with_custom_getter( 1184 name=name, 1185 shape=shape, 1186 getter=base_layer_utils.make_variable, 1187 overwrite=True, 1188 initializer=initializer, 1189 dtype=dtype, 1190 trainable=trainable, 1191 use_resource=True, 1192 synchronization=synchronization, 1193 aggregation=aggregation) 1194 backend.track_variable(variable) 1195 1196 return variable 1197 1198 def _init_set_name(self, name, zero_based=True): 1199 if not name: 1200 self._name = backend.unique_object_name( 1201 generic_utils.to_snake_case(self.__class__.__name__), 1202 zero_based=zero_based) 1203 else: 1204 self._name = name 1205 1206 def _assert_valid_dtypes(self, tensors): 1207 """Asserts tensors are all valid types (see `_valid_dtypes`). 1208 1209 Args: 1210 tensors: Tensors to check. 1211 1212 Raises: 1213 ValueError: If any tensor is not a valid type. 1214 """ 1215 valid_dtypes = self._valid_dtypes() 1216 for t in tensors: 1217 dtype = t.dtype.base_dtype 1218 if dtype not in valid_dtypes: 1219 raise ValueError("Invalid type %r for %s, expected: %s." % 1220 (dtype, t.name, [v for v in valid_dtypes])) 1221 1222 def _valid_dtypes(self): 1223 """Valid types for loss, variables and gradients. 1224 1225 Subclasses should override to allow other float types. 1226 1227 Returns: 1228 Valid types for loss, variables and gradients. 1229 """ 1230 return _DEFAULT_VALID_DTYPES 1231 1232 def _call_if_callable(self, param): 1233 """Call the function if param is callable.""" 1234 return param() if callable(param) else param 1235 1236 def _resource_apply_dense(self, grad, handle, apply_state): 1237 """Add ops to apply dense gradients to the variable `handle`. 1238 1239 Args: 1240 grad: a `Tensor` representing the gradient. 1241 handle: a `Tensor` of dtype `resource` which points to the variable to be 1242 updated. 1243 apply_state: A dict which is used across multiple apply calls. 1244 1245 Returns: 1246 An `Operation` which updates the value of the variable. 1247 """ 1248 raise NotImplementedError("Must be implemented in subclasses.") 1249 1250 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices, 1251 **kwargs): 1252 """Add ops to apply sparse gradients to `handle`, with repeated indices. 1253 1254 Optimizers which override this method must deal with repeated indices. See 1255 the docstring of `_apply_sparse_duplicate_indices` for details. By default 1256 the correct behavior, to sum non-unique indices and their associated 1257 gradients, is enforced by first pre-processing `grad` and `indices` and 1258 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 1259 with duplicate indices may instead override this method to avoid the 1260 overhead of summing. 1261 1262 Args: 1263 grad: a `Tensor` representing the gradient for the affected indices. 1264 handle: a `Tensor` of dtype `resource` which points to the variable to be 1265 updated. 1266 indices: a `Tensor` of integral type representing the indices for which 1267 the gradient is nonzero. Indices may be repeated. 1268 **kwargs: May optionally contain `apply_state` 1269 1270 Returns: 1271 An `Operation` which updates the value of the variable. 1272 """ 1273 summed_grad, unique_indices = _deduplicate_indexed_slices( 1274 values=grad, indices=indices) 1275 return self._resource_apply_sparse(summed_grad, handle, unique_indices, 1276 **kwargs) 1277 1278 def _resource_apply_sparse(self, grad, handle, indices, apply_state): 1279 """Add ops to apply sparse gradients to the variable `handle`. 1280 1281 Similar to `_apply_sparse`, the `indices` argument to this method has been 1282 de-duplicated. Optimizers which deal correctly with non-unique indices may 1283 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 1284 overhead. 1285 1286 Args: 1287 grad: a `Tensor` representing the gradient for the affected indices. 1288 handle: a `Tensor` of dtype `resource` which points to the variable to be 1289 updated. 1290 indices: a `Tensor` of integral type representing the indices for which 1291 the gradient is nonzero. Indices are unique. 1292 apply_state: A dict which is used across multiple apply calls. 1293 1294 Returns: 1295 An `Operation` which updates the value of the variable. 1296 """ 1297 raise NotImplementedError("Must be implemented in subclasses.") 1298 1299 def _resource_scatter_add(self, x, i, v): 1300 with ops.control_dependencies([ 1301 gen_resource_variable_ops.ResourceScatterAdd( 1302 resource=x.handle, indices=i, updates=v) 1303 ]): 1304 return x.value() 1305 1306 def _resource_scatter_update(self, x, i, v): 1307 with ops.control_dependencies( 1308 [gen_resource_variable_ops.ResourceScatterUpdate( 1309 resource=x.handle, indices=i, updates=v)]): 1310 return x.value() 1311 1312 @property 1313 @layer_utils.cached_per_instance 1314 def _dense_apply_args(self): 1315 return tf_inspect.getfullargspec(self._resource_apply_dense).args 1316 1317 @property 1318 @layer_utils.cached_per_instance 1319 def _sparse_apply_args(self): 1320 return tf_inspect.getfullargspec(self._resource_apply_sparse).args 1321 1322 # --------------- 1323 # For implementing the trackable interface 1324 # --------------- 1325 1326 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1327 """Restore a newly created slot variable's value.""" 1328 variable_key = _var_key(variable) 1329 deferred_restorations = self._deferred_slot_restorations.get( 1330 slot_name, {}).pop(variable_key, []) 1331 # Iterate over restores, highest restore UID first to minimize the number 1332 # of assignments. 1333 deferred_restorations.sort(key=lambda position: position.restore_uid, 1334 reverse=True) 1335 for checkpoint_position in deferred_restorations: 1336 checkpoint_position.restore(slot_variable) 1337 1338 def _create_or_restore_slot_variable( 1339 self, slot_variable_position, slot_name, variable): 1340 """Restore a slot variable's value, possibly creating it. 1341 1342 Called when a variable which has an associated slot variable is created or 1343 restored. When executing eagerly, we create the slot variable with a 1344 restoring initializer. 1345 1346 No new variables are created when graph building. Instead, 1347 _restore_slot_variable catches these after normal creation and adds restore 1348 ops to the graph. This method is nonetheless important when graph building 1349 for the case when a slot variable has already been created but `variable` 1350 has just been added to a dependency graph (causing us to realize that the 1351 slot variable needs to be restored). 1352 1353 Args: 1354 slot_variable_position: A `trackable._CheckpointPosition` object 1355 indicating the slot variable `Trackable` object to be restored. 1356 slot_name: The name of this `Optimizer`'s slot to restore into. 1357 variable: The variable object this slot is being created for. 1358 """ 1359 variable_key = _var_key(variable) 1360 slot_dict = self._slots.get(variable_key, {}) 1361 slot_variable = slot_dict.get(slot_name, None) 1362 if (slot_variable is None and context.executing_eagerly() and 1363 slot_variable_position.is_simple_variable() 1364 # Defer slot variable creation if there is an active variable creator 1365 # scope. Generally we'd like to eagerly create/restore slot variables 1366 # when possible, but this may mean that scopes intended to catch 1367 # `variable` also catch its eagerly created slot variable 1368 # unintentionally (specifically make_template would add a dependency on 1369 # a slot variable if not for this case). Deferring is mostly harmless 1370 # (aside from double initialization), and makes variable creator scopes 1371 # behave the same way they do when graph building. 1372 # 1373 # One notable case is with distribution strategy, which uses variable 1374 # creator scope but always desires the `variable` and the slot to use 1375 # the same scope, thus we can safely eagerly create/restore slot 1376 # variables. 1377 and (not ops.get_default_graph()._variable_creator_stack or # pylint: disable=protected-access 1378 self._distribution_strategy)): 1379 initializer = trackable.CheckpointInitialValueCallable( 1380 checkpoint_position=slot_variable_position) 1381 # Shape is unknown until we read the checkpoint value. 1382 slot_variable = self.add_slot( 1383 var=variable, 1384 initializer=initializer, 1385 slot_name=slot_name) 1386 # Slot variables are not owned by any one object (because we don't want to 1387 # save the slot variable if the optimizer is saved without the non-slot 1388 # variable, or if the non-slot variable is saved without the optimizer; 1389 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1390 # variable, variable)). So we don't _track_ slot variables anywhere, and 1391 # instead special-case this dependency and otherwise pretend it's a normal 1392 # graph. 1393 if slot_variable is not None: 1394 # If we've either made this slot variable, or if we've pulled out an 1395 # existing slot variable, we should restore it. 1396 slot_variable_position.restore(slot_variable) 1397 else: 1398 # We didn't make the slot variable. Defer restoring until it gets created 1399 # normally. We keep a list rather than the one with the highest restore 1400 # UID in case slot variables have their own dependencies, in which case 1401 # those could differ between restores. 1402 self._deferred_slot_restorations.setdefault( 1403 slot_name, {}).setdefault(variable_key, []).append( 1404 slot_variable_position) 1405 1406 @contextlib.contextmanager 1407 def _distribution_strategy_scope(self): 1408 """Returns the `tf.distribute.Strategy` this optimizer was created under.""" 1409 if self._distribution_strategy and not distribute_ctx.has_strategy(): 1410 with self._distribution_strategy.scope(): 1411 yield self._distribution_strategy.scope() 1412 else: 1413 yield 1414 1415 1416def _var_key(var): 1417 """Key for representing a primary variable, for looking up slots. 1418 1419 In graph mode the name is derived from the var shared name. 1420 In eager mode the name is derived from the var unique id. 1421 If distribution strategy exists, get the primary variable first. 1422 1423 Args: 1424 var: the variable. 1425 1426 Returns: 1427 the unique name of the variable. 1428 """ 1429 1430 # pylint: disable=protected-access 1431 # Get the distributed variable if it exists. 1432 if hasattr(var, "_distributed_container"): 1433 var = var._distributed_container() 1434 if var._in_graph_mode: 1435 return var._shared_name 1436 return var._unique_id 1437 1438 1439def _get_slot_key_from_var(var, slot_name): 1440 """Get the slot key for the variable: var_name/slot_name.""" 1441 1442 name = _var_key(var) 1443 return name + "/" + slot_name 1444 1445 1446class RestoredOptimizer(OptimizerV2): 1447 """A non-functional Optimizer implementation for checkpoint compatibility. 1448 1449 Holds slot variables and hyperparameters when an optimizer is restored from a 1450 SavedModel. These variables may be referenced in functions along with ops 1451 created by the original optimizer, but currently we do not support using the 1452 optimizer object iself (e.g. through `apply_gradients`). 1453 """ 1454 # TODO(allenl): Make the restored optimizer functional by tracing its apply 1455 # methods. 1456 1457 def __init__(self): 1458 super(RestoredOptimizer, self).__init__("RestoredOptimizer") 1459 self._hypers_created = True 1460 1461 def get_config(self): 1462 # TODO(allenl): Save and restore the Optimizer's config 1463 raise NotImplementedError( 1464 "Restoring functional Optimizers from SavedModels is not currently " 1465 "supported. Please file a feature request if this limitation bothers " 1466 "you.") 1467 1468revived_types.register_revived_type( 1469 "optimizer", 1470 lambda obj: isinstance(obj, OptimizerV2), 1471 versions=[revived_types.VersionedTypeRegistration( 1472 object_factory=lambda proto: RestoredOptimizer(), 1473 version=1, 1474 min_producer_version=1, 1475 min_consumer_version=1, 1476 setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access 1477 )]) 1478