1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Version 2 of class Optimizer.""" 17# pylint: disable=g-bad-name 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24import functools 25 26import six 27 28from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 29from tensorflow.python.distribute import reduce_util as ds_reduce_util 30from tensorflow.python.eager import backprop 31from tensorflow.python.eager import context 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.keras import backend 36from tensorflow.python.keras import initializers 37from tensorflow.python.keras.engine import base_layer_utils 38from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 39from tensorflow.python.keras.utils import generic_utils 40from tensorflow.python.keras.utils import tf_utils 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import clip_ops 43from tensorflow.python.ops import control_flow_ops 44from tensorflow.python.ops import gradients 45from tensorflow.python.ops import math_ops 46from tensorflow.python.ops import resource_variable_ops 47from tensorflow.python.ops import variables as tf_variables 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.saved_model import revived_types 50from tensorflow.python.training.tracking import base as trackable 51from tensorflow.python.training.tracking import tracking 52from tensorflow.python.util import nest 53from tensorflow.python.util import tf_inspect 54from tensorflow.python.util.tf_export import keras_export 55 56 57def _deduplicate_indexed_slices(values, indices): 58 """Sums `values` associated with any non-unique `indices`. 59 60 Args: 61 values: A `Tensor` with rank >= 1. 62 indices: A one-dimensional integer `Tensor`, indexing into the first 63 dimension of `values` (as in an IndexedSlices object). 64 65 Returns: 66 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 67 de-duplicated version of `indices` and `summed_values` contains the sum of 68 `values` slices associated with each unique index. 69 """ 70 unique_indices, new_index_positions = array_ops.unique(indices) 71 summed_values = math_ops.unsorted_segment_sum( 72 values, new_index_positions, 73 array_ops.shape(unique_indices)[0]) 74 return (summed_values, unique_indices) 75 76 77@six.add_metaclass(abc.ABCMeta) 78@keras_export("keras.optimizers.Optimizer") 79class OptimizerV2(trackable.Trackable): 80 """Updated base class for optimizers. 81 82 This class defines the API to add Ops to train a model. You never use this 83 class directly, but instead instantiate one of its subclasses such as 84 `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`. 85 86 ### Usage 87 88 ```python 89 # Create an optimizer with the desired parameters. 90 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 91 # `loss` is a callable that takes no argument and returns the value 92 # to minimize. 93 loss = lambda: 3 * var1 * var1 + 2 * var2 * var2 94 # In graph mode, returns op that minimizes the loss by updating the listed 95 # variables. 96 opt_op = opt.minimize(loss, var_list=[var1, var2]) 97 opt_op.run() 98 # In eager mode, simply call minimize to update the list of variables. 99 opt.minimize(loss, var_list=[var1, var2]) 100 ``` 101 102 ### Custom training loop with Keras models 103 104 In Keras models, sometimes variables are created when the model is first 105 called, instead of construction time. Examples include 1) sequential models 106 without input shape pre-defined, or 2) subclassed models. Pass var_list as 107 callable in these cases. 108 109 Example: 110 ```python 111 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 112 model = tf.keras.Sequential() 113 model.add(tf.keras.layers.Dense(num_hidden, activation='relu')) 114 model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid')) 115 loss_fn = lambda: tf.keras.losses.mse(model(input), output) 116 var_list_fn = lambda: model.trainable_weights 117 for input, output in data: 118 opt.minimize(loss_fn, var_list_fn) 119 ``` 120 121 ### Processing gradients before applying them. 122 123 Calling `minimize()` takes care of both computing the gradients and 124 applying them to the variables. If you want to process the gradients 125 before applying them you can instead use the optimizer in three steps: 126 127 1. Compute the gradients with `tf.GradientTape`. 128 2. Process the gradients as you wish. 129 3. Apply the processed gradients with `apply_gradients()`. 130 131 Example: 132 133 ```python 134 # Create an optimizer. 135 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 136 137 # Compute the gradients for a list of variables. 138 with tf.GradientTape() as tape: 139 loss = <call_loss_function> 140 vars = <list_of_variables> 141 grads = tape.gradient(loss, vars) 142 143 # Process the gradients, for example cap them, etc. 144 # capped_grads = [MyCapper(g) for g in grads] 145 processed_grads = [process_gradient(g) for g in grads] 146 147 # Ask the optimizer to apply the processed gradients. 148 opt.apply_gradients(zip(processed_grads, var_list)) 149 ``` 150 151 ### Use with `tf.distribute.Strategy`. 152 153 This optimizer class is `tf.distribute.Strategy` aware, which means it 154 automatically sums gradients across all replicas. To average gradients, 155 you divide your loss by the global batch size, which is done 156 automatically if you use `tf.keras` built-in training or evaluation loops. 157 See the `reduction` argument of your loss which should be set to 158 `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or 159 `tf.keras.losses.Reduction.SUM` for not. 160 161 If you are not using these and you want to average gradients, you should use 162 `tf.math.reduce_sum` to add up your per-example losses and then divide by the 163 global batch size. Note that when using `tf.distribute.Strategy`, the first 164 component of a tensor's shape is the *replica-local* batch size, which is off 165 by a factor equal to the number of replicas being used to compute a single 166 step. As a result, using `tf.math.reduce_mean` will give the wrong answer, 167 resulting in gradients that can be many times too big. 168 169 ### Variable Constraint 170 171 All Keras optimizers respect variable constraints. If constraint function is 172 passed to any variable, the constraint will be applied to the variable after 173 the gradient has been applied to the variable. 174 Important: If gradient is sparse tensor, variable constraint is not supported. 175 176 ### Thread Compatibility 177 178 The entire optimizer is currently thread compatible, not thread-safe. The user 179 needs to perform synchronization if necessary. 180 181 ### Slots 182 183 Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage 184 additional variables associated with the variables to train. These are called 185 <i>Slots</i>. Slots have names and you can ask the optimizer for the names of 186 the slots that it uses. Once you have a slot name you can ask the optimizer 187 for the variable it created to hold the slot value. 188 189 This can be useful if you want to log debug a training algorithm, report stats 190 about the slots, etc. 191 192 ### Hyper parameters 193 194 These are arguments passed to the optimizer subclass constructor 195 (the `__init__` method), and then passed to `self._set_hyper()`. 196 They can be either regular Python values (like 1.0), tensors, or 197 callables. If they are callable, the callable will be called during 198 `apply_gradients()` to get the value for the hyper parameter. 199 200 Hyper parameters can be overwritten through user code: 201 202 Example: 203 204 ```python 205 # Create an optimizer with the desired parameters. 206 opt = tf.keras.optimizers.SGD(learning_rate=0.1) 207 # `loss` is a callable that takes no argument and returns the value 208 # to minimize. 209 loss = lambda: 3 * var1 + 2 * var2 210 # In eager mode, simply call minimize to update the list of variables. 211 opt.minimize(loss, var_list=[var1, var2]) 212 # update learning rate 213 opt.learning_rate = 0.05 214 opt.minimize(loss, var_list=[var1, var2]) 215 ``` 216 217 ### Write a customized optimizer. 218 If you intend to create your own optimization algorithm, simply inherit from 219 this class and override the following methods: 220 221 - resource_apply_dense (update variable given gradient tensor is dense) 222 - resource_apply_sparse (update variable given gradient tensor is sparse) 223 - create_slots (if your optimizer algorithm requires additional variables) 224 - get_config (serialization of the optimizer, include all hyper parameters) 225 """ 226 227 def __init__(self, name, **kwargs): 228 """Create a new Optimizer. 229 230 This must be called by the constructors of subclasses. 231 Note that Optimizer instances should not bind to a single graph, 232 and so shouldn't keep Tensors as member variables. Generally 233 you should be able to use the _set_hyper()/state.get_hyper() 234 facility instead. 235 236 This class in stateful and thread-compatible. 237 238 Args: 239 name: A non-empty string. The name to use for accumulators created 240 for the optimizer. 241 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 242 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 243 gradients by value, `decay` is included for backward compatibility to 244 allow time inverse decay of learning rate. `lr` is included for backward 245 compatibility, recommended to use `learning_rate` instead. 246 247 Raises: 248 ValueError: If name is malformed. 249 RuntimeError: If _create_slots has been overridden instead of 250 _create_vars. 251 """ 252 allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay"} 253 for k in kwargs: 254 if k not in allowed_kwargs: 255 raise TypeError("Unexpected keyword argument " 256 "passed to optimizer: " + str(k)) 257 # checks that all keyword arguments are non-negative. 258 if kwargs[k] < 0: 259 raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k])) 260 261 self._use_locking = True 262 self._init_set_name(name) 263 self._hyper = {} 264 # dict: {variable name : {slot name : variable}} 265 self._slots = {} 266 self._slot_names = [] 267 self._weights = [] 268 self._iterations = None 269 270 # For implementing Trackable. Stores information about how to restore 271 # slot variables which have not yet been created 272 # (trackable._CheckpointPosition objects). 273 # {slot_name : 274 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 275 # ... } 276 self._deferred_slot_restorations = {} 277 278 decay = kwargs.pop("decay", 0.0) 279 if decay < 0.: 280 raise ValueError("decay cannot be less than 0: {}".format(decay)) 281 self._initial_decay = decay 282 if "clipnorm" in kwargs: 283 self.clipnorm = kwargs.pop("clipnorm") 284 if "clipvalue" in kwargs: 285 self.clipvalue = kwargs.pop("clipvalue") 286 287 self._hypers_created = False 288 289 def minimize(self, loss, var_list, grad_loss=None, name=None): 290 """Minimize `loss` by updating `var_list`. 291 292 This method simply computes gradient using `tf.GradientTape` and calls 293 `apply_gradients()`. If you want to process the gradient before applying 294 then call `tf.GradientTape` and `apply_gradients()` explicitly instead 295 of using this function. 296 297 Args: 298 loss: A callable taking no arguments which returns the value to minimize. 299 var_list: list or tuple of `Variable` objects to update to minimize 300 `loss`, or a callable returning the list or tuple of `Variable` objects. 301 Use callable when the variable list would otherwise be incomplete before 302 `minimize` since the variables are created at the first time `loss` is 303 called. 304 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 305 name: Optional name for the returned operation. 306 307 Returns: 308 An `Operation` that updates the variables in `var_list`. The `iterations` 309 will be automatically increased by 1. 310 311 Raises: 312 ValueError: If some of the variables are not `Variable` objects. 313 314 """ 315 grads_and_vars = self._compute_gradients( 316 loss, var_list=var_list, grad_loss=grad_loss) 317 318 return self.apply_gradients(grads_and_vars, name=name) 319 320 def _compute_gradients(self, loss, var_list, grad_loss=None): 321 """Compute gradients of `loss` for the variables in `var_list`. 322 323 This is the first part of `minimize()`. It returns a list 324 of (gradient, variable) pairs where "gradient" is the gradient 325 for "variable". Note that "gradient" can be a `Tensor`, an 326 `IndexedSlices`, or `None` if there is no gradient for the 327 given variable. 328 329 Args: 330 loss: A callable taking no arguments which returns the value to minimize. 331 var_list: list or tuple of `Variable` objects to update to minimize 332 `loss`, or a callable returning the list or tuple of `Variable` objects. 333 Use callable when the variable list would otherwise be incomplete before 334 `minimize` and the variables are created at the first time when `loss` 335 is called. 336 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 337 338 Returns: 339 A list of (gradient, variable) pairs. Variable is always present, but 340 gradient can be `None`. 341 342 Raises: 343 TypeError: If `var_list` contains anything else than `Variable` objects. 344 ValueError: If some arguments are invalid, or var_list is None. 345 """ 346 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 347 with backprop.GradientTape() as tape: 348 if not callable(var_list): 349 tape.watch(var_list) 350 loss_value = loss() 351 if callable(var_list): 352 var_list = var_list() 353 var_list = nest.flatten(var_list) 354 with backend.name_scope(self._name + "/gradients"): 355 grads = tape.gradient(loss_value, var_list, grad_loss) 356 357 if hasattr(self, "clipnorm"): 358 grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] 359 if hasattr(self, "clipvalue"): 360 grads = [ 361 clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) 362 for g in grads 363 ] 364 365 grads_and_vars = list(zip(grads, var_list)) 366 self._assert_valid_dtypes([ 367 v for g, v in grads_and_vars 368 if g is not None and v.dtype != dtypes.resource 369 ]) 370 371 return grads_and_vars 372 373 def get_gradients(self, loss, params): 374 """Returns gradients of `loss` with respect to `params`. 375 376 Arguments: 377 loss: Loss tensor. 378 params: List of variables. 379 380 Returns: 381 List of gradient tensors. 382 383 Raises: 384 ValueError: In case any gradient cannot be computed (e.g. if gradient 385 function not implemented). 386 """ 387 params = nest.flatten(params) 388 with backend.get_graph().as_default(), backend.name_scope(self._name + 389 "/gradients"): 390 grads = gradients.gradients(loss, params) 391 for grad, param in zip(grads, params): 392 if grad is None: 393 raise ValueError("Variable {} has `None` for gradient. " 394 "Please make sure that all of your ops have a " 395 "gradient defined (i.e. are differentiable). " 396 "Common ops without gradient: " 397 "K.argmax, K.round, K.eval.".format(param)) 398 if hasattr(self, "clipnorm"): 399 grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] 400 if hasattr(self, "clipvalue"): 401 grads = [ 402 clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) 403 for g in grads 404 ] 405 return grads 406 407 def apply_gradients(self, grads_and_vars, name=None): 408 """Apply gradients to variables. 409 410 This is the second part of `minimize()`. It returns an `Operation` that 411 applies gradients. 412 413 Args: 414 grads_and_vars: List of (gradient, variable) pairs. 415 name: Optional name for the returned operation. Default to the name 416 passed to the `Optimizer` constructor. 417 418 Returns: 419 An `Operation` that applies the specified gradients. The `iterations` 420 will be automatically increased by 1. 421 422 Raises: 423 TypeError: If `grads_and_vars` is malformed. 424 ValueError: If none of the variables have gradients. 425 """ 426 grads_and_vars = _filter_grads(grads_and_vars) 427 var_list = [v for (_, v) in grads_and_vars] 428 429 with backend.name_scope(self._name): 430 # Create iteration if necessary. 431 with ops.init_scope(): 432 _ = self.iterations 433 self._create_hypers() 434 self._create_slots(var_list) 435 436 if not grads_and_vars: 437 # Distribution strategy does not support reducing an empty list of 438 # gradients 439 return control_flow_ops.no_op() 440 apply_state = self._prepare(var_list) 441 return distribute_ctx.get_replica_context().merge_call( 442 functools.partial(self._distributed_apply, apply_state=apply_state), 443 args=(grads_and_vars,), 444 kwargs={"name": name}) 445 446 def _distributed_apply(self, distribution, grads_and_vars, name, apply_state): 447 """`apply_gradients` using a `DistributionStrategy`.""" 448 reduced_grads = distribution.extended.batch_reduce_to( 449 ds_reduce_util.ReduceOp.SUM, grads_and_vars) 450 var_list = [v for _, v in grads_and_vars] 451 grads_and_vars = zip(reduced_grads, var_list) 452 453 def apply_grad_to_update_var(var, grad): 454 """Apply gradient to variable.""" 455 if isinstance(var, ops.Tensor): 456 raise NotImplementedError("Trying to update a Tensor ", var) 457 458 apply_kwargs = {} 459 if isinstance(grad, ops.IndexedSlices): 460 if var.constraint is not None: 461 raise RuntimeError( 462 "Cannot use a constraint function on a sparse variable.") 463 if "apply_state" in self._sparse_apply_args: 464 apply_kwargs["apply_state"] = apply_state 465 return self._resource_apply_sparse_duplicate_indices( 466 grad.values, var, grad.indices, **apply_kwargs) 467 468 if "apply_state" in self._dense_apply_args: 469 apply_kwargs["apply_state"] = apply_state 470 update_op = self._resource_apply_dense(grad, var, **apply_kwargs) 471 if var.constraint is not None: 472 with ops.control_dependencies([update_op]): 473 return var.assign(var.constraint(var)) 474 else: 475 return update_op 476 477 eagerly_outside_functions = ops.executing_eagerly_outside_functions() 478 update_ops = [] 479 with ops.name_scope(name or self._name, skip_on_eager=True): 480 for grad, var in grads_and_vars: 481 # Colocate the update with variables to avoid unnecessary communication 482 # delays. See b/136304694. 483 with distribution.extended.colocate_vars_with(var): 484 with ops.name_scope("update" if eagerly_outside_functions else 485 "update_" + var.op.name, skip_on_eager=True): 486 update_ops.extend(distribution.extended.update( 487 var, apply_grad_to_update_var, args=(grad,), group=False)) 488 489 any_symbolic = any(isinstance(i, ops.Operation) or 490 tf_utils.is_symbolic_tensor(i) for i in update_ops) 491 if not context.executing_eagerly() or any_symbolic: 492 # If the current context is graph mode or any of the update ops are 493 # symbolic then the step update should be carried out under a graph 494 # context. (eager updates execute immediately) 495 with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access 496 with ops.control_dependencies(update_ops): 497 return self._iterations.assign_add(1).op 498 499 return self._iterations.assign_add(1) 500 501 def get_updates(self, loss, params): 502 grads = self.get_gradients(loss, params) 503 grads_and_vars = list(zip(grads, params)) 504 self._assert_valid_dtypes([ 505 v for g, v in grads_and_vars 506 if g is not None and v.dtype != dtypes.resource 507 ]) 508 return [self.apply_gradients(grads_and_vars)] 509 510 def _set_hyper(self, name, value): 511 """set hyper `name` to value. value can be callable, tensor, numeric.""" 512 if isinstance(value, trackable.Trackable): 513 self._track_trackable(value, name, overwrite=True) 514 if name not in self._hyper: 515 self._hyper[name] = value 516 else: 517 prev_value = self._hyper[name] 518 if (callable(prev_value) 519 or isinstance(prev_value, 520 (ops.Tensor, int, float, 521 learning_rate_schedule.LearningRateSchedule)) 522 or isinstance(value, learning_rate_schedule.LearningRateSchedule)): 523 self._hyper[name] = value 524 else: 525 backend.set_value(self._hyper[name], value) 526 527 def _get_hyper(self, name, dtype=None): 528 if not self._hypers_created: 529 self._create_hypers() 530 value = self._hyper[name] 531 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 532 return value 533 if callable(value): 534 value = value() 535 if dtype: 536 return math_ops.cast(value, dtype) 537 else: 538 return value 539 540 def __getattribute__(self, name): 541 """Overridden to support hyperparameter access.""" 542 try: 543 return super(OptimizerV2, self).__getattribute__(name) 544 except AttributeError as e: 545 # Needed to avoid infinite recursion with __setattr__. 546 if name == "_hyper": 547 raise e 548 # Backwards compatibility with Keras optimizers. 549 if name == "lr": 550 name = "learning_rate" 551 if name in self._hyper: 552 return self._get_hyper(name) 553 raise e 554 555 def __setattr__(self, name, value): 556 """Override setattr to support dynamic hyperparameter setting.""" 557 # Backwards compatibility with Keras optimizers. 558 if name == "lr": 559 name = "learning_rate" 560 if hasattr(self, "_hyper") and name in self._hyper: 561 self._set_hyper(name, value) 562 else: 563 super(OptimizerV2, self).__setattr__(name, value) 564 565 def get_slot_names(self): 566 """A list of names for this optimizer's slots.""" 567 return self._slot_names 568 569 def add_slot(self, var, slot_name, initializer="zeros"): 570 """Add a new slot variable for `var`.""" 571 if slot_name not in self._slot_names: 572 self._slot_names.append(slot_name) 573 var_key = _var_key(var) 574 slot_dict = self._slots.setdefault(var_key, {}) 575 weight = slot_dict.get(slot_name, None) 576 if weight is None: 577 if isinstance(initializer, six.string_types) or callable(initializer): 578 initializer = initializers.get(initializer) 579 initial_value = functools.partial( 580 initializer, shape=var.shape, dtype=var.dtype) 581 else: 582 initial_value = initializer 583 strategy = distribute_ctx.get_strategy() 584 if not strategy.extended.variable_created_in_scope(var): 585 raise ValueError( 586 "Trying to create optimizer slot variable under the scope for " 587 "tf.distribute.Strategy ({}), which is different from the scope " 588 "used for the original variable ({}). Make sure the slot " 589 "variables are created under the same strategy scope. This may " 590 "happen if you're restoring from a checkpoint outside the scope" 591 .format(strategy, var)) 592 593 with strategy.extended.colocate_vars_with(var): 594 weight = tf_variables.Variable( 595 name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access 596 dtype=var.dtype, 597 trainable=False, 598 initial_value=initial_value) 599 backend.track_variable(weight) 600 slot_dict[slot_name] = weight 601 self._restore_slot_variable( 602 slot_name=slot_name, variable=var, 603 slot_variable=weight) 604 self._weights.append(weight) 605 return weight 606 607 def get_slot(self, var, slot_name): 608 var_key = _var_key(var) 609 slot_dict = self._slots[var_key] 610 return slot_dict[slot_name] 611 612 def _prepare(self, var_list): 613 keys = set() 614 for var in var_list: 615 var_devices = (getattr(var, "devices", None) or # Distributed 616 [var.device]) # Regular 617 var_dtype = var.dtype.base_dtype 618 for var_device in var_devices: 619 keys.add((var_device, var_dtype)) 620 621 apply_state = {} 622 for var_device, var_dtype in keys: 623 apply_state[(var_device, var_dtype)] = {} 624 with ops.device(var_device): 625 self._prepare_local(var_device, var_dtype, apply_state) 626 627 return apply_state 628 629 def _prepare_local(self, var_device, var_dtype, apply_state): 630 if "learning_rate" in self._hyper: 631 lr_t = array_ops.identity(self._decayed_lr(var_dtype)) 632 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t 633 634 def _fallback_apply_state(self, var_device, var_dtype): 635 """Compatibility for subclasses that don't pass apply_state through.""" 636 apply_state = {(var_device, var_dtype): {}} 637 self._prepare_local(var_device, var_dtype, apply_state) 638 return apply_state[(var_device, var_dtype)] 639 640 def _create_hypers(self): 641 if self._hypers_created: 642 return 643 # Iterate hyper values deterministically. 644 for name, value in sorted(self._hyper.items()): 645 if isinstance( 646 value, (ops.Tensor, tf_variables.Variable)) or callable(value): 647 continue 648 else: 649 self._hyper[name] = self.add_weight( 650 name, 651 shape=[], 652 trainable=False, 653 initializer=value, 654 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 655 self._hypers_created = True 656 657 @property 658 def iterations(self): 659 """Variable. The number of training steps this Optimizer has run.""" 660 if self._iterations is None: 661 self._iterations = self.add_weight( 662 "iter", 663 shape=[], 664 dtype=dtypes.int64, 665 trainable=False, 666 aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA) 667 self._weights.append(self._iterations) 668 return self._iterations 669 670 @iterations.setter 671 def iterations(self, variable): 672 if self._iterations is not None: 673 raise RuntimeError("Cannot set `iterations` to a new Variable after " 674 "the Optimizer weights have been created") 675 self._iterations = variable 676 self._weights.append(self._iterations) 677 678 def _decayed_lr(self, var_dtype): 679 """Get decayed learning rate as a Tensor with dtype=var_dtype.""" 680 lr_t = self._get_hyper("learning_rate", var_dtype) 681 if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): 682 local_step = math_ops.cast(self.iterations, var_dtype) 683 lr_t = math_ops.cast(lr_t(local_step), var_dtype) 684 if self._initial_decay > 0.: 685 local_step = math_ops.cast(self.iterations, var_dtype) 686 decay_t = self._get_hyper("decay", var_dtype) 687 lr_t = lr_t / (1. + decay_t * local_step) 688 return lr_t 689 690 @abc.abstractmethod 691 def get_config(self): 692 """Returns the config of the optimimizer. 693 694 An optimizer config is a Python dictionary (serializable) 695 containing the configuration of an optimizer. 696 The same optimizer can be reinstantiated later 697 (without any saved state) from this configuration. 698 699 Returns: 700 Python dictionary. 701 """ 702 config = {"name": self._name} 703 if hasattr(self, "clipnorm"): 704 config["clipnorm"] = self.clipnorm 705 if hasattr(self, "clipvalue"): 706 config["clipvalue"] = self.clipvalue 707 return config 708 709 @classmethod 710 def from_config(cls, config, custom_objects=None): 711 """Creates an optimizer from its config. 712 713 This method is the reverse of `get_config`, 714 capable of instantiating the same optimizer from the config 715 dictionary. 716 717 Arguments: 718 config: A Python dictionary, typically the output of get_config. 719 custom_objects: A Python dictionary mapping names to additional Python 720 objects used to create this optimizer, such as a function used for a 721 hyperparameter. 722 723 Returns: 724 An optimizer instance. 725 """ 726 if "lr" in config: 727 config["learning_rate"] = config.pop("lr") 728 if "learning_rate" in config: 729 if isinstance(config["learning_rate"], dict): 730 config["learning_rate"] = learning_rate_schedule.deserialize( 731 config["learning_rate"], custom_objects=custom_objects) 732 return cls(**config) 733 734 def _serialize_hyperparameter(self, hyperparameter_name): 735 """Serialize a hyperparameter that can be a float, callable, or Tensor.""" 736 value = self._hyper[hyperparameter_name] 737 if isinstance(value, learning_rate_schedule.LearningRateSchedule): 738 return learning_rate_schedule.serialize(value) 739 if callable(value): 740 return value() 741 if tensor_util.is_tensor(value): 742 return backend.get_value(value) 743 return value 744 745 def variables(self): 746 """Returns variables of this Optimizer based on the order created.""" 747 return self._weights 748 749 @property 750 def weights(self): 751 """Returns variables of this Optimizer based on the order created.""" 752 return self._weights 753 754 def get_weights(self): 755 """Returns the current weights of the optimizer. 756 757 The weights of an optimizer are its state (ie, variables). 758 This function returns the weight values associated with this 759 optimizer as a list of Numpy arrays. The first value is always the 760 iterations count of the optimizer, followed by the optimizer's state 761 variables in the order they were created. The returned list can in turn 762 be used to load state into similarly parameterized optimizers. 763 764 For example, the RMSprop optimizer for this simple model returns a list of 765 three values-- the iteration count, followed by the root-mean-square value 766 of the kernel and bias of the single Dense layer: 767 768 >>> opt = tf.keras.optimizers.RMSprop() 769 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 770 >>> m.compile(opt, loss='mse') 771 >>> data = np.arange(100).reshape(5, 20) 772 >>> labels = np.zeros(5) 773 >>> print('Training'); results = m.fit(data, labels) 774 Training ... 775 >>> len(opt.get_weights()) 776 3 777 778 Returns: 779 Weights values as a list of numpy arrays. 780 """ 781 params = self.weights 782 return backend.batch_get_value(params) 783 784 # TODO(tanzheny): Maybe share this logic with base_layer. 785 def set_weights(self, weights): 786 """Set the weights of the optimizer. 787 788 The weights of an optimizer are its state (ie, variables). 789 This function takes the weight values associated with this 790 optimizer as a list of Numpy arrays. The first value is always the 791 iterations count of the optimizer, followed by the optimizer's state 792 variables in the order they are created. The passed values are used to set 793 the new state of the optimizer. 794 795 For example, the RMSprop optimizer for this simple model takes a list of 796 three values-- the iteration count, followed by the root-mean-square value 797 of the kernel and bias of the single Dense layer: 798 799 >>> opt = tf.keras.optimizers.RMSprop() 800 >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 801 >>> m.compile(opt, loss='mse') 802 >>> data = np.arange(100).reshape(5, 20) 803 >>> labels = np.zeros(5) 804 >>> print('Training'); results = m.fit(data, labels) 805 Training ... 806 >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])] 807 >>> opt.set_weights(new_weights) 808 >>> opt.iterations 809 <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10> 810 811 Arguments: 812 weights: weight values as a list of numpy arrays. 813 """ 814 params = self.weights 815 if len(params) != len(weights): 816 raise ValueError( 817 "You called `set_weights(weights)` on optimizer " + self._name + 818 " with a weight list of length " + str(len(weights)) + 819 ", but the optimizer was expecting " + str(len(params)) + 820 " weights. Provided weights: " + str(weights)[:50] + "...") 821 if not params: 822 return 823 weight_value_tuples = [] 824 param_values = backend.batch_get_value(params) 825 for pv, p, w in zip(param_values, params, weights): 826 if pv.shape != w.shape: 827 raise ValueError("Optimizer weight shape " + str(pv.shape) + 828 " not compatible with " 829 "provided weight shape " + str(w.shape)) 830 weight_value_tuples.append((p, w)) 831 backend.batch_set_value(weight_value_tuples) 832 833 def add_weight(self, 834 name, 835 shape, 836 dtype=None, 837 initializer="zeros", 838 trainable=None, 839 synchronization=tf_variables.VariableSynchronization.AUTO, 840 aggregation=tf_variables.VariableAggregation.NONE): 841 842 if dtype is None: 843 dtype = dtypes.float32 844 if isinstance(initializer, six.string_types) or callable(initializer): 845 initializer = initializers.get(initializer) 846 847 if synchronization == tf_variables.VariableSynchronization.ON_READ: 848 if trainable: 849 raise ValueError( 850 "Synchronization value can be set to " 851 "VariableSynchronization.ON_READ only for non-trainable variables. " 852 "You have specified trainable=True and " 853 "synchronization=VariableSynchronization.ON_READ.") 854 else: 855 # Set trainable to be false when variable is to be synced on read. 856 trainable = False 857 elif trainable is None: 858 trainable = True 859 860 variable = self._add_variable_with_custom_getter( 861 name=name, 862 shape=shape, 863 getter=base_layer_utils.make_variable, 864 overwrite=True, 865 initializer=initializer, 866 dtype=dtype, 867 trainable=trainable, 868 use_resource=True, 869 synchronization=synchronization, 870 aggregation=aggregation) 871 backend.track_variable(variable) 872 873 return variable 874 875 def _init_set_name(self, name, zero_based=True): 876 if not name: 877 self._name = backend.unique_object_name( 878 generic_utils.to_snake_case(self.__class__.__name__), 879 zero_based=zero_based) 880 else: 881 self._name = name 882 883 def _assert_valid_dtypes(self, tensors): 884 """Asserts tensors are all valid types (see `_valid_dtypes`). 885 886 Args: 887 tensors: Tensors to check. 888 889 Raises: 890 ValueError: If any tensor is not a valid type. 891 """ 892 valid_dtypes = self._valid_dtypes() 893 for t in tensors: 894 dtype = t.dtype.base_dtype 895 if dtype not in valid_dtypes: 896 raise ValueError("Invalid type %r for %s, expected: %s." % 897 (dtype, t.name, [v for v in valid_dtypes])) 898 899 def _valid_dtypes(self): 900 """Valid types for loss, variables and gradients. 901 902 Subclasses should override to allow other float types. 903 904 Returns: 905 Valid types for loss, variables and gradients. 906 """ 907 return set([ 908 dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, 909 dtypes.complex64, dtypes.complex128 910 ]) 911 912 def _call_if_callable(self, param): 913 """Call the function if param is callable.""" 914 return param() if callable(param) else param 915 916 def _resource_apply_dense(self, grad, handle, apply_state): 917 """Add ops to apply dense gradients to the variable `handle`. 918 919 Args: 920 grad: a `Tensor` representing the gradient. 921 handle: a `Tensor` of dtype `resource` which points to the variable to be 922 updated. 923 apply_state: A dict which is used across multiple apply calls. 924 925 Returns: 926 An `Operation` which updates the value of the variable. 927 """ 928 raise NotImplementedError() 929 930 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices, 931 **kwargs): 932 """Add ops to apply sparse gradients to `handle`, with repeated indices. 933 934 Optimizers which override this method must deal with repeated indices. See 935 the docstring of `_apply_sparse_duplicate_indices` for details. By default 936 the correct behavior, to sum non-unique indices and their associated 937 gradients, is enforced by first pre-processing `grad` and `indices` and 938 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 939 with duplicate indices may instead override this method to avoid the 940 overhead of summing. 941 942 Args: 943 grad: a `Tensor` representing the gradient for the affected indices. 944 handle: a `Tensor` of dtype `resource` which points to the variable to be 945 updated. 946 indices: a `Tensor` of integral type representing the indices for which 947 the gradient is nonzero. Indices may be repeated. 948 **kwargs: May optionally contain `apply_state` 949 950 Returns: 951 An `Operation` which updates the value of the variable. 952 """ 953 summed_grad, unique_indices = _deduplicate_indexed_slices( 954 values=grad, indices=indices) 955 return self._resource_apply_sparse(summed_grad, handle, unique_indices, 956 **kwargs) 957 958 def _resource_apply_sparse(self, grad, handle, indices, apply_state): 959 """Add ops to apply sparse gradients to the variable `handle`. 960 961 Similar to `_apply_sparse`, the `indices` argument to this method has been 962 de-duplicated. Optimizers which deal correctly with non-unique indices may 963 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 964 overhead. 965 966 Args: 967 grad: a `Tensor` representing the gradient for the affected indices. 968 handle: a `Tensor` of dtype `resource` which points to the variable to be 969 updated. 970 indices: a `Tensor` of integral type representing the indices for which 971 the gradient is nonzero. Indices are unique. 972 apply_state: A dict which is used across multiple apply calls. 973 974 Returns: 975 An `Operation` which updates the value of the variable. 976 """ 977 raise NotImplementedError() 978 979 def _resource_scatter_add(self, x, i, v): 980 with ops.control_dependencies( 981 [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): 982 return x.value() 983 984 def _resource_scatter_update(self, x, i, v): 985 with ops.control_dependencies( 986 [resource_variable_ops.resource_scatter_update(x.handle, i, v)]): 987 return x.value() 988 989 @property 990 @tracking.cached_per_instance 991 def _dense_apply_args(self): 992 return tf_inspect.getfullargspec(self._resource_apply_dense).args 993 994 @property 995 @tracking.cached_per_instance 996 def _sparse_apply_args(self): 997 return tf_inspect.getfullargspec(self._resource_apply_sparse).args 998 999 # --------------- 1000 # For implementing the trackable interface 1001 # --------------- 1002 1003 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1004 """Restore a newly created slot variable's value.""" 1005 variable_key = _var_key(variable) 1006 deferred_restorations = self._deferred_slot_restorations.get( 1007 slot_name, {}).pop(variable_key, []) 1008 # Iterate over restores, highest restore UID first to minimize the number 1009 # of assignments. 1010 deferred_restorations.sort(key=lambda position: position.restore_uid, 1011 reverse=True) 1012 for checkpoint_position in deferred_restorations: 1013 checkpoint_position.restore(slot_variable) 1014 1015 def _create_or_restore_slot_variable( 1016 self, slot_variable_position, slot_name, variable): 1017 """Restore a slot variable's value, possibly creating it. 1018 1019 Called when a variable which has an associated slot variable is created or 1020 restored. When executing eagerly, we create the slot variable with a 1021 restoring initializer. 1022 1023 No new variables are created when graph building. Instead, 1024 _restore_slot_variable catches these after normal creation and adds restore 1025 ops to the graph. This method is nonetheless important when graph building 1026 for the case when a slot variable has already been created but `variable` 1027 has just been added to a dependency graph (causing us to realize that the 1028 slot variable needs to be restored). 1029 1030 Args: 1031 slot_variable_position: A `trackable._CheckpointPosition` object 1032 indicating the slot variable `Trackable` object to be restored. 1033 slot_name: The name of this `Optimizer`'s slot to restore into. 1034 variable: The variable object this slot is being created for. 1035 """ 1036 variable_key = _var_key(variable) 1037 slot_dict = self._slots.get(variable_key, {}) 1038 slot_variable = slot_dict.get(slot_name, None) 1039 if (slot_variable is None and context.executing_eagerly() and 1040 slot_variable_position.is_simple_variable() 1041 # Defer slot variable creation if there is an active variable creator 1042 # scope. Generally we'd like to eagerly create/restore slot variables 1043 # when possible, but this may mean that scopes intended to catch 1044 # `variable` also catch its eagerly created slot variable 1045 # unintentionally (specifically make_template would add a dependency on 1046 # a slot variable if not for this case). Deferring is mostly harmless 1047 # (aside from double initialization), and makes variable creator scopes 1048 # behave the same way they do when graph building. 1049 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access 1050 initializer = trackable.CheckpointInitialValue( 1051 checkpoint_position=slot_variable_position) 1052 slot_variable = self.add_slot( 1053 var=variable, 1054 initializer=initializer, 1055 slot_name=slot_name) 1056 # Slot variables are not owned by any one object (because we don't want to 1057 # save the slot variable if the optimizer is saved without the non-slot 1058 # variable, or if the non-slot variable is saved without the optimizer; 1059 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1060 # variable, variable)). So we don't _track_ slot variables anywhere, and 1061 # instead special-case this dependency and otherwise pretend it's a normal 1062 # graph. 1063 if slot_variable is not None: 1064 # If we've either made this slot variable, or if we've pulled out an 1065 # existing slot variable, we should restore it. 1066 slot_variable_position.restore(slot_variable) 1067 else: 1068 # We didn't make the slot variable. Defer restoring until it gets created 1069 # normally. We keep a list rather than the one with the highest restore 1070 # UID in case slot variables have their own dependencies, in which case 1071 # those could differ between restores. 1072 self._deferred_slot_restorations.setdefault( 1073 slot_name, {}).setdefault(variable_key, []).append( 1074 slot_variable_position) 1075 1076 1077def _filter_grads(grads_and_vars): 1078 """Filter out iterable with grad equal to None.""" 1079 grads_and_vars = tuple(grads_and_vars) 1080 if not grads_and_vars: 1081 return grads_and_vars 1082 filtered = [] 1083 vars_with_empty_grads = [] 1084 for grad, var in grads_and_vars: 1085 if grad is None: 1086 vars_with_empty_grads.append(var) 1087 else: 1088 filtered.append((grad, var)) 1089 filtered = tuple(filtered) 1090 if not filtered: 1091 raise ValueError("No gradients provided for any variable: %s." % 1092 ([v.name for _, v in grads_and_vars],)) 1093 if vars_with_empty_grads: 1094 logging.warning( 1095 ("Gradients do not exist for variables %s when minimizing the loss."), 1096 ([v.name for v in vars_with_empty_grads])) 1097 return filtered 1098 1099 1100def _var_key(var): 1101 """Key for representing a primary variable, for looking up slots. 1102 1103 In graph mode the name is derived from the var shared name. 1104 In eager mode the name is derived from the var unique id. 1105 If distribution strategy exists, get the primary variable first. 1106 1107 Args: 1108 var: the variable. 1109 1110 Returns: 1111 the unique name of the variable. 1112 """ 1113 1114 # pylint: disable=protected-access 1115 # Get the distributed variable if it exists. 1116 if hasattr(var, "_distributed_container"): 1117 var = var._distributed_container() 1118 if var._in_graph_mode: 1119 return var._shared_name 1120 return var._unique_id 1121 1122 1123def _get_slot_key_from_var(var, slot_name): 1124 """Get the slot key for the variable: var_name/slot_name.""" 1125 1126 name = _var_key(var) 1127 return name + "/" + slot_name 1128 1129 1130class RestoredOptimizer(OptimizerV2): 1131 """A non-functional Optimizer implementation for checkpoint compatibility. 1132 1133 Holds slot variables and hyperparameters when an optimizer is restored from a 1134 SavedModel. These variables may be referenced in functions along with ops 1135 created by the original optimizer, but currently we do not support using the 1136 optimizer object iself (e.g. through `apply_gradients`). 1137 """ 1138 # TODO(allenl): Make the restored optimizer functional by tracing its apply 1139 # methods. 1140 1141 def __init__(self): 1142 super(RestoredOptimizer, self).__init__("RestoredOptimizer") 1143 self._hypers_created = True 1144 1145 def get_config(self): 1146 # TODO(allenl): Save and restore the Optimizer's config 1147 raise NotImplementedError( 1148 "Restoring functional Optimzers from SavedModels is not currently " 1149 "supported. Please file a feature request if this limitation bothers " 1150 "you.") 1151 1152revived_types.register_revived_type( 1153 "optimizer", 1154 lambda obj: isinstance(obj, OptimizerV2), 1155 versions=[revived_types.VersionedTypeRegistration( 1156 object_factory=lambda proto: RestoredOptimizer(), 1157 version=1, 1158 min_producer_version=1, 1159 min_consumer_version=1, 1160 setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access 1161 )]) 1162