1"""Utilities for saving/loading Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import os 22import weakref 23 24from tensorflow.core.protobuf import trackable_object_graph_pb2 25from tensorflow.python import pywrap_tensorflow 26from tensorflow.python.client import session as session_lib 27from tensorflow.python.eager import context 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors_impl 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.lib.io import file_io 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import gen_io_ops as io_ops 38from tensorflow.python.ops import init_ops 39from tensorflow.python.ops import variable_scope 40from tensorflow.python.ops import variables 41from tensorflow.python.training import checkpoint_management 42from tensorflow.python.training import saver as v1_saver_lib 43from tensorflow.python.training.saving import functional_saver 44from tensorflow.python.training.saving import saveable_object_util 45from tensorflow.python.training.tracking import base 46from tensorflow.python.training.tracking import data_structures 47from tensorflow.python.training.tracking import graph_view as graph_view_lib 48from tensorflow.python.training.tracking import object_identity 49from tensorflow.python.training.tracking import tracking 50from tensorflow.python.util import compat 51from tensorflow.python.util import deprecation 52from tensorflow.python.util import tf_contextlib 53from tensorflow.python.util.tf_export import tf_export 54 55 56class _CheckpointRestoreCoordinator(object): 57 """Holds the status of an object-based checkpoint load.""" 58 59 def __init__(self, object_graph_proto, save_path, save_path_tensor, 60 restore_op_cache, graph_view): 61 """Specify the checkpoint being loaded. 62 63 Args: 64 object_graph_proto: The TrackableObjectGraph protocol buffer 65 associated with this checkpoint. 66 save_path: A string, the path to the checkpoint, as returned by 67 `tf.train.latest_checkpoint`. 68 save_path_tensor: A string `Tensor` which contains or will be fed the save 69 path. 70 restore_op_cache: A dictionary shared between 71 `_CheckpointRestoreCoordinator`s for the same Python objects, used to 72 look up restore ops by name to avoid re-creating them across multiple 73 `restore()` calls. 74 graph_view: A graph_view_lib.ObjectGraphView object for the restored 75 objects. 76 """ 77 self.object_graph_proto = object_graph_proto 78 self.restore_uid = ops.uid() 79 # Maps from objects to lists of attributes which were in the checkpoint but 80 # not loaded into any object, for error checking. 81 self.unused_attributes = weakref.WeakKeyDictionary() 82 # Dictionary mapping from an id in the protocol buffer flat array to 83 # Trackable Python objects. This mapping may be deferred if a 84 # checkpoint is restored before all dependencies have been tracked. Uses 85 # weak references so that partial restorations don't create reference cycles 86 # (as objects with deferred dependencies will generally have references to 87 # this object). 88 self.object_by_proto_id = weakref.WeakValueDictionary() 89 # A set of all Python objects we've seen as dependencies, even if we didn't 90 # use them (for example because of inconsistent references when 91 # loading). Used to make status assertions fail when loading checkpoints 92 # that don't quite match. 93 self.all_python_objects = object_identity.ObjectIdentityWeakSet() 94 self.save_path_tensor = save_path_tensor 95 self.save_path_string = save_path 96 self.dtype_map = pywrap_tensorflow.NewCheckpointReader( 97 save_path).get_variable_to_dtype_map() 98 # A NewCheckpointReader for the most recent checkpoint, for streaming Python 99 # state restoration. 100 # When graph building, contains a list of ops to run to restore objects from 101 # this checkpoint. 102 self.restore_ops = [] 103 self.restore_ops_by_name = restore_op_cache 104 self.graph_view = graph_view 105 self.new_restore_ops_callback = None 106 # A mapping from optimizer proto ids to lists of slot variables to be 107 # restored when the optimizer is tracked. Only includes slot variables whose 108 # regular variables have already been created, and only for optimizer 109 # objects which have not yet been created/tracked. 110 self.deferred_slot_restorations = {} 111 # A mapping from variable proto ids to lists of slot variables to be 112 # restored when the variable is created/tracked. These get shifted over to 113 # deferred_slot_restorations if the optimizer hasn't been created when that 114 # happens. 115 self.slot_restorations = {} 116 for node_index, node in enumerate(self.object_graph_proto.nodes): 117 for slot_reference in node.slot_variables: 118 # `node` refers to an `Optimizer`, since only these have slot variables. 119 self.slot_restorations.setdefault( 120 slot_reference.original_variable_node_id, []).append( 121 base._SlotVariableRestoration( # pylint: disable=protected-access 122 optimizer_id=node_index, 123 slot_variable_id=slot_reference.slot_variable_node_id, 124 slot_name=slot_reference.slot_name)) 125 126 def new_restore_ops(self, new_ops): 127 self.restore_ops.extend(new_ops) 128 if self.new_restore_ops_callback: 129 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable 130 131 def restore_saveables(self, tensor_saveables, python_saveables): 132 """Run or build restore operations for SaveableObjects. 133 134 Args: 135 tensor_saveables: `SaveableObject`s which correspond to Tensors. 136 python_saveables: `PythonStateSaveable`s which correspond to Python 137 values. 138 139 Returns: 140 When graph building, a list of restore operations, either cached or newly 141 created, to restore `tensor_saveables`. 142 """ 143 restore_ops = [] 144 # Eagerly run restorations for Python state. 145 reader = pywrap_tensorflow.NewCheckpointReader( 146 self.save_path_string) 147 for saveable in python_saveables: 148 spec_names = [spec.name for spec in saveable.specs] 149 saveable.python_restore( 150 [reader.get_tensor(name) for name in spec_names]) 151 152 # If we have new SaveableObjects, extract and cache restore ops. 153 if tensor_saveables: 154 validated_saveables = saveable_object_util.validate_and_slice_inputs( 155 tensor_saveables) 156 validated_names = set(saveable.name for saveable in validated_saveables) 157 if set(tensor_saveables.keys()) != validated_names: 158 raise AssertionError( 159 ("Saveable keys changed when validating. Got back %s, was " 160 "expecting %s") % (tensor_saveables.keys(), validated_names)) 161 new_restore_ops = functional_saver.restore_from_saveable_objects( 162 self.save_path_tensor, validated_saveables) 163 if not context.executing_eagerly(): 164 restore_ops.extend(new_restore_ops) 165 for saveable, restore_op in zip(validated_saveables, new_restore_ops): 166 assert saveable.name not in self.restore_ops_by_name 167 self.restore_ops_by_name[saveable.name] = restore_op 168 return restore_ops 169 170 171class _NameBasedRestoreCoordinator(object): 172 """Keeps the status of a name-based checkpoint restore.""" 173 174 def __init__(self, save_path, dtype_map=None): 175 self.save_path = save_path 176 self.dtype_map = dtype_map 177 self.unused_attributes = weakref.WeakKeyDictionary() 178 self.restore_uid = ops.uid() 179 180 def globally_named_object_attributes(self, trackable): 181 """Create globally named SaveableObjects from attributes. 182 183 If an object's attribute has no global name specified (default construction 184 for the SaveableObject factory), records the failure in 185 `self.unused_attributes` (which can then be used to make status assertions 186 fail; see `NameBasedSaverStatus`). 187 188 Args: 189 trackable: An object to save. 190 191 Yields: 192 SaveableObjects for `trackable`'s attributes. 193 """ 194 for attribute_name, saveable_factory in ( 195 trackable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 196 if callable(saveable_factory): 197 try: 198 # This saveable object factory does not have a default name= argument, 199 # which means there's no way to save/restore it using a name-based 200 # checkpoint. Ignore the error now and make sure assert_consumed() 201 # fails. 202 saveable = saveable_factory() 203 except TypeError: 204 # Even if we can't name this object, we should construct it and check 205 # whether it's optional to restore it. If it's optional we don't need 206 # to make assertions fail. 207 if not saveable_factory("").optional_restore: 208 self.unused_attributes.setdefault(trackable, []).append( 209 attribute_name) 210 continue 211 else: 212 saveable = saveable_factory 213 names_to_saveables = saveable_object_util.op_list_to_dict( 214 [saveable], 215 convert_variable_to_tensor=False) 216 for name, op in names_to_saveables.items(): 217 for saveable_object in saveable_object_util.saveable_objects_for_op( 218 op=op, name=name): 219 yield saveable_object 220 221 def eager_restore(self, trackable): 222 """Runs restore ops for `trackable`'s attributes.""" 223 # When graph building, we don't add any restore ops to the graph until 224 # run_restore_ops/initialize_or_restore on the status object for name-based 225 # checkpoints. 226 assert context.executing_eagerly() 227 for saveable in self.globally_named_object_attributes( 228 trackable): 229 restored_tensors = [] 230 tensor_missing = False 231 for spec in saveable.specs: 232 if spec.name in self.dtype_map: 233 with ops.device("cpu:0"): 234 restored, = io_ops.restore_v2( 235 prefix=self.save_path, 236 tensor_names=[spec.name], 237 shape_and_slices=[""], 238 dtypes=[self.dtype_map[spec.name]], 239 name="%s_checkpoint_read" % (spec.name,)) 240 restored_tensors.append(array_ops.identity(restored)) 241 else: 242 tensor_missing = True 243 244 if tensor_missing: 245 # Record that this variable didn't match so assertions will fail. 246 self.unused_attributes.setdefault(trackable, []).append(saveable.name) 247 else: 248 # Ignores values missing from the checkpoint, as with object-based 249 # restore. Status assertions can be used to check exact matches, 250 # although it's unlikely to ever happen for name-based checkpoints. 251 saveable.restore(restored_tensors=restored_tensors, 252 restored_shapes=None) 253 254 255# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange 256# or consolidating the implementation with get_variable. 257def _default_getter(name, shape, dtype, initializer=None, 258 partition_info=None, **kwargs): 259 """A pared-down version of get_variable which does not reuse variables.""" 260 dtype = dtypes.as_dtype(dtype) 261 shape_object = tensor_shape.as_shape(shape) 262 with ops.init_scope(): 263 if initializer is None: 264 initializer, initializing_from_value = ( 265 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access 266 name=name, shape=shape_object, dtype=dtype)) 267 else: 268 initializing_from_value = not callable(initializer) 269 # Same logic as get_variable 270 variable_dtype = dtype.base_dtype 271 if initializing_from_value: 272 if shape is not None: 273 raise ValueError("If initializer is a constant, do not specify shape.") 274 initial_value = initializer 275 else: 276 # Instantiate initializer if provided initializer is a type object. 277 if isinstance(initializer, type(init_ops.Initializer)): 278 initializer = initializer(dtype=dtype) 279 def initial_value(): 280 return initializer( 281 shape_object.as_list(), dtype=dtype, partition_info=partition_info) 282 return variables.VariableV1( 283 initial_value=initial_value, 284 name=name, 285 dtype=variable_dtype, 286 use_resource=True, 287 **kwargs 288 ) 289 290 291def add_variable(trackable, name, shape=None, dtype=dtypes.float32, 292 initializer=None): 293 """Add a variable to a Trackable with no scope influence.""" 294 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access 295 name=name, shape=shape, dtype=dtype, 296 initializer=initializer, getter=_default_getter) 297 298 299def object_metadata(save_path): 300 """Retrieves information about the objects in a checkpoint. 301 302 Example usage: 303 304 ```python 305 object_graph = tf.contrib.checkpoint.object_metadata( 306 tf.train.latest_checkpoint(checkpoint_directory)) 307 ckpt_variable_names = set() 308 for node in object_graph.nodes: 309 for attribute in node.attributes: 310 ckpt_variable_names.add(attribute.full_name) 311 ``` 312 313 Args: 314 save_path: The path to the checkpoint, as returned by `save` or 315 `tf.train.latest_checkpoint`. 316 Returns: 317 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. 318 Raises: 319 ValueError: If an object graph was not found in the checkpoint. 320 """ 321 reader = pywrap_tensorflow.NewCheckpointReader(save_path) 322 try: 323 object_graph_string = reader.get_tensor( 324 base.OBJECT_GRAPH_PROTO_KEY) 325 except errors_impl.NotFoundError: 326 raise ValueError( 327 ('The specified checkpoint "%s" does not appear to be object-based (it ' 328 'is missing the key "%s"). Likely it was created with a name-based ' 329 'saver and does not contain an object dependency graph.') % ( 330 save_path, base.OBJECT_GRAPH_PROTO_KEY)) 331 object_graph_proto = ( 332 trackable_object_graph_pb2.TrackableObjectGraph()) 333 object_graph_proto.ParseFromString(object_graph_string) 334 return object_graph_proto 335 336 337def list_objects(root_trackable): 338 """Traverse the object graph and list all accessible objects. 339 340 Looks for `Trackable` objects which are dependencies of 341 `root_trackable`. Includes slot variables only if the variable they are 342 slotting for and the optimizer are dependencies of `root_trackable` 343 (i.e. if they would be saved with a checkpoint). 344 345 Args: 346 root_trackable: A `Trackable` object whose dependencies should be 347 flattened. 348 Returns: 349 A flat list of objects. 350 """ 351 return graph_view_lib.ObjectGraphView(root_trackable).list_objects() 352 353 354def gather_initializers(root_trackable): 355 """Traverse the object graph and find initialization ops. 356 357 Looks for `Trackable` objects which are dependencies of 358 `root_trackable` and which have an `initializer` property. Includes 359 initializers for slot variables only if the variable they are slotting for and 360 the optimizer are dependencies of `root_trackable` (i.e. if they would be 361 saved with a checkpoint). 362 363 Args: 364 root_trackable: A `Trackable` object to gather initializers for. 365 Returns: 366 A list of initialization ops. 367 """ 368 trackable_objects = list_objects(root_trackable) 369 return [c.initializer for c in trackable_objects 370 if hasattr(c, "initializer") and c.initializer is not None] 371 372 373@tf_contextlib.contextmanager 374def capture_dependencies(template): 375 """Capture variables created within this scope as `Template` dependencies. 376 377 Requires that `template.variable_scope` is active. 378 379 This scope is intended as a compatibility measure, allowing a trackable 380 object to add dependencies on variables created in a block of code which is 381 not aware of object-based saving (and instead uses variable names 382 heavily). This is how `Template` objects add dependencies on variables and 383 sub-`Template`s. Where possible, use `tf.make_template` directly. 384 385 Args: 386 template: The `Template` object to register dependencies with. 387 388 Yields: 389 None (when used as a context manager). 390 """ 391 name_prefix = template.variable_scope.name 392 393 def _trackable_custom_creator(next_creator, name, initial_value, 394 trackable_parent=None, **kwargs): 395 """A variable creation hook which adds Trackable dependencies. 396 397 Set for example during a `Template`'s first wrapped function 398 execution. Ensures that (a) `template` depends on any trackable 399 objects using their own `capture_dependencies` scope inside this scope which 400 create variables, and (b) that any variables not in a more deeply nested 401 scope are added as dependencies directly. 402 403 The `trackable_parent` argument is passed between custom creators but 404 ignored when the variable object itself is created. This argument indicates 405 (if not `None`) that a more deeply nested scope has already added the 406 variable as a dependency, and that parent scopes should add a dependency on 407 that object rather than on the variable directly. 408 409 Args: 410 next_creator: See `variable_scope.variable_creator_scope`; the next 411 creator in the chain. 412 name: The (full, scope-influenced) name of the variable. The `name_prefix` 413 itself is stripped for the purposes of object-based dependency tracking, 414 but scopes opened within this scope are respected. 415 initial_value: See `variable_scope.variable_creator_scope`. Taken 416 explicitly so the argument can be re-named and used with 417 `Trackable._add_variable_with_custom_getter`. 418 trackable_parent: If not None, a more deeply nested trackable 419 object and its name prefix which were passed to `capture_dependencies` 420 to add a dependency on (rather than depending on the variable directly). 421 **kwargs: Passed through to the next creator. 422 423 Returns: 424 The output of `next_creator`: the fetched/created variable object. 425 """ 426 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): 427 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which 428 # we don't want to propagate. 429 return next_creator( 430 initial_value=initializer, 431 name=name, 432 **inner_kwargs) 433 if name is not None and name.startswith(name_prefix): 434 scope_stripped_name = name[len(name_prefix) + 1:] 435 if not trackable_parent: 436 return template._add_variable_with_custom_getter( # pylint: disable=protected-access 437 initializer=initial_value, 438 name=scope_stripped_name, 439 getter=_call_next_creator_renaming_initializer, 440 # Disable error checking for Trackable. Exceptions are instead 441 # raised if necessary when the object-based saver tries to 442 # save/restore the object. 443 overwrite=True, 444 trackable_parent=(template, name_prefix), 445 **kwargs) 446 else: 447 parent_object, parent_name_prefix = trackable_parent 448 template._track_trackable( # pylint: disable=protected-access 449 parent_object, 450 name=parent_name_prefix[len(name_prefix) + 1:], 451 overwrite=True) 452 return next_creator( 453 name=name, initial_value=initial_value, 454 trackable_parent=(template, name_prefix), **kwargs) 455 456 with variable_scope.variable_creator_scope(_trackable_custom_creator): 457 yield 458 459 460class _LoadStatus(object): 461 """Abstract base for load status callbacks.""" 462 463 @abc.abstractmethod 464 def assert_consumed(self): 465 """Raises an exception unless a non-trivial restoration has completed.""" 466 pass 467 468 @abc.abstractmethod 469 def assert_existing_objects_matched(self): 470 """Raises an exception unless existing Python objects have been matched.""" 471 pass 472 473 @abc.abstractmethod 474 def assert_nontrivial_match(self): 475 """Raises an exception if only the root object matched.""" 476 pass 477 478 @abc.abstractmethod 479 def run_restore_ops(self, session=None): 480 """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" 481 pass 482 483 @abc.abstractmethod 484 def initialize_or_restore(self, session=None): 485 """Runs restore ops from the checkpoint, or initializes variables.""" 486 pass 487 488 489def streaming_restore(status, session=None): 490 """When graph building, runs restore ops as soon as they come in. 491 492 Args: 493 status: A _LoadStatus objects from an object-based saver's 494 restore(). Streaming restore from name-based checkpoints is not currently 495 supported. 496 session: A session to run new restore ops in. 497 """ 498 if context.executing_eagerly(): 499 # Streaming restore is the default/only behavior when executing eagerly. 500 return 501 if session is None: 502 session = ops.get_default_session() 503 if isinstance(status, NameBasedSaverStatus): 504 raise NotImplementedError( 505 "Streaming restore not supported from name-based checkpoints. File a " 506 "feature request if this limitation bothers you.") 507 status.run_restore_ops(session=session) 508 # pylint: disable=protected-access 509 status._checkpoint.new_restore_ops_callback = ( 510 lambda ops: session.run(ops, feed_dict=status._feed_dict)) 511 # pylint: enable=protected-access 512 513 514class CheckpointLoadStatus(_LoadStatus): 515 """Checks the status of checkpoint loading and manages restore ops. 516 517 Returned from `Saver.restore`. Since `restore` may defer the loading of values 518 in the checkpoint which don't yet have corresponding Python objects, 519 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading 520 is complete (`assert_consumed`). 521 522 When graph building, `restore` does not run restore ops itself since their 523 creation may be deferred. The `run_restore_ops` method must be called once all 524 Python objects with values to restore have been created and added to the 525 dependency graph (this does not necessarily have to be the whole checkpoint; 526 calling `run_restore_ops` while `assert_consumed` fails is supported and will 527 partially restore the checkpoint). 528 529 See `Saver.restore` for usage examples. 530 """ 531 532 def __init__(self, checkpoint, feed_dict, graph_view): 533 self._checkpoint = checkpoint 534 self._feed_dict = feed_dict 535 self._graph_view = graph_view 536 537 def assert_consumed(self): 538 """Asserts that all objects in the checkpoint have been created/matched. 539 540 Returns: 541 `self` for chaining. 542 Raises: 543 AssertionError: If there are any Python objects in the dependency graph 544 which have not been restored from this checkpoint or a later `restore`, 545 or if there are any checkpointed values which have not been matched to 546 Python objects. 547 """ 548 self.assert_existing_objects_matched() 549 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 550 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 551 if trackable is None: 552 raise AssertionError("Unresolved object in checkpoint: %s" % (node,)) 553 if self._checkpoint.slot_restorations: 554 # Sanity check; this collection should be clear if everything has been 555 # restored. 556 raise AssertionError("Unresolved slot restorations: %s" % ( 557 self._checkpoint.slot_restorations,)) 558 if self._checkpoint.unused_attributes: 559 raise AssertionError( 560 ("Unused attributes in these objects (the attributes exist in the " 561 "checkpoint but not in the objects): %s") % ( 562 list(self._checkpoint.unused_attributes.items()),)) 563 return self 564 565 def assert_existing_objects_matched(self): 566 """Asserts that trackable Python objects have been matched. 567 568 Note that this is a weaker assertion than `assert_consumed`. It will only 569 fail for existing Python objects which are (transitive) dependencies of the 570 root object and which do not have an entry in the checkpoint. 571 572 It will not fail, for example, if a `tf.keras.Layer` object has not yet been 573 built and so has not created any `tf.Variable` objects. 574 575 Returns: 576 `self` for chaining. 577 578 Raises: 579 AssertionError: If a Python object exists in the transitive dependencies 580 of the root object but does not have a value in the checkpoint. 581 """ 582 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 583 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 584 if (trackable is not None 585 and trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access 586 raise AssertionError( 587 "Object not assigned a value from checkpoint: %s" % (node,)) 588 for trackable_object in self._graph_view.list_objects(): 589 # Remove data structures that do not contain any variables from 590 # restoration checks. 591 if (isinstance(trackable_object, 592 data_structures.TrackableDataStructure) and 593 not trackable_object._checkpoint_dependencies): 594 continue 595 self._checkpoint.all_python_objects.add(trackable_object) 596 unused_python_objects = ( 597 object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects) 598 - object_identity.ObjectIdentitySet( 599 self._checkpoint.object_by_proto_id.values())) 600 if unused_python_objects: 601 raise AssertionError( 602 ("Some Python objects were not bound to checkpointed values, likely " 603 "due to changes in the Python program: %s") 604 % (list(unused_python_objects),)) 605 return self 606 607 def assert_nontrivial_match(self): 608 """Raises an exception if only the root object matched.""" 609 for trackable_object in self._graph_view.list_objects(): 610 self._checkpoint.all_python_objects.add(trackable_object) 611 if len(self._checkpoint.object_by_proto_id) <= 1: 612 unused_python_objects = ( 613 object_identity.ObjectIdentitySet( 614 self._checkpoint.all_python_objects) 615 - object_identity.ObjectIdentitySet( 616 self._checkpoint.object_by_proto_id.values())) 617 if unused_python_objects: 618 raise AssertionError( 619 ("Nothing except the root object matched a checkpointed value. " 620 "Typically this means that the checkpoint does not match the " 621 "Python program. The following objects have no matching " 622 "checkpointed value: %s") % (list(unused_python_objects),)) 623 else: 624 raise AssertionError( 625 "Nothing to load. No dependencies have been added to %s yet." % ( 626 self._graph_view.root,)) 627 return self 628 629 def run_restore_ops(self, session=None): 630 """Run operations to restore objects in the dependency graph.""" 631 if context.executing_eagerly(): 632 return # Run eagerly 633 if session is None: 634 session = ops.get_default_session() 635 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) 636 637 def initialize_or_restore(self, session=None): 638 """Run operations to initialize or restore objects in the dependency graph. 639 640 Any objects in the dependency graph which have initializers but are not in 641 the checkpoint will have those initializers run, unless those variables are 642 being restored by a later call to `tf.train.Checkpoint.restore()`. 643 644 This method has a sibling in `InitializationOnlyStatus` which instead 645 initializes variables. That type is returned if no checkpoint is specified 646 in `Saver.restore`. 647 648 Args: 649 session: The session to run init/restore ops in. If `None`, uses the 650 default session. 651 """ 652 if context.executing_eagerly(): 653 return # Initialization and restoration ops are run eagerly 654 if session is None: 655 session = ops.get_default_session() 656 all_objects = self._graph_view.list_objects() 657 already_initialized_objects = object_identity.ObjectIdentitySet( 658 self._checkpoint.object_by_proto_id.values()) 659 initializers_for_non_restored_variables = [ 660 c.initializer for c in all_objects 661 if hasattr(c, "initializer") 662 and c not in already_initialized_objects 663 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1) 664 < self._checkpoint.restore_uid)] 665 self.run_restore_ops(session=session) 666 session.run(initializers_for_non_restored_variables) 667 668 669class InitializationOnlyStatus(_LoadStatus): 670 """Returned from `Saver.restore` when no checkpoint has been specified. 671 672 Objects of this type have the same `assert_consumed` method as 673 `CheckpointLoadStatus`, but it always fails. However, 674 `initialize_or_restore` works on objects of both types, and will 675 initialize variables in `InitializationOnlyStatus` objects or restore them 676 otherwise. 677 """ 678 679 def __init__(self, graph_view, restore_uid): 680 self._restore_uid = restore_uid 681 self._graph_view = graph_view 682 683 def assert_consumed(self): 684 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 685 raise AssertionError( 686 "No checkpoint specified (save_path=None); nothing is being restored.") 687 688 def assert_existing_objects_matched(self): 689 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 690 raise AssertionError( 691 "No checkpoint specified (save_path=None); nothing is being restored.") 692 693 def assert_nontrivial_match(self): 694 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 695 raise AssertionError( 696 "No checkpoint specified (save_path=None); nothing is being restored.") 697 698 def run_restore_ops(self, session=None): 699 """For consistency with `CheckpointLoadStatus`. 700 701 Use `initialize_or_restore` for initializing if no checkpoint was passed 702 to `Saver.restore` and restoring otherwise. 703 704 Args: 705 session: Not used. 706 """ 707 raise AssertionError( 708 "No checkpoint specified, so no restore ops are available " 709 "(save_path=None to Saver.restore).") 710 711 def initialize_or_restore(self, session=None): 712 """Runs initialization ops for variables. 713 714 Objects which would be saved by `Saver.save` will be initialized, unless 715 those variables are being restored by a later call to 716 `tf.train.Checkpoint.restore()`. 717 718 This method does nothing when executing eagerly (initializers get run 719 eagerly). 720 721 Args: 722 session: The session to run initialization ops in. If `None`, uses the 723 default session. 724 """ 725 if context.executing_eagerly(): 726 return # run eagerly 727 if session is None: 728 session = ops.get_default_session() 729 trackable_objects = self._graph_view.list_objects() 730 initializers = [ 731 c.initializer for c in trackable_objects 732 if hasattr(c, "initializer") and c.initializer is not None 733 and (getattr(c, "_update_uid", self._restore_uid - 1) 734 < self._restore_uid)] 735 session.run(initializers) 736 737 738_DEPRECATED_RESTORE_INSTRUCTIONS = ( 739 "Restoring a name-based tf.train.Saver checkpoint using the object-based " 740 "restore API. This mode uses global names to match variables, and so is " 741 "somewhat fragile. It also adds new restore ops to the graph each time it " 742 "is called when graph building. Prefer re-encoding training checkpoints in " 743 "the object-based format: run save() on the object-based saver (the same " 744 "one this message is coming from) and use that checkpoint in the future.") 745 746 747class NameBasedSaverStatus(_LoadStatus): 748 """Status for loading a name-based training checkpoint.""" 749 750 # Ideally this deprecation decorator would be on the class, but that 751 # interferes with isinstance checks. 752 @deprecation.deprecated( 753 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) 754 def __init__(self, checkpoint, graph_view): 755 self._checkpoint = checkpoint 756 self._graph_view = graph_view 757 758 def assert_consumed(self): 759 """Raises an exception if any variables/objects are unmatched.""" 760 unused_attributes = dict(self._checkpoint.unused_attributes) 761 if unused_attributes: 762 raise AssertionError( 763 "Some objects had attributes which were not restored: %s" 764 % (unused_attributes,)) 765 for trackable in self._graph_view.list_objects(): 766 # pylint: disable=protected-access 767 trackable._maybe_initialize_trackable() 768 if trackable._update_uid < self._checkpoint.restore_uid: 769 raise AssertionError("Object not restored: %s" % (trackable,)) 770 # pylint: enable=protected-access 771 return self 772 773 def assert_existing_objects_matched(self): 774 """Raises an exception if currently created objects are unmatched.""" 775 # For name-based checkpoints there's no object information in the 776 # checkpoint, so there's no distinction between 777 # assert_existing_objects_matched and assert_consumed (and both are less 778 # useful since we don't touch Python objects or Python state). 779 return self.assert_consumed() 780 781 def assert_nontrivial_match(self): 782 """Raises an exception if currently created objects are unmatched.""" 783 # For name-based checkpoints there's no object information in the 784 # checkpoint, so there's no distinction between 785 # assert_nontrivial_match and assert_consumed (and both are less 786 # useful since we don't touch Python objects or Python state). 787 return self.assert_consumed() 788 789 def _gather_saveable_objects(self): 790 """Walk the object graph, using global names for SaveableObjects.""" 791 objects = self._graph_view.list_objects() 792 saveable_objects = [] 793 for trackable in objects: 794 # pylint: disable=protected-access 795 trackable._maybe_initialize_trackable() 796 if trackable._update_uid < self._checkpoint.restore_uid: 797 trackable._update_uid = self._checkpoint.restore_uid 798 else: 799 continue 800 # pylint: enable=protected-access 801 saveable_objects.extend( 802 self._checkpoint.globally_named_object_attributes( 803 trackable)) 804 return saveable_objects 805 806 def run_restore_ops(self, session=None): 807 """Load the name-based training checkpoint using a new `tf.train.Saver`.""" 808 if context.executing_eagerly(): 809 return # Nothing to do, variables are restored on creation. 810 if session is None: 811 session = ops.get_default_session() 812 with ops.device("/cpu:0"): 813 saveables = self._gather_saveable_objects() 814 v1_saver_lib.Saver(saveables).restore( 815 sess=session, save_path=self._checkpoint.save_path) 816 817 def initialize_or_restore(self, session=None): 818 """Alias for `run_restore_ops`.""" 819 self.run_restore_ops(session=session) 820 821 822class _SessionWithFeedDictAdditions(session_lib.SessionInterface): 823 """Pretends to be a session, inserts extra feeds on run().""" 824 825 def __init__(self, session, feed_additions): 826 self._wrapped_session = session 827 self._feed_additions = feed_additions 828 829 def run(self, fetches, feed_dict=None, **kwargs): 830 if feed_dict is None: 831 feed_dict = {} 832 else: 833 feed_dict = feed_dict.copy() 834 feed_dict.update(self._feed_additions) 835 return self._wrapped_session.run( 836 fetches=fetches, feed_dict=feed_dict, **kwargs) 837 838 839class TrackableSaver(object): 840 """Saves and restores a `Trackable` object and its dependencies. 841 842 See `Trackable` for details of dependency management. `Saver` wraps 843 `tf.train.Saver` for saving, including extra information about the graph of 844 dependencies between Python objects. When restoring, it uses this information 845 about the save-time dependency graph to more robustly match objects with their 846 checkpointed values. When executing eagerly, it supports restoring variables 847 on object creation (see `Saver.restore`). 848 849 Values in a checkpoint are mapped to `Trackable` Python objects 850 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the 851 checkpoint was written. To avoid breaking existing checkpoints when modifying 852 a class, dependency names (the names of attributes to which `Trackable` 853 objects are assigned) may not change. These names are local to objects, in 854 contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and 855 so allow additional program transformations. 856 """ 857 858 def __init__(self, graph_view): 859 """Configure saving. 860 861 Args: 862 graph_view: A `GraphView` object containing a description of the object 863 graph to save. 864 """ 865 # The file prefix placeholder is created lazily when graph building (and not 866 # at all when executing eagerly) to avoid creating ops in the constructor 867 # (when they may never be necessary). 868 self._file_prefix_placeholder = None 869 870 # Op caching for save 871 self._object_graph_feed_tensor = None 872 self._last_save_object_graph = None 873 self._file_prefix_feed_tensor = None 874 self._cached_save_operation = None 875 876 # Op caching for restore, shared between _CheckpointRestoreCoordinators 877 self._restore_op_cache = {} 878 self._graph_view = graph_view 879 880 def _gather_saveables( 881 self, object_graph_tensor=None): 882 """Wraps _serialize_object_graph to include the object graph proto.""" 883 (named_saveable_objects, graph_proto, 884 feed_additions) = self._graph_view.serialize_object_graph() 885 if object_graph_tensor is None: 886 with ops.device("/cpu:0"): 887 object_graph_tensor = constant_op.constant( 888 graph_proto.SerializeToString(), dtype=dtypes.string) 889 else: 890 feed_additions.update( 891 {object_graph_tensor: graph_proto.SerializeToString()}) 892 assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects 893 named_saveable_objects.append( 894 base.NoRestoreSaveable( 895 tensor=object_graph_tensor, 896 name=base.OBJECT_GRAPH_PROTO_KEY)) 897 return named_saveable_objects, graph_proto, feed_additions 898 899 def _save_cached_when_graph_building( 900 self, 901 file_prefix, 902 object_graph_tensor=None): 903 """Create or retrieve save ops. 904 905 Args: 906 file_prefix: The prefix for saved checkpoint files. 907 object_graph_tensor: A `Tensor` to which the current object graph will be 908 fed. 909 910 Returns: 911 A two-element tuple with a filename tensor and a feed_dict of tensors to 912 feed when running it (if graph building). The feed dict contains the 913 current object graph and any Python state to be saved in the 914 checkpoint. When executing eagerly only the first argument is meaningful. 915 """ 916 (named_saveable_objects, graph_proto, 917 feed_additions) = self._gather_saveables( 918 object_graph_tensor=object_graph_tensor) 919 if (self._last_save_object_graph != graph_proto 920 # When executing eagerly, we need to re-create SaveableObjects each time 921 # save() is called so they pick up new Tensors passed to their 922 # constructors. That means the Saver needs to be copied with a new 923 # var_list. 924 or context.executing_eagerly() 925 or ops.inside_function()): 926 saver = functional_saver.Saver(named_saveable_objects) 927 with ops.device("/cpu:0"): 928 self._cached_save_operation = saver.save(file_prefix) 929 self._last_save_object_graph = graph_proto 930 return self._cached_save_operation, feed_additions 931 932 def save(self, file_prefix, checkpoint_number=None, session=None): 933 """Save a training checkpoint. 934 935 The saved checkpoint includes variables created by this object and any 936 Trackable objects it depends on at the time `Saver.save()` is called. 937 938 Args: 939 file_prefix: A prefix to use for the checkpoint filenames 940 (/path/to/directory/and_a_prefix). Names are generated based on this 941 prefix and `checkpoint_number`, if provided. 942 checkpoint_number: An integer variable or Tensor, used to number 943 checkpoints. Typically this value is saved along with other variables in 944 training checkpoints, which will happen automatically if it was created 945 by `root_trackable` or one of its dependencies (via 946 `Trackable._add_variable`). 947 session: The session to evaluate variables in. Ignored when executing 948 eagerly. If not provided when graph building, the default session is 949 used. 950 951 Returns: 952 The full path to the checkpoint. 953 """ 954 feed_dict = {} 955 use_session = (not context.executing_eagerly() 956 and not ops.inside_function()) 957 if checkpoint_number: 958 file_prefix = "%s-%d" % (file_prefix, checkpoint_number) 959 if use_session: 960 if self._object_graph_feed_tensor is None: 961 with ops.device("/cpu:0"): 962 self._object_graph_feed_tensor = constant_op.constant( 963 "", dtype=dtypes.string) 964 self._file_prefix_feed_tensor = constant_op.constant( 965 "", dtype=dtypes.string) 966 object_graph_tensor = self._object_graph_feed_tensor 967 file_prefix_tensor = self._file_prefix_feed_tensor 968 feed_dict[file_prefix_tensor] = file_prefix 969 else: 970 with ops.device("/cpu:0"): 971 file_prefix_tensor = constant_op.constant( 972 file_prefix, dtype=dtypes.string) 973 object_graph_tensor = None 974 975 file_io.recursive_create_dir(os.path.dirname(file_prefix)) 976 save_path, new_feed_additions = self._save_cached_when_graph_building( 977 file_prefix=file_prefix_tensor, 978 object_graph_tensor=object_graph_tensor) 979 if new_feed_additions: 980 feed_dict.update(new_feed_additions) 981 if not use_session: 982 session = None 983 elif session is None: 984 session = ops.get_default_session() 985 986 if session: 987 return session.run(save_path, feed_dict=feed_dict) 988 else: 989 return save_path 990 991 def restore(self, save_path): 992 """Restore a training checkpoint. 993 994 Restores `root_trackable` and any objects that it tracks 995 (transitive). Either assigns values immediately if variables to restore have 996 been created already, or defers restoration until the variables are 997 created. Dependencies added to the `root_trackable` passed to the 998 constructor after this call will be matched if they have a corresponding 999 object in the checkpoint. 1000 1001 When building a graph, restorations are added to the graph but not run. 1002 1003 To disallow deferred loading, assert immediately that all checkpointed 1004 variables have been matched to variable objects: 1005 1006 ```python 1007 saver = Saver(root) 1008 saver.restore(path).assert_consumed() 1009 ``` 1010 1011 An exception will be raised unless every object was matched and its 1012 variables already exist. 1013 1014 When graph building, `assert_consumed()` indicates that all of the restore 1015 ops which will be created for this checkpoint have been created. They can be 1016 run via the `run_restore_ops()` function of the status object: 1017 1018 ```python 1019 saver.restore(path).assert_consumed().run_restore_ops() 1020 ``` 1021 1022 If the checkpoint has not been consumed completely, then the list of restore 1023 ops will grow as more objects are added to the dependency graph. 1024 1025 Name-based `tf.train.Saver` checkpoints can be loaded using this 1026 method. There is no deferred loading, and names are used to match 1027 variables. No restore ops are created/run until `run_restore_ops()` or 1028 `initialize_or_restore()` are called on the returned status object, even 1029 when executing eagerly. Re-encode name-based checkpoints using this 1030 object-based `Saver.save` as soon as possible. 1031 1032 Args: 1033 save_path: The path to the checkpoint, as returned by `save` or 1034 `tf.train.latest_checkpoint`. If None (as when there is no latest 1035 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1036 object which may run initializers for objects in the dependency 1037 graph. If the checkpoint was written by the name-based `tf.train.Saver`, 1038 names are used to match variables. 1039 1040 Returns: 1041 A load status object, which can be used to make assertions about the 1042 status of checkpoint restoration and run initialization/restore ops 1043 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if 1044 `save_path` is `None`). 1045 1046 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` 1047 object is returned which runs restore ops from a name-based saver. 1048 """ 1049 if save_path is None: 1050 return InitializationOnlyStatus(self._graph_view, ops.uid()) 1051 reader = pywrap_tensorflow.NewCheckpointReader(save_path) 1052 graph_building = not context.executing_eagerly() 1053 if graph_building: 1054 dtype_map = None 1055 else: 1056 dtype_map = reader.get_variable_to_dtype_map() 1057 try: 1058 object_graph_string = reader.get_tensor( 1059 base.OBJECT_GRAPH_PROTO_KEY) 1060 except errors_impl.NotFoundError: 1061 # The object graph proto does not exist in this checkpoint. Try the 1062 # name-based compatibility mode. 1063 restore_coordinator = _NameBasedRestoreCoordinator( 1064 save_path=save_path, dtype_map=dtype_map) 1065 if not graph_building: 1066 for existing_trackable in self._graph_view.list_objects(): 1067 # pylint: disable=protected-access 1068 existing_trackable._maybe_initialize_trackable() 1069 existing_trackable._name_based_restores.add(restore_coordinator) 1070 existing_trackable._name_based_attribute_restore( 1071 restore_coordinator) 1072 # pylint: enable=protected-access 1073 return NameBasedSaverStatus( 1074 restore_coordinator, graph_view=self._graph_view) 1075 1076 if graph_building: 1077 if self._file_prefix_placeholder is None: 1078 with ops.device("/cpu:0"): 1079 self._file_prefix_placeholder = constant_op.constant("model") 1080 file_prefix_tensor = self._file_prefix_placeholder 1081 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 1082 else: 1083 with ops.device("/cpu:0"): 1084 file_prefix_tensor = constant_op.constant(save_path) 1085 file_prefix_feed_dict = None 1086 object_graph_proto = ( 1087 trackable_object_graph_pb2.TrackableObjectGraph()) 1088 object_graph_proto.ParseFromString(object_graph_string) 1089 checkpoint = _CheckpointRestoreCoordinator( 1090 object_graph_proto=object_graph_proto, 1091 save_path=save_path, 1092 save_path_tensor=file_prefix_tensor, 1093 restore_op_cache=self._restore_op_cache, 1094 graph_view=self._graph_view) 1095 base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore( 1096 self._graph_view.root) 1097 load_status = CheckpointLoadStatus( 1098 checkpoint, 1099 graph_view=self._graph_view, 1100 feed_dict=file_prefix_feed_dict) 1101 return load_status 1102 1103 1104def frozen_saver(root_trackable): 1105 """Creates a static `tf.train.Saver` from a trackable object. 1106 1107 The returned `Saver` saves object-based checkpoints, but these checkpoints 1108 will no longer reflect structural changes to the object graph, only changes to 1109 the values of `Variable`s added as dependencies of the root object before 1110 `freeze` was called. 1111 1112 `restore` works on the returned `Saver`, but requires that the object graph of 1113 the checkpoint being loaded exactly matches the object graph when `freeze` was 1114 called. This is in contrast the object-based restore performed by 1115 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's 1116 object graph and the current Python object graph. 1117 1118 Args: 1119 root_trackable: A trackable object to save. 1120 1121 Returns: 1122 A saver which saves object-based checkpoints for the object graph frozen at 1123 the time `frozen_saver` was called. 1124 """ 1125 named_saveable_objects = graph_view_lib.ObjectGraphView( 1126 root_trackable).frozen_saveable_objects() 1127 return functional_saver.Saver(named_saveable_objects) 1128 1129 1130def saver_with_op_caching(obj): 1131 """A TrackableSaver with a SaveableObject cache when graph building.""" 1132 if context.executing_eagerly(): 1133 saveables_cache = None 1134 else: 1135 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 1136 return TrackableSaver(graph_view_lib.ObjectGraphView( 1137 weakref.ref(obj), 1138 saveables_cache=saveables_cache)) 1139 1140 1141# Mentions graph building / Sessions. The v2 version is below. 1142@tf_export(v1=["train.Checkpoint"]) 1143class CheckpointV1(tracking.AutoTrackable): 1144 """Groups trackable objects, saving and restoring them. 1145 1146 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1147 that contain trackable state, such as `tf.train.Optimizer` 1148 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1149 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1150 maintains a `save_counter` for numbering checkpoints. 1151 1152 Example usage when graph building: 1153 1154 ```python 1155 import tensorflow as tf 1156 import os 1157 1158 checkpoint_directory = "/tmp/training_checkpoints" 1159 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1160 1161 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1162 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1163 train_op = optimizer.minimize( ... ) 1164 status.assert_consumed() # Optional sanity checks. 1165 with tf.Session() as session: 1166 # Use the Session to restore variables, or initialize them if 1167 # tf.train.latest_checkpoint returned None. 1168 status.initialize_or_restore(session) 1169 for _ in range(num_training_steps): 1170 session.run(train_op) 1171 checkpoint.save(file_prefix=checkpoint_prefix) 1172 ``` 1173 1174 Example usage with eager execution enabled: 1175 1176 ```python 1177 import tensorflow as tf 1178 import os 1179 1180 tf.enable_eager_execution() 1181 1182 checkpoint_directory = "/tmp/training_checkpoints" 1183 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1184 1185 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1186 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1187 for _ in range(num_training_steps): 1188 optimizer.minimize( ... ) # Variables will be restored on creation. 1189 status.assert_consumed() # Optional sanity checks. 1190 checkpoint.save(file_prefix=checkpoint_prefix) 1191 ``` 1192 1193 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1194 checkpoints, in contrast to `tf.train.Saver` which writes and reads 1195 `variable.name` based checkpoints. Object-based checkpointing saves a graph of 1196 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, 1197 etc.) with named edges, and this graph is used to match variables when 1198 restoring a checkpoint. It can be more robust to changes in the Python 1199 program, and helps to support restore-on-create for variables when executing 1200 eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code. 1201 1202 `Checkpoint` objects have dependencies on the objects passed as keyword 1203 arguments to their constructors, and each dependency is given a name that is 1204 identical to the name of the keyword argument for which it was created. 1205 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1206 dependencies on their variables (e.g. "kernel" and "bias" for 1207 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1208 dependencies easy in user-defined classes, since `Model` hooks into attribute 1209 assignment. For example: 1210 1211 ```python 1212 class Regress(tf.keras.Model): 1213 1214 def __init__(self): 1215 super(Regress, self).__init__() 1216 self.input_transform = tf.keras.layers.Dense(10) 1217 # ... 1218 1219 def call(self, inputs): 1220 x = self.input_transform(inputs) 1221 # ... 1222 ``` 1223 1224 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1225 which in turn depends on its variables. As a result, saving an instance of 1226 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1227 by the `Dense` layer. 1228 1229 Attributes: 1230 save_counter: Incremented when `save()` is called. Used to number 1231 checkpoints. 1232 """ 1233 1234 def __init__(self, **kwargs): 1235 """Group objects into a training checkpoint. 1236 1237 Args: 1238 **kwargs: Keyword arguments are set as attributes of this object, and are 1239 saved with the checkpoint. Values must be trackable objects. 1240 Raises: 1241 ValueError: If objects in `kwargs` are not trackable. 1242 """ 1243 super(CheckpointV1, self).__init__() 1244 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1245 if not isinstance(v, (base.Trackable, def_function.Function)): 1246 raise ValueError( 1247 ("`Checkpoint` was expecting a trackable object (an object " 1248 "derived from `TrackableBase`), got %s. If you believe this " 1249 "object should be trackable (i.e. it is part of the " 1250 "TensorFlow Python API and manages state), please open an issue.") 1251 % (v,)) 1252 setattr(self, k, v) 1253 self._save_counter = None # Created lazily for restore-on-create. 1254 self._save_assign_op = None 1255 self._saver = saver_with_op_caching(self) 1256 1257 def _maybe_create_save_counter(self): 1258 """Create a save counter if it does not yet exist.""" 1259 if self._save_counter is None: 1260 # Initialized to 0 and incremented before saving. 1261 with ops.device("/cpu:0"): 1262 # add_variable creates a dependency named "save_counter"; NoDependency 1263 # prevents creating a second dependency named "_save_counter". 1264 self._save_counter = data_structures.NoDependency( 1265 add_variable(self, name="save_counter", initializer=0, 1266 dtype=dtypes.int64)) 1267 1268 def write(self, file_prefix, session=None): 1269 """Writes a training checkpoint. 1270 1271 The checkpoint includes variables created by this object and any 1272 trackable objects it depends on at the time `Checkpoint.write()` is 1273 called. 1274 1275 `write` does not number checkpoints, increment `save_counter`, or update the 1276 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1277 use by higher level checkpoint management utilities. `save` provides a very 1278 basic implementation of these features. 1279 1280 Args: 1281 file_prefix: A prefix to use for the checkpoint filenames 1282 (/path/to/directory/and_a_prefix). 1283 session: The session to evaluate variables in. Ignored when executing 1284 eagerly. If not provided when graph building, the default session is 1285 used. 1286 1287 Returns: 1288 The full path to the checkpoint (i.e. `file_prefix`). 1289 """ 1290 output = self._saver.save( 1291 file_prefix=file_prefix, 1292 session=session) 1293 if tensor_util.is_tensor(output): 1294 if context.executing_eagerly(): 1295 return compat.as_str(output.numpy()) 1296 else: 1297 # Function building 1298 return output 1299 else: 1300 # Graph + Session, so we already session.ran it. 1301 return compat.as_str(output) 1302 1303 @property 1304 def save_counter(self): 1305 """An integer variable which starts at zero and is incremented on save. 1306 1307 Used to number checkpoints. 1308 1309 Returns: 1310 The save counter variable. 1311 """ 1312 self._maybe_create_save_counter() 1313 return self._save_counter 1314 1315 def save(self, file_prefix, session=None): 1316 """Saves a training checkpoint and provides basic checkpoint management. 1317 1318 The saved checkpoint includes variables created by this object and any 1319 trackable objects it depends on at the time `Checkpoint.save()` is 1320 called. 1321 1322 `save` is a basic convenience wrapper around the `write` method, 1323 sequentially numbering checkpoints using `save_counter` and updating the 1324 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1325 management, for example garbage collection and custom numbering, may be 1326 provided by other utilities which also wrap `write` 1327 (`tf.contrib.checkpoint.CheckpointManager` for example). 1328 1329 Args: 1330 file_prefix: A prefix to use for the checkpoint filenames 1331 (/path/to/directory/and_a_prefix). Names are generated based on this 1332 prefix and `Checkpoint.save_counter`. 1333 session: The session to evaluate variables in. Ignored when executing 1334 eagerly. If not provided when graph building, the default session is 1335 used. 1336 1337 Returns: 1338 The full path to the checkpoint. 1339 """ 1340 graph_building = not context.executing_eagerly() 1341 if graph_building: 1342 if ops.inside_function(): 1343 raise NotImplementedError( 1344 "Calling tf.train.Checkpoint.save() from a function is not " 1345 "supported, as save() modifies saving metadata in ways not " 1346 "supported by TensorFlow Operations. Consider using " 1347 "tf.train.Checkpoint.write(), a lower-level API which does not " 1348 "update metadata. tf.train.latest_checkpoint and related APIs will " 1349 "not see this checkpoint.") 1350 if session is None: 1351 session = ops.get_default_session() 1352 if self._save_counter is None: 1353 # When graph building, if this is a new save counter variable then it 1354 # needs to be initialized before assign_add. This is only an issue if 1355 # restore() has not been called first. 1356 session.run(self.save_counter.initializer) 1357 if not graph_building or self._save_assign_op is None: 1358 with ops.colocate_with(self.save_counter): 1359 assign_op = self.save_counter.assign_add(1, read_value=True) 1360 if graph_building: 1361 self._save_assign_op = data_structures.NoDependency(assign_op) 1362 if graph_building: 1363 checkpoint_number = session.run(self._save_assign_op) 1364 else: 1365 checkpoint_number = assign_op.numpy() 1366 file_path = self.write("%s-%d" % (file_prefix, checkpoint_number), 1367 session=session) 1368 checkpoint_management.update_checkpoint_state_internal( 1369 save_dir=os.path.dirname(file_prefix), 1370 model_checkpoint_path=file_path, 1371 all_model_checkpoint_paths=[file_path], 1372 save_relative_paths=True) 1373 return file_path 1374 1375 def restore(self, save_path): 1376 """Restore a training checkpoint. 1377 1378 Restores this `Checkpoint` and any objects it depends on. 1379 1380 When executing eagerly, either assigns values immediately if variables to 1381 restore have been created already, or defers restoration until the variables 1382 are created. Dependencies added after this call will be matched if they have 1383 a corresponding object in the checkpoint (the restore request will queue in 1384 any trackable object waiting for the expected dependency to be added). 1385 1386 When graph building, restoration ops are added to the graph but not run 1387 immediately. 1388 1389 To ensure that loading is complete and no more assignments will take place, 1390 use the `assert_consumed()` method of the status object returned by 1391 `restore`: 1392 1393 ```python 1394 checkpoint = tf.train.Checkpoint( ... ) 1395 checkpoint.restore(path).assert_consumed() 1396 ``` 1397 1398 An exception will be raised if any Python objects in the dependency graph 1399 were not found in the checkpoint, or if any checkpointed values do not have 1400 a matching Python object. 1401 1402 When graph building, `assert_consumed()` indicates that all of the restore 1403 ops that will be created for this checkpoint have been created. They can be 1404 run via the `run_restore_ops()` method of the status object: 1405 1406 ```python 1407 checkpoint.restore(path).assert_consumed().run_restore_ops() 1408 ``` 1409 1410 If the checkpoint has not been consumed completely, then the list of restore 1411 ops will grow as more objects are added to the dependency graph. 1412 1413 Name-based `tf.train.Saver` checkpoints can be loaded using this 1414 method. Names are used to match variables. No restore ops are created/run 1415 until `run_restore_ops()` or `initialize_or_restore()` are called on the 1416 returned status object when graph building, but there is restore-on-creation 1417 when executing eagerly. Re-encode name-based checkpoints using 1418 `tf.train.Checkpoint.save` as soon as possible. 1419 1420 Args: 1421 save_path: The path to the checkpoint, as returned by `save` or 1422 `tf.train.latest_checkpoint`. If None (as when there is no latest 1423 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1424 object which may run initializers for objects in the dependency 1425 graph. If the checkpoint was written by the name-based `tf.train.Saver`, 1426 names are used to match variables. 1427 1428 Returns: 1429 A load status object, which can be used to make assertions about the 1430 status of a checkpoint restoration and run initialization/restore ops. 1431 1432 The returned status object has the following methods: 1433 1434 * `assert_consumed()`: 1435 Raises an exception if any variables/objects are unmatched: either 1436 checkpointed values which don't have a matching Python object or 1437 Python objects in the dependency graph with no values in the 1438 checkpoint. This method returns the status object, and so may be 1439 chained with `initialize_or_restore` or `run_restore_ops`. 1440 1441 * `assert_existing_objects_matched()`: 1442 Raises an exception if any existing Python objects in the dependency 1443 graph are unmatched. Unlike `assert_consumed`, this assertion will 1444 pass if values in the checkpoint have no corresponding Python 1445 objects. For example a `tf.keras.Layer` object which has not yet been 1446 built, and so has not created any variables, will pass this assertion 1447 but fail `assert_consumed`. Useful when loading part of a larger 1448 checkpoint into a new Python program, e.g. a training checkpoint with 1449 a `tf.train.Optimizer` was saved but only the state required for 1450 inference is being loaded. This method returns the status object, and 1451 so may be chained with `initialize_or_restore` or `run_restore_ops`. 1452 1453 * `assert_nontrivial_match()`: Asserts that something aside from the root 1454 object was matched. This is a very weak assertion, but is useful for 1455 sanity checking in library code where objects may exist in the 1456 checkpoint which haven't been created in Python and some Python 1457 objects may not have a checkpointed value. 1458 1459 * `initialize_or_restore(session=None)`: 1460 When graph building, runs variable initializers if `save_path` is 1461 `None`, but otherwise runs restore operations. If no `session` is 1462 explicitly specified, the default session is used. No effect when 1463 executing eagerly (variables are initialized or restored eagerly). 1464 1465 * `run_restore_ops(session=None)`: 1466 When graph building, runs restore operations. If no `session` is 1467 explicitly specified, the default session is used. No effect when 1468 executing eagerly (restore operations are run eagerly). May only be 1469 called when `save_path` is not `None`. 1470 """ 1471 status = self._saver.restore(save_path=save_path) 1472 # Create the save counter now so it gets initialized with other variables 1473 # when graph building. Creating it earlier would lead to double 1474 # initialization when executing eagerly. 1475 self._maybe_create_save_counter() 1476 return status 1477 1478 1479@tf_export("train.Checkpoint", v1=[]) 1480class Checkpoint(tracking.AutoTrackable): 1481 """Groups trackable objects, saving and restoring them. 1482 1483 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1484 that contain trackable state, such as `tf.train.Optimizer` 1485 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1486 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1487 maintains a `save_counter` for numbering checkpoints. 1488 1489 Example usage: 1490 1491 ```python 1492 import tensorflow as tf 1493 import os 1494 1495 checkpoint_directory = "/tmp/training_checkpoints" 1496 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1497 1498 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1499 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1500 for _ in range(num_training_steps): 1501 optimizer.minimize( ... ) # Variables will be restored on creation. 1502 status.assert_consumed() # Optional sanity checks. 1503 checkpoint.save(file_prefix=checkpoint_prefix) 1504 ``` 1505 1506 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1507 checkpoints, in contrast to TensorFlow 1.x's `tf.train.Saver` which writes and 1508 reads `variable.name` based checkpoints. Object-based checkpointing saves a 1509 graph of dependencies between Python objects (`Layer`s, `Optimizer`s, 1510 `Variable`s, etc.) with named edges, and this graph is used to match variables 1511 when restoring a checkpoint. It can be more robust to changes in the Python 1512 program, and helps to support restore-on-create for variables. 1513 1514 `Checkpoint` objects have dependencies on the objects passed as keyword 1515 arguments to their constructors, and each dependency is given a name that is 1516 identical to the name of the keyword argument for which it was created. 1517 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1518 dependencies on their variables (e.g. "kernel" and "bias" for 1519 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1520 dependencies easy in user-defined classes, since `Model` hooks into attribute 1521 assignment. For example: 1522 1523 ```python 1524 class Regress(tf.keras.Model): 1525 1526 def __init__(self): 1527 super(Regress, self).__init__() 1528 self.input_transform = tf.keras.layers.Dense(10) 1529 # ... 1530 1531 def call(self, inputs): 1532 x = self.input_transform(inputs) 1533 # ... 1534 ``` 1535 1536 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1537 which in turn depends on its variables. As a result, saving an instance of 1538 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1539 by the `Dense` layer. 1540 1541 Attributes: 1542 save_counter: Incremented when `save()` is called. Used to number 1543 checkpoints. 1544 """ 1545 1546 def __init__(self, **kwargs): 1547 """Group objects into a training checkpoint. 1548 1549 Args: 1550 **kwargs: Keyword arguments are set as attributes of this object, and are 1551 saved with the checkpoint. Values must be trackable objects. 1552 Raises: 1553 ValueError: If objects in `kwargs` are not trackable. 1554 """ 1555 super(Checkpoint, self).__init__() 1556 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1557 if not isinstance(v, (base.Trackable, def_function.Function)): 1558 raise ValueError( 1559 ("`Checkpoint` was expecting a trackable object (an object " 1560 "derived from `TrackableBase`), got %s. If you believe this " 1561 "object should be trackable (i.e. it is part of the " 1562 "TensorFlow Python API and manages state), please open an issue.") 1563 % (v,)) 1564 setattr(self, k, v) 1565 self._save_counter = None # Created lazily for restore-on-create. 1566 self._save_assign_op = None 1567 self._saver = saver_with_op_caching(self) 1568 1569 def _maybe_create_save_counter(self): 1570 """Create a save counter if it does not yet exist.""" 1571 if self._save_counter is None: 1572 # Initialized to 0 and incremented before saving. 1573 with ops.device("/cpu:0"): 1574 # add_variable creates a dependency named "save_counter"; NoDependency 1575 # prevents creating a second dependency named "_save_counter". 1576 self._save_counter = data_structures.NoDependency( 1577 add_variable(self, name="save_counter", initializer=0, 1578 dtype=dtypes.int64)) 1579 1580 def write(self, file_prefix): 1581 """Writes a training checkpoint. 1582 1583 The checkpoint includes variables created by this object and any 1584 trackable objects it depends on at the time `Checkpoint.write()` is 1585 called. 1586 1587 `write` does not number checkpoints, increment `save_counter`, or update the 1588 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1589 use by higher level checkpoint management utilities. `save` provides a very 1590 basic implementation of these features. 1591 1592 Args: 1593 file_prefix: A prefix to use for the checkpoint filenames 1594 (/path/to/directory/and_a_prefix). 1595 1596 Returns: 1597 The full path to the checkpoint (i.e. `file_prefix`). 1598 """ 1599 output = self._saver.save( 1600 file_prefix=file_prefix) 1601 if tensor_util.is_tensor(output): 1602 if context.executing_eagerly(): 1603 return compat.as_str(output.numpy()) 1604 else: 1605 # Function building 1606 return output 1607 else: 1608 # Graph + Session, so we already session.ran it. 1609 return compat.as_str(output) 1610 1611 @property 1612 def save_counter(self): 1613 """An integer variable which starts at zero and is incremented on save. 1614 1615 Used to number checkpoints. 1616 1617 Returns: 1618 The save counter variable. 1619 """ 1620 self._maybe_create_save_counter() 1621 return self._save_counter 1622 1623 def save(self, file_prefix): 1624 """Saves a training checkpoint and provides basic checkpoint management. 1625 1626 The saved checkpoint includes variables created by this object and any 1627 trackable objects it depends on at the time `Checkpoint.save()` is 1628 called. 1629 1630 `save` is a basic convenience wrapper around the `write` method, 1631 sequentially numbering checkpoints using `save_counter` and updating the 1632 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1633 management, for example garbage collection and custom numbering, may be 1634 provided by other utilities which also wrap `write` 1635 (`tf.contrib.checkpoint.CheckpointManager` for example). 1636 1637 Args: 1638 file_prefix: A prefix to use for the checkpoint filenames 1639 (/path/to/directory/and_a_prefix). Names are generated based on this 1640 prefix and `Checkpoint.save_counter`. 1641 1642 Returns: 1643 The full path to the checkpoint. 1644 """ 1645 graph_building = not context.executing_eagerly() 1646 if graph_building: 1647 if ops.inside_function(): 1648 raise NotImplementedError( 1649 "Calling tf.train.Checkpoint.save() from a function is not " 1650 "supported, as save() modifies saving metadata in ways not " 1651 "supported by TensorFlow Operations. Consider using " 1652 "tf.train.Checkpoint.write(), a lower-level API which does not " 1653 "update metadata. tf.train.latest_checkpoint and related APIs will " 1654 "not see this checkpoint.") 1655 session = ops.get_default_session() 1656 if self._save_counter is None: 1657 # When graph building, if this is a new save counter variable then it 1658 # needs to be initialized before assign_add. This is only an issue if 1659 # restore() has not been called first. 1660 session.run(self.save_counter.initializer) 1661 if not graph_building or self._save_assign_op is None: 1662 with ops.colocate_with(self.save_counter): 1663 assign_op = self.save_counter.assign_add(1, read_value=True) 1664 if graph_building: 1665 self._save_assign_op = data_structures.NoDependency(assign_op) 1666 if graph_building: 1667 checkpoint_number = session.run(self._save_assign_op) 1668 else: 1669 checkpoint_number = assign_op.numpy() 1670 file_path = self.write("%s-%d" % (file_prefix, checkpoint_number)) 1671 checkpoint_management.update_checkpoint_state_internal( 1672 save_dir=os.path.dirname(file_prefix), 1673 model_checkpoint_path=file_path, 1674 all_model_checkpoint_paths=[file_path], 1675 save_relative_paths=True) 1676 return file_path 1677 1678 def restore(self, save_path): 1679 """Restore a training checkpoint. 1680 1681 Restores this `Checkpoint` and any objects it depends on. 1682 1683 Either assigns values immediately if variables to restore have been created 1684 already, or defers restoration until the variables are created. Dependencies 1685 added after this call will be matched if they have a corresponding object in 1686 the checkpoint (the restore request will queue in any trackable object 1687 waiting for the expected dependency to be added). 1688 1689 To ensure that loading is complete and no more assignments will take place, 1690 use the `assert_consumed()` method of the status object returned by 1691 `restore`: 1692 1693 ```python 1694 checkpoint = tf.train.Checkpoint( ... ) 1695 checkpoint.restore(path).assert_consumed() 1696 ``` 1697 1698 An exception will be raised if any Python objects in the dependency graph 1699 were not found in the checkpoint, or if any checkpointed values do not have 1700 a matching Python object. 1701 1702 Name-based `tf.train.Saver` checkpoints from TensorFlow 1.x can be loaded 1703 using this method. Names are used to match variables. Re-encode name-based 1704 checkpoints using `tf.train.Checkpoint.save` as soon as possible. 1705 1706 Args: 1707 save_path: The path to the checkpoint, as returned by `save` or 1708 `tf.train.latest_checkpoint`. If None (as when there is no latest 1709 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1710 object which may run initializers for objects in the dependency 1711 graph. If the checkpoint was written by the name-based `tf.train.Saver`, 1712 names are used to match variables. 1713 1714 Returns: 1715 A load status object, which can be used to make assertions about the 1716 status of a checkpoint restoration. 1717 1718 The returned status object has the following methods: 1719 1720 * `assert_consumed()`: 1721 Raises an exception if any variables/objects are unmatched: either 1722 checkpointed values which don't have a matching Python object or 1723 Python objects in the dependency graph with no values in the 1724 checkpoint. This method returns the status object, and so may be 1725 chained with other assertions. 1726 1727 * `assert_existing_objects_matched()`: 1728 Raises an exception if any existing Python objects in the dependency 1729 graph are unmatched. Unlike `assert_consumed`, this assertion will 1730 pass if values in the checkpoint have no corresponding Python 1731 objects. For example a `tf.keras.Layer` object which has not yet been 1732 built, and so has not created any variables, will pass this assertion 1733 but fail `assert_consumed`. Useful when loading part of a larger 1734 checkpoint into a new Python program, e.g. a training checkpoint with 1735 a `tf.train.Optimizer` was saved but only the state required for 1736 inference is being loaded. This method returns the status object, and 1737 so may be chained with other assertions. 1738 1739 * `assert_nontrivial_match()`: Asserts that something aside from the root 1740 object was matched. This is a very weak assertion, but is useful for 1741 sanity checking in library code where objects may exist in the 1742 checkpoint which haven't been created in Python and some Python 1743 objects may not have a checkpointed value. 1744 """ 1745 status = self._saver.restore(save_path=save_path) 1746 # Create the save counter now so it gets initialized with other variables 1747 # when graph building. Creating it earlier would lead to double 1748 # initialization when executing eagerly. 1749 self._maybe_create_save_counter() 1750 return status 1751