1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Version 2 of class Optimizer.""" 17# pylint: disable=g-bad-name 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import abc 24 25import six 26 27from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 28from tensorflow.python.distribute import reduce_util as ds_reduce_util 29from tensorflow.python.eager import backprop 30from tensorflow.python.eager import context 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import gradients 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import variable_scope 38from tensorflow.python.ops import variables 39from tensorflow.python.training import optimizer as optimizer_v1 40from tensorflow.python.training import slot_creator 41from tensorflow.python.training.tracking import base as trackable 42from tensorflow.python.util import nest 43 44 45@six.add_metaclass(abc.ABCMeta) 46class _OptimizableVariable(object): 47 """Interface for abstracting over variables in the optimizers.""" 48 49 @abc.abstractmethod 50 def target(self): 51 """Returns the optimization target for this variable.""" 52 raise NotImplementedError("Calling an abstract method.") 53 54 @abc.abstractmethod 55 def update_op(self, optimizer, g, *args): 56 """Returns the update ops for updating the variable.""" 57 raise NotImplementedError("Calling an abstract method.") 58 59 60class _RefVariableProcessor(_OptimizableVariable): 61 """Processor for Variable.""" 62 63 def __init__(self, v): 64 self._v = v 65 66 def target(self): 67 return self._v._ref() # pylint: disable=protected-access 68 69 def update_op(self, optimizer, g, *args): 70 if isinstance(g, ops.Tensor): 71 update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access 72 if self._v.constraint is not None: 73 with ops.control_dependencies([update_op]): 74 return self._v.assign(self._v.constraint(self._v)) 75 else: 76 return update_op 77 else: 78 assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a " 79 "tensor nor IndexedSlices.") 80 if self._v.constraint is not None: 81 raise RuntimeError( 82 "Cannot use a constraint function on a sparse variable.") 83 # pylint: disable=protected-access 84 return optimizer._apply_sparse_duplicate_indices(g, self._v, *args) 85 86 87class _DenseReadResourceVariableProcessor(_OptimizableVariable): 88 """Processor for dense ResourceVariables.""" 89 90 def __init__(self, v): 91 self._v = v 92 93 def target(self): 94 return self._v 95 96 def update_op(self, optimizer, g, *args): 97 # pylint: disable=protected-access 98 update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args) 99 if self._v.constraint is not None: 100 with ops.control_dependencies([update_op]): 101 return self._v.assign(self._v.constraint(self._v)) 102 else: 103 return update_op 104 105 106class _DenseResourceVariableProcessor(_OptimizableVariable): 107 """Processor for dense ResourceVariables.""" 108 109 def __init__(self, v): 110 self._v = v 111 112 def target(self): 113 return self._v 114 115 def update_op(self, optimizer, g, *args): 116 # pylint: disable=protected-access 117 if isinstance(g, ops.IndexedSlices): 118 if self._v.constraint is not None: 119 raise RuntimeError( 120 "Cannot use a constraint function on a sparse variable.") 121 return optimizer._resource_apply_sparse_duplicate_indices( 122 g.values, self._v, g.indices, *args) 123 update_op = optimizer._resource_apply_dense(g, self._v, *args) 124 if self._v.constraint is not None: 125 with ops.control_dependencies([update_op]): 126 return self._v.assign(self._v.constraint(self._v)) 127 else: 128 return update_op 129 130 131class _TensorProcessor(_OptimizableVariable): 132 """Processor for ordinary Tensors. 133 134 Even though a Tensor can't really be updated, sometimes it is useful to 135 compute the gradients with respect to a Tensor using the optimizer. Updating 136 the Tensor is, of course, unsupported. 137 """ 138 139 def __init__(self, v): 140 self._v = v 141 142 def target(self): 143 return self._v 144 145 def update_op(self, optimizer, g, *args): 146 raise NotImplementedError("Trying to update a Tensor ", self._v) 147 148 149def _get_processor(v): 150 """The processor of v.""" 151 if context.executing_eagerly(): 152 if isinstance(v, ops.Tensor): 153 return _TensorProcessor(v) 154 else: 155 return _DenseResourceVariableProcessor(v) 156 if v.op.type == "VarHandleOp": 157 return _DenseResourceVariableProcessor(v) 158 if isinstance(v, variables.Variable): 159 return _RefVariableProcessor(v) 160 if isinstance(v, ops.Tensor): 161 return _TensorProcessor(v) 162 raise NotImplementedError("Trying to optimize unsupported type ", v) 163 164 165def _var_key_v2(var): 166 """Key for representing a primary variable, for looking up slots.""" 167 # pylint: disable=protected-access 168 if hasattr(var, "_distributed_container"): 169 distributed_container = var._distributed_container() 170 assert distributed_container is not None 171 if context.executing_eagerly(): 172 return distributed_container._unique_id 173 return distributed_container._shared_name 174 if context.executing_eagerly(): 175 return var._unique_id 176 return var.op.name 177 178 179def _resolve(value, name): 180 if callable(value): 181 value = value() 182 return ops.convert_to_tensor(value, name=name) 183 184 185def _is_dynamic(value): 186 """Returns true if __init__ arg `value` should be re-evaluated each step.""" 187 if callable(value): 188 return True 189 # Don't need to do anything special in graph mode, since dynamic values 190 # will propagate correctly automatically. 191 # TODO(josh11b): Add per-device caching across steps using variables for 192 # truly static values once we add distributed support. 193 if context.executing_eagerly() and isinstance( 194 value, resource_variable_ops.ResourceVariable): 195 return True 196 return False 197 198 199class _OptimizerV2State(object): 200 """Holds per-graph and per-step optimizer state. 201 202 Use _init_with_static_hyper() to create the state for a graph, and then 203 _copy_with_dynamic_hyper() to convert that to state for a particular step. 204 The difference between the two is that the former only has hyper 205 parameter values that are static and the latter also has values that 206 can change every step (according to _is_dynamic()). 207 """ 208 209 def __init__(self, op_name): 210 self._op_name = op_name 211 212 def _init_with_static_hyper(self, hyper): 213 """Initialize a fresh state object from hyper dict.""" 214 # self._hyper contains a dict from name to a dict with the Tensor values. 215 # This dict starts with a single item with key "None" with the hyper 216 # parameter value converted to a Tensor. Other items have dtype keys 217 # with that Tensor cast to that dtype. 218 with ops.init_scope(): 219 self._hyper = { 220 name: { 221 None: ops.convert_to_tensor(value, name=name) 222 } for name, (dynamic, value) in sorted(hyper.items()) if not dynamic 223 } 224 self._slots = {} 225 self._non_slot_dict = {} 226 # Extra state to help Optimizers implement Trackable. Holds information 227 # about variables which will be restored as soon as they're created. 228 self._deferred_dependencies = {} # Non-slot variables 229 self._deferred_slot_restorations = {} # Slot variables 230 231 def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices): 232 """Create a new state object for a particular step.""" 233 ret = _OptimizerV2State(self._op_name) 234 # pylint: disable=protected-access 235 ret._slots = self._slots 236 ret._non_slot_dict = self._non_slot_dict 237 ret._deferred_dependencies = self._deferred_dependencies 238 ret._deferred_slot_restorations = self._deferred_slot_restorations 239 ret._hyper = { 240 name: { 241 None: _resolve(value, name) 242 } for name, (dynamic, value) in sorted(hyper.items()) if dynamic 243 } 244 ret._hyper.update(self._hyper) 245 ret._non_slot_devices = non_slot_devices 246 ret._distribution = distribution 247 return ret 248 249 def _variables(self): 250 """Returns a list of all variables held by self.""" 251 optimizer_variables = list(self._non_slot_dict.values()) 252 for variable_dict in self._slots.values(): 253 for slot_for_variable in variable_dict.values(): 254 optimizer_variables.append(slot_for_variable) 255 # Sort variables by name so that the return is deterministic. 256 return sorted(optimizer_variables, key=lambda v: v.name) 257 258 def _slot_dict(self, slot_name): 259 """Returns a dict for caching slots created under the given name. 260 261 Args: 262 slot_name: Name for the slot. 263 264 Returns: 265 A dict that maps primary `Variable` objects to the slot created 266 for that variable, under the given slot name. 267 """ 268 named_slots = self._slots.get(slot_name, None) 269 if named_slots is None: 270 named_slots = {} 271 self._slots[slot_name] = named_slots 272 return named_slots 273 274 def create_slot(self, var, val, slot_name, optional_op_name=None): 275 """Find or create a slot for a variable. 276 277 Args: 278 var: A `Variable` object. 279 val: A `Tensor`. The initial value of the slot. 280 slot_name: Name for the slot. 281 optional_op_name: Name to use when scoping the Variable that needs to be 282 created for the slot. 283 284 Returns: 285 A `Variable` object. 286 """ 287 named_slots = self._slot_dict(slot_name) 288 var_key = _var_key_v2(var) 289 if var_key not in named_slots: 290 new_slot_variable = slot_creator.create_slot( 291 var, val, optional_op_name or self._op_name) 292 self._restore_slot_variable( 293 slot_name=slot_name, variable=var, slot_variable=new_slot_variable) 294 named_slots[var_key] = new_slot_variable 295 return named_slots[var_key] 296 297 def create_slot_with_initializer(self, 298 var, 299 initializer, 300 shape, 301 dtype, 302 slot_name, 303 optional_op_name=None): 304 """Find or create a slot for a variable, using an Initializer. 305 306 Args: 307 var: A `Variable` object. 308 initializer: An `Initializer`. The initial value of the slot. 309 shape: Shape of the initial value of the slot. 310 dtype: Type of the value of the slot. 311 slot_name: Name for the slot. 312 optional_op_name: Name to use when scoping the Variable that needs to be 313 created for the slot. 314 315 Returns: 316 A `Variable` object. 317 """ 318 named_slots = self._slot_dict(slot_name) 319 var_key = _var_key_v2(var) 320 if var_key not in named_slots: 321 new_slot_variable = slot_creator.create_slot_with_initializer( 322 var, initializer, shape, dtype, optional_op_name or self._op_name) 323 self._restore_slot_variable( 324 slot_name=slot_name, variable=var, slot_variable=new_slot_variable) 325 named_slots[var_key] = new_slot_variable 326 return named_slots[var_key] 327 328 def zeros_slot(self, var, slot_name, optional_op_name=None): 329 """Find or create a slot initialized with 0.0. 330 331 Args: 332 var: A `Variable` object. 333 slot_name: Name for the slot. 334 optional_op_name: Name to use when scoping the Variable that needs to be 335 created for the slot. 336 337 Returns: 338 A `Variable` object. 339 """ 340 named_slots = self._slot_dict(slot_name) 341 var_key = _var_key_v2(var) 342 if var_key not in named_slots: 343 new_slot_variable = slot_creator.create_zeros_slot( 344 var, optional_op_name or self._op_name) 345 self._restore_slot_variable( 346 slot_name=slot_name, variable=var, slot_variable=new_slot_variable) 347 named_slots[var_key] = new_slot_variable 348 return named_slots[var_key] 349 350 def _create_or_restore_slot_variable(self, 351 slot_variable_position, 352 slot_name, 353 variable, 354 optional_op_name=None): 355 """Restore a slot variable's value, possibly creating it. 356 357 Called when a variable which has an associated slot variable is created or 358 restored. When executing eagerly, we create the slot variable with a 359 restoring initializer. 360 361 No new variables are created when graph building. Instead, 362 _restore_slot_variable catches these after normal creation and adds restore 363 ops to the graph. This method is nonetheless important when graph building 364 for the case when a slot variable has already been created but `variable` 365 has just been added to a dependency graph (causing us to realize that the 366 slot variable needs to be restored). 367 368 Args: 369 slot_variable_position: A `trackable._CheckpointPosition` object 370 indicating the slot variable `Trackable` object to be restored. 371 slot_name: The name of this `Optimizer`'s slot to restore into. 372 variable: The variable object this slot is being created for. 373 optional_op_name: Name to use when scoping the Variable that needs to be 374 created for the slot. 375 """ 376 slot_variable = self.get_slot(var=variable, name=slot_name) 377 if (slot_variable is None and context.executing_eagerly() and 378 slot_variable_position.is_simple_variable() 379 # Defer slot variable creation if there is an active variable creator 380 # scope. Generally we'd like to eagerly create/restore slot variables 381 # when possible, but this may mean that scopes intended to catch 382 # `variable` also catch its eagerly created slot variable 383 # unintentionally (specifically make_template would add a dependency on 384 # a slot variable if not for this case). Deferring is mostly harmless 385 # (aside from double initialization), and makes variable creator scopes 386 # behave the same way they do when graph building. 387 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access 388 initializer = trackable.CheckpointInitialValue( 389 checkpoint_position=slot_variable_position) 390 slot_variable = self.create_slot( 391 var=variable, 392 val=initializer, 393 slot_name=slot_name, 394 optional_op_name=optional_op_name) 395 # Optimizers do not have unconditional dependencies on their slot 396 # variables (nor do any other objects). They are only saved if the 397 # variables they were created for are also saved. 398 if slot_variable is not None: 399 # If we've either made this slot variable, or if we've pulled out an 400 # existing slot variable, we should restore it. 401 slot_variable_position.restore(slot_variable) 402 else: 403 # We didn't make the slot variable. Defer restoring until it gets created 404 # normally. We keep a list rather than the one with the highest restore 405 # UID in case slot variables have their own dependencies, in which case 406 # those could differ between restores. 407 variable_key = _var_key_v2(variable) 408 self._deferred_slot_restorations.setdefault(slot_name, {}).setdefault( 409 variable_key, []).append(slot_variable_position) 410 411 def get_slot(self, var, name): 412 """Return a slot named `name` created for `var` by the Optimizer. 413 414 Some `Optimizer` subclasses use additional variables. For example 415 `Momentum` and `Adagrad` use variables to accumulate updates. This method 416 gives access to these `Variable` objects if for some reason you need them. 417 418 Use `get_slot_names()` to get the list of slot names created by the 419 `Optimizer`. 420 421 Args: 422 var: A variable passed to `minimize()` or `apply_gradients()`. 423 name: A string. 424 425 Returns: 426 The `Variable` for the slot if it was created, `None` otherwise. 427 """ 428 named_slots = self._slots.get(name, None) 429 if not named_slots: 430 return None 431 return named_slots.get(_var_key_v2(var), None) 432 433 def get_slot_names(self): 434 """Return a list of the names of slots created by the `Optimizer`. 435 436 See `get_slot()`. 437 438 Returns: 439 A list of strings. 440 """ 441 return sorted(self._slots.keys()) 442 443 def create_non_slot(self, initial_value, name, colocate_with=None): 444 """Add an extra variable, not associated with a slot.""" 445 v = self._non_slot_dict.get(name, None) 446 if v is None: 447 if colocate_with is None: 448 colocate_with = self._non_slot_devices 449 with self._distribution.extended.colocate_vars_with(colocate_with): 450 # TODO(josh11b): Use get_variable() except for the legacy Adam use case. 451 v = variable_scope.variable(initial_value, name=name, trainable=False) 452 self._non_slot_dict[name] = v 453 deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) 454 for checkpoint_position in sorted( 455 deferred_dependencies_list, 456 key=lambda restore: restore.checkpoint.restore_uid, 457 reverse=True): 458 checkpoint_position.restore(v) 459 return v 460 461 def _restore_slot_variable(self, slot_name, variable, slot_variable): 462 """Restore a newly created slot variable's value.""" 463 variable_key = _var_key_v2(variable) 464 deferred_restorations = self._deferred_slot_restorations.get( 465 slot_name, {}).pop(variable_key, []) 466 # Iterate over restores, highest restore UID first to minimize the number 467 # of assignments. 468 deferred_restorations.sort( 469 key=lambda position: position.restore_uid, reverse=True) 470 for checkpoint_position in deferred_restorations: 471 checkpoint_position.restore(slot_variable) 472 473 def get_non_slot(self, name): 474 """Returns the non-slot variable identified by `name`.""" 475 return self._non_slot_dict.get(name, None) 476 477 def get_hyper(self, name, dtype=None): 478 """Returns the `name` hyper parameter, optionally cast to `dtype`.""" 479 dtype_dict = self._hyper[name] 480 # Do we have the value cast to dtype already cached? This should always 481 # succeed when dtype is None. 482 if dtype in dtype_dict: 483 return dtype_dict[dtype] 484 # Not cached, cast to dtype and save the result in the cache. 485 result = math_ops.cast(dtype_dict[None], dtype) 486 dtype_dict[dtype] = result 487 return result 488 489 490class OptimizerV2(optimizer_v1.Optimizer): 491 """Updated base class for optimizers. 492 493 This class defines the API to add Ops to train a model. You never use this 494 class directly, but instead instantiate one of its subclasses such as 495 `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. 496 497 ### Usage 498 499 ```python 500 # Create an optimizer with the desired parameters. 501 opt = GradientDescentOptimizer(learning_rate=0.1) 502 # Add Ops to the graph to minimize a cost by updating a list of variables. 503 # "cost" is a Tensor, and the list of variables contains tf.Variable 504 # objects. 505 opt_op = opt.minimize(cost, var_list=<list of variables>) 506 ``` 507 508 In the training program you will just have to run the returned Op. 509 510 ```python 511 # Execute opt_op to do one step of training: 512 opt_op.run() 513 ``` 514 515 ### Processing gradients before applying them. 516 517 Calling `minimize()` takes care of both computing the gradients and 518 applying them to the variables. If you want to process the gradients 519 before applying them you can instead use the optimizer in three steps: 520 521 1. Compute the gradients with `compute_gradients()`. 522 2. Process the gradients as you wish. 523 3. Apply the processed gradients with `apply_gradients()`. 524 525 Example: 526 527 ```python 528 # Create an optimizer. 529 opt = GradientDescentOptimizer(learning_rate=0.1) 530 531 # Compute the gradients for a list of variables. 532 grads_and_vars = opt.compute_gradients(loss, <list of variables>) 533 534 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you 535 # need to the 'gradient' part, for example cap them, etc. 536 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] 537 538 # Ask the optimizer to apply the capped gradients. 539 opt.apply_gradients(capped_grads_and_vars) 540 ``` 541 542 ### Gating Gradients 543 544 Both `minimize()` and `compute_gradients()` accept a `gate_gradients` 545 argument that controls the degree of parallelism during the application of 546 the gradients. 547 548 The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. 549 550 <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides 551 the maximum parallelism in execution, at the cost of some non-reproducibility 552 in the results. For example the two gradients of `matmul` depend on the input 553 values: With `GATE_NONE` one of the gradients could be applied to one of the 554 inputs _before_ the other gradient is computed resulting in non-reproducible 555 results. 556 557 <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before 558 they are used. This prevents race conditions for Ops that generate gradients 559 for multiple inputs where the gradients depend on the inputs. 560 561 <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed 562 before any one of them is used. This provides the least parallelism but can 563 be useful if you want to process all gradients before applying any of them. 564 565 ### Slots 566 567 Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` 568 allocate and manage additional variables associated with the variables to 569 train. These are called <i>Slots</i>. Slots have names and you can ask the 570 optimizer for the names of the slots that it uses. Once you have a slot name 571 you can ask the optimizer for the variable it created to hold the slot value. 572 573 This can be useful if you want to log debug a training algorithm, report stats 574 about the slots, etc. 575 576 ### Non-slot variables 577 578 Some optimizer subclasses, such as `AdamOptimizer` have variables that 579 are not associated with the variables to train, just the step itself. 580 581 ### Hyper parameters 582 583 These are arguments passed to the optimizer subclass constructor 584 (the `__init__` method), and then passed to `self._set_hyper()`. 585 They can be either regular Python values (like 1.0), tensors, or 586 callables. If they are callable, the callable will be called during 587 `apply_gradients()` to get the value for the hyper parameter. 588 589 ### State 590 591 Internal methods are passed a `state` argument with the correct 592 values to use for the slot and non-slot variables, and the hyper 593 parameters. 594 """ 595 596 # Values for gate_gradients. 597 GATE_NONE = 0 598 GATE_OP = 1 599 GATE_GRAPH = 2 600 601 def __init__(self, use_locking, name): 602 """Create a new Optimizer. 603 604 This must be called by the constructors of subclasses. 605 Note that Optimizer instances should not bind to a single graph, 606 and so shouldn't keep Tensors as member variables. Generally 607 you should be able to use the _set_hyper()/state.get_hyper() 608 facility instead. 609 610 Args: 611 use_locking: Bool. If True apply use locks to prevent concurrent updates 612 to variables. 613 name: A non-empty string. The name to use for accumulators created 614 for the optimizer. 615 616 Raises: 617 ValueError: If name is malformed. 618 RuntimeError: If _create_slots has been overridden instead of 619 _create_vars. 620 """ 621 # Note: We intentionally don't call parent __init__. 622 623 # Optimizer._create_slots was replaced by _create_vars in OptimizerV2. 624 if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access 625 OptimizerV2._create_slots.__code__): 626 raise RuntimeError( 627 "Override _create_vars instead of _create_slots when " 628 "descending from OptimizerV2 (class %s)" % self.__class__.__name__) 629 if not name: 630 raise ValueError("Must specify the optimizer name") 631 632 self._use_locking = use_locking 633 self._name = name 634 # Map from graph_key to state for that graph. We use the graph_key 635 # since it works in both eager and graph mode, and gives the outer 636 # graph inside functions. 637 replica_context = distribute_ctx.get_replica_context() 638 if replica_context is None: 639 # In a cross-replica context for a DistributionStrategy, which means 640 # only one Optimizer will be created, not one per replica. 641 self._per_graph_state = {} 642 else: 643 # We use get_replica_context().merge_call() to get a single dict 644 # shared across all model replicas when running with a 645 # DistributionStrategy. 646 self._per_graph_state = replica_context.merge_call(lambda _: {}) 647 648 # Hyper parameters, and whether they should be re-evaluated every step. 649 self._hyper = {} 650 651 def _set_hyper(self, name, value): 652 self._hyper[name] = (_is_dynamic(value), value) 653 654 def minimize(self, 655 loss, 656 global_step=None, 657 var_list=None, 658 gate_gradients=GATE_OP, 659 aggregation_method=None, 660 name=None, 661 grad_loss=None, 662 stop_gradients=None, 663 scale_loss_by_num_replicas=False): 664 """Add operations to minimize `loss` by updating `var_list`. 665 666 This method simply combines calls `compute_gradients()` and 667 `apply_gradients()`. If you want to process the gradient before applying 668 them call `compute_gradients()` and `apply_gradients()` explicitly instead 669 of using this function. 670 671 Args: 672 loss: A `Tensor` containing the value to minimize. 673 global_step: Optional `Variable` to increment by one after the variables 674 have been updated. 675 var_list: Optional list or tuple of `Variable` objects to update to 676 minimize `loss`. Defaults to the list of variables collected in the 677 graph under the key `GraphKeys.TRAINABLE_VARIABLES`. 678 gate_gradients: How to gate the computation of gradients. Can be 679 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 680 aggregation_method: Specifies the method used to combine gradient terms. 681 Valid values are defined in the class `AggregationMethod`. 682 name: Optional name for the returned operation. 683 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 684 stop_gradients: Optional. A Tensor or list of tensors not to differentiate 685 through. 686 scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down 687 by the number of replicas. DEPRECATED and generally no longer needed. 688 689 Returns: 690 An Operation that updates the variables in `var_list`. If `global_step` 691 was not `None`, that operation also increments `global_step`. 692 693 Raises: 694 ValueError: If some of the variables are not `Variable` objects. 695 696 @compatibility(eager) 697 When eager execution is enabled, `loss` should be a Python function that 698 takes elements of `var_list` as arguments and computes the value to be 699 minimized. If `var_list` is None, `loss` should take no arguments. 700 Minimization (and gradient computation) is done with respect to the 701 elements of `var_list` if not None, else with respect to any trainable 702 variables created during the execution of the `loss` function. 703 `gate_gradients`, `aggregation_method`, and `grad_loss` are ignored when 704 eager execution is enabled. 705 @end_compatibility 706 """ 707 grads_and_vars = self.compute_gradients( 708 loss, 709 var_list=var_list, 710 gate_gradients=gate_gradients, 711 aggregation_method=aggregation_method, 712 grad_loss=grad_loss, 713 stop_gradients=stop_gradients, 714 scale_loss_by_num_replicas=scale_loss_by_num_replicas) 715 716 vars_with_grad = [v for g, v in grads_and_vars if g is not None] 717 if not vars_with_grad: 718 raise ValueError( 719 "No gradients provided for any variable, check your graph for ops" 720 " that do not support gradients, between variables %s and loss %s." % 721 ([str(v) for _, v in grads_and_vars], loss)) 722 723 return self.apply_gradients( 724 grads_and_vars, global_step=global_step, name=name) 725 726 def compute_gradients(self, 727 loss, 728 var_list=None, 729 gate_gradients=GATE_OP, 730 aggregation_method=None, 731 grad_loss=None, 732 stop_gradients=None, 733 scale_loss_by_num_replicas=False): 734 """Compute gradients of `loss` for the variables in `var_list`. 735 736 This is the first part of `minimize()`. It returns a list 737 of (gradient, variable) pairs where "gradient" is the gradient 738 for "variable". Note that "gradient" can be a `Tensor`, an 739 `IndexedSlices`, or `None` if there is no gradient for the 740 given variable. 741 742 Args: 743 loss: A Tensor containing the value to minimize or a callable taking no 744 arguments which returns the value to minimize. When eager execution is 745 enabled it must be a callable. 746 var_list: Optional list or tuple of `tf.Variable` to update to minimize 747 `loss`. Defaults to the list of variables collected in the graph under 748 the key `GraphKeys.TRAINABLE_VARIABLES`. 749 gate_gradients: How to gate the computation of gradients. Can be 750 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 751 aggregation_method: Specifies the method used to combine gradient terms. 752 Valid values are defined in the class `AggregationMethod`. 753 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 754 stop_gradients: Optional. A Tensor or list of tensors not to differentiate 755 through. 756 scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down 757 by the number of replicas. DEPRECATED and generally no longer needed. 758 759 Returns: 760 A list of (gradient, variable) pairs. Variable is always present, but 761 gradient can be `None`. 762 763 Raises: 764 TypeError: If `var_list` contains anything else than `Variable` objects. 765 ValueError: If some arguments are invalid. 766 RuntimeError: If called with eager execution enabled and `loss` is 767 not callable. 768 769 @compatibility(eager) 770 When eager execution is enabled, `gate_gradients`, and `aggregation_method` 771 are ignored. 772 @end_compatibility 773 """ 774 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 775 if callable(loss): 776 with backprop.GradientTape() as tape: 777 if var_list is not None: 778 tape.watch(var_list) 779 loss_value = loss() 780 781 # Scale loss for number of replicas (callable-loss case). 782 loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) 783 784 if var_list is None: 785 var_list = tape.watched_variables() 786 grads = tape.gradient(loss_value, var_list, grad_loss) 787 return list(zip(grads, var_list)) 788 if context.executing_eagerly(): 789 raise RuntimeError("`loss` passed to Optimizer.compute_gradients should " 790 "be a function when eager execution is enabled.") 791 792 # Scale loss for number of replicas (non-callable-loss case). 793 loss = self._scale_loss(loss, scale_loss_by_num_replicas) 794 795 if gate_gradients not in [ 796 optimizer_v1.Optimizer.GATE_NONE, optimizer_v1.Optimizer.GATE_OP, 797 optimizer_v1.Optimizer.GATE_GRAPH 798 ]: 799 raise ValueError( 800 "gate_gradients must be one of: Optimizer.GATE_NONE, " 801 "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % gate_gradients) 802 self._assert_valid_dtypes([loss]) 803 if grad_loss is not None: 804 self._assert_valid_dtypes([grad_loss]) 805 if var_list is None: 806 var_list = ( 807 variables.trainable_variables() + ops.get_collection( 808 ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 809 else: 810 var_list = nest.flatten(var_list) 811 # pylint: disable=protected-access 812 var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) 813 # pylint: enable=protected-access 814 processors = [_get_processor(v) for v in var_list] 815 if not var_list: 816 raise ValueError("No variables to optimize.") 817 var_refs = [p.target() for p in processors] 818 grads = gradients.gradients( 819 loss, 820 var_refs, 821 grad_ys=grad_loss, 822 gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP), 823 aggregation_method=aggregation_method, 824 stop_gradients=stop_gradients) 825 if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH: 826 grads = control_flow_ops.tuple(grads) 827 grads_and_vars = list(zip(grads, var_list)) 828 self._assert_valid_dtypes([ 829 v for g, v in grads_and_vars 830 if g is not None and v.dtype != dtypes.resource 831 ]) 832 return grads_and_vars 833 834 @staticmethod 835 def _scale_loss(loss_value, scale_loss_by_num_replicas): 836 """Scale loss for the number of replicas.""" 837 if scale_loss_by_num_replicas: 838 num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync 839 if num_replicas > 1: 840 loss_value *= 1. / num_replicas 841 return loss_value 842 843 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 844 """Apply gradients to variables. 845 846 This is the second part of `minimize()`. It returns an `Operation` that 847 applies gradients. 848 849 Args: 850 grads_and_vars: List of (gradient, variable) pairs as returned by 851 `compute_gradients()`. 852 global_step: Optional `Variable` to increment by one after the variables 853 have been updated. 854 name: Optional name for the returned operation. Default to the name 855 passed to the `Optimizer` constructor. 856 857 Returns: 858 An `Operation` that applies the specified gradients. If `global_step` 859 was not None, that operation also increments `global_step`. 860 861 Raises: 862 TypeError: If `grads_and_vars` is malformed. 863 ValueError: If none of the variables have gradients. 864 """ 865 # This is a default implementation of apply_gradients() that can be shared 866 # by most optimizers. It relies on the subclass implementing the following 867 # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse(). 868 869 # Filter out variables with gradients of `None`. 870 grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. 871 if not grads_and_vars: 872 raise ValueError("No variables provided.") 873 filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None) 874 if not filtered: 875 raise ValueError("No gradients provided for any variable: %s." % 876 ([str(v) for _, v in grads_and_vars],)) 877 return distribute_ctx.get_replica_context().merge_call( 878 self._distributed_apply, args=(filtered,), 879 kwargs={"global_step": global_step, "name": name}) 880 881 def _get_or_create_state(self, var_list=None): 882 """Either looks up or creates `_OptimizerV2State`. 883 884 If any variables are available, they should be passed via the `var_list` 885 argument, and these will be used to determine the graph to create/retrieve 886 state for. Otherwise the returned state is for the current default graph. 887 888 Args: 889 var_list: A list of variables to extract a graph from. 890 891 Returns: 892 An `_OptimizerV2State` object. 893 """ 894 # Determine the graph_key from the current graph. 895 eager_execution = context.executing_eagerly() 896 if eager_execution or var_list is None: 897 graph = ops.get_default_graph() 898 else: 899 graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access 900 assert graph is not None 901 graph_key = graph._graph_key # pylint: disable=protected-access 902 903 # Get the per graph state by looking up the graph_key. 904 if graph_key in self._per_graph_state: 905 per_graph_state = self._per_graph_state[graph_key] 906 else: 907 per_graph_state = _OptimizerV2State(self._name) 908 per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access 909 self._per_graph_state[graph_key] = per_graph_state 910 return per_graph_state 911 912 def _distributed_apply(self, distribution, grads_and_vars, global_step, name): 913 """`apply_gradients` for use with a `DistributionStrategy`.""" 914 reduced_grads = distribution.extended.batch_reduce_to( 915 ds_reduce_util.ReduceOp.SUM, grads_and_vars) 916 var_list = [v for _, v in grads_and_vars] 917 grads_and_vars = zip(reduced_grads, var_list) 918 919 unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)] 920 eager_execution = context.executing_eagerly() 921 if eager_execution: 922 # Give a clear error in this case instead of "name not supported 923 # for Eager Tensors" when we compute non_slot_devices. 924 for v in unwrapped_var_list: 925 if isinstance(v, ops.Tensor): 926 raise NotImplementedError("Trying to update a Tensor ", v) 927 928 with ops.name_scope(name, self._name) as name: 929 per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list) 930 # Include the current value of any dynamic hyper parameters in `state`. 931 non_slot_devices = distribution.extended.non_slot_devices(var_list) 932 state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access 933 self._hyper, distribution, non_slot_devices) 934 935 # Create any slot and non-slot variables we need in `state`. 936 with ops.init_scope(): 937 self._create_vars(var_list, state) 938 939 with ops.name_scope(name): # Re-enter name_scope created above 940 # Give the child class a chance to do something before we start 941 # applying gradients. 942 self._prepare(state) 943 944 def update(v, g): 945 """Update variable `v` using gradient `g`.""" 946 assert v is not None 947 948 # Convert the grad to Tensor or IndexedSlices if necessary, and 949 # look up a processor for each variable's type. 950 try: 951 g = ops.convert_to_tensor_or_indexed_slices(g) 952 except TypeError: 953 raise TypeError("Gradient must be convertible to a Tensor" 954 " or IndexedSlices, or None: %s" % g) 955 if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): 956 raise TypeError( 957 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 958 processor = _get_processor(v) 959 960 # We colocate all ops created in _apply_dense or _apply_sparse 961 # on the same device as the variable. 962 # TODO(apassos): figure out how to get the variable name here. 963 scope_name = "" if eager_execution else v.op.name 964 # device_policy is set because non-mirrored tensors will be read in 965 # `update_op`. 966 # TODO(josh11b): Make different state objects for each device to 967 # avoid needing to set the device_policy. 968 device_policy = context.device_policy( 969 context.DEVICE_PLACEMENT_SILENT) 970 with ops.name_scope("update_" + scope_name), device_policy: 971 return processor.update_op(self, g, state) 972 973 # Use the processors to update the variables. 974 update_ops = [] 975 for grad, var in grads_and_vars: 976 update_ops.extend(distribution.extended.update( 977 var, update, args=(grad,), group=False)) 978 979 # Give the child class a chance to do something after applying 980 # gradients 981 def finish(): 982 # TODO(josh11b): Make different state objects for each device to 983 # avoid needing to set the device_policy. 984 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 985 return self._finish(state) 986 987 update_ops = control_flow_ops.group(update_ops) 988 with ops.control_dependencies([update_ops]): 989 finish_updates = distribution.extended.update_non_slot( 990 non_slot_devices, finish, group=False) 991 # We said group=False, which means finish_updates is always a tuple. 992 # It will be (None,) when finish() returns None. 993 if finish_updates == (None,): 994 finish_updates = (update_ops,) 995 996 # Update `global_step` (if any). 997 if global_step is None: 998 apply_updates = distribution.group(finish_updates, name=name) 999 else: 1000 with ops.control_dependencies(finish_updates): 1001 1002 def update_global_step(global_step, name): 1003 return global_step.assign_add(1, read_value=False, name=name) 1004 1005 apply_updates = distribution.extended.update( 1006 global_step, update_global_step, args=(name,)) 1007 1008 # Add the training op to the TRAIN_OP graph collection in graph mode. 1009 if not eager_execution: 1010 if isinstance(apply_updates, ops.Tensor): 1011 apply_updates = apply_updates.op 1012 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 1013 if apply_updates not in train_op: 1014 train_op.append(apply_updates) 1015 1016 return apply_updates 1017 1018 def get_slot(self, var, name): 1019 """Return a slot named `name` created for `var` by the Optimizer. 1020 1021 Some `Optimizer` subclasses use additional variables. For example 1022 `Momentum` and `Adagrad` use variables to accumulate updates. This method 1023 gives access to these `Variable` objects if for some reason you need them. 1024 1025 Use `get_slot_names()` to get the list of slot names created by the 1026 `Optimizer`. 1027 1028 Args: 1029 var: A variable passed to `minimize()` or `apply_gradients()`. 1030 name: A string. 1031 1032 Returns: 1033 The `Variable` for the slot if it was created, `None` otherwise. 1034 """ 1035 state = self._get_state_for_var(var) 1036 return state.get_slot(var, name) if state is not None else None 1037 1038 def get_slot_names(self): 1039 """Return a list of the names of slots created by the `Optimizer`. 1040 1041 See `get_slot()`. 1042 1043 Returns: 1044 A list of strings. 1045 """ 1046 state = self._get_per_graph_state() 1047 return state.get_slot_names() if state is not None else [] 1048 1049 def variables(self): 1050 """A list of variables which encode the current state of `Optimizer`. 1051 1052 Includes slot variables and additional global variables created by the 1053 optimizer in the current default graph. 1054 1055 Returns: 1056 A list of variables. 1057 """ 1058 state = self._get_per_graph_state() 1059 return state._variables() if state is not None else [] # pylint: disable=protected-access 1060 1061 # -------------- 1062 # Methods to be implemented by subclasses if they want to use the 1063 # inherited implementation of apply_gradients() or compute_gradients(). 1064 # -------------- 1065 def _create_vars(self, var_list, state): 1066 """Create all slots needed by the variables and any non-slot variables. 1067 1068 Args: 1069 var_list: A list of `Variable` objects. 1070 state: An object with these methods: `create_slot(var, val, slot_name, 1071 optional_op_name)`, `create_slot_with_initializer(` `var, initializer, 1072 shape, dtype, slot_name, optional_op_name)`, `zeros_slot(var, slot_name, 1073 optional_op_name)`, `create_non_slot_variable(initial_value, name, 1074 colocate_with)`, `get_hyper(name)` 1075 """ 1076 # No slots needed by default 1077 pass 1078 1079 def _prepare(self, state): 1080 """Code to execute before applying gradients. 1081 1082 Note that most uses of _prepare() in Optimizer have been subsumed 1083 by explicit support for hyper parameters in OptimizerV2 1084 1085 Args: 1086 state: An object with a `get_hyper(name)` method. 1087 1088 Returns: 1089 Return value will be ignored. 1090 """ 1091 pass 1092 1093 def _apply_dense(self, grad, var, state): 1094 """Add ops to apply dense gradients to `var`. 1095 1096 Args: 1097 grad: A `Tensor`. 1098 var: A `Variable` object. 1099 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1100 and `get_hyper(name)` methods. 1101 1102 Returns: 1103 An `Operation`. 1104 """ 1105 raise NotImplementedError() 1106 1107 def _resource_apply_dense(self, grad, handle, state): 1108 """Add ops to apply dense gradients to the variable `handle`. 1109 1110 Args: 1111 grad: a `Tensor` representing the gradient. 1112 handle: a `Tensor` of dtype `resource` which points to the variable to be 1113 updated. 1114 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1115 and `get_hyper(name)` methods. 1116 1117 Returns: 1118 An `Operation` which updates the value of the variable. 1119 """ 1120 raise NotImplementedError() 1121 1122 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices, 1123 state): 1124 """Add ops to apply sparse gradients to `handle`, with repeated indices. 1125 1126 Optimizers which override this method must deal with repeated indices. See 1127 the docstring of `_apply_sparse_duplicate_indices` for details. By default 1128 the correct behavior, to sum non-unique indices and their associated 1129 gradients, is enforced by first pre-processing `grad` and `indices` and 1130 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 1131 with duplicate indices may instead override this method to avoid the 1132 overhead of summing. 1133 1134 Args: 1135 grad: a `Tensor` representing the gradient for the affected indices. 1136 handle: a `Tensor` of dtype `resource` which points to the variable to be 1137 updated. 1138 indices: a `Tensor` of integral type representing the indices for which 1139 the gradient is nonzero. Indices may be repeated. 1140 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1141 and `get_hyper(name)` methods. 1142 1143 Returns: 1144 An `Operation` which updates the value of the variable. 1145 """ 1146 # pylint: disable=protected-access 1147 summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices( 1148 values=grad, indices=indices) 1149 # pylint: enable=protected-access 1150 return self._resource_apply_sparse(summed_grad, handle, unique_indices, 1151 state) 1152 1153 def _resource_apply_sparse(self, grad, handle, indices, state): 1154 """Add ops to apply sparse gradients to the variable `handle`. 1155 1156 Similar to `_apply_sparse`, the `indices` argument to this method has been 1157 de-duplicated. Optimizers which deal correctly with non-unique indices may 1158 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 1159 overhead. 1160 1161 Args: 1162 grad: a `Tensor` representing the gradient for the affected indices. 1163 handle: a `Tensor` of dtype `resource` which points to the variable to be 1164 updated. 1165 indices: a `Tensor` of integral type representing the indices for which 1166 the gradient is nonzero. Indices are unique. 1167 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1168 and `get_hyper(name)` methods. 1169 1170 Returns: 1171 An `Operation` which updates the value of the variable. 1172 """ 1173 raise NotImplementedError() 1174 1175 def _apply_sparse_duplicate_indices(self, grad, var, state): 1176 """Add ops to apply sparse gradients to `var`, with repeated sparse indices. 1177 1178 Optimizers which override this method must deal with IndexedSlices objects 1179 such as the following: 1180 1181 IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1]) 1182 1183 The correct interpretation is: 1184 1185 IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1]) 1186 1187 Many optimizers deal incorrectly with repeated indices when updating based 1188 on sparse gradients (e.g. summing squares rather than squaring the sum, or 1189 applying momentum terms multiple times). Adding first is always the correct 1190 behavior, so this is enforced here by reconstructing the IndexedSlices to 1191 have only unique indices, then calling _apply_sparse. 1192 1193 Optimizers which deal correctly with repeated indices may instead override 1194 this method to avoid the overhead of summing indices. 1195 1196 Args: 1197 grad: `IndexedSlices`. 1198 var: A `Variable` object. 1199 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1200 and `get_hyper(name)` methods. 1201 1202 Returns: 1203 An `Operation`. 1204 """ 1205 # pylint: disable=protected-access 1206 summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices( 1207 values=grad.values, indices=grad.indices) 1208 # pylint: enable=protected-access 1209 gradient_no_duplicate_indices = ops.IndexedSlices( 1210 indices=unique_indices, 1211 values=summed_values, 1212 dense_shape=grad.dense_shape) 1213 return self._apply_sparse(gradient_no_duplicate_indices, var, state) 1214 1215 def _apply_sparse(self, grad, var, state): 1216 """Add ops to apply sparse gradients to `var`. 1217 1218 The IndexedSlices object passed to `grad` in this function is by default 1219 pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate 1220 indices (see its docstring for details). Optimizers which can tolerate or 1221 have correct special cases for duplicate sparse indices may override 1222 `_apply_sparse_duplicate_indices` instead of this function, avoiding that 1223 overhead. 1224 1225 Args: 1226 grad: `IndexedSlices`, with no repeated indices. 1227 var: A `Variable` object. 1228 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1229 and `get_hyper(name)` methods. 1230 1231 Returns: 1232 An `Operation`. 1233 """ 1234 raise NotImplementedError() 1235 1236 def _finish(self, state): 1237 """Do what is needed to finish the update. 1238 1239 This is called inside a scope colocated with any non-slot variables. 1240 1241 Args: 1242 state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, 1243 and `get_hyper(name)` methods. 1244 1245 Returns: 1246 The operation to apply updates, or None if no updates. 1247 """ 1248 return None 1249 1250 # -------------- 1251 # Utility methods for subclasses. 1252 # -------------- 1253 def _get_per_graph_state(self): 1254 # pylint: disable=protected-access 1255 return self._per_graph_state.get(ops.get_default_graph()._graph_key, None) 1256 1257 def _get_state_for_var(self, var): 1258 # pylint: disable=protected-access 1259 return self._per_graph_state.get(var._graph_key, None) 1260 1261 # -------------- 1262 # Overridden methods from Trackable. 1263 # -------------- 1264 1265 def _track_trackable(self, *args, **kwargs): 1266 """Optimizers may not track dependencies. Raises an error.""" 1267 raise NotImplementedError( 1268 "Optimizers may not have dependencies. File a feature request if this " 1269 "limitation bothers you.") 1270 1271 @property 1272 def _checkpoint_dependencies(self): 1273 """From Trackable. Gather graph-specific non-slot variables to save.""" 1274 current_graph_non_slot_variables = [] 1275 state = self._get_per_graph_state() 1276 if state is not None: 1277 for name, variable_object in sorted( 1278 state._non_slot_dict.items(), # pylint: disable=protected-access 1279 # Avoid comparing variables 1280 key=lambda item: item[0]): 1281 current_graph_non_slot_variables.append( 1282 trackable.TrackableReference( 1283 name=name, ref=variable_object)) 1284 # Note: ignores super(); Optimizers may not have any dependencies outside of 1285 # state objects. 1286 return current_graph_non_slot_variables 1287 1288 def _lookup_dependency(self, name): 1289 """From Trackable. Find a non-slot variable in the current graph.""" 1290 state = self._get_per_graph_state() 1291 if state is None: 1292 return None 1293 else: 1294 return state.get_non_slot(name) 1295 1296 @property 1297 def _deferred_dependencies(self): 1298 """Lets Trackable know where non-slot variables are created. 1299 1300 If necessary, creates a new state object for the current default graph. 1301 Trackable will then add entries to that state's deferred dependency 1302 dictionary. The state object will check that dictionary when creating 1303 non-slot variables, restoring their value if an entry is found. 1304 1305 Returns: 1306 A dictionary which holds deferred dependencies for the current default 1307 graph. 1308 """ 1309 state = self._get_or_create_state() 1310 return state._deferred_dependencies # pylint: disable=protected-access 1311 1312 def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, 1313 variable): 1314 """Trackable: Restore a slot variable's value, possibly creating it. 1315 1316 Called when a variable which has an associated slot variable is created or 1317 restored. 1318 1319 Args: 1320 slot_variable_position: A `trackable._CheckpointPosition` object 1321 indicating the slot variable `Trackable` object to be restored. 1322 slot_name: The name of this `Optimizer`'s slot to restore into. 1323 variable: The variable object this slot is being created for. 1324 """ 1325 state = self._get_or_create_state(var_list=[variable]) 1326 state._create_or_restore_slot_variable( # pylint: disable=protected-access 1327 slot_variable_position=slot_variable_position, 1328 slot_name=slot_name, 1329 variable=variable, 1330 optional_op_name=self._name) 1331 1332 # -------------- 1333 # Unsupported parent methods 1334 # -------------- 1335 def _slot_dict(self, slot_name): 1336 raise NotImplementedError("_slot_dict() method unsupported in OptimizerV2") 1337 1338 def _get_or_make_slot(self, var, val, slot_name, op_name): 1339 raise NotImplementedError( 1340 "_get_or_make_slot() method unsupported in OptimizerV2") 1341 1342 def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, 1343 slot_name, op_name): 1344 raise NotImplementedError( 1345 "_get_or_make_slot_with_initializer() method unsupported in " 1346 "OptimizerV2") 1347 1348 def _create_non_slot_variable(self, initial_value, name, colocate_with): 1349 raise NotImplementedError( 1350 "_create_non_slot_variable() method unsupported in OptimizerV2") 1351 1352 def _get_non_slot_variable(self, name, graph=None): 1353 raise NotImplementedError( 1354 "_get_non_slot_variable() method unsupported in OptimizerV2") 1355 1356 def _non_slot_variables(self): 1357 raise NotImplementedError( 1358 "_non_slot_variables() method unsupported in OptimizerV2") 1359