1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Base class for optimizers.""" 17# pylint: disable=g-bad-name 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24 25import six 26 27from tensorflow.python.distribute import distribute_lib 28from tensorflow.python.distribute import distribute_utils 29from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 30from tensorflow.python.distribute import reduce_util as ds_reduce_util 31from tensorflow.python.eager import backprop 32from tensorflow.python.eager import context 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gradients 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables 43from tensorflow.python.training import slot_creator 44from tensorflow.python.training.tracking import base as trackable 45from tensorflow.python.util import nest 46from tensorflow.python.util.tf_export import tf_export 47 48 49def get_filtered_grad_fn(grad_fn): 50 # `distributed_context.join()` requires that its arguments are parallel 51 # across threads, and in particular that `grads_and_vars` has the same 52 # variables in the same order. 53 54 # When computing gradients in eager mode with multiple threads, you 55 # can get extra variables with a gradient of `None`. This happens when 56 # those variables are accessed in another thread during the gradient 57 # computation. To get a consistent set of variables, we filter out 58 # those with `None` gradients. 59 def filtered_grad_fn(*args, **kwargs): 60 return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None] 61 62 return filtered_grad_fn 63 64 65def _deduplicate_indexed_slices(values, indices): 66 """Sums `values` associated with any non-unique `indices`. 67 68 Args: 69 values: A `Tensor` with rank >= 1. 70 indices: A one-dimensional integer `Tensor`, indexing into the first 71 dimension of `values` (as in an IndexedSlices object). 72 Returns: 73 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 74 de-duplicated version of `indices` and `summed_values` contains the sum of 75 `values` slices associated with each unique index. 76 """ 77 unique_indices, new_index_positions = array_ops.unique(indices) 78 summed_values = math_ops.unsorted_segment_sum( 79 values, new_index_positions, 80 array_ops.shape(unique_indices)[0]) 81 return (summed_values, unique_indices) 82 83 84def _var_key(var): 85 """Returns slot key for `var`.""" 86 # pylint: disable=protected-access 87 if hasattr(var, "_distributed_container"): 88 var = var._distributed_container() 89 if (distribute_utils.is_distributed_variable(var) and 90 not ops.executing_eagerly_outside_functions()): 91 return (var.graph, var._shared_name) 92 if hasattr(var, "op"): 93 return (var.op.graph, var.op.name) 94 return var._unique_id 95 # pylint: enable=protected-access 96 97 98@six.add_metaclass(abc.ABCMeta) 99class _OptimizableVariable(object): 100 """Interface for abstracting over variables in the optimizers.""" 101 102 @abc.abstractmethod 103 def target(self): 104 """Returns the optimization target for this variable.""" 105 raise NotImplementedError("Calling an abstract method.") 106 107 @abc.abstractmethod 108 def update_op(self, optimizer, g): 109 """Returns the update ops for updating the variable.""" 110 raise NotImplementedError("Calling an abstract method.") 111 112 113class _RefVariableProcessor(_OptimizableVariable): 114 """Processor for Variable.""" 115 116 def __init__(self, v): 117 self._v = v 118 119 def __str__(self): 120 return "<_RefVariableProcessor(%s)>" % self._v 121 122 def target(self): 123 return self._v._ref() # pylint: disable=protected-access 124 125 def update_op(self, optimizer, g): 126 if isinstance(g, ops.Tensor): 127 update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access 128 if self._v.constraint is not None: 129 with ops.control_dependencies([update_op]): 130 return self._v.assign(self._v.constraint(self._v)) 131 else: 132 return update_op 133 else: 134 assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a " 135 "tensor nor IndexedSlices.") 136 if self._v.constraint is not None: 137 raise RuntimeError( 138 "Cannot use a constraint function on a sparse variable.") 139 # pylint: disable=protected-access 140 return optimizer._apply_sparse_duplicate_indices(g, self._v) 141 142 143class _DenseReadResourceVariableProcessor(_OptimizableVariable): 144 """Processor for dense ResourceVariables.""" 145 146 def __init__(self, v): 147 self._v = v 148 149 def target(self): 150 return self._v 151 152 def update_op(self, optimizer, g): 153 # pylint: disable=protected-access 154 update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0]) 155 if self._v.constraint is not None: 156 with ops.control_dependencies([update_op]): 157 return self._v.assign(self._v.constraint(self._v)) 158 else: 159 return update_op 160 161 162class _DenseResourceVariableProcessor(_OptimizableVariable): 163 """Processor for dense ResourceVariables.""" 164 165 def __init__(self, v): 166 self._v = v 167 168 def target(self): 169 return self._v 170 171 def update_op(self, optimizer, g): 172 # pylint: disable=protected-access 173 if isinstance(g, ops.IndexedSlices): 174 if self._v.constraint is not None: 175 raise RuntimeError( 176 "Cannot use a constraint function on a sparse variable.") 177 return optimizer._resource_apply_sparse_duplicate_indices( 178 g.values, self._v, g.indices) 179 update_op = optimizer._resource_apply_dense(g, self._v) 180 if self._v.constraint is not None: 181 with ops.control_dependencies([update_op]): 182 return self._v.assign(self._v.constraint(self._v)) 183 else: 184 return update_op 185 186 187class _TensorProcessor(_OptimizableVariable): 188 """Processor for ordinary Tensors. 189 190 Even though a Tensor can't really be updated, sometimes it is useful to 191 compute the gradients with respect to a Tensor using the optimizer. Updating 192 the Tensor is, of course, unsupported. 193 """ 194 195 def __init__(self, v): 196 self._v = v 197 198 def target(self): 199 return self._v 200 201 def update_op(self, optimizer, g): 202 raise NotImplementedError("Trying to update a Tensor ", self._v) 203 204 205def _get_processor(v): 206 """The processor of v.""" 207 if context.executing_eagerly(): 208 if isinstance(v, ops.Tensor): 209 return _TensorProcessor(v) 210 else: 211 return _DenseResourceVariableProcessor(v) 212 if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access 213 # True if and only if `v` was initialized eagerly. 214 return _DenseResourceVariableProcessor(v) 215 if v.op.type == "VarHandleOp": 216 return _DenseResourceVariableProcessor(v) 217 if isinstance(v, variables.Variable): 218 return _RefVariableProcessor(v) 219 if isinstance(v, ops.Tensor): 220 return _TensorProcessor(v) 221 raise NotImplementedError("Trying to optimize unsupported type ", v) 222 223 224@tf_export(v1=["train.Optimizer"]) 225class Optimizer( 226 # Optimizers inherit from Trackable rather than AutoTrackable 227 # since they do most of their dependency management themselves (slot 228 # variables are special-cased, and non-slot variables are keyed to graphs). 229 trackable.Trackable): 230 """Base class for optimizers. 231 232 This class defines the API to add Ops to train a model. You never use this 233 class directly, but instead instantiate one of its subclasses such as 234 `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. 235 236 ### Usage 237 238 ```python 239 # Create an optimizer with the desired parameters. 240 opt = GradientDescentOptimizer(learning_rate=0.1) 241 # Add Ops to the graph to minimize a cost by updating a list of variables. 242 # "cost" is a Tensor, and the list of variables contains tf.Variable 243 # objects. 244 opt_op = opt.minimize(cost, var_list=<list of variables>) 245 ``` 246 247 In the training program you will just have to run the returned Op. 248 249 ```python 250 # Execute opt_op to do one step of training: 251 opt_op.run() 252 ``` 253 254 ### Processing gradients before applying them. 255 256 Calling `minimize()` takes care of both computing the gradients and 257 applying them to the variables. If you want to process the gradients 258 before applying them you can instead use the optimizer in three steps: 259 260 1. Compute the gradients with `compute_gradients()`. 261 2. Process the gradients as you wish. 262 3. Apply the processed gradients with `apply_gradients()`. 263 264 Example: 265 266 ```python 267 # Create an optimizer. 268 opt = GradientDescentOptimizer(learning_rate=0.1) 269 270 # Compute the gradients for a list of variables. 271 grads_and_vars = opt.compute_gradients(loss, <list of variables>) 272 273 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you 274 # need to the 'gradient' part, for example cap them, etc. 275 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] 276 277 # Ask the optimizer to apply the capped gradients. 278 opt.apply_gradients(capped_grads_and_vars) 279 ``` 280 281 ### Gating Gradients 282 283 Both `minimize()` and `compute_gradients()` accept a `gate_gradients` 284 argument that controls the degree of parallelism during the application of 285 the gradients. 286 287 The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. 288 289 <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides 290 the maximum parallelism in execution, at the cost of some non-reproducibility 291 in the results. For example the two gradients of `matmul` depend on the input 292 values: With `GATE_NONE` one of the gradients could be applied to one of the 293 inputs _before_ the other gradient is computed resulting in non-reproducible 294 results. 295 296 <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before 297 they are used. This prevents race conditions for Ops that generate gradients 298 for multiple inputs where the gradients depend on the inputs. 299 300 <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed 301 before any one of them is used. This provides the least parallelism but can 302 be useful if you want to process all gradients before applying any of them. 303 304 ### Slots 305 306 Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` 307 allocate and manage additional variables associated with the variables to 308 train. These are called <i>Slots</i>. Slots have names and you can ask the 309 optimizer for the names of the slots that it uses. Once you have a slot name 310 you can ask the optimizer for the variable it created to hold the slot value. 311 312 This can be useful if you want to log debug a training algorithm, report stats 313 about the slots, etc. 314 """ 315 316 # Values for gate_gradients. 317 GATE_NONE = 0 318 GATE_OP = 1 319 GATE_GRAPH = 2 320 321 def __init__(self, use_locking, name): 322 """Create a new Optimizer. 323 324 This must be called by the constructors of subclasses. 325 326 Args: 327 use_locking: Bool. If True apply use locks to prevent concurrent updates 328 to variables. 329 name: A non-empty string. The name to use for accumulators created 330 for the optimizer. 331 332 Raises: 333 ValueError: If name is malformed. 334 """ 335 if not name: 336 raise ValueError("Must specify the optimizer name") 337 self._use_locking = use_locking 338 self._name = name 339 # Dictionary of slots. 340 # {slot_name : 341 # {_var_key(variable_to_train): slot_for_the_variable, ... }, 342 # ... } 343 self._slots = {} 344 self._non_slot_dict = {} 345 # For implementing Trackable. Stores information about how to restore 346 # slot variables which have not yet been created 347 # (trackable._CheckpointPosition objects). 348 # {slot_name : 349 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 350 # ... } 351 self._deferred_slot_restorations = {} 352 353 # TODO(isaprykin): When using a DistributionStrategy, and when an 354 # optimizer is created in each replica, it might be dangerous to 355 # rely on some Optimizer methods. When such methods are called on a 356 # per-replica optimizer, an exception needs to be thrown. We do 357 # allow creation per-replica optimizers however, because the 358 # compute_gradients()->apply_gradients() sequence is safe. 359 360 def get_name(self): 361 return self._name 362 363 def minimize(self, loss, global_step=None, var_list=None, 364 gate_gradients=GATE_OP, aggregation_method=None, 365 colocate_gradients_with_ops=False, name=None, 366 grad_loss=None): 367 """Add operations to minimize `loss` by updating `var_list`. 368 369 This method simply combines calls `compute_gradients()` and 370 `apply_gradients()`. If you want to process the gradient before applying 371 them call `compute_gradients()` and `apply_gradients()` explicitly instead 372 of using this function. 373 374 Args: 375 loss: A `Tensor` containing the value to minimize. 376 global_step: Optional `Variable` to increment by one after the 377 variables have been updated. 378 var_list: Optional list or tuple of `Variable` objects to update to 379 minimize `loss`. Defaults to the list of variables collected in 380 the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. 381 gate_gradients: How to gate the computation of gradients. Can be 382 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 383 aggregation_method: Specifies the method used to combine gradient terms. 384 Valid values are defined in the class `AggregationMethod`. 385 colocate_gradients_with_ops: If True, try colocating gradients with 386 the corresponding op. 387 name: Optional name for the returned operation. 388 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 389 390 Returns: 391 An Operation that updates the variables in `var_list`. If `global_step` 392 was not `None`, that operation also increments `global_step`. 393 394 Raises: 395 ValueError: If some of the variables are not `Variable` objects. 396 397 @compatibility(eager) 398 When eager execution is enabled, `loss` should be a Python function that 399 takes no arguments and computes the value to be minimized. Minimization (and 400 gradient computation) is done with respect to the elements of `var_list` if 401 not None, else with respect to any trainable variables created during the 402 execution of the `loss` function. `gate_gradients`, `aggregation_method`, 403 `colocate_gradients_with_ops` and `grad_loss` are ignored when eager 404 execution is enabled. 405 @end_compatibility 406 """ 407 grads_and_vars = self.compute_gradients( 408 loss, var_list=var_list, gate_gradients=gate_gradients, 409 aggregation_method=aggregation_method, 410 colocate_gradients_with_ops=colocate_gradients_with_ops, 411 grad_loss=grad_loss) 412 413 vars_with_grad = [v for g, v in grads_and_vars if g is not None] 414 if not vars_with_grad: 415 raise ValueError( 416 "No gradients provided for any variable, check your graph for ops" 417 " that do not support gradients, between variables %s and loss %s." % 418 ([str(v) for _, v in grads_and_vars], loss)) 419 420 return self.apply_gradients(grads_and_vars, global_step=global_step, 421 name=name) 422 423 def compute_gradients(self, loss, var_list=None, 424 gate_gradients=GATE_OP, 425 aggregation_method=None, 426 colocate_gradients_with_ops=False, 427 grad_loss=None): 428 """Compute gradients of `loss` for the variables in `var_list`. 429 430 This is the first part of `minimize()`. It returns a list 431 of (gradient, variable) pairs where "gradient" is the gradient 432 for "variable". Note that "gradient" can be a `Tensor`, an 433 `IndexedSlices`, or `None` if there is no gradient for the 434 given variable. 435 436 Args: 437 loss: A Tensor containing the value to minimize or a callable taking 438 no arguments which returns the value to minimize. When eager execution 439 is enabled it must be a callable. 440 var_list: Optional list or tuple of `tf.Variable` to update to minimize 441 `loss`. Defaults to the list of variables collected in the graph 442 under the key `GraphKeys.TRAINABLE_VARIABLES`. 443 gate_gradients: How to gate the computation of gradients. Can be 444 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 445 aggregation_method: Specifies the method used to combine gradient terms. 446 Valid values are defined in the class `AggregationMethod`. 447 colocate_gradients_with_ops: If True, try colocating gradients with 448 the corresponding op. 449 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 450 451 Returns: 452 A list of (gradient, variable) pairs. Variable is always present, but 453 gradient can be `None`. 454 455 Raises: 456 TypeError: If `var_list` contains anything else than `Variable` objects. 457 ValueError: If some arguments are invalid. 458 RuntimeError: If called with eager execution enabled and `loss` is 459 not callable. 460 461 @compatibility(eager) 462 When eager execution is enabled, `gate_gradients`, `aggregation_method`, 463 and `colocate_gradients_with_ops` are ignored. 464 @end_compatibility 465 """ 466 if callable(loss): 467 with backprop.GradientTape() as tape: 468 if var_list is not None: 469 tape.watch(var_list) 470 loss_value = loss() 471 472 # Scale loss if using a "mean" loss reduction and multiple replicas. 473 # Have to be careful to call distribute_lib.get_loss_reduction() 474 # *after* loss() is evaluated, so we know what loss reduction it uses. 475 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 476 loss_value = self._scale_loss(loss_value) 477 478 if var_list is None: 479 var_list = tape.watched_variables() 480 # TODO(jhseu): Figure out why GradientTape's gradients don't require loss 481 # to be executed. 482 with ops.control_dependencies([loss_value]): 483 grads = tape.gradient(loss_value, var_list, grad_loss) 484 return list(zip(grads, var_list)) 485 486 # Non-callable/Tensor loss case 487 if context.executing_eagerly(): 488 raise RuntimeError( 489 "`loss` passed to Optimizer.compute_gradients should " 490 "be a function when eager execution is enabled.") 491 492 # Scale loss if using a "mean" loss reduction and multiple replicas. 493 loss = self._scale_loss(loss) 494 495 if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, 496 Optimizer.GATE_GRAPH]: 497 raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " 498 "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % 499 gate_gradients) 500 self._assert_valid_dtypes([loss]) 501 if grad_loss is not None: 502 self._assert_valid_dtypes([grad_loss]) 503 if var_list is None: 504 var_list = ( 505 variables.trainable_variables() + 506 ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 507 else: 508 var_list = nest.flatten(var_list) 509 # pylint: disable=protected-access 510 var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) 511 # pylint: enable=protected-access 512 processors = [_get_processor(v) for v in var_list] 513 if not var_list: 514 raise ValueError("No variables to optimize.") 515 var_refs = [p.target() for p in processors] 516 grads = gradients.gradients( 517 loss, var_refs, grad_ys=grad_loss, 518 gate_gradients=(gate_gradients == Optimizer.GATE_OP), 519 aggregation_method=aggregation_method, 520 colocate_gradients_with_ops=colocate_gradients_with_ops) 521 if gate_gradients == Optimizer.GATE_GRAPH: 522 grads = control_flow_ops.tuple(grads) 523 grads_and_vars = list(zip(grads, var_list)) 524 self._assert_valid_dtypes( 525 [v for g, v in grads_and_vars 526 if g is not None and v.dtype != dtypes.resource]) 527 return grads_and_vars 528 529 @staticmethod 530 def _scale_loss(loss_value): 531 ops.get_default_graph()._is_loss_scaled_by_optimizer = False # pylint: disable=protected-access 532 if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: 533 num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync 534 if num_replicas > 1: 535 loss_value *= (1. / num_replicas) 536 ops.get_default_graph()._is_loss_scaled_by_optimizer = True # pylint: disable=protected-access 537 return loss_value 538 539 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 540 """Apply gradients to variables. 541 542 This is the second part of `minimize()`. It returns an `Operation` that 543 applies gradients. 544 545 Args: 546 grads_and_vars: List of (gradient, variable) pairs as returned by 547 `compute_gradients()`. 548 global_step: Optional `Variable` to increment by one after the 549 variables have been updated. 550 name: Optional name for the returned operation. Default to the 551 name passed to the `Optimizer` constructor. 552 553 Returns: 554 An `Operation` that applies the specified gradients. If `global_step` 555 was not None, that operation also increments `global_step`. 556 557 Raises: 558 TypeError: If `grads_and_vars` is malformed. 559 ValueError: If none of the variables have gradients. 560 RuntimeError: If you should use `_distributed_apply()` instead. 561 """ 562 # This is a default implementation of apply_gradients() that can be shared 563 # by most optimizers. It relies on the subclass implementing the following 564 # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). 565 566 # TODO(isaprykin): Get rid of `has_strategy()` check by 567 # always calling _distributed_apply(), using the default distribution 568 # as needed. 569 if distribute_ctx.has_strategy(): 570 # Handle DistributionStrategy case. 571 if distribute_ctx.in_cross_replica_context(): 572 raise RuntimeError("Use `_distributed_apply()` instead of " 573 "`apply_gradients()` in a cross-replica context.") 574 575 grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() 576 return distribute_ctx.get_replica_context().merge_call( 577 self._distributed_apply, args=(grads_and_vars, global_step, name)) 578 579 # No DistributionStrategy case. 580 grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. 581 if not grads_and_vars: 582 raise ValueError("No variables provided.") 583 converted_grads_and_vars = [] 584 for g, v in grads_and_vars: 585 if g is not None: 586 try: 587 # Convert the grad to Tensor or IndexedSlices if necessary. 588 g = ops.convert_to_tensor_or_indexed_slices(g) 589 except TypeError: 590 raise TypeError( 591 "Gradient must be convertible to a Tensor" 592 " or IndexedSlices, or None: %s" % g) 593 if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): 594 raise TypeError( 595 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 596 p = _get_processor(v) 597 converted_grads_and_vars.append((g, v, p)) 598 599 converted_grads_and_vars = tuple(converted_grads_and_vars) 600 var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] 601 if not var_list: 602 raise ValueError("No gradients provided for any variable: %s." % 603 ([str(v) for _, v, _ in converted_grads_and_vars],)) 604 with ops.init_scope(): 605 self._create_slots(var_list) 606 update_ops = [] 607 with ops.name_scope(name, self._name, skip_on_eager=False) as name: 608 self._prepare() 609 for grad, var, processor in converted_grads_and_vars: 610 if grad is None: 611 continue 612 # We colocate all ops created in _apply_dense or _apply_sparse 613 # on the same device as the variable. 614 # TODO(apassos): figure out how to get the variable name here. 615 if (context.executing_eagerly() or 616 resource_variable_ops.is_resource_variable(var) 617 and not var._in_graph_mode): # pylint: disable=protected-access 618 scope_name = "" 619 else: 620 scope_name = var.op.name 621 with ops.name_scope( 622 "update_" + scope_name, 623 skip_on_eager=False), ops.colocate_with(var): 624 update_ops.append(processor.update_op(self, grad)) 625 if global_step is None: 626 apply_updates = self._finish(update_ops, name) 627 else: 628 with ops.control_dependencies([self._finish(update_ops, "update")]): 629 with ops.colocate_with(global_step): 630 if isinstance( 631 global_step, resource_variable_ops.BaseResourceVariable): 632 # TODO(apassos): the implicit read in assign_add is slow; consider 633 # making it less so. 634 apply_updates = resource_variable_ops.assign_add_variable_op( 635 global_step.handle, 636 ops.convert_to_tensor(1, dtype=global_step.dtype), 637 name=name) 638 else: 639 apply_updates = state_ops.assign_add(global_step, 1, name=name) 640 641 if not context.executing_eagerly(): 642 if isinstance(apply_updates, ops.Tensor): 643 apply_updates = apply_updates.op 644 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 645 if apply_updates not in train_op: 646 train_op.append(apply_updates) 647 648 return apply_updates 649 650 def _distributed_apply(self, 651 distribution, 652 grads_and_vars, 653 global_step=None, 654 name=None): 655 """A version of `apply_gradients` for cross-replica context. 656 657 This is a version of `apply_gradients()` for when you are using a 658 `DistributionStrategy` and are in a cross-replica context. If in a 659 replica context, use `apply_gradients()` as normal. 660 661 Args: 662 distribution: A `DistributionStrategy` object. 663 grads_and_vars: List of (gradient, variable) pairs as returned by 664 `compute_gradients()`, and then aggregated across replicas. 665 global_step: Optional (mirrored) `Variable` to increment by one 666 after the variables have been updated. 667 name: Optional name for the returned operation. Default to the 668 name passed to the `Optimizer` constructor. 669 670 Returns: 671 An `Operation` that applies the specified gradients across all 672 replicas. If `global_step` was not None, that operation also 673 increments `global_step` 674 """ 675 reduced_grads = distribution.extended.batch_reduce_to( 676 ds_reduce_util.ReduceOp.SUM, grads_and_vars) 677 var_list = [v for _, v in grads_and_vars] 678 grads_and_vars = zip(reduced_grads, var_list) 679 680 # Note that this is called in a cross-replica context. 681 with ops.init_scope(): 682 self._create_slots(var_list) 683 684 def update(v, g): 685 """Apply gradients to a replica variable.""" 686 assert v is not None 687 688 try: 689 # Convert the grad to Tensor or IndexedSlices if necessary. 690 g = ops.convert_to_tensor_or_indexed_slices(g) 691 except TypeError: 692 raise TypeError("Gradient must be convertible to a Tensor" 693 " or IndexedSlices, or None: %s" % g) 694 if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): 695 raise TypeError( 696 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 697 p = _get_processor(v) 698 699 if context.executing_eagerly() or ( 700 resource_variable_ops.is_resource_variable(v) and 701 not v._in_graph_mode): # pylint: disable=protected-access 702 scope_name = v.name.split(":")[0] 703 else: 704 scope_name = v.op.name 705 706 # device_policy is set because non-mirrored tensors will be read in 707 # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t` 708 # is an example. 709 with ops.name_scope("update_" + scope_name): 710 return p.update_op(self, g) 711 712 with ops.name_scope(name, self._name) as name: 713 self._prepare() 714 715 update_ops = [ 716 op 717 for grad, var in grads_and_vars 718 for op in distribution.extended.update( 719 var, update, args=(grad,), group=False) 720 ] 721 722 def finish(self, update_ops): 723 return self._finish(update_ops, "update") 724 725 non_slot_devices = distribution.extended.non_slot_devices(var_list) 726 finish_updates = distribution.extended.update_non_slot( 727 non_slot_devices, finish, args=(self, update_ops), group=False) 728 if global_step is None: 729 apply_updates = distribution.group(finish_updates, name=name) 730 else: 731 with ops.control_dependencies(finish_updates): 732 apply_updates = distribution.extended.update( 733 global_step, state_ops.assign_add, args=(1,), 734 kwargs={"name": name}) 735 736 if not context.executing_eagerly(): 737 if isinstance(apply_updates, ops.Tensor): 738 apply_updates = apply_updates.op 739 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 740 if apply_updates not in train_op: 741 train_op.append(apply_updates) 742 743 return apply_updates 744 745 def get_slot(self, var, name): 746 """Return a slot named `name` created for `var` by the Optimizer. 747 748 Some `Optimizer` subclasses use additional variables. For example 749 `Momentum` and `Adagrad` use variables to accumulate updates. This method 750 gives access to these `Variable` objects if for some reason you need them. 751 752 Use `get_slot_names()` to get the list of slot names created by the 753 `Optimizer`. 754 755 Args: 756 var: A variable passed to `minimize()` or `apply_gradients()`. 757 name: A string. 758 759 Returns: 760 The `Variable` for the slot if it was created, `None` otherwise. 761 """ 762 named_slots = self._slots.get(name, None) 763 if not named_slots: 764 return None 765 slot = named_slots.get(_var_key(var), None) 766 if (distribute_utils.is_distributed_variable(slot) and 767 not distribute_utils.is_distributed_variable(var)): 768 # Make sure var and slot are either both DistributedVariable, or both 769 # per replica variables. 770 slot = slot._get_on_device_or_primary() # pylint: disable=protected-access 771 return slot 772 773 def get_slot_names(self): 774 """Return a list of the names of slots created by the `Optimizer`. 775 776 See `get_slot()`. 777 778 Returns: 779 A list of strings. 780 """ 781 return sorted(self._slots.keys()) 782 783 def variables(self): 784 """A list of variables which encode the current state of `Optimizer`. 785 786 Includes slot variables and additional global variables created by the 787 optimizer in the current default graph. 788 789 Returns: 790 A list of variables. 791 """ 792 current_graph = ops.get_default_graph() 793 794 def _from_current_graph(variable): 795 if variable._in_graph_mode: # pylint: disable=protected-access 796 return variable.op.graph is current_graph 797 else: 798 # No variable.op in eager mode. We don't expect lots of eager graphs, 799 # but behavior should be consistent with graph mode. 800 return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access 801 802 optimizer_variables = [v for v in self._non_slot_variables() 803 if _from_current_graph(v)] 804 for _, variable_dict in self._slots.items(): 805 for _, slot_for_variable in variable_dict.items(): 806 if _from_current_graph(slot_for_variable): 807 optimizer_variables.append(slot_for_variable) 808 # Sort variables by name so that the return is deterministic. 809 return sorted(optimizer_variables, key=lambda v: v.name) 810 811 def _create_non_slot_variable(self, initial_value, name, colocate_with): 812 """Add an extra variable, not associated with a slot.""" 813 # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. 814 eager = context.executing_eagerly() 815 graph = None if eager else colocate_with.graph 816 817 key = (name, graph) 818 v = self._non_slot_dict.get(key, None) 819 if v is None: 820 self._maybe_initialize_trackable() 821 distribution_strategy = distribute_ctx.get_strategy() 822 with distribution_strategy.extended.colocate_vars_with(colocate_with): 823 if eager: 824 restored_initial_value = self._preload_simple_restoration( 825 name=name) 826 if restored_initial_value is not None: 827 initial_value = restored_initial_value 828 v = variable_scope.variable( 829 initial_value, name=name, trainable=False, 830 use_resource=resource_variable_ops.is_resource_variable( 831 colocate_with)) 832 # Restore this variable by name if necessary, but don't add a 833 # Trackable dependency. Optimizers return the current graph's 834 # non-slot variables from _checkpoint_dependencies explicitly rather 835 # than unconditionally adding dependencies (since there may be multiple 836 # non-slot variables with the same name in different graphs, trying to 837 # save all of them would result in errors). 838 self._handle_deferred_dependencies(name=name, trackable=v) 839 self._non_slot_dict[key] = v 840 841 return v 842 843 @property 844 def _checkpoint_dependencies(self): 845 """From Trackable. Gather graph-specific non-slot variables to save.""" 846 current_graph_non_slot_variables = [] 847 current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 848 for (name, _), variable_object in sorted(self._non_slot_dict.items(), 849 # Avoid comparing graphs 850 key=lambda item: item[0][0]): 851 if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access 852 current_graph_non_slot_variables.append( 853 trackable.TrackableReference( 854 name=name, ref=variable_object)) 855 return (super(Optimizer, self)._checkpoint_dependencies 856 + current_graph_non_slot_variables) 857 858 def _lookup_dependency(self, name): 859 """From Trackable. Find a non-slot variable in the current graph.""" 860 unconditional = super(Optimizer, self)._lookup_dependency(name) 861 if unconditional is not None: 862 return unconditional 863 graph = None if context.executing_eagerly() else ops.get_default_graph() 864 return self._get_non_slot_variable(name, graph=graph) 865 866 def _get_non_slot_variable(self, name, graph=None): 867 non_slot = self._non_slot_dict.get((name, graph), None) 868 if hasattr(non_slot, "_distributed_container"): 869 # This is a mirrored non-slot. In order to enable code like `_finish` 870 # to assign to a non-slot, return the current context replica. 871 return non_slot.get() 872 else: 873 return non_slot 874 875 def _non_slot_variables(self): 876 """Additional variables created by the `Optimizer`. 877 878 Returns: 879 A list or tuple of variables. 880 """ 881 return self._non_slot_dict.values() 882 883 def _assert_valid_dtypes(self, tensors): 884 """Asserts tensors are all valid types (see `_valid_dtypes`). 885 886 Args: 887 tensors: Tensors to check. 888 889 Raises: 890 ValueError: If any tensor is not a valid type. 891 """ 892 valid_dtypes = self._valid_dtypes() 893 for t in tensors: 894 dtype = t.dtype.base_dtype 895 if dtype not in valid_dtypes: 896 raise ValueError( 897 "Invalid type %r for %s, expected: %s." % ( 898 dtype, t.name, [v for v in valid_dtypes])) 899 900 # -------------- 901 # Methods to be implemented by subclasses if they want to use the 902 # inherited implementation of apply_gradients() or compute_gradients(). 903 # -------------- 904 def _valid_dtypes(self): 905 """Valid types for loss, variables and gradients. 906 907 Subclasses should override to allow other float types. 908 909 Returns: 910 Valid types for loss, variables and gradients. 911 """ 912 return set( 913 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]) 914 915 def _create_slots(self, var_list): 916 """Create all slots needed by the variables. 917 918 Args: 919 var_list: A list of `Variable` objects. 920 """ 921 # No slots needed by default 922 pass 923 924 def _prepare(self): 925 """Create all needed tensors before applying gradients. 926 927 This is called with the name_scope using the "name" that 928 users have chosen for the application of gradients. 929 """ 930 pass 931 932 def _apply_dense(self, grad, var): 933 """Add ops to apply dense gradients to `var`. 934 935 Args: 936 grad: A `Tensor`. 937 var: A `Variable` object. 938 939 Returns: 940 An `Operation`. 941 """ 942 raise NotImplementedError() 943 944 def _resource_apply_dense(self, grad, handle): 945 """Add ops to apply dense gradients to the variable `handle`. 946 947 Args: 948 grad: a `Tensor` representing the gradient. 949 handle: a `Tensor` of dtype `resource` which points to the variable 950 to be updated. 951 952 Returns: 953 An `Operation` which updates the value of the variable. 954 """ 955 raise NotImplementedError() 956 957 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): 958 """Add ops to apply sparse gradients to `handle`, with repeated indices. 959 960 Optimizers which override this method must deal with repeated indices. See 961 the docstring of `_apply_sparse_duplicate_indices` for details. By default 962 the correct behavior, to sum non-unique indices and their associated 963 gradients, is enforced by first pre-processing `grad` and `indices` and 964 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 965 with duplicate indices may instead override this method to avoid the 966 overhead of summing. 967 968 Args: 969 grad: a `Tensor` representing the gradient for the affected indices. 970 handle: a `Tensor` of dtype `resource` which points to the variable 971 to be updated. 972 indices: a `Tensor` of integral type representing the indices for 973 which the gradient is nonzero. Indices may be repeated. 974 975 Returns: 976 An `Operation` which updates the value of the variable. 977 """ 978 summed_grad, unique_indices = _deduplicate_indexed_slices( 979 values=grad, indices=indices) 980 return self._resource_apply_sparse(summed_grad, handle, unique_indices) 981 982 def _resource_apply_sparse(self, grad, handle, indices): 983 """Add ops to apply sparse gradients to the variable `handle`. 984 985 Similar to `_apply_sparse`, the `indices` argument to this method has been 986 de-duplicated. Optimizers which deal correctly with non-unique indices may 987 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 988 overhead. 989 990 Args: 991 grad: a `Tensor` representing the gradient for the affected indices. 992 handle: a `Tensor` of dtype `resource` which points to the variable 993 to be updated. 994 indices: a `Tensor` of integral type representing the indices for 995 which the gradient is nonzero. Indices are unique. 996 997 Returns: 998 An `Operation` which updates the value of the variable. 999 """ 1000 raise NotImplementedError() 1001 1002 def _apply_sparse_duplicate_indices(self, grad, var): 1003 """Add ops to apply sparse gradients to `var`, with repeated sparse indices. 1004 1005 Optimizers which override this method must deal with IndexedSlices objects 1006 such as the following: 1007 1008 IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1]) 1009 1010 The correct interpretation is: 1011 1012 IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1]) 1013 1014 Many optimizers deal incorrectly with repeated indices when updating based 1015 on sparse gradients (e.g. summing squares rather than squaring the sum, or 1016 applying momentum terms multiple times). Adding first is always the correct 1017 behavior, so this is enforced here by reconstructing the IndexedSlices to 1018 have only unique indices, then calling _apply_sparse. 1019 1020 Optimizers which deal correctly with repeated indices may instead override 1021 this method to avoid the overhead of summing indices. 1022 1023 Args: 1024 grad: `IndexedSlices`. 1025 var: A `Variable` object. 1026 1027 Returns: 1028 An `Operation`. 1029 """ 1030 summed_values, unique_indices = _deduplicate_indexed_slices( 1031 values=grad.values, indices=grad.indices) 1032 gradient_no_duplicate_indices = ops.IndexedSlices( 1033 indices=unique_indices, 1034 values=summed_values, 1035 dense_shape=grad.dense_shape) 1036 return self._apply_sparse(gradient_no_duplicate_indices, var) 1037 1038 def _apply_sparse(self, grad, var): 1039 """Add ops to apply sparse gradients to `var`. 1040 1041 The IndexedSlices object passed to `grad` in this function is by default 1042 pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate 1043 indices (see its docstring for details). Optimizers which can tolerate or 1044 have correct special cases for duplicate sparse indices may override 1045 `_apply_sparse_duplicate_indices` instead of this function, avoiding that 1046 overhead. 1047 1048 Args: 1049 grad: `IndexedSlices`, with no repeated indices. 1050 var: A `Variable` object. 1051 1052 Returns: 1053 An `Operation`. 1054 """ 1055 raise NotImplementedError() 1056 1057 def _finish(self, update_ops, name_scope): 1058 """Do what is needed to finish the update. 1059 1060 This is called with the `name_scope` using the "name" that 1061 users have chosen for the application of gradients. 1062 1063 Args: 1064 update_ops: List of `Operation` objects to update variables. This list 1065 contains the values returned by the `_apply_dense()` and 1066 `_apply_sparse()` calls. 1067 name_scope: String. Name to use for the returned operation. 1068 1069 Returns: 1070 The operation to apply updates. 1071 """ 1072 return control_flow_ops.group(*update_ops, name=name_scope) 1073 1074 # -------------- 1075 # Utility methods for subclasses. 1076 # -------------- 1077 1078 def _slot_dict(self, slot_name): 1079 """Returns a dict for caching slots created under the given name. 1080 1081 Args: 1082 slot_name: Name for the slot. 1083 1084 Returns: 1085 A dict that maps primary `Variable` objects to the slot created 1086 for that variable, under the given slot name. 1087 """ 1088 named_slots = self._slots.get(slot_name, None) 1089 if named_slots is None: 1090 named_slots = {} 1091 self._slots[slot_name] = named_slots 1092 return named_slots 1093 1094 def _get_or_make_slot(self, var, val, slot_name, op_name): 1095 """Find or create a slot for a variable. 1096 1097 Args: 1098 var: A `Variable` object. 1099 val: A `Tensor`. The initial value of the slot. 1100 slot_name: Name for the slot. 1101 op_name: Name to use when scoping the Variable that 1102 needs to be created for the slot. 1103 1104 Returns: 1105 A `Variable` object. 1106 """ 1107 named_slots = self._slot_dict(slot_name) 1108 if _var_key(var) not in named_slots: 1109 new_slot_variable = slot_creator.create_slot(var, val, op_name) 1110 self._restore_slot_variable( 1111 slot_name=slot_name, variable=var, 1112 slot_variable=new_slot_variable) 1113 named_slots[_var_key(var)] = new_slot_variable 1114 return named_slots[_var_key(var)] 1115 1116 def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, 1117 slot_name, op_name): 1118 """Find or create a slot for a variable, using an Initializer. 1119 1120 Args: 1121 var: A `Variable` object. 1122 initializer: An `Initializer`. The initial value of the slot. 1123 shape: Shape of the initial value of the slot. 1124 dtype: Type of the value of the slot. 1125 slot_name: Name for the slot. 1126 op_name: Name to use when scoping the Variable that 1127 needs to be created for the slot. 1128 1129 Returns: 1130 A `Variable` object. 1131 """ 1132 named_slots = self._slot_dict(slot_name) 1133 if _var_key(var) not in named_slots: 1134 new_slot_variable = slot_creator.create_slot_with_initializer( 1135 var, initializer, shape, dtype, op_name) 1136 self._restore_slot_variable( 1137 slot_name=slot_name, variable=var, 1138 slot_variable=new_slot_variable) 1139 named_slots[_var_key(var)] = new_slot_variable 1140 return named_slots[_var_key(var)] 1141 1142 def _zeros_slot(self, var, slot_name, op_name): 1143 """Find or create a slot initialized with 0.0. 1144 1145 Args: 1146 var: A `Variable` object. 1147 slot_name: Name for the slot. 1148 op_name: Name to use when scoping the Variable that 1149 needs to be created for the slot. 1150 1151 Returns: 1152 A `Variable` object. 1153 """ 1154 named_slots = self._slot_dict(slot_name) 1155 if _var_key(var) not in named_slots: 1156 new_slot_variable = slot_creator.create_zeros_slot( 1157 var, op_name, copy_xla_sharding=True) 1158 self._restore_slot_variable( 1159 slot_name=slot_name, variable=var, 1160 slot_variable=new_slot_variable) 1161 named_slots[_var_key(var)] = new_slot_variable 1162 return named_slots[_var_key(var)] 1163 1164 # -------------- 1165 # For implementing the Trackable interface. 1166 # -------------- 1167 1168 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1169 """Restore a newly created slot variable's value.""" 1170 variable_key = _var_key(variable) 1171 deferred_restorations = self._deferred_slot_restorations.get( 1172 slot_name, {}).pop(variable_key, []) 1173 # Iterate over restores, highest restore UID first to minimize the number 1174 # of assignments. 1175 deferred_restorations.sort(key=lambda position: position.restore_uid, 1176 reverse=True) 1177 for checkpoint_position in deferred_restorations: 1178 checkpoint_position.restore(slot_variable) 1179 1180 def _create_or_restore_slot_variable( 1181 self, slot_variable_position, slot_name, variable): 1182 """Restore a slot variable's value, possibly creating it. 1183 1184 Called when a variable which has an associated slot variable is created or 1185 restored. When executing eagerly, we create the slot variable with a 1186 restoring initializer. 1187 1188 No new variables are created when graph building. Instead, 1189 _restore_slot_variable catches these after normal creation and adds restore 1190 ops to the graph. This method is nonetheless important when graph building 1191 for the case when a slot variable has already been created but `variable` 1192 has just been added to a dependency graph (causing us to realize that the 1193 slot variable needs to be restored). 1194 1195 Args: 1196 slot_variable_position: A `trackable._CheckpointPosition` object 1197 indicating the slot variable `Trackable` object to be restored. 1198 slot_name: The name of this `Optimizer`'s slot to restore into. 1199 variable: The variable object this slot is being created for. 1200 """ 1201 named_slots = self._slot_dict(slot_name) 1202 variable_key = _var_key(variable) 1203 slot_variable = named_slots.get(variable_key, None) 1204 if (slot_variable is None and context.executing_eagerly() and 1205 slot_variable_position.is_simple_variable() 1206 # Defer slot variable creation if there is an active variable creator 1207 # scope. Generally we'd like to eagerly create/restore slot variables 1208 # when possible, but this may mean that scopes intended to catch 1209 # `variable` also catch its eagerly created slot variable 1210 # unintentionally (specifically make_template would add a dependency on 1211 # a slot variable if not for this case). Deferring is mostly harmless 1212 # (aside from double initialization), and makes variable creator scopes 1213 # behave the same way they do when graph building. 1214 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access 1215 initializer = trackable.CheckpointInitialValueCallable( 1216 checkpoint_position=slot_variable_position) 1217 # CheckpointInitialValueCallable will ignore the shape and dtype 1218 # parameters but they must be passed. 1219 slot_variable = self._get_or_make_slot_with_initializer( 1220 var=variable, 1221 initializer=initializer, 1222 shape=variable.shape, 1223 dtype=variable.dtype, 1224 slot_name=slot_name, 1225 op_name=self._name) 1226 # Slot variables are not owned by any one object (because we don't want to 1227 # save the slot variable if the optimizer is saved without the non-slot 1228 # variable, or if the non-slot variable is saved without the optimizer; 1229 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1230 # variable, variable)). So we don't _track_ slot variables anywhere, and 1231 # instead special-case this dependency and otherwise pretend it's a normal 1232 # graph. 1233 if slot_variable is not None: 1234 # If we've either made this slot variable, or if we've pulled out an 1235 # existing slot variable, we should restore it. 1236 slot_variable_position.restore(slot_variable) 1237 else: 1238 # We didn't make the slot variable. Defer restoring until it gets created 1239 # normally. We keep a list rather than the one with the highest restore 1240 # UID in case slot variables have their own dependencies, in which case 1241 # those could differ between restores. 1242 self._deferred_slot_restorations.setdefault( 1243 slot_name, {}).setdefault(variable_key, []).append( 1244 slot_variable_position) 1245 1246 def _call_if_callable(self, param): 1247 """Call the function if param is callable.""" 1248 return param() if callable(param) else param 1249