1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Base class for optimizers.""" 17# pylint: disable=g-bad-name 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24 25import six 26 27from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 28from tensorflow.python.distribute import reduce_util as ds_reduce_util 29from tensorflow.python.eager import backprop 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import gradients 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import state_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.ops import variables 41from tensorflow.python.training import slot_creator 42from tensorflow.python.training.tracking import base as trackable 43from tensorflow.python.util import nest 44from tensorflow.python.util.tf_export import tf_export 45 46 47def get_filtered_grad_fn(grad_fn): 48 # `distributed_context.join()` requires that its arguments are parallel 49 # across threads, and in particular that `grads_and_vars` has the same 50 # variables in the same order. 51 52 # When computing gradients in eager mode with multiple threads, you 53 # can get extra variables with a gradient of `None`. This happens when 54 # those variables are accessed in another thread during the gradient 55 # computation. To get a consistent set of variables, we filter out 56 # those with `None` gradients. 57 def filtered_grad_fn(*args, **kwargs): 58 return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None] 59 60 return filtered_grad_fn 61 62 63def _deduplicate_indexed_slices(values, indices): 64 """Sums `values` associated with any non-unique `indices`. 65 66 Args: 67 values: A `Tensor` with rank >= 1. 68 indices: A one-dimensional integer `Tensor`, indexing into the first 69 dimension of `values` (as in an IndexedSlices object). 70 Returns: 71 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 72 de-duplicated version of `indices` and `summed_values` contains the sum of 73 `values` slices associated with each unique index. 74 """ 75 unique_indices, new_index_positions = array_ops.unique(indices) 76 summed_values = math_ops.unsorted_segment_sum( 77 values, new_index_positions, 78 array_ops.shape(unique_indices)[0]) 79 return (summed_values, unique_indices) 80 81 82def _var_key(var): 83 # TODO(ashankar): Consolidate handling for eager and graph 84 if hasattr(var, "op"): 85 return (var.op.graph, var.op.name) 86 return var._unique_id # pylint: disable=protected-access 87 88 89@six.add_metaclass(abc.ABCMeta) 90class _OptimizableVariable(object): 91 """Interface for abstracting over variables in the optimizers.""" 92 93 @abc.abstractmethod 94 def target(self): 95 """Returns the optimization target for this variable.""" 96 raise NotImplementedError("Calling an abstract method.") 97 98 @abc.abstractmethod 99 def update_op(self, optimizer, g): 100 """Returns the update ops for updating the variable.""" 101 raise NotImplementedError("Calling an abstract method.") 102 103 104class _RefVariableProcessor(_OptimizableVariable): 105 """Processor for Variable.""" 106 107 def __init__(self, v): 108 self._v = v 109 110 def __str__(self): 111 return "<_RefVariableProcessor(%s)>" % self._v 112 113 def target(self): 114 return self._v._ref() # pylint: disable=protected-access 115 116 def update_op(self, optimizer, g): 117 if isinstance(g, ops.Tensor): 118 update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access 119 if self._v.constraint is not None: 120 with ops.control_dependencies([update_op]): 121 return self._v.assign(self._v.constraint(self._v)) 122 else: 123 return update_op 124 else: 125 assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a " 126 "tensor nor IndexedSlices.") 127 if self._v.constraint is not None: 128 raise RuntimeError( 129 "Cannot use a constraint function on a sparse variable.") 130 # pylint: disable=protected-access 131 return optimizer._apply_sparse_duplicate_indices(g, self._v) 132 133 134class _DenseReadResourceVariableProcessor(_OptimizableVariable): 135 """Processor for dense ResourceVariables.""" 136 137 def __init__(self, v): 138 self._v = v 139 140 def target(self): 141 return self._v 142 143 def update_op(self, optimizer, g): 144 # pylint: disable=protected-access 145 update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0]) 146 if self._v.constraint is not None: 147 with ops.control_dependencies([update_op]): 148 return self._v.assign(self._v.constraint(self._v)) 149 else: 150 return update_op 151 152 153class _DenseResourceVariableProcessor(_OptimizableVariable): 154 """Processor for dense ResourceVariables.""" 155 156 def __init__(self, v): 157 self._v = v 158 159 def target(self): 160 return self._v 161 162 def update_op(self, optimizer, g): 163 # pylint: disable=protected-access 164 if isinstance(g, ops.IndexedSlices): 165 if self._v.constraint is not None: 166 raise RuntimeError( 167 "Cannot use a constraint function on a sparse variable.") 168 return optimizer._resource_apply_sparse_duplicate_indices( 169 g.values, self._v, g.indices) 170 update_op = optimizer._resource_apply_dense(g, self._v) 171 if self._v.constraint is not None: 172 with ops.control_dependencies([update_op]): 173 return self._v.assign(self._v.constraint(self._v)) 174 else: 175 return update_op 176 177 178class _TensorProcessor(_OptimizableVariable): 179 """Processor for ordinary Tensors. 180 181 Even though a Tensor can't really be updated, sometimes it is useful to 182 compute the gradients with respect to a Tensor using the optimizer. Updating 183 the Tensor is, of course, unsupported. 184 """ 185 186 def __init__(self, v): 187 self._v = v 188 189 def target(self): 190 return self._v 191 192 def update_op(self, optimizer, g): 193 raise NotImplementedError("Trying to update a Tensor ", self._v) 194 195 196def _get_processor(v): 197 """The processor of v.""" 198 if context.executing_eagerly(): 199 if isinstance(v, ops.Tensor): 200 return _TensorProcessor(v) 201 else: 202 return _DenseResourceVariableProcessor(v) 203 if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access 204 # True if and only if `v` was initialized eagerly. 205 return _DenseResourceVariableProcessor(v) 206 if v.op.type == "VarHandleOp": 207 return _DenseResourceVariableProcessor(v) 208 if isinstance(v, variables.Variable): 209 return _RefVariableProcessor(v) 210 if isinstance(v, ops.Tensor): 211 return _TensorProcessor(v) 212 raise NotImplementedError("Trying to optimize unsupported type ", v) 213 214 215@tf_export(v1=["train.Optimizer"]) 216class Optimizer( 217 # Optimizers inherit from Trackable rather than AutoTrackable 218 # since they do most of their dependency management themselves (slot 219 # variables are special-cased, and non-slot variables are keyed to graphs). 220 trackable.Trackable): 221 """Base class for optimizers. 222 223 This class defines the API to add Ops to train a model. You never use this 224 class directly, but instead instantiate one of its subclasses such as 225 `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. 226 227 ### Usage 228 229 ```python 230 # Create an optimizer with the desired parameters. 231 opt = GradientDescentOptimizer(learning_rate=0.1) 232 # Add Ops to the graph to minimize a cost by updating a list of variables. 233 # "cost" is a Tensor, and the list of variables contains tf.Variable 234 # objects. 235 opt_op = opt.minimize(cost, var_list=<list of variables>) 236 ``` 237 238 In the training program you will just have to run the returned Op. 239 240 ```python 241 # Execute opt_op to do one step of training: 242 opt_op.run() 243 ``` 244 245 ### Processing gradients before applying them. 246 247 Calling `minimize()` takes care of both computing the gradients and 248 applying them to the variables. If you want to process the gradients 249 before applying them you can instead use the optimizer in three steps: 250 251 1. Compute the gradients with `compute_gradients()`. 252 2. Process the gradients as you wish. 253 3. Apply the processed gradients with `apply_gradients()`. 254 255 Example: 256 257 ```python 258 # Create an optimizer. 259 opt = GradientDescentOptimizer(learning_rate=0.1) 260 261 # Compute the gradients for a list of variables. 262 grads_and_vars = opt.compute_gradients(loss, <list of variables>) 263 264 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you 265 # need to the 'gradient' part, for example cap them, etc. 266 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] 267 268 # Ask the optimizer to apply the capped gradients. 269 opt.apply_gradients(capped_grads_and_vars) 270 ``` 271 272 ### Gating Gradients 273 274 Both `minimize()` and `compute_gradients()` accept a `gate_gradients` 275 argument that controls the degree of parallelism during the application of 276 the gradients. 277 278 The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. 279 280 <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides 281 the maximum parallelism in execution, at the cost of some non-reproducibility 282 in the results. For example the two gradients of `matmul` depend on the input 283 values: With `GATE_NONE` one of the gradients could be applied to one of the 284 inputs _before_ the other gradient is computed resulting in non-reproducible 285 results. 286 287 <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before 288 they are used. This prevents race conditions for Ops that generate gradients 289 for multiple inputs where the gradients depend on the inputs. 290 291 <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed 292 before any one of them is used. This provides the least parallelism but can 293 be useful if you want to process all gradients before applying any of them. 294 295 ### Slots 296 297 Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` 298 allocate and manage additional variables associated with the variables to 299 train. These are called <i>Slots</i>. Slots have names and you can ask the 300 optimizer for the names of the slots that it uses. Once you have a slot name 301 you can ask the optimizer for the variable it created to hold the slot value. 302 303 This can be useful if you want to log debug a training algorithm, report stats 304 about the slots, etc. 305 """ 306 307 # Values for gate_gradients. 308 GATE_NONE = 0 309 GATE_OP = 1 310 GATE_GRAPH = 2 311 312 def __init__(self, use_locking, name): 313 """Create a new Optimizer. 314 315 This must be called by the constructors of subclasses. 316 317 Args: 318 use_locking: Bool. If True apply use locks to prevent concurrent updates 319 to variables. 320 name: A non-empty string. The name to use for accumulators created 321 for the optimizer. 322 323 Raises: 324 ValueError: If name is malformed. 325 """ 326 if not name: 327 raise ValueError("Must specify the optimizer name") 328 self._use_locking = use_locking 329 self._name = name 330 # Dictionary of slots. 331 # {slot_name : 332 # {_var_key(variable_to_train): slot_for_the_variable, ... }, 333 # ... } 334 self._slots = {} 335 self._non_slot_dict = {} 336 # For implementing Trackable. Stores information about how to restore 337 # slot variables which have not yet been created 338 # (trackable._CheckpointPosition objects). 339 # {slot_name : 340 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 341 # ... } 342 self._deferred_slot_restorations = {} 343 344 # TODO(isaprykin): When using a DistributionStrategy, and when an 345 # optimizer is created in each replica, it might be dangerous to 346 # rely on some Optimizer methods. When such methods are called on a 347 # per-replica optimizer, an exception needs to be thrown. We do 348 # allow creation per-replica optimizers however, because the 349 # compute_gradients()->apply_gradients() sequence is safe. 350 351 def get_name(self): 352 return self._name 353 354 def minimize(self, loss, global_step=None, var_list=None, 355 gate_gradients=GATE_OP, aggregation_method=None, 356 colocate_gradients_with_ops=False, name=None, 357 grad_loss=None): 358 """Add operations to minimize `loss` by updating `var_list`. 359 360 This method simply combines calls `compute_gradients()` and 361 `apply_gradients()`. If you want to process the gradient before applying 362 them call `compute_gradients()` and `apply_gradients()` explicitly instead 363 of using this function. 364 365 Args: 366 loss: A `Tensor` containing the value to minimize. 367 global_step: Optional `Variable` to increment by one after the 368 variables have been updated. 369 var_list: Optional list or tuple of `Variable` objects to update to 370 minimize `loss`. Defaults to the list of variables collected in 371 the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. 372 gate_gradients: How to gate the computation of gradients. Can be 373 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 374 aggregation_method: Specifies the method used to combine gradient terms. 375 Valid values are defined in the class `AggregationMethod`. 376 colocate_gradients_with_ops: If True, try colocating gradients with 377 the corresponding op. 378 name: Optional name for the returned operation. 379 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 380 381 Returns: 382 An Operation that updates the variables in `var_list`. If `global_step` 383 was not `None`, that operation also increments `global_step`. 384 385 Raises: 386 ValueError: If some of the variables are not `Variable` objects. 387 388 @compatibility(eager) 389 When eager execution is enabled, `loss` should be a Python function that 390 takes no arguments and computes the value to be minimized. Minimization (and 391 gradient computation) is done with respect to the elements of `var_list` if 392 not None, else with respect to any trainable variables created during the 393 execution of the `loss` function. `gate_gradients`, `aggregation_method`, 394 `colocate_gradients_with_ops` and `grad_loss` are ignored when eager 395 execution is enabled. 396 @end_compatibility 397 """ 398 grads_and_vars = self.compute_gradients( 399 loss, var_list=var_list, gate_gradients=gate_gradients, 400 aggregation_method=aggregation_method, 401 colocate_gradients_with_ops=colocate_gradients_with_ops, 402 grad_loss=grad_loss) 403 404 vars_with_grad = [v for g, v in grads_and_vars if g is not None] 405 if not vars_with_grad: 406 raise ValueError( 407 "No gradients provided for any variable, check your graph for ops" 408 " that do not support gradients, between variables %s and loss %s." % 409 ([str(v) for _, v in grads_and_vars], loss)) 410 411 return self.apply_gradients(grads_and_vars, global_step=global_step, 412 name=name) 413 414 def compute_gradients(self, loss, var_list=None, 415 gate_gradients=GATE_OP, 416 aggregation_method=None, 417 colocate_gradients_with_ops=False, 418 grad_loss=None): 419 """Compute gradients of `loss` for the variables in `var_list`. 420 421 This is the first part of `minimize()`. It returns a list 422 of (gradient, variable) pairs where "gradient" is the gradient 423 for "variable". Note that "gradient" can be a `Tensor`, an 424 `IndexedSlices`, or `None` if there is no gradient for the 425 given variable. 426 427 Args: 428 loss: A Tensor containing the value to minimize or a callable taking 429 no arguments which returns the value to minimize. When eager execution 430 is enabled it must be a callable. 431 var_list: Optional list or tuple of `tf.Variable` to update to minimize 432 `loss`. Defaults to the list of variables collected in the graph 433 under the key `GraphKeys.TRAINABLE_VARIABLES`. 434 gate_gradients: How to gate the computation of gradients. Can be 435 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 436 aggregation_method: Specifies the method used to combine gradient terms. 437 Valid values are defined in the class `AggregationMethod`. 438 colocate_gradients_with_ops: If True, try colocating gradients with 439 the corresponding op. 440 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 441 442 Returns: 443 A list of (gradient, variable) pairs. Variable is always present, but 444 gradient can be `None`. 445 446 Raises: 447 TypeError: If `var_list` contains anything else than `Variable` objects. 448 ValueError: If some arguments are invalid. 449 RuntimeError: If called with eager execution enabled and `loss` is 450 not callable. 451 452 @compatibility(eager) 453 When eager execution is enabled, `gate_gradients`, `aggregation_method`, 454 and `colocate_gradients_with_ops` are ignored. 455 @end_compatibility 456 """ 457 if callable(loss): 458 with backprop.GradientTape() as tape: 459 if var_list is not None: 460 tape.watch(var_list) 461 loss_value = loss() 462 463 if var_list is None: 464 var_list = tape.watched_variables() 465 # TODO(jhseu): Figure out why GradientTape's gradients don't require loss 466 # to be executed. 467 with ops.control_dependencies([loss_value]): 468 grads = tape.gradient(loss_value, var_list, grad_loss) 469 return list(zip(grads, var_list)) 470 471 # Non-callable/Tensor loss case 472 if context.executing_eagerly(): 473 raise RuntimeError( 474 "`loss` passed to Optimizer.compute_gradients should " 475 "be a function when eager execution is enabled.") 476 477 if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, 478 Optimizer.GATE_GRAPH]: 479 raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " 480 "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % 481 gate_gradients) 482 self._assert_valid_dtypes([loss]) 483 if grad_loss is not None: 484 self._assert_valid_dtypes([grad_loss]) 485 if var_list is None: 486 var_list = ( 487 variables.trainable_variables() + 488 ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 489 else: 490 var_list = nest.flatten(var_list) 491 # pylint: disable=protected-access 492 var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) 493 # pylint: enable=protected-access 494 processors = [_get_processor(v) for v in var_list] 495 if not var_list: 496 raise ValueError("No variables to optimize.") 497 var_refs = [p.target() for p in processors] 498 grads = gradients.gradients( 499 loss, var_refs, grad_ys=grad_loss, 500 gate_gradients=(gate_gradients == Optimizer.GATE_OP), 501 aggregation_method=aggregation_method, 502 colocate_gradients_with_ops=colocate_gradients_with_ops) 503 if gate_gradients == Optimizer.GATE_GRAPH: 504 grads = control_flow_ops.tuple(grads) 505 grads_and_vars = list(zip(grads, var_list)) 506 self._assert_valid_dtypes( 507 [v for g, v in grads_and_vars 508 if g is not None and v.dtype != dtypes.resource]) 509 return grads_and_vars 510 511 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 512 """Apply gradients to variables. 513 514 This is the second part of `minimize()`. It returns an `Operation` that 515 applies gradients. 516 517 Args: 518 grads_and_vars: List of (gradient, variable) pairs as returned by 519 `compute_gradients()`. 520 global_step: Optional `Variable` to increment by one after the 521 variables have been updated. 522 name: Optional name for the returned operation. Default to the 523 name passed to the `Optimizer` constructor. 524 525 Returns: 526 An `Operation` that applies the specified gradients. If `global_step` 527 was not None, that operation also increments `global_step`. 528 529 Raises: 530 TypeError: If `grads_and_vars` is malformed. 531 ValueError: If none of the variables have gradients. 532 RuntimeError: If you should use `_distributed_apply()` instead. 533 """ 534 # This is a default implementation of apply_gradients() that can be shared 535 # by most optimizers. It relies on the subclass implementing the following 536 # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). 537 538 # TODO(isaprykin): Get rid of `has_strategy()` check by 539 # always calling _distributed_apply(), using the default distribution 540 # as needed. 541 if distribute_ctx.has_strategy(): 542 # Handle DistributionStrategy case. 543 if distribute_ctx.in_cross_replica_context(): 544 raise RuntimeError("Use `_distributed_apply()` instead of " 545 "`apply_gradients()` in a cross-replica context.") 546 547 grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() 548 return distribute_ctx.get_replica_context().merge_call( 549 self._distributed_apply, args=(grads_and_vars, global_step, name)) 550 551 # No DistributionStrategy case. 552 grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. 553 if not grads_and_vars: 554 raise ValueError("No variables provided.") 555 converted_grads_and_vars = [] 556 for g, v in grads_and_vars: 557 if g is not None: 558 try: 559 # Convert the grad to Tensor or IndexedSlices if necessary. 560 g = ops.convert_to_tensor_or_indexed_slices(g) 561 except TypeError: 562 raise TypeError( 563 "Gradient must be convertible to a Tensor" 564 " or IndexedSlices, or None: %s" % g) 565 if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): 566 raise TypeError( 567 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 568 p = _get_processor(v) 569 converted_grads_and_vars.append((g, v, p)) 570 571 converted_grads_and_vars = tuple(converted_grads_and_vars) 572 var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] 573 if not var_list: 574 raise ValueError("No gradients provided for any variable: %s." % 575 ([str(v) for _, v, _ in converted_grads_and_vars],)) 576 with ops.init_scope(): 577 self._create_slots(var_list) 578 update_ops = [] 579 with ops.name_scope(name, self._name) as name: 580 self._prepare() 581 for grad, var, processor in converted_grads_and_vars: 582 if grad is None: 583 continue 584 # We colocate all ops created in _apply_dense or _apply_sparse 585 # on the same device as the variable. 586 # TODO(apassos): figure out how to get the variable name here. 587 if context.executing_eagerly() or isinstance( 588 var, 589 resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access 590 scope_name = "" 591 else: 592 scope_name = var.op.name 593 with ops.name_scope("update_" + scope_name), ops.colocate_with(var): 594 update_ops.append(processor.update_op(self, grad)) 595 if global_step is None: 596 apply_updates = self._finish(update_ops, name) 597 else: 598 with ops.control_dependencies([self._finish(update_ops, "update")]): 599 with ops.colocate_with(global_step): 600 if isinstance(global_step, resource_variable_ops.ResourceVariable): 601 # TODO(apassos): the implicit read in assign_add is slow; consider 602 # making it less so. 603 apply_updates = resource_variable_ops.assign_add_variable_op( 604 global_step.handle, 605 ops.convert_to_tensor(1, dtype=global_step.dtype), 606 name=name) 607 else: 608 apply_updates = state_ops.assign_add(global_step, 1, name=name) 609 610 if not context.executing_eagerly(): 611 if isinstance(apply_updates, ops.Tensor): 612 apply_updates = apply_updates.op 613 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 614 if apply_updates not in train_op: 615 train_op.append(apply_updates) 616 617 return apply_updates 618 619 def _distributed_apply(self, 620 distribution, 621 grads_and_vars, 622 global_step=None, 623 name=None): 624 """A version of `apply_gradients` for cross-replica context. 625 626 This is a version of `apply_gradients()` for when you are using a 627 `DistributionStrategy` and are in a cross-replica context. If in a 628 replica context, use `apply_gradients()` as normal. 629 630 Args: 631 distribution: A `DistributionStrategy` object. 632 grads_and_vars: List of (gradient, variable) pairs as returned by 633 `compute_gradients()`, and then aggregated across replicas. 634 global_step: Optional (mirrored) `Variable` to increment by one 635 after the variables have been updated. 636 name: Optional name for the returned operation. Default to the 637 name passed to the `Optimizer` constructor. 638 639 Returns: 640 An `Operation` that applies the specified gradients across all 641 replicas. If `global_step` was not None, that operation also 642 increments `global_step` 643 """ 644 reduced_grads = distribution.extended.batch_reduce_to( 645 ds_reduce_util.ReduceOp.SUM, grads_and_vars) 646 var_list = [v for _, v in grads_and_vars] 647 grads_and_vars = zip(reduced_grads, var_list) 648 649 # Note that this is called in a cross-replica context. 650 with ops.init_scope(): 651 self._create_slots(var_list) 652 653 def update(v, g): 654 """Apply gradients to a replica variable.""" 655 assert v is not None 656 657 try: 658 # Convert the grad to Tensor or IndexedSlices if necessary. 659 g = ops.convert_to_tensor_or_indexed_slices(g) 660 except TypeError: 661 raise TypeError("Gradient must be convertible to a Tensor" 662 " or IndexedSlices, or None: %s" % g) 663 if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): 664 raise TypeError( 665 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 666 p = _get_processor(v) 667 668 if context.executing_eagerly() or ( 669 resource_variable_ops.is_resource_variable(v) and 670 not v._in_graph_mode): # pylint: disable=protected-access 671 scope_name = v.name.split(":")[0] 672 else: 673 scope_name = v.op.name 674 675 # device_policy is set because non-mirrored tensors will be read in 676 # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t` 677 # is an example. 678 with ops.name_scope("update_" + scope_name): 679 return p.update_op(self, g) 680 681 with ops.name_scope(name, self._name) as name: 682 self._prepare() 683 684 update_ops = [ 685 op 686 for grad, var in grads_and_vars 687 for op in distribution.extended.update( 688 var, update, args=(grad,), group=False) 689 ] 690 691 def finish(self, update_ops): 692 return self._finish(update_ops, "update") 693 694 non_slot_devices = distribution.extended.non_slot_devices(var_list) 695 finish_updates = distribution.extended.update_non_slot( 696 non_slot_devices, finish, args=(self, update_ops), group=False) 697 if global_step is None: 698 apply_updates = distribution.group(finish_updates, name=name) 699 else: 700 with ops.control_dependencies(finish_updates): 701 apply_updates = distribution.extended.update( 702 global_step, state_ops.assign_add, args=(1,), 703 kwargs={"name": name}) 704 705 if not context.executing_eagerly(): 706 if isinstance(apply_updates, ops.Tensor): 707 apply_updates = apply_updates.op 708 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 709 if apply_updates not in train_op: 710 train_op.append(apply_updates) 711 712 return apply_updates 713 714 def get_slot(self, var, name): 715 """Return a slot named `name` created for `var` by the Optimizer. 716 717 Some `Optimizer` subclasses use additional variables. For example 718 `Momentum` and `Adagrad` use variables to accumulate updates. This method 719 gives access to these `Variable` objects if for some reason you need them. 720 721 Use `get_slot_names()` to get the list of slot names created by the 722 `Optimizer`. 723 724 Args: 725 var: A variable passed to `minimize()` or `apply_gradients()`. 726 name: A string. 727 728 Returns: 729 The `Variable` for the slot if it was created, `None` otherwise. 730 """ 731 # pylint: disable=protected-access 732 named_slots = self._slots.get(name, None) 733 if not named_slots: 734 return None 735 736 if hasattr(var, "_distributed_container"): 737 # NOTE: If this isn't patched, then there is no `handle` in 738 # `_resource_apply_dense`. 739 distributed_container = var._distributed_container() 740 assert distributed_container is not None 741 if ops.executing_eagerly_outside_functions(): 742 key = distributed_container._unique_id 743 else: 744 key = (distributed_container.graph, distributed_container._shared_name) 745 # pylint: enable=protected-access 746 mirrored_slot = named_slots.get(key, None) 747 if mirrored_slot is None: return None 748 return mirrored_slot.get(device=var.device) 749 750 return named_slots.get(_var_key(var), None) 751 752 def get_slot_names(self): 753 """Return a list of the names of slots created by the `Optimizer`. 754 755 See `get_slot()`. 756 757 Returns: 758 A list of strings. 759 """ 760 return sorted(self._slots.keys()) 761 762 def variables(self): 763 """A list of variables which encode the current state of `Optimizer`. 764 765 Includes slot variables and additional global variables created by the 766 optimizer in the current default graph. 767 768 Returns: 769 A list of variables. 770 """ 771 current_graph = ops.get_default_graph() 772 773 def _from_current_graph(variable): 774 if variable._in_graph_mode: # pylint: disable=protected-access 775 return variable.op.graph is current_graph 776 else: 777 # No variable.op in eager mode. We don't expect lots of eager graphs, 778 # but behavior should be consistent with graph mode. 779 return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access 780 781 optimizer_variables = [v for v in self._non_slot_variables() 782 if _from_current_graph(v)] 783 for _, variable_dict in self._slots.items(): 784 for _, slot_for_variable in variable_dict.items(): 785 if _from_current_graph(slot_for_variable): 786 optimizer_variables.append(slot_for_variable) 787 # Sort variables by name so that the return is deterministic. 788 return sorted(optimizer_variables, key=lambda v: v.name) 789 790 def _create_non_slot_variable(self, initial_value, name, colocate_with): 791 """Add an extra variable, not associated with a slot.""" 792 # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. 793 eager = context.executing_eagerly() 794 graph = None if eager else colocate_with.graph 795 796 key = (name, graph) 797 v = self._non_slot_dict.get(key, None) 798 if v is None: 799 self._maybe_initialize_trackable() 800 distribution_strategy = distribute_ctx.get_strategy() 801 with distribution_strategy.extended.colocate_vars_with(colocate_with): 802 if eager: 803 restored_initial_value = self._preload_simple_restoration( 804 name=name, shape=None) 805 if restored_initial_value is not None: 806 initial_value = restored_initial_value 807 v = variable_scope.variable( 808 initial_value, name=name, trainable=False, 809 use_resource=resource_variable_ops.is_resource_variable( 810 colocate_with)) 811 # Restore this variable by name if necessary, but don't add a 812 # Trackable dependency. Optimizers return the current graph's 813 # non-slot variables from _checkpoint_dependencies explicitly rather 814 # than unconditionally adding dependencies (since there may be multiple 815 # non-slot variables with the same name in different graphs, trying to 816 # save all of them would result in errors). 817 self._handle_deferred_dependencies(name=name, trackable=v) 818 self._non_slot_dict[key] = v 819 820 return v 821 822 @property 823 def _checkpoint_dependencies(self): 824 """From Trackable. Gather graph-specific non-slot variables to save.""" 825 current_graph_non_slot_variables = [] 826 current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 827 for (name, _), variable_object in sorted(self._non_slot_dict.items(), 828 # Avoid comparing graphs 829 key=lambda item: item[0][0]): 830 if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access 831 current_graph_non_slot_variables.append( 832 trackable.TrackableReference( 833 name=name, ref=variable_object)) 834 return (super(Optimizer, self)._checkpoint_dependencies 835 + current_graph_non_slot_variables) 836 837 def _lookup_dependency(self, name): 838 """From Trackable. Find a non-slot variable in the current graph.""" 839 unconditional = super(Optimizer, self)._lookup_dependency(name) 840 if unconditional is not None: 841 return unconditional 842 graph = None if context.executing_eagerly() else ops.get_default_graph() 843 return self._get_non_slot_variable(name, graph=graph) 844 845 def _get_non_slot_variable(self, name, graph=None): 846 non_slot = self._non_slot_dict.get((name, graph), None) 847 if hasattr(non_slot, "_distributed_container"): 848 # This is a mirrored non-slot. In order to enable code like `_finish` 849 # to assign to a non-slot, return the current context replica. 850 return non_slot.get() 851 else: 852 return non_slot 853 854 def _non_slot_variables(self): 855 """Additional variables created by the `Optimizer`. 856 857 Returns: 858 A list or tuple of variables. 859 """ 860 return self._non_slot_dict.values() 861 862 def _assert_valid_dtypes(self, tensors): 863 """Asserts tensors are all valid types (see `_valid_dtypes`). 864 865 Args: 866 tensors: Tensors to check. 867 868 Raises: 869 ValueError: If any tensor is not a valid type. 870 """ 871 valid_dtypes = self._valid_dtypes() 872 for t in tensors: 873 dtype = t.dtype.base_dtype 874 if dtype not in valid_dtypes: 875 raise ValueError( 876 "Invalid type %r for %s, expected: %s." % ( 877 dtype, t.name, [v for v in valid_dtypes])) 878 879 # -------------- 880 # Methods to be implemented by subclasses if they want to use the 881 # inherited implementation of apply_gradients() or compute_gradients(). 882 # -------------- 883 def _valid_dtypes(self): 884 """Valid types for loss, variables and gradients. 885 886 Subclasses should override to allow other float types. 887 888 Returns: 889 Valid types for loss, variables and gradients. 890 """ 891 return set( 892 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]) 893 894 def _create_slots(self, var_list): 895 """Create all slots needed by the variables. 896 897 Args: 898 var_list: A list of `Variable` objects. 899 """ 900 # No slots needed by default 901 pass 902 903 def _prepare(self): 904 """Create all needed tensors before applying gradients. 905 906 This is called with the name_scope using the "name" that 907 users have chosen for the application of gradients. 908 """ 909 pass 910 911 def _apply_dense(self, grad, var): 912 """Add ops to apply dense gradients to `var`. 913 914 Args: 915 grad: A `Tensor`. 916 var: A `Variable` object. 917 918 Returns: 919 An `Operation`. 920 """ 921 raise NotImplementedError() 922 923 def _resource_apply_dense(self, grad, handle): 924 """Add ops to apply dense gradients to the variable `handle`. 925 926 Args: 927 grad: a `Tensor` representing the gradient. 928 handle: a `Tensor` of dtype `resource` which points to the variable 929 to be updated. 930 931 Returns: 932 An `Operation` which updates the value of the variable. 933 """ 934 raise NotImplementedError() 935 936 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): 937 """Add ops to apply sparse gradients to `handle`, with repeated indices. 938 939 Optimizers which override this method must deal with repeated indices. See 940 the docstring of `_apply_sparse_duplicate_indices` for details. By default 941 the correct behavior, to sum non-unique indices and their associated 942 gradients, is enforced by first pre-processing `grad` and `indices` and 943 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 944 with duplicate indices may instead override this method to avoid the 945 overhead of summing. 946 947 Args: 948 grad: a `Tensor` representing the gradient for the affected indices. 949 handle: a `Tensor` of dtype `resource` which points to the variable 950 to be updated. 951 indices: a `Tensor` of integral type representing the indices for 952 which the gradient is nonzero. Indices may be repeated. 953 954 Returns: 955 An `Operation` which updates the value of the variable. 956 """ 957 summed_grad, unique_indices = _deduplicate_indexed_slices( 958 values=grad, indices=indices) 959 return self._resource_apply_sparse(summed_grad, handle, unique_indices) 960 961 def _resource_apply_sparse(self, grad, handle, indices): 962 """Add ops to apply sparse gradients to the variable `handle`. 963 964 Similar to `_apply_sparse`, the `indices` argument to this method has been 965 de-duplicated. Optimizers which deal correctly with non-unique indices may 966 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 967 overhead. 968 969 Args: 970 grad: a `Tensor` representing the gradient for the affected indices. 971 handle: a `Tensor` of dtype `resource` which points to the variable 972 to be updated. 973 indices: a `Tensor` of integral type representing the indices for 974 which the gradient is nonzero. Indices are unique. 975 976 Returns: 977 An `Operation` which updates the value of the variable. 978 """ 979 raise NotImplementedError() 980 981 def _apply_sparse_duplicate_indices(self, grad, var): 982 """Add ops to apply sparse gradients to `var`, with repeated sparse indices. 983 984 Optimizers which override this method must deal with IndexedSlices objects 985 such as the following: 986 987 IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1]) 988 989 The correct interpretation is: 990 991 IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1]) 992 993 Many optimizers deal incorrectly with repeated indices when updating based 994 on sparse gradients (e.g. summing squares rather than squaring the sum, or 995 applying momentum terms multiple times). Adding first is always the correct 996 behavior, so this is enforced here by reconstructing the IndexedSlices to 997 have only unique indices, then calling _apply_sparse. 998 999 Optimizers which deal correctly with repeated indices may instead override 1000 this method to avoid the overhead of summing indices. 1001 1002 Args: 1003 grad: `IndexedSlices`. 1004 var: A `Variable` object. 1005 1006 Returns: 1007 An `Operation`. 1008 """ 1009 summed_values, unique_indices = _deduplicate_indexed_slices( 1010 values=grad.values, indices=grad.indices) 1011 gradient_no_duplicate_indices = ops.IndexedSlices( 1012 indices=unique_indices, 1013 values=summed_values, 1014 dense_shape=grad.dense_shape) 1015 return self._apply_sparse(gradient_no_duplicate_indices, var) 1016 1017 def _apply_sparse(self, grad, var): 1018 """Add ops to apply sparse gradients to `var`. 1019 1020 The IndexedSlices object passed to `grad` in this function is by default 1021 pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate 1022 indices (see its docstring for details). Optimizers which can tolerate or 1023 have correct special cases for duplicate sparse indices may override 1024 `_apply_sparse_duplicate_indices` instead of this function, avoiding that 1025 overhead. 1026 1027 Args: 1028 grad: `IndexedSlices`, with no repeated indices. 1029 var: A `Variable` object. 1030 1031 Returns: 1032 An `Operation`. 1033 """ 1034 raise NotImplementedError() 1035 1036 def _finish(self, update_ops, name_scope): 1037 """Do what is needed to finish the update. 1038 1039 This is called with the `name_scope` using the "name" that 1040 users have chosen for the application of gradients. 1041 1042 Args: 1043 update_ops: List of `Operation` objects to update variables. This list 1044 contains the values returned by the `_apply_dense()` and 1045 `_apply_sparse()` calls. 1046 name_scope: String. Name to use for the returned operation. 1047 1048 Returns: 1049 The operation to apply updates. 1050 """ 1051 return control_flow_ops.group(*update_ops, name=name_scope) 1052 1053 # -------------- 1054 # Utility methods for subclasses. 1055 # -------------- 1056 1057 def _slot_dict(self, slot_name): 1058 """Returns a dict for caching slots created under the given name. 1059 1060 Args: 1061 slot_name: Name for the slot. 1062 1063 Returns: 1064 A dict that maps primary `Variable` objects to the slot created 1065 for that variable, under the given slot name. 1066 """ 1067 named_slots = self._slots.get(slot_name, None) 1068 if named_slots is None: 1069 named_slots = {} 1070 self._slots[slot_name] = named_slots 1071 return named_slots 1072 1073 def _get_or_make_slot(self, var, val, slot_name, op_name): 1074 """Find or create a slot for a variable. 1075 1076 Args: 1077 var: A `Variable` object. 1078 val: A `Tensor`. The initial value of the slot. 1079 slot_name: Name for the slot. 1080 op_name: Name to use when scoping the Variable that 1081 needs to be created for the slot. 1082 1083 Returns: 1084 A `Variable` object. 1085 """ 1086 named_slots = self._slot_dict(slot_name) 1087 if _var_key(var) not in named_slots: 1088 new_slot_variable = slot_creator.create_slot(var, val, op_name) 1089 self._restore_slot_variable( 1090 slot_name=slot_name, variable=var, 1091 slot_variable=new_slot_variable) 1092 named_slots[_var_key(var)] = new_slot_variable 1093 return named_slots[_var_key(var)] 1094 1095 def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, 1096 slot_name, op_name): 1097 """Find or create a slot for a variable, using an Initializer. 1098 1099 Args: 1100 var: A `Variable` object. 1101 initializer: An `Initializer`. The initial value of the slot. 1102 shape: Shape of the initial value of the slot. 1103 dtype: Type of the value of the slot. 1104 slot_name: Name for the slot. 1105 op_name: Name to use when scoping the Variable that 1106 needs to be created for the slot. 1107 1108 Returns: 1109 A `Variable` object. 1110 """ 1111 named_slots = self._slot_dict(slot_name) 1112 if _var_key(var) not in named_slots: 1113 new_slot_variable = slot_creator.create_slot_with_initializer( 1114 var, initializer, shape, dtype, op_name) 1115 self._restore_slot_variable( 1116 slot_name=slot_name, variable=var, 1117 slot_variable=new_slot_variable) 1118 named_slots[_var_key(var)] = new_slot_variable 1119 return named_slots[_var_key(var)] 1120 1121 def _zeros_slot(self, var, slot_name, op_name): 1122 """Find or create a slot initialized with 0.0. 1123 1124 Args: 1125 var: A `Variable` object. 1126 slot_name: Name for the slot. 1127 op_name: Name to use when scoping the Variable that 1128 needs to be created for the slot. 1129 1130 Returns: 1131 A `Variable` object. 1132 """ 1133 named_slots = self._slot_dict(slot_name) 1134 if _var_key(var) not in named_slots: 1135 new_slot_variable = slot_creator.create_zeros_slot(var, op_name) 1136 self._restore_slot_variable( 1137 slot_name=slot_name, variable=var, 1138 slot_variable=new_slot_variable) 1139 named_slots[_var_key(var)] = new_slot_variable 1140 return named_slots[_var_key(var)] 1141 1142 # -------------- 1143 # For implementing the Trackable interface. 1144 # -------------- 1145 1146 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1147 """Restore a newly created slot variable's value.""" 1148 variable_key = _var_key(variable) 1149 deferred_restorations = self._deferred_slot_restorations.get( 1150 slot_name, {}).pop(variable_key, []) 1151 # Iterate over restores, highest restore UID first to minimize the number 1152 # of assignments. 1153 deferred_restorations.sort(key=lambda position: position.restore_uid, 1154 reverse=True) 1155 for checkpoint_position in deferred_restorations: 1156 checkpoint_position.restore(slot_variable) 1157 1158 def _create_or_restore_slot_variable( 1159 self, slot_variable_position, slot_name, variable): 1160 """Restore a slot variable's value, possibly creating it. 1161 1162 Called when a variable which has an associated slot variable is created or 1163 restored. When executing eagerly, we create the slot variable with a 1164 restoring initializer. 1165 1166 No new variables are created when graph building. Instead, 1167 _restore_slot_variable catches these after normal creation and adds restore 1168 ops to the graph. This method is nonetheless important when graph building 1169 for the case when a slot variable has already been created but `variable` 1170 has just been added to a dependency graph (causing us to realize that the 1171 slot variable needs to be restored). 1172 1173 Args: 1174 slot_variable_position: A `trackable._CheckpointPosition` object 1175 indicating the slot variable `Trackable` object to be restored. 1176 slot_name: The name of this `Optimizer`'s slot to restore into. 1177 variable: The variable object this slot is being created for. 1178 """ 1179 named_slots = self._slot_dict(slot_name) 1180 variable_key = _var_key(variable) 1181 slot_variable = named_slots.get(variable_key, None) 1182 if (slot_variable is None and context.executing_eagerly() and 1183 slot_variable_position.is_simple_variable() 1184 # Defer slot variable creation if there is an active variable creator 1185 # scope. Generally we'd like to eagerly create/restore slot variables 1186 # when possible, but this may mean that scopes intended to catch 1187 # `variable` also catch its eagerly created slot variable 1188 # unintentionally (specifically make_template would add a dependency on 1189 # a slot variable if not for this case). Deferring is mostly harmless 1190 # (aside from double initialization), and makes variable creator scopes 1191 # behave the same way they do when graph building. 1192 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access 1193 initializer = trackable.CheckpointInitialValue( 1194 checkpoint_position=slot_variable_position) 1195 slot_variable = self._get_or_make_slot( 1196 var=variable, 1197 val=initializer, 1198 slot_name=slot_name, 1199 op_name=self._name) 1200 # Slot variables are not owned by any one object (because we don't want to 1201 # save the slot variable if the optimizer is saved without the non-slot 1202 # variable, or if the non-slot variable is saved without the optimizer; 1203 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1204 # variable, variable)). So we don't _track_ slot variables anywhere, and 1205 # instead special-case this dependency and otherwise pretend it's a normal 1206 # graph. 1207 if slot_variable is not None: 1208 # If we've either made this slot variable, or if we've pulled out an 1209 # existing slot variable, we should restore it. 1210 slot_variable_position.restore(slot_variable) 1211 else: 1212 # We didn't make the slot variable. Defer restoring until it gets created 1213 # normally. We keep a list rather than the one with the highest restore 1214 # UID in case slot variables have their own dependencies, in which case 1215 # those could differ between restores. 1216 self._deferred_slot_restorations.setdefault( 1217 slot_name, {}).setdefault(variable_key, []).append( 1218 slot_variable_position) 1219 1220 def _call_if_callable(self, param): 1221 """Call the function if param is callable.""" 1222 return param() if callable(param) else param 1223