1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Base class for optimizers.""" 17# pylint: disable=g-bad-name 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gradients 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import state_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.ops import variables 37from tensorflow.python.training import checkpointable 38from tensorflow.python.training import slot_creator 39from tensorflow.python.util import nest 40from tensorflow.python.util.tf_export import tf_export 41 42 43def _get_variable_for(v): 44 """Returns the ResourceVariable responsible for v, or v if not necessary.""" 45 if context.in_eager_mode(): 46 return v 47 if v.op.type == "VarHandleOp": 48 for var in variables.trainable_variables(): 49 if (isinstance(var, resource_variable_ops.ResourceVariable) 50 and var.handle.op is v.op): 51 return var 52 raise ValueError("Got %s but could not locate source variable." % (str(v))) 53 return v 54 55 56def _deduplicate_indexed_slices(values, indices): 57 """Sums `values` associated with any non-unique `indices`. 58 59 Args: 60 values: A `Tensor` with rank >= 1. 61 indices: A one-dimensional integer `Tensor`, indexing into the first 62 dimension of `values` (as in an IndexedSlices object). 63 Returns: 64 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 65 de-duplicated version of `indices` and `summed_values` contains the sum of 66 `values` slices associated with each unique index. 67 """ 68 unique_indices, new_index_positions = array_ops.unique(indices) 69 summed_values = math_ops.unsorted_segment_sum( 70 values, new_index_positions, 71 array_ops.shape(unique_indices)[0]) 72 return (summed_values, unique_indices) 73 74 75def _var_key(var): 76 if context.in_eager_mode(): 77 return var._shared_name # pylint: disable=protected-access 78 return (var.op.graph, var.op.name) 79 80 81class _OptimizableVariable(object): 82 """Interface for abstracting over variables in the optimizers.""" 83 84 @abc.abstractmethod 85 def target(self): 86 """Returns the optimization target for this variable.""" 87 raise NotImplementedError("Calling an abstract method.") 88 89 @abc.abstractmethod 90 def update_op(self, optimizer, g): 91 """Returns the update ops for updating the variable.""" 92 raise NotImplementedError("Calling an abstract method.") 93 94 95class _RefVariableProcessor(_OptimizableVariable): 96 """Processor for Variable.""" 97 98 def __init__(self, v): 99 self._v = v 100 101 def target(self): 102 return self._v._ref() # pylint: disable=protected-access 103 104 def update_op(self, optimizer, g): 105 if isinstance(g, ops.Tensor): 106 update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access 107 if self._v.constraint is not None: 108 with ops.control_dependencies([update_op]): 109 return self._v.assign(self._v.constraint(self._v)) 110 else: 111 return update_op 112 else: 113 assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a " 114 "tensor nor IndexedSlices.") 115 if self._v.constraint is not None: 116 raise RuntimeError( 117 "Cannot use a constraint function on a sparse variable.") 118 # pylint: disable=protected-access 119 return optimizer._apply_sparse_duplicate_indices(g, self._v) 120 121 122class _DenseReadResourceVariableProcessor(_OptimizableVariable): 123 """Processor for dense ResourceVariables.""" 124 125 def __init__(self, v): 126 self._v = v 127 128 def target(self): 129 return self._v 130 131 def update_op(self, optimizer, g): 132 # pylint: disable=protected-access 133 update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0]) 134 if self._v.constraint is not None: 135 with ops.control_dependencies([update_op]): 136 return self._v.assign(self._v.constraint(self._v)) 137 else: 138 return update_op 139 140 141class _DenseResourceVariableProcessor(_OptimizableVariable): 142 """Processor for dense ResourceVariables.""" 143 144 def __init__(self, v): 145 self._v = v 146 147 def target(self): 148 return self._v 149 150 def update_op(self, optimizer, g): 151 # pylint: disable=protected-access 152 if isinstance(g, ops.IndexedSlices): 153 if self._v.constraint is not None: 154 raise RuntimeError( 155 "Cannot use a constraint function on a sparse variable.") 156 return optimizer._resource_apply_sparse_duplicate_indices( 157 g.values, self._v, g.indices) 158 update_op = optimizer._resource_apply_dense(g, self._v) 159 if self._v.constraint is not None: 160 with ops.control_dependencies([update_op]): 161 return self._v.assign(self._v.constraint(self._v)) 162 else: 163 return update_op 164 165 166class _StreamingModelPortProcessor(_OptimizableVariable): 167 """Processor for streaming ModelPorts.""" 168 169 def __init__(self, v): 170 self._v = v 171 172 def target(self): 173 return self._v 174 175 def update_op(self, optimizer, g): 176 return g 177 178 179class _TensorProcessor(_OptimizableVariable): 180 """Processor for ordinary Tensors. 181 182 Even though a Tensor can't really be updated, sometimes it is useful to 183 compute the gradients with respect to a Tensor using the optimizer. Updating 184 the Tensor is, of course, unsupported. 185 """ 186 187 def __init__(self, v): 188 self._v = v 189 190 def target(self): 191 return self._v 192 193 def update_op(self, optimizer, g): 194 raise NotImplementedError("Trying to update a Tensor ", self._v) 195 196 197def _get_processor(v): 198 """The processor of v.""" 199 if context.in_eager_mode(): 200 if isinstance(v, ops.Tensor): 201 return _TensorProcessor(v) 202 else: 203 return _DenseResourceVariableProcessor(v) 204 if v.op.type == "VarHandleOp": 205 return _DenseResourceVariableProcessor(v) 206 if isinstance(v, variables.Variable): 207 return _RefVariableProcessor(v) 208 if v.op.type == "SubmodelPort": 209 return _StreamingModelPortProcessor(v) 210 if isinstance(v, ops.Tensor): 211 return _TensorProcessor(v) 212 raise NotImplementedError("Trying to optimize unsupported type ", v) 213 214 215@tf_export("train.Optimizer") 216class Optimizer(checkpointable.Checkpointable): 217 """Base class for optimizers. 218 219 This class defines the API to add Ops to train a model. You never use this 220 class directly, but instead instantiate one of its subclasses such as 221 `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. 222 223 ### Usage 224 225 ```python 226 # Create an optimizer with the desired parameters. 227 opt = GradientDescentOptimizer(learning_rate=0.1) 228 # Add Ops to the graph to minimize a cost by updating a list of variables. 229 # "cost" is a Tensor, and the list of variables contains tf.Variable 230 # objects. 231 opt_op = opt.minimize(cost, var_list=<list of variables>) 232 ``` 233 234 In the training program you will just have to run the returned Op. 235 236 ```python 237 # Execute opt_op to do one step of training: 238 opt_op.run() 239 ``` 240 241 ### Processing gradients before applying them. 242 243 Calling `minimize()` takes care of both computing the gradients and 244 applying them to the variables. If you want to process the gradients 245 before applying them you can instead use the optimizer in three steps: 246 247 1. Compute the gradients with `compute_gradients()`. 248 2. Process the gradients as you wish. 249 3. Apply the processed gradients with `apply_gradients()`. 250 251 Example: 252 253 ```python 254 # Create an optimizer. 255 opt = GradientDescentOptimizer(learning_rate=0.1) 256 257 # Compute the gradients for a list of variables. 258 grads_and_vars = opt.compute_gradients(loss, <list of variables>) 259 260 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you 261 # need to the 'gradient' part, for example cap them, etc. 262 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] 263 264 # Ask the optimizer to apply the capped gradients. 265 opt.apply_gradients(capped_grads_and_vars) 266 ``` 267 268 ### Gating Gradients 269 270 Both `minimize()` and `compute_gradients()` accept a `gate_gradients` 271 argument that controls the degree of parallelism during the application of 272 the gradients. 273 274 The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. 275 276 <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides 277 the maximum parallelism in execution, at the cost of some non-reproducibility 278 in the results. For example the two gradients of `matmul` depend on the input 279 values: With `GATE_NONE` one of the gradients could be applied to one of the 280 inputs _before_ the other gradient is computed resulting in non-reproducible 281 results. 282 283 <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before 284 they are used. This prevents race conditions for Ops that generate gradients 285 for multiple inputs where the gradients depend on the inputs. 286 287 <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed 288 before any one of them is used. This provides the least parallelism but can 289 be useful if you want to process all gradients before applying any of them. 290 291 ### Slots 292 293 Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` 294 allocate and manage additional variables associated with the variables to 295 train. These are called <i>Slots</i>. Slots have names and you can ask the 296 optimizer for the names of the slots that it uses. Once you have a slot name 297 you can ask the optimizer for the variable it created to hold the slot value. 298 299 This can be useful if you want to log debug a training algorithm, report stats 300 about the slots, etc. 301 """ 302 303 # Values for gate_gradients. 304 GATE_NONE = 0 305 GATE_OP = 1 306 GATE_GRAPH = 2 307 308 def __init__(self, use_locking, name): 309 """Create a new Optimizer. 310 311 This must be called by the constructors of subclasses. 312 313 Args: 314 use_locking: Bool. If True apply use locks to prevent concurrent updates 315 to variables. 316 name: A non-empty string. The name to use for accumulators created 317 for the optimizer. 318 319 Raises: 320 ValueError: If name is malformed. 321 """ 322 if not name: 323 raise ValueError("Must specify the optimizer name") 324 self._use_locking = use_locking 325 self._name = name 326 # Dictionary of slots. 327 # {slot_name : 328 # {_var_key(variable_to_train): slot_for_the_variable, ... }, 329 # ... } 330 self._slots = {} 331 self._non_slot_dict = {} 332 # For implementing Checkpointable. Stores information about how to restore 333 # slot variables which have not yet been created 334 # (checkpointable._CheckpointPosition objects). 335 # {slot_name : 336 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 337 # ... } 338 self._deferred_slot_restorations = {} 339 340 def get_name(self): 341 return self._name 342 343 def minimize(self, loss, global_step=None, var_list=None, 344 gate_gradients=GATE_OP, aggregation_method=None, 345 colocate_gradients_with_ops=False, name=None, 346 grad_loss=None): 347 """Add operations to minimize `loss` by updating `var_list`. 348 349 This method simply combines calls `compute_gradients()` and 350 `apply_gradients()`. If you want to process the gradient before applying 351 them call `compute_gradients()` and `apply_gradients()` explicitly instead 352 of using this function. 353 354 Args: 355 loss: A `Tensor` containing the value to minimize. 356 global_step: Optional `Variable` to increment by one after the 357 variables have been updated. 358 var_list: Optional list or tuple of `Variable` objects to update to 359 minimize `loss`. Defaults to the list of variables collected in 360 the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. 361 gate_gradients: How to gate the computation of gradients. Can be 362 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 363 aggregation_method: Specifies the method used to combine gradient terms. 364 Valid values are defined in the class `AggregationMethod`. 365 colocate_gradients_with_ops: If True, try colocating gradients with 366 the corresponding op. 367 name: Optional name for the returned operation. 368 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 369 370 Returns: 371 An Operation that updates the variables in `var_list`. If `global_step` 372 was not `None`, that operation also increments `global_step`. 373 374 Raises: 375 ValueError: If some of the variables are not `Variable` objects. 376 377 @compatibility(eager) 378 When eager execution is enabled, `loss` should be a Python function that 379 takes elements of `var_list` as arguments and computes the value to be 380 minimized. If `var_list` is None, `loss` should take no arguments. 381 Minimization (and gradient computation) is done with respect to the 382 elements of `var_list` if not None, else with respect to any trainable 383 variables created during the execution of the `loss` function. 384 `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and 385 `grad_loss` are ignored when eager execution is enabled. 386 @end_compatibility 387 """ 388 grads_and_vars = self.compute_gradients( 389 loss, var_list=var_list, gate_gradients=gate_gradients, 390 aggregation_method=aggregation_method, 391 colocate_gradients_with_ops=colocate_gradients_with_ops, 392 grad_loss=grad_loss) 393 394 vars_with_grad = [v for g, v in grads_and_vars if g is not None] 395 if not vars_with_grad: 396 raise ValueError( 397 "No gradients provided for any variable, check your graph for ops" 398 " that do not support gradients, between variables %s and loss %s." % 399 ([str(v) for _, v in grads_and_vars], loss)) 400 401 return self.apply_gradients(grads_and_vars, global_step=global_step, 402 name=name) 403 404 def compute_gradients(self, loss, var_list=None, 405 gate_gradients=GATE_OP, 406 aggregation_method=None, 407 colocate_gradients_with_ops=False, 408 grad_loss=None): 409 """Compute gradients of `loss` for the variables in `var_list`. 410 411 This is the first part of `minimize()`. It returns a list 412 of (gradient, variable) pairs where "gradient" is the gradient 413 for "variable". Note that "gradient" can be a `Tensor`, an 414 `IndexedSlices`, or `None` if there is no gradient for the 415 given variable. 416 417 Args: 418 loss: A Tensor containing the value to minimize or a callable taking 419 no arguments which returns the value to minimize. When eager execution 420 is enabled it must be a callable. 421 var_list: Optional list or tuple of `tf.Variable` to update to minimize 422 `loss`. Defaults to the list of variables collected in the graph 423 under the key `GraphKeys.TRAINABLE_VARIABLES`. 424 gate_gradients: How to gate the computation of gradients. Can be 425 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 426 aggregation_method: Specifies the method used to combine gradient terms. 427 Valid values are defined in the class `AggregationMethod`. 428 colocate_gradients_with_ops: If True, try colocating gradients with 429 the corresponding op. 430 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 431 432 Returns: 433 A list of (gradient, variable) pairs. Variable is always present, but 434 gradient can be `None`. 435 436 Raises: 437 TypeError: If `var_list` contains anything else than `Variable` objects. 438 ValueError: If some arguments are invalid. 439 RuntimeError: If called with eager execution enabled and `loss` is 440 not callable. 441 442 @compatibility(eager) 443 When eager execution is enabled, `gate_gradients`, `aggregation_method`, 444 and `colocate_gradients_with_ops` are ignored. 445 @end_compatibility 446 """ 447 if callable(loss): 448 with backprop.GradientTape() as tape: 449 if var_list is not None: 450 tape.watch(var_list) 451 loss_value = loss() 452 if var_list is None: 453 var_list = tape.watched_variables() 454 grads = tape.gradient(loss_value, var_list, grad_loss) 455 return list(zip(grads, var_list)) 456 if context.in_eager_mode(): 457 raise RuntimeError( 458 "`loss` passed to Optimizer.compute_gradients should " 459 "be a function when eager execution is enabled.") 460 if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, 461 Optimizer.GATE_GRAPH]: 462 raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " 463 "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % 464 gate_gradients) 465 self._assert_valid_dtypes([loss]) 466 if grad_loss is not None: 467 self._assert_valid_dtypes([grad_loss]) 468 if var_list is None: 469 var_list = ( 470 variables.trainable_variables() + 471 ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 472 else: 473 var_list = nest.flatten(var_list) 474 # pylint: disable=protected-access 475 var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) 476 # pylint: enable=protected-access 477 processors = [_get_processor(v) for v in var_list] 478 if not var_list: 479 raise ValueError("No variables to optimize.") 480 var_refs = [p.target() for p in processors] 481 grads = gradients.gradients( 482 loss, var_refs, grad_ys=grad_loss, 483 gate_gradients=(gate_gradients == Optimizer.GATE_OP), 484 aggregation_method=aggregation_method, 485 colocate_gradients_with_ops=colocate_gradients_with_ops) 486 if gate_gradients == Optimizer.GATE_GRAPH: 487 grads = control_flow_ops.tuple(grads) 488 grads_and_vars = list(zip(grads, var_list)) 489 self._assert_valid_dtypes( 490 [v for g, v in grads_and_vars 491 if g is not None and v.dtype != dtypes.resource]) 492 return grads_and_vars 493 494 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 495 """Apply gradients to variables. 496 497 This is the second part of `minimize()`. It returns an `Operation` that 498 applies gradients. 499 500 Args: 501 grads_and_vars: List of (gradient, variable) pairs as returned by 502 `compute_gradients()`. 503 global_step: Optional `Variable` to increment by one after the 504 variables have been updated. 505 name: Optional name for the returned operation. Default to the 506 name passed to the `Optimizer` constructor. 507 508 Returns: 509 An `Operation` that applies the specified gradients. If `global_step` 510 was not None, that operation also increments `global_step`. 511 512 Raises: 513 TypeError: If `grads_and_vars` is malformed. 514 ValueError: If none of the variables have gradients. 515 """ 516 # This is a default implementation of apply_gradients() that can be shared 517 # by most optimizers. It relies on the subclass implementing the following 518 # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). 519 520 grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. 521 if not grads_and_vars: 522 raise ValueError("No variables provided.") 523 converted_grads_and_vars = [] 524 for g, v in grads_and_vars: 525 if g is not None: 526 try: 527 # Convert the grad to Tensor or IndexedSlices if necessary. 528 g = ops.convert_to_tensor_or_indexed_slices(g) 529 except TypeError: 530 raise TypeError( 531 "Gradient must be convertible to a Tensor" 532 " or IndexedSlices, or None: %s" % g) 533 if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): 534 raise TypeError( 535 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 536 p = _get_processor(v) 537 converted_grads_and_vars.append((g, v, p)) 538 539 converted_grads_and_vars = tuple(converted_grads_and_vars) 540 var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] 541 if not var_list: 542 raise ValueError("No gradients provided for any variable: %s." % 543 ([str(v) for _, _, v in converted_grads_and_vars],)) 544 with ops.init_scope(): 545 self._create_slots([_get_variable_for(v) for v in var_list]) 546 update_ops = [] 547 with ops.name_scope(name, self._name) as name: 548 self._prepare() 549 for grad, var, processor in converted_grads_and_vars: 550 if grad is None: 551 continue 552 # We colocate all ops created in _apply_dense or _apply_sparse 553 # on the same device as the variable. 554 # TODO(apassos): figure out how to get the variable name here. 555 scope_name = var.op.name if context.in_graph_mode() else "" 556 with ops.name_scope("update_" + scope_name), ops.colocate_with(var): 557 update_ops.append(processor.update_op(self, grad)) 558 if global_step is None: 559 apply_updates = self._finish(update_ops, name) 560 else: 561 with ops.control_dependencies([self._finish(update_ops, "update")]): 562 with ops.colocate_with(global_step): 563 if isinstance(global_step, resource_variable_ops.ResourceVariable): 564 # TODO(apassos): the implicit read in assign_add is slow; consider 565 # making it less so. 566 apply_updates = resource_variable_ops.assign_add_variable_op( 567 global_step.handle, 568 ops.convert_to_tensor(1, dtype=global_step.dtype), 569 name=name) 570 else: 571 apply_updates = state_ops.assign_add(global_step, 1, name=name) 572 573 if context.in_graph_mode(): 574 if isinstance(apply_updates, ops.Tensor): 575 apply_updates = apply_updates.op 576 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 577 if apply_updates not in train_op: 578 train_op.append(apply_updates) 579 580 return apply_updates 581 582 def get_slot(self, var, name): 583 """Return a slot named `name` created for `var` by the Optimizer. 584 585 Some `Optimizer` subclasses use additional variables. For example 586 `Momentum` and `Adagrad` use variables to accumulate updates. This method 587 gives access to these `Variable` objects if for some reason you need them. 588 589 Use `get_slot_names()` to get the list of slot names created by the 590 `Optimizer`. 591 592 Args: 593 var: A variable passed to `minimize()` or `apply_gradients()`. 594 name: A string. 595 596 Returns: 597 The `Variable` for the slot if it was created, `None` otherwise. 598 """ 599 named_slots = self._slots.get(name, None) 600 if not named_slots: 601 return None 602 return named_slots.get(_var_key(var), None) 603 604 def get_slot_names(self): 605 """Return a list of the names of slots created by the `Optimizer`. 606 607 See `get_slot()`. 608 609 Returns: 610 A list of strings. 611 """ 612 return sorted(self._slots.keys()) 613 614 def variables(self): 615 """A list of variables which encode the current state of `Optimizer`. 616 617 Includes slot variables and additional global variables created by the 618 optimizer in the current default graph. 619 620 Returns: 621 A list of variables. 622 """ 623 executing_eagerly = context.in_eager_mode() 624 current_graph = ops.get_default_graph() 625 626 def _from_current_graph(variable): 627 if executing_eagerly: 628 # No variable.op in eager mode. We don't expect lots of eager graphs, 629 # but behavior should be consistent with graph mode. 630 return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access 631 else: 632 return variable.op.graph is current_graph 633 634 optimizer_variables = [v for v in self._non_slot_variables() 635 if _from_current_graph(v)] 636 for _, variable_dict in self._slots.items(): 637 for _, slot_for_variable in variable_dict.items(): 638 if _from_current_graph(slot_for_variable): 639 optimizer_variables.append(slot_for_variable) 640 # Sort variables by name so that the return is deterministic. 641 return sorted(optimizer_variables, key=lambda v: v.name) 642 643 def _create_non_slot_variable(self, initial_value, name, colocate_with): 644 """Add an extra variable, not associated with a slot.""" 645 if context.in_graph_mode(): 646 graph = colocate_with.graph 647 else: 648 graph = None 649 650 key = (name, graph) 651 v = self._non_slot_dict.get(key, None) 652 if v is None: 653 with ops.colocate_with(colocate_with): 654 v = variable_scope.variable(initial_value, name=name, trainable=False) 655 self._non_slot_dict[key] = v 656 657 return v 658 659 def _get_non_slot_variable(self, name, graph=None): 660 return self._non_slot_dict.get((name, graph), None) 661 662 def _non_slot_variables(self): 663 """Additional variables created by the `Optimizer`. 664 665 Returns: 666 A list or tuple of variables. 667 """ 668 return self._non_slot_dict.values() 669 670 def _assert_valid_dtypes(self, tensors): 671 """Asserts tensors are all valid types (see `_valid_dtypes`). 672 673 Args: 674 tensors: Tensors to check. 675 676 Raises: 677 ValueError: If any tensor is not a valid type. 678 """ 679 valid_dtypes = self._valid_dtypes() 680 for t in tensors: 681 dtype = t.dtype.base_dtype 682 if dtype not in valid_dtypes: 683 raise ValueError( 684 "Invalid type %r for %s, expected: %s." % ( 685 dtype, t.name, [v for v in valid_dtypes])) 686 687 # -------------- 688 # Methods to be implemented by subclasses if they want to use the 689 # inherited implementation of apply_gradients() or compute_gradients(). 690 # -------------- 691 def _valid_dtypes(self): 692 """Valid types for loss, variables and gradients. 693 694 Subclasses should override to allow other float types. 695 696 Returns: 697 Valid types for loss, variables and gradients. 698 """ 699 return set( 700 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]) 701 702 def _create_slots(self, var_list): 703 """Create all slots needed by the variables. 704 705 Args: 706 var_list: A list of `Variable` objects. 707 """ 708 # No slots needed by default 709 pass 710 711 def _prepare(self): 712 """Create all needed tensors before applying gradients. 713 714 This is called with the name_scope using the "name" that 715 users have chosen for the application of gradients. 716 """ 717 pass 718 719 def _apply_dense(self, grad, var): 720 """Add ops to apply dense gradients to `var`. 721 722 Args: 723 grad: A `Tensor`. 724 var: A `Variable` object. 725 726 Returns: 727 An `Operation`. 728 """ 729 raise NotImplementedError() 730 731 def _resource_apply_dense(self, grad, handle): 732 """Add ops to apply dense gradients to the variable `handle`. 733 734 Args: 735 grad: a `Tensor` representing the gradient. 736 handle: a `Tensor` of dtype `resource` which points to the variable 737 to be updated. 738 739 Returns: 740 An `Operation` which updates the value of the variable. 741 """ 742 raise NotImplementedError() 743 744 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): 745 """Add ops to apply sparse gradients to `handle`, with repeated indices. 746 747 Optimizers which override this method must deal with repeated indices. See 748 the docstring of `_apply_sparse_duplicate_indices` for details. By default 749 the correct behavior, to sum non-unique indices and their associated 750 gradients, is enforced by first pre-processing `grad` and `indices` and 751 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 752 with duplicate indices may instead override this method to avoid the 753 overhead of summing. 754 755 Args: 756 grad: a `Tensor` representing the gradient for the affected indices. 757 handle: a `Tensor` of dtype `resource` which points to the variable 758 to be updated. 759 indices: a `Tensor` of integral type representing the indices for 760 which the gradient is nonzero. Indices may be repeated. 761 762 Returns: 763 An `Operation` which updates the value of the variable. 764 """ 765 summed_grad, unique_indices = _deduplicate_indexed_slices( 766 values=grad, indices=indices) 767 return self._resource_apply_sparse(summed_grad, handle, unique_indices) 768 769 def _resource_apply_sparse(self, grad, handle, indices): 770 """Add ops to apply sparse gradients to the variable `handle`. 771 772 Similar to `_apply_sparse`, the `indices` argument to this method has been 773 de-duplicated. Optimizers which deal correctly with non-unique indices may 774 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 775 overhead. 776 777 Args: 778 grad: a `Tensor` representing the gradient for the affected indices. 779 handle: a `Tensor` of dtype `resource` which points to the variable 780 to be updated. 781 indices: a `Tensor` of integral type representing the indices for 782 which the gradient is nonzero. Indices are unique. 783 784 Returns: 785 An `Operation` which updates the value of the variable. 786 """ 787 raise NotImplementedError() 788 789 def _apply_sparse_duplicate_indices(self, grad, var): 790 """Add ops to apply sparse gradients to `var`, with repeated sparse indices. 791 792 Optimizers which override this method must deal with IndexedSlices objects 793 such as the following: 794 795 IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1]) 796 797 The correct interpretation is: 798 799 IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1]) 800 801 Many optimizers deal incorrectly with repeated indices when updating based 802 on sparse gradients (e.g. summing squares rather than squaring the sum, or 803 applying momentum terms multiple times). Adding first is always the correct 804 behavior, so this is enforced here by reconstructing the IndexedSlices to 805 have only unique indices, then calling _apply_sparse. 806 807 Optimizers which deal correctly with repeated indices may instead override 808 this method to avoid the overhead of summing indices. 809 810 Args: 811 grad: `IndexedSlices`. 812 var: A `Variable` object. 813 814 Returns: 815 An `Operation`. 816 """ 817 summed_values, unique_indices = _deduplicate_indexed_slices( 818 values=grad.values, indices=grad.indices) 819 gradient_no_duplicate_indices = ops.IndexedSlices( 820 indices=unique_indices, 821 values=summed_values, 822 dense_shape=grad.dense_shape) 823 return self._apply_sparse(gradient_no_duplicate_indices, var) 824 825 def _apply_sparse(self, grad, var): 826 """Add ops to apply sparse gradients to `var`. 827 828 The IndexedSlices object passed to `grad` in this function is by default 829 pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate 830 indices (see its docstring for details). Optimizers which can tolerate or 831 have correct special cases for duplicate sparse indices may override 832 `_apply_sparse_duplicate_indices` instead of this function, avoiding that 833 overhead. 834 835 Args: 836 grad: `IndexedSlices`, with no repeated indices. 837 var: A `Variable` object. 838 839 Returns: 840 An `Operation`. 841 """ 842 raise NotImplementedError() 843 844 def _finish(self, update_ops, name_scope): 845 """Do what is needed to finish the update. 846 847 This is called with the `name_scope` using the "name" that 848 users have chosen for the application of gradients. 849 850 Args: 851 update_ops: List of `Operation` objects to update variables. This list 852 contains the values returned by the `_apply_dense()` and 853 `_apply_sparse()` calls. 854 name_scope: String. Name to use for the returned operation. 855 856 Returns: 857 The operation to apply updates. 858 """ 859 return control_flow_ops.group(*update_ops, name=name_scope) 860 861 # -------------- 862 # Utility methods for subclasses. 863 # -------------- 864 865 def _slot_dict(self, slot_name): 866 """Returns a dict for caching slots created under the given name. 867 868 Args: 869 slot_name: Name for the slot. 870 871 Returns: 872 A dict that maps primary `Variable` objects to the slot created 873 for that variable, under the given slot name. 874 """ 875 named_slots = self._slots.get(slot_name, None) 876 if named_slots is None: 877 named_slots = {} 878 self._slots[slot_name] = named_slots 879 return named_slots 880 881 def _get_or_make_slot(self, var, val, slot_name, op_name): 882 """Find or create a slot for a variable. 883 884 Args: 885 var: A `Variable` object. 886 val: A `Tensor`. The initial value of the slot. 887 slot_name: Name for the slot. 888 op_name: Name to use when scoping the Variable that 889 needs to be created for the slot. 890 891 Returns: 892 A `Variable` object. 893 """ 894 named_slots = self._slot_dict(slot_name) 895 if _var_key(var) not in named_slots: 896 new_slot_variable = slot_creator.create_slot(var, val, op_name) 897 self._restore_slot_variable( 898 slot_name=slot_name, variable=var, 899 slot_variable=new_slot_variable) 900 named_slots[_var_key(var)] = new_slot_variable 901 return named_slots[_var_key(var)] 902 903 def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, 904 slot_name, op_name): 905 """Find or create a slot for a variable, using an Initializer. 906 907 Args: 908 var: A `Variable` object. 909 initializer: An `Initializer`. The initial value of the slot. 910 shape: Shape of the initial value of the slot. 911 dtype: Type of the value of the slot. 912 slot_name: Name for the slot. 913 op_name: Name to use when scoping the Variable that 914 needs to be created for the slot. 915 916 Returns: 917 A `Variable` object. 918 """ 919 named_slots = self._slot_dict(slot_name) 920 if _var_key(var) not in named_slots: 921 new_slot_variable = slot_creator.create_slot_with_initializer( 922 var, initializer, shape, dtype, op_name) 923 self._restore_slot_variable( 924 slot_name=slot_name, variable=var, 925 slot_variable=new_slot_variable) 926 named_slots[_var_key(var)] = new_slot_variable 927 return named_slots[_var_key(var)] 928 929 def _zeros_slot(self, var, slot_name, op_name): 930 """Find or create a slot initialized with 0.0. 931 932 Args: 933 var: A `Variable` object. 934 slot_name: Name for the slot. 935 op_name: Name to use when scoping the Variable that 936 needs to be created for the slot. 937 938 Returns: 939 A `Variable` object. 940 """ 941 named_slots = self._slot_dict(slot_name) 942 if _var_key(var) not in named_slots: 943 new_slot_variable = slot_creator.create_zeros_slot(var, op_name) 944 self._restore_slot_variable( 945 slot_name=slot_name, variable=var, 946 slot_variable=new_slot_variable) 947 named_slots[_var_key(var)] = new_slot_variable 948 return named_slots[_var_key(var)] 949 950 # -------------- 951 # For implementing the Checkpointable interface. 952 # -------------- 953 954 def _restore_slot_variable(self, slot_name, variable, slot_variable): 955 """Restore a newly created slot variable's value.""" 956 variable_key = _var_key(variable) 957 deferred_restorations = self._deferred_slot_restorations.get( 958 slot_name, {}).pop(variable_key, []) 959 # Iterate over restores, highest restore UID first to minimize the number 960 # of assignments. 961 deferred_restorations.sort(key=lambda position: position.restore_uid, 962 reverse=True) 963 for checkpoint_position in deferred_restorations: 964 checkpoint_position.restore(slot_variable) 965 966 def _create_or_restore_slot_variable( 967 self, slot_variable_position, slot_name, variable): 968 """Restore a slot variable's value, possibly creating it. 969 970 Called when a variable which has an associated slot variable is created or 971 restored. When executing eagerly, we create the slot variable with a 972 restoring initializer. 973 974 No new variables are created when graph building. Instead, 975 _restore_slot_variable catches these after normal creation and adds restore 976 ops to the graph. This method is nonetheless important when graph building 977 for the case when a slot variable has already been created but `variable` 978 has just been added to a dependency graph (causing us to realize that the 979 slot variable needs to be restored). 980 981 Args: 982 slot_variable_position: A `checkpointable._CheckpointPosition` object 983 indicating the slot variable `Checkpointable` object to be restored. 984 slot_name: The name of this `Optimizer`'s slot to restore into. 985 variable: The variable object this slot is being created for. 986 """ 987 named_slots = self._slot_dict(slot_name) 988 variable_key = _var_key(variable) 989 slot_variable = named_slots.get(variable_key, None) 990 if (slot_variable is None 991 and context.in_eager_mode() 992 and slot_variable_position.is_simple_variable()): 993 initializer = checkpointable.CheckpointInitialValue( 994 checkpoint_position=slot_variable_position) 995 slot_variable = self._get_or_make_slot( 996 var=variable, 997 val=initializer, 998 slot_name=slot_name, 999 op_name=self._name) 1000 # Slot variables are not owned by any one object (because we don't want to 1001 # save the slot variable if the optimizer is saved without the non-slot 1002 # variable, or if the non-slot variable is saved without the optimizer; 1003 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1004 # variable, variable)). So we don't _track_ slot variables anywhere, and 1005 # instead special-case this dependency and otherwise pretend it's a normal 1006 # graph. 1007 if slot_variable is not None: 1008 # If we've either made this slot variable, or if we've pulled out an 1009 # existing slot variable, we should restore it. 1010 slot_variable_position.restore(slot_variable) 1011 else: 1012 # We didn't make the slot variable. Defer restoring until it gets created 1013 # normally. We keep a list rather than the one with the highest restore 1014 # UID in case slot variables have their own dependencies, in which case 1015 # those could differ between restores. 1016 self._deferred_slot_restorations.setdefault( 1017 slot_name, {}).setdefault(variable_key, []).append( 1018 slot_variable_position) 1019