1"""Utilities for saving/loading Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22import os 23import weakref 24 25import six 26 27from tensorflow.core.protobuf import trackable_object_graph_pb2 28from tensorflow.python.client import session as session_lib 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors_impl 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.lib.io import file_io 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import gen_io_ops as io_ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.training import checkpoint_management 45from tensorflow.python.training import py_checkpoint_reader 46from tensorflow.python.training import saver as v1_saver_lib 47from tensorflow.python.training.saving import functional_saver 48from tensorflow.python.training.saving import saveable_object_util 49from tensorflow.python.training.tracking import base 50from tensorflow.python.training.tracking import data_structures 51from tensorflow.python.training.tracking import graph_view as graph_view_lib 52from tensorflow.python.training.tracking import tracking 53from tensorflow.python.util import compat 54from tensorflow.python.util import deprecation 55from tensorflow.python.util import lazy_loader 56from tensorflow.python.util import object_identity 57from tensorflow.python.util import tf_contextlib 58from tensorflow.python.util.tf_export import tf_export 59 60 61# Loaded lazily due to a circular dependency. 62keras_backend = lazy_loader.LazyLoader( 63 "keras_backend", globals(), 64 "tensorflow.python.keras.backend") 65 66 67def get_session(): 68 # Prefer TF's default session since get_session from Keras has side-effects. 69 session = ops.get_default_session() 70 if session is None: 71 session = keras_backend.get_session() 72 return session 73 74 75class _ObjectGraphProtoPrettyPrinter(object): 76 """Lazily traverses an object graph proto to pretty print names. 77 78 If no calls to `node_names` are made this object has no performance 79 overhead. On the other hand, it will only traverse the object graph once, so 80 repeated naming is cheap after the first. 81 """ 82 83 def __init__(self, object_graph_proto): 84 self._object_graph_proto = object_graph_proto 85 self._node_name_cache = None 86 87 @property 88 def node_names(self): 89 """Lazily creates a mapping from node id to ("path", "to", "root").""" 90 if self._node_name_cache is not None: 91 return self._node_name_cache 92 path_to_root = {} 93 path_to_root[0] = ("(root)",) 94 to_visit = collections.deque([0]) 95 while to_visit: 96 node_id = to_visit.popleft() 97 obj = self._object_graph_proto.nodes[node_id] 98 for child in obj.children: 99 if child.node_id not in path_to_root: 100 path_to_root[child.node_id] = ( 101 path_to_root[node_id] + (child.local_name,)) 102 to_visit.append(child.node_id) 103 104 node_names = {} 105 for node_id, path_to_root in path_to_root.items(): 106 node_names[node_id] = ".".join(path_to_root) 107 108 for node_id, node in enumerate(self._object_graph_proto.nodes): 109 for slot_reference in node.slot_variables: 110 node_names[slot_reference.slot_variable_node_id] = ( 111 "{}'s state '{}' for {}".format( 112 node_names[node_id], slot_reference.slot_name, 113 node_names[slot_reference.original_variable_node_id])) 114 self._node_name_cache = node_names 115 return node_names 116 117 118class _CheckpointRestoreCoordinatorDeleter(object): 119 """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__().""" 120 121 def __init__(self, expect_partial, object_graph_proto, matched_proto_ids, 122 unused_attributes): 123 self.expect_partial = expect_partial 124 self.object_graph_proto = object_graph_proto 125 self.matched_proto_ids = matched_proto_ids 126 self.unused_attributes = unused_attributes 127 128 def set_expect_partial(self, expect_partial): 129 self.expect_partial = expect_partial 130 131 def __del__(self): 132 if self.expect_partial: 133 return 134 if logging is None: 135 # The logging module may have been unloaded when __del__ is called. 136 log_fn = print 137 else: 138 log_fn = logging.warning 139 printed_warning = False 140 pretty_printer = _ObjectGraphProtoPrettyPrinter(self.object_graph_proto) 141 for node_id in range(len(self.object_graph_proto.nodes)): 142 if node_id not in self.matched_proto_ids: 143 log_fn("Unresolved object in checkpoint: {}" 144 .format(pretty_printer.node_names[node_id])) 145 printed_warning = True 146 for node_id, attribute_name in self.unused_attributes.items(): 147 log_fn(("Unused attribute in object {}: {}" 148 .format(pretty_printer.node_names[node_id], attribute_name))) 149 printed_warning = True 150 if printed_warning: 151 log_fn( 152 "A checkpoint was restored (e.g. tf.train.Checkpoint.restore or " 153 "tf.keras.Model.load_weights) but not all checkpointed values were " 154 "used. See above for specific issues. Use expect_partial() on the " 155 "load status object, e.g. " 156 "tf.train.Checkpoint.restore(...).expect_partial(), to silence these " 157 "warnings, or use assert_consumed() to make the check explicit. See " 158 "https://www.tensorflow.org/guide/checkpoint#loading_mechanics" 159 " for details.") 160 161 162class _CheckpointRestoreCoordinator(object): 163 """Holds the status of an object-based checkpoint load.""" 164 165 def __init__(self, object_graph_proto, save_path, save_path_tensor, 166 restore_op_cache, graph_view): 167 """Specify the checkpoint being loaded. 168 169 Args: 170 object_graph_proto: The TrackableObjectGraph protocol buffer associated 171 with this checkpoint. 172 save_path: A string, the path to the checkpoint, as returned by 173 `tf.train.latest_checkpoint`. 174 save_path_tensor: A string `Tensor` which contains or will be fed the save 175 path. 176 restore_op_cache: A dictionary shared between 177 `_CheckpointRestoreCoordinator`s for the same Python objects, used to 178 look up restore ops by name to avoid re-creating them across multiple 179 `restore()` calls. 180 graph_view: A graph_view_lib.ObjectGraphView object for the restored 181 objects. 182 """ 183 self.object_graph_proto = object_graph_proto 184 self.restore_uid = ops.uid() 185 # Maps from proto ids to lists of attributes which were in the checkpoint 186 # but not loaded into any object, for error checking. 187 self.unused_attributes = {} 188 # Dictionary mapping from an id in the protocol buffer flat array to 189 # Trackable Python objects. This mapping may be deferred if a 190 # checkpoint is restored before all dependencies have been tracked. Uses 191 # weak references so that partial restorations don't create reference cycles 192 # (as objects with deferred dependencies will generally have references to 193 # this object). 194 self.object_by_proto_id = weakref.WeakValueDictionary() 195 self.matched_proto_ids = set() 196 # A set of all Python objects we've seen as dependencies, even if we didn't 197 # use them (for example because of inconsistent references when 198 # loading). Used to make status assertions fail when loading checkpoints 199 # that don't quite match. 200 self.all_python_objects = object_identity.ObjectIdentityWeakSet() 201 self.save_path_tensor = save_path_tensor 202 self.save_path_string = save_path 203 self.dtype_map = py_checkpoint_reader.NewCheckpointReader( 204 save_path).get_variable_to_dtype_map() 205 # A NewCheckpointReader for the most recent checkpoint, for streaming Python 206 # state restoration. 207 # When graph building, contains a list of ops to run to restore objects from 208 # this checkpoint. 209 self.restore_ops = [] 210 self.restore_ops_by_name = restore_op_cache 211 self.graph_view = graph_view 212 self.new_restore_ops_callback = None 213 # A mapping from optimizer proto ids to lists of slot variables to be 214 # restored when the optimizer is tracked. Only includes slot variables whose 215 # regular variables have already been created, and only for optimizer 216 # objects which have not yet been created/tracked. 217 self.deferred_slot_restorations = {} 218 # A mapping from variable proto ids to lists of slot variables to be 219 # restored when the variable is created/tracked. These get shifted over to 220 # deferred_slot_restorations if the optimizer hasn't been created when that 221 # happens. 222 self.slot_restorations = {} 223 # Controls whether errors are printed in __del__ if some objects did not 224 # match. 225 self.expect_partial_attr = False 226 for node_index, node in enumerate(self.object_graph_proto.nodes): 227 for slot_reference in node.slot_variables: 228 # `node` refers to an `Optimizer`, since only these have slot variables. 229 self.slot_restorations.setdefault( 230 slot_reference.original_variable_node_id, []).append( 231 base._SlotVariableRestoration( # pylint: disable=protected-access 232 optimizer_id=node_index, 233 slot_variable_id=slot_reference.slot_variable_node_id, 234 slot_name=slot_reference.slot_name)) 235 236 self._deleter = _CheckpointRestoreCoordinatorDeleter( 237 self.expect_partial_attr, 238 self.object_graph_proto, 239 self.matched_proto_ids, 240 self.unused_attributes) 241 242 @property 243 def expect_partial(self): 244 return self.expect_partial_attr 245 246 @expect_partial.setter 247 def expect_partial(self, expect_partial): 248 self.expect_partial_attr = expect_partial 249 self._deleter.set_expect_partial(expect_partial) 250 251 def new_restore_ops(self, new_ops): 252 self.restore_ops.extend(new_ops) 253 if self.new_restore_ops_callback: 254 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable 255 256 def restore_saveables(self, tensor_saveables, python_saveables): 257 """Run or build restore operations for SaveableObjects. 258 259 Args: 260 tensor_saveables: `SaveableObject`s which correspond to Tensors. 261 python_saveables: `PythonStateSaveable`s which correspond to Python 262 values. 263 264 Returns: 265 When graph building, a list of restore operations, either cached or newly 266 created, to restore `tensor_saveables`. 267 """ 268 restore_ops = [] 269 # Eagerly run restorations for Python state. 270 reader = None 271 for saveable in python_saveables: 272 if reader is None: 273 # Lazily create the NewCheckpointReader, since this requires file access 274 # and we may not have any Python saveables. 275 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string) 276 spec_names = [spec.name for spec in saveable.specs] 277 saveable.python_restore([reader.get_tensor(name) for name in spec_names]) 278 279 # If we have new SaveableObjects, extract and cache restore ops. 280 if tensor_saveables: 281 validated_saveables = saveable_object_util.validate_and_slice_inputs( 282 tensor_saveables) 283 validated_names = set(saveable.name for saveable in validated_saveables) 284 if set(tensor_saveables.keys()) != validated_names: 285 raise AssertionError( 286 ("Saveable keys changed when validating. Got back %s, was " 287 "expecting %s") % (tensor_saveables.keys(), validated_names)) 288 new_restore_ops = functional_saver.MultiDeviceSaver( 289 validated_saveables).restore(self.save_path_tensor) 290 if not context.executing_eagerly(): 291 for name, restore_op in sorted(new_restore_ops.items()): 292 restore_ops.append(restore_op) 293 assert name not in self.restore_ops_by_name 294 self.restore_ops_by_name[name] = restore_op 295 return restore_ops 296 297 298class _NameBasedRestoreCoordinator(object): 299 """Keeps the status of a name-based checkpoint restore.""" 300 301 def __init__(self, save_path, dtype_map=None): 302 self.save_path = save_path 303 self.dtype_map = dtype_map 304 # A map from trackable objects to unused attribute names. We don't have 305 # proto IDs when doing a name-based restore, so the map keys differ from 306 # those in _CheckpointRestoreCoordinator. 307 self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary() 308 self.restore_uid = ops.uid() 309 310 def globally_named_object_attributes(self, trackable): 311 """Create globally named SaveableObjects from attributes. 312 313 If an object's attribute has no global name specified (default construction 314 for the SaveableObject factory), records the failure in 315 `self.unused_attributes` (which can then be used to make status assertions 316 fail; see `NameBasedSaverStatus`). 317 318 Args: 319 trackable: An object to save. 320 321 Yields: 322 SaveableObjects for `trackable`'s attributes. 323 """ 324 for attribute_name, saveable_factory in ( 325 trackable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 326 if callable(saveable_factory): 327 try: 328 # This saveable object factory does not have a default name= argument, 329 # which means there's no way to save/restore it using a name-based 330 # checkpoint. Ignore the error now and make sure assert_consumed() 331 # fails. 332 saveable = saveable_factory() 333 except TypeError: 334 # Even if we can't name this object, we should construct it and check 335 # whether it's optional to restore it. If it's optional we don't need 336 # to make assertions fail. 337 if not saveable_factory("").optional_restore: 338 self.unused_attributes.setdefault(trackable, 339 []).append(attribute_name) 340 continue 341 else: 342 saveable = saveable_factory 343 names_to_saveables = saveable_object_util.op_list_to_dict( 344 [saveable], convert_variable_to_tensor=False) 345 for name, op in names_to_saveables.items(): 346 for saveable_object in saveable_object_util.saveable_objects_for_op( 347 op=op, name=name): 348 yield saveable_object 349 350 def eager_restore(self, trackable): 351 """Runs restore ops for `trackable`'s attributes.""" 352 # When graph building, we don't add any restore ops to the graph until 353 # run_restore_ops/initialize_or_restore on the status object for name-based 354 # checkpoints. 355 assert context.executing_eagerly() 356 for saveable in self.globally_named_object_attributes(trackable): 357 restored_tensors = [] 358 tensor_missing = False 359 for spec in saveable.specs: 360 if spec.name in self.dtype_map: 361 with ops.device("cpu:0"): 362 restored, = io_ops.restore_v2( 363 prefix=self.save_path, 364 tensor_names=[spec.name], 365 shape_and_slices=[""], 366 dtypes=[self.dtype_map[spec.name]], 367 name="%s_checkpoint_read" % (spec.name,)) 368 restored_tensors.append(array_ops.identity(restored)) 369 else: 370 tensor_missing = True 371 372 if tensor_missing: 373 # Record that this variable didn't match so assertions will fail. 374 self.unused_attributes.setdefault(trackable, []).append(saveable.name) 375 else: 376 # Ignores values missing from the checkpoint, as with object-based 377 # restore. Status assertions can be used to check exact matches, 378 # although it's unlikely to ever happen for name-based checkpoints. 379 saveable.restore( 380 restored_tensors=restored_tensors, restored_shapes=None) 381 382 383# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange 384# or consolidating the implementation with get_variable. 385def _default_getter(name, 386 shape, 387 dtype, 388 initializer=None, 389 partition_info=None, 390 **kwargs): 391 """A pared-down version of get_variable which does not reuse variables.""" 392 dtype = dtypes.as_dtype(dtype) 393 shape_object = tensor_shape.as_shape(shape) 394 with ops.init_scope(): 395 if initializer is None: 396 initializer, initializing_from_value = ( 397 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access 398 name=name, 399 shape=shape_object, 400 dtype=dtype)) 401 else: 402 initializing_from_value = not callable(initializer) 403 # Same logic as get_variable 404 variable_dtype = dtype.base_dtype 405 if initializing_from_value: 406 if shape is not None: 407 raise ValueError("If initializer is a constant, do not specify shape.") 408 initial_value = initializer 409 else: 410 # Instantiate initializer if provided initializer is a type object. 411 if isinstance(initializer, type(init_ops.Initializer)): 412 initializer = initializer(dtype=dtype) 413 414 def initial_value(): 415 return initializer( 416 shape_object.as_list(), dtype=dtype, partition_info=partition_info) 417 418 return variables.VariableV1( 419 initial_value=initial_value, 420 name=name, 421 dtype=variable_dtype, 422 use_resource=True, 423 **kwargs) 424 425 426def add_variable(trackable, 427 name, 428 shape=None, 429 dtype=dtypes.float32, 430 initializer=None, 431 trainable=True): 432 """Add a variable to a Trackable with no scope influence.""" 433 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access 434 name=name, 435 shape=shape, 436 dtype=dtype, 437 initializer=initializer, 438 getter=_default_getter, 439 trainable=trainable) 440 441 442def object_metadata(save_path): 443 """Retrieves information about the objects in a checkpoint. 444 445 Example usage: 446 447 ```python 448 object_graph = tf.contrib.checkpoint.object_metadata( 449 tf.train.latest_checkpoint(checkpoint_directory)) 450 ckpt_variable_names = set() 451 for node in object_graph.nodes: 452 for attribute in node.attributes: 453 ckpt_variable_names.add(attribute.full_name) 454 ``` 455 456 Args: 457 save_path: The path to the checkpoint, as returned by `save` or 458 `tf.train.latest_checkpoint`. 459 460 Returns: 461 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. 462 Raises: 463 ValueError: If an object graph was not found in the checkpoint. 464 """ 465 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 466 try: 467 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 468 except errors_impl.NotFoundError: 469 raise ValueError( 470 ('The specified checkpoint "%s" does not appear to be object-based (it ' 471 'is missing the key "%s"). Likely it was created with a name-based ' 472 "saver and does not contain an object dependency graph.") % 473 (save_path, base.OBJECT_GRAPH_PROTO_KEY)) 474 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 475 object_graph_proto.ParseFromString(object_graph_string) 476 return object_graph_proto 477 478 479def list_objects(root_trackable): 480 """Traverse the object graph and list all accessible objects. 481 482 Looks for `Trackable` objects which are dependencies of 483 `root_trackable`. Includes slot variables only if the variable they are 484 slotting for and the optimizer are dependencies of `root_trackable` 485 (i.e. if they would be saved with a checkpoint). 486 487 Args: 488 root_trackable: A `Trackable` object whose dependencies should be flattened. 489 490 Returns: 491 A flat list of objects. 492 """ 493 return graph_view_lib.ObjectGraphView(root_trackable).list_objects() 494 495 496def gather_initializers(root_trackable): 497 """Traverse the object graph and find initialization ops. 498 499 Looks for `Trackable` objects which are dependencies of 500 `root_trackable` and which have an `initializer` property. Includes 501 initializers for slot variables only if the variable they are slotting for and 502 the optimizer are dependencies of `root_trackable` (i.e. if they would be 503 saved with a checkpoint). 504 505 Args: 506 root_trackable: A `Trackable` object to gather initializers for. 507 508 Returns: 509 A list of initialization ops. 510 """ 511 trackable_objects = list_objects(root_trackable) 512 return [ 513 c.initializer 514 for c in trackable_objects 515 if hasattr(c, "initializer") and c.initializer is not None 516 ] 517 518 519@tf_contextlib.contextmanager 520def capture_dependencies(template): 521 """Capture variables created within this scope as `Template` dependencies. 522 523 Requires that `template.variable_scope` is active. 524 525 This scope is intended as a compatibility measure, allowing a trackable 526 object to add dependencies on variables created in a block of code which is 527 not aware of object-based saving (and instead uses variable names 528 heavily). This is how `Template` objects add dependencies on variables and 529 sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly. 530 531 Args: 532 template: The `Template` object to register dependencies with. 533 534 Yields: 535 None (when used as a context manager). 536 """ 537 name_prefix = template.variable_scope.name 538 539 def _trackable_custom_creator(next_creator, 540 name, 541 initial_value, 542 trackable_parent=None, 543 **kwargs): 544 """A variable creation hook which adds Trackable dependencies. 545 546 Set for example during a `Template`'s first wrapped function 547 execution. Ensures that (a) `template` depends on any trackable 548 objects using their own `capture_dependencies` scope inside this scope which 549 create variables, and (b) that any variables not in a more deeply nested 550 scope are added as dependencies directly. 551 552 The `trackable_parent` argument is passed between custom creators but 553 ignored when the variable object itself is created. This argument indicates 554 (if not `None`) that a more deeply nested scope has already added the 555 variable as a dependency, and that parent scopes should add a dependency on 556 that object rather than on the variable directly. 557 558 Args: 559 next_creator: See `variable_scope.variable_creator_scope`; the next 560 creator in the chain. 561 name: The (full, scope-influenced) name of the variable. The `name_prefix` 562 itself is stripped for the purposes of object-based dependency tracking, 563 but scopes opened within this scope are respected. 564 initial_value: See `variable_scope.variable_creator_scope`. Taken 565 explicitly so the argument can be re-named and used with 566 `Trackable._add_variable_with_custom_getter`. 567 trackable_parent: If not None, a more deeply nested trackable object and 568 its name prefix which were passed to `capture_dependencies` to add a 569 dependency on (rather than depending on the variable directly). 570 **kwargs: Passed through to the next creator. 571 572 Returns: 573 The output of `next_creator`: the fetched/created variable object. 574 """ 575 576 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): 577 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which 578 # we don't want to propagate. 579 return next_creator(initial_value=initializer, name=name, **inner_kwargs) 580 581 if name is not None and name.startswith(name_prefix): 582 scope_stripped_name = name[len(name_prefix) + 1:] 583 if not trackable_parent: 584 return template._add_variable_with_custom_getter( # pylint: disable=protected-access 585 initializer=initial_value, 586 name=scope_stripped_name, 587 getter=_call_next_creator_renaming_initializer, 588 # Disable error checking for Trackable. Exceptions are instead 589 # raised if necessary when the object-based saver tries to 590 # save/restore the object. 591 overwrite=True, 592 trackable_parent=(template, name_prefix), 593 **kwargs) 594 else: 595 parent_object, parent_name_prefix = trackable_parent 596 template._track_trackable( # pylint: disable=protected-access 597 parent_object, 598 name=parent_name_prefix[len(name_prefix) + 1:], 599 overwrite=True) 600 return next_creator( 601 name=name, 602 initial_value=initial_value, 603 trackable_parent=(template, name_prefix), 604 **kwargs) 605 606 with variable_scope.variable_creator_scope(_trackable_custom_creator): 607 yield 608 609 610class _LoadStatus(object): 611 """Abstract base for load status callbacks.""" 612 613 @abc.abstractmethod 614 def assert_consumed(self): 615 """Raises an exception unless a non-trivial restoration has completed.""" 616 pass 617 618 @abc.abstractmethod 619 def assert_existing_objects_matched(self): 620 """Raises an exception unless existing Python objects have been matched.""" 621 pass 622 623 @abc.abstractmethod 624 def assert_nontrivial_match(self): 625 """Raises an exception if only the root object matched.""" 626 pass 627 628 @abc.abstractmethod 629 def run_restore_ops(self, session=None): 630 """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" 631 pass 632 633 @abc.abstractmethod 634 def initialize_or_restore(self, session=None): 635 """Runs restore ops from the checkpoint, or initializes variables.""" 636 pass 637 638 def expect_partial(self): 639 """Silence warnings about incomplete checkpoint restores.""" 640 return self 641 642 643def streaming_restore(status, session=None): 644 """When graph building, runs restore ops as soon as they come in. 645 646 Args: 647 status: A _LoadStatus objects from an object-based saver's restore(). 648 Streaming restore from name-based checkpoints is not currently supported. 649 session: A session to run new restore ops in. 650 """ 651 if context.executing_eagerly(): 652 # Streaming restore is the default/only behavior when executing eagerly. 653 return 654 if session is None: 655 session = get_session() 656 if isinstance(status, NameBasedSaverStatus): 657 raise NotImplementedError( 658 "Streaming restore not supported from name-based checkpoints when " 659 "graph building. File a feature request if this limitation bothers " 660 "you. As a workaround, consider either using tf.train.Checkpoint to " 661 "load name-based checkpoints or enabling eager execution.") 662 status.run_restore_ops(session=session) 663 # pylint: disable=protected-access 664 status._checkpoint.new_restore_ops_callback = ( 665 lambda ops: session.run(ops, feed_dict=status._feed_dict)) 666 # pylint: enable=protected-access 667 668 669def _objects_with_attributes(full_list): 670 """Filters out objects with no direct variable dependencies for assertions.""" 671 return [o for o in full_list if o._gather_saveables_for_checkpoint()] # pylint: disable=protected-access 672 673 674class CheckpointLoadStatus(_LoadStatus): 675 """Checks the status of checkpoint loading and manages restore ops. 676 677 Returned from `Saver.restore`. Since `restore` may defer the loading of values 678 in the checkpoint which don't yet have corresponding Python objects, 679 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading 680 is complete (`assert_consumed`). 681 682 When graph building, `restore` does not run restore ops itself since their 683 creation may be deferred. The `run_restore_ops` method must be called once all 684 Python objects with values to restore have been created and added to the 685 dependency graph (this does not necessarily have to be the whole checkpoint; 686 calling `run_restore_ops` while `assert_consumed` fails is supported and will 687 partially restore the checkpoint). 688 689 See `Saver.restore` for usage examples. 690 """ 691 692 def __init__(self, checkpoint, feed_dict, graph_view): 693 self._checkpoint = checkpoint 694 self._feed_dict = feed_dict 695 self._graph_view = graph_view 696 # Keep a reference to the root, since graph_view might only have a weakref. 697 self._root = graph_view.root 698 699 def assert_consumed(self): 700 """Asserts that all objects in the checkpoint have been created/matched. 701 702 Returns: 703 `self` for chaining. 704 Raises: 705 AssertionError: If there are any Python objects in the dependency graph 706 which have not been restored from this checkpoint or a later `restore`, 707 or if there are any checkpointed values which have not been matched to 708 Python objects. 709 """ 710 pretty_printer = _ObjectGraphProtoPrettyPrinter( 711 self._checkpoint.object_graph_proto) 712 self.assert_existing_objects_matched() 713 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 714 if not node.attributes: 715 # Only raise exceptions for the nodes with attributes themselves. Either 716 # they're ultimately not important, or they have a child with an 717 # attribute. 718 continue 719 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 720 if trackable is None: 721 raise AssertionError("Unresolved object in checkpoint {}: {}" 722 .format(pretty_printer.node_names[node_id], node)) 723 if self._checkpoint.slot_restorations: 724 # Sanity check; this collection should be clear if everything has been 725 # restored. 726 raise AssertionError("Unresolved slot restorations: %s" % 727 (self._checkpoint.slot_restorations,)) 728 if self._checkpoint.unused_attributes: 729 unused_attribute_messages = [] 730 for node_id, attribute in six.iteritems( 731 self._checkpoint.unused_attributes): 732 obj = self._checkpoint.object_by_proto_id[node_id] 733 unused_attribute_messages.append( 734 "{} ({}): {}" 735 .format(pretty_printer.node_names[node_id], obj, attribute)) 736 raise AssertionError( 737 ("Unused attributes in these objects (the attributes exist in the " 738 "checkpoint but were not restored):\n{}") 739 .format("\n".join(unused_attribute_messages))) 740 return self 741 742 def assert_existing_objects_matched(self): 743 """Asserts that trackable Python objects have been matched. 744 745 Note that this is a weaker assertion than `assert_consumed`. It will only 746 fail for existing Python objects which are (transitive) dependencies of the 747 root object and which do not have an entry in the checkpoint. 748 749 It will not fail, for example, if a `tf.keras.Layer` object has not yet been 750 built and so has not created any `tf.Variable` objects. 751 752 Returns: 753 `self` for chaining. 754 755 Raises: 756 AssertionError: If a Python object exists in the transitive dependencies 757 of the root object but does not have a value in the checkpoint. 758 """ 759 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 760 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 761 if (trackable is not None and 762 trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access 763 raise AssertionError("Object not assigned a value from checkpoint: %s" % 764 (node,)) 765 for trackable_object in self._graph_view.list_objects(): 766 # Remove data structures that do not contain any variables from 767 # restoration checks. 768 if (isinstance(trackable_object, 769 data_structures.TrackableDataStructure) and 770 not trackable_object._checkpoint_dependencies): 771 continue 772 self._checkpoint.all_python_objects.add(trackable_object) 773 unused_python_objects = ( 774 object_identity.ObjectIdentitySet( 775 _objects_with_attributes( 776 self._checkpoint.all_python_objects)) - 777 object_identity.ObjectIdentitySet( 778 self._checkpoint.object_by_proto_id.values())) 779 if unused_python_objects: 780 raise AssertionError( 781 ("Some Python objects were not bound to checkpointed values, likely " 782 "due to changes in the Python program: %s") % 783 (list(unused_python_objects),)) 784 return self 785 786 def assert_nontrivial_match(self): 787 """Raises an exception if only the root object matched.""" 788 for trackable_object in self._graph_view.list_objects(): 789 self._checkpoint.all_python_objects.add(trackable_object) 790 if len(self._checkpoint.object_by_proto_id) <= 1: 791 unused_python_objects = ( 792 object_identity.ObjectIdentitySet( 793 _objects_with_attributes(self._checkpoint.all_python_objects)) 794 - object_identity.ObjectIdentitySet( 795 self._checkpoint.object_by_proto_id.values())) 796 if unused_python_objects: 797 raise AssertionError( 798 ("Nothing except the root object matched a checkpointed value. " 799 "Typically this means that the checkpoint does not match the " 800 "Python program. The following objects have no matching " 801 "checkpointed value: %s") % (list(unused_python_objects),)) 802 else: 803 raise AssertionError( 804 "Nothing to load. No dependencies have been added to %s yet." % 805 (self._graph_view.root,)) 806 return self 807 808 def run_restore_ops(self, session=None): 809 """Run operations to restore objects in the dependency graph.""" 810 if context.executing_eagerly(): 811 return # Run eagerly 812 if session is None: 813 session = get_session() 814 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) 815 816 def initialize_or_restore(self, session=None): 817 """Run operations to initialize or restore objects in the dependency graph. 818 819 Any objects in the dependency graph which have initializers but are not in 820 the checkpoint will have those initializers run, unless those variables are 821 being restored by a later call to `tf.train.Checkpoint.restore()`. 822 823 This method has a sibling in `InitializationOnlyStatus` which instead 824 initializes variables. That type is returned if no checkpoint is specified 825 in `Saver.restore`. 826 827 Args: 828 session: The session to run init/restore ops in. If `None`, uses the 829 default session. 830 """ 831 if context.executing_eagerly(): 832 return # Initialization and restoration ops are run eagerly 833 if session is None: 834 session = get_session() 835 all_objects = self._graph_view.list_objects() 836 already_initialized_objects = object_identity.ObjectIdentitySet( 837 self._checkpoint.object_by_proto_id.values()) 838 initializers_for_non_restored_variables = [ 839 c.initializer for c in all_objects 840 if hasattr(c, "initializer") 841 and c not in already_initialized_objects 842 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1) 843 < self._checkpoint.restore_uid)] 844 self.run_restore_ops(session=session) 845 session.run(initializers_for_non_restored_variables) 846 847 def expect_partial(self): 848 """Silence warnings about incomplete checkpoint restores.""" 849 self._checkpoint.expect_partial = True 850 return self 851 852 853class InitializationOnlyStatus(_LoadStatus): 854 """Returned from `Saver.restore` when no checkpoint has been specified. 855 856 Objects of this type have the same `assert_consumed` method as 857 `CheckpointLoadStatus`, but it always fails. However, 858 `initialize_or_restore` works on objects of both types, and will 859 initialize variables in `InitializationOnlyStatus` objects or restore them 860 otherwise. 861 """ 862 863 def __init__(self, graph_view, restore_uid): 864 self._restore_uid = restore_uid 865 self._graph_view = graph_view 866 # Keep a reference to the root, since graph_view might only have a weakref. 867 self._root = graph_view.root 868 869 def assert_consumed(self): 870 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 871 raise AssertionError( 872 "No checkpoint specified (save_path=None); nothing is being restored.") 873 874 def assert_existing_objects_matched(self): 875 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 876 raise AssertionError( 877 "No checkpoint specified (save_path=None); nothing is being restored.") 878 879 def assert_nontrivial_match(self): 880 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 881 raise AssertionError( 882 "No checkpoint specified (save_path=None); nothing is being restored.") 883 884 def run_restore_ops(self, session=None): 885 """For consistency with `CheckpointLoadStatus`. 886 887 Use `initialize_or_restore` for initializing if no checkpoint was passed 888 to `Saver.restore` and restoring otherwise. 889 890 Args: 891 session: Not used. 892 """ 893 raise AssertionError( 894 "No checkpoint specified, so no restore ops are available " 895 "(save_path=None to Saver.restore).") 896 897 def initialize_or_restore(self, session=None): 898 """Runs initialization ops for variables. 899 900 Objects which would be saved by `Saver.save` will be initialized, unless 901 those variables are being restored by a later call to 902 `tf.train.Checkpoint.restore()`. 903 904 This method does nothing when executing eagerly (initializers get run 905 eagerly). 906 907 Args: 908 session: The session to run initialization ops in. If `None`, uses the 909 default session. 910 """ 911 if context.executing_eagerly(): 912 return # run eagerly 913 if session is None: 914 session = get_session() 915 trackable_objects = self._graph_view.list_objects() 916 initializers = [ 917 c.initializer for c in trackable_objects 918 if hasattr(c, "initializer") and c.initializer is not None 919 and (getattr(c, "_update_uid", self._restore_uid - 1) 920 < self._restore_uid)] 921 session.run(initializers) 922 923 924_DEPRECATED_RESTORE_INSTRUCTIONS = ( 925 "Restoring a name-based tf.train.Saver checkpoint using the object-based " 926 "restore API. This mode uses global names to match variables, and so is " 927 "somewhat fragile. It also adds new restore ops to the graph each time it " 928 "is called when graph building. Prefer re-encoding training checkpoints in " 929 "the object-based format: run save() on the object-based saver (the same " 930 "one this message is coming from) and use that checkpoint in the future.") 931 932 933class NameBasedSaverStatus(_LoadStatus): 934 """Status for loading a name-based training checkpoint.""" 935 936 # Ideally this deprecation decorator would be on the class, but that 937 # interferes with isinstance checks. 938 @deprecation.deprecated( 939 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) 940 def __init__(self, checkpoint, graph_view): 941 self._checkpoint = checkpoint 942 self._graph_view = graph_view 943 self._optionally_restored = [] 944 # Keep a reference to the root, since graph_view might only have a weakref. 945 self._root = graph_view.root 946 947 def add_to_optionally_restored(self, var): 948 """Add a variable to the list of optionally restored variables. 949 950 There are situations where certain variables should be ignored in assertions 951 such as assert_existing_objects_matched(). One example is that of a 952 checkpoint saved with train.Saver(), and restored with train.Checkpoint(): 953 it is possible for the train.Saver() checkpoint to be missing the internal 954 `save_counter` variable, which we want to ignore on restore. 955 956 Args: 957 var: The variable to treat as optionally restored. 958 """ 959 self._optionally_restored.append(var) 960 961 def assert_consumed(self): 962 """Raises an exception if any variables are unmatched.""" 963 unused_attributes = list(self._checkpoint.unused_attributes.items()) 964 unused_attributes = [ 965 a for a in unused_attributes 966 if all(a[0] is not x for x in self._optionally_restored) 967 ] 968 if unused_attributes: 969 unused_attribute_strings = [ 970 "\n {}: {}".format(obj, attributes) 971 for obj, attributes in unused_attributes 972 ] 973 raise AssertionError( 974 "Some objects had attributes which were not restored:{}".format( 975 "".join(unused_attribute_strings))) 976 for trackable in self._graph_view.list_objects(): 977 # pylint: disable=protected-access 978 trackable._maybe_initialize_trackable() 979 if trackable._update_uid < self._checkpoint.restore_uid: 980 raise AssertionError("Object not restored: %s" % (trackable,)) 981 # pylint: enable=protected-access 982 return self 983 984 def assert_existing_objects_matched(self): 985 """Raises an exception if currently created objects are unmatched.""" 986 # For name-based checkpoints there's no object information in the 987 # checkpoint, so there's no distinction between 988 # assert_existing_objects_matched and assert_consumed (and both are less 989 # useful since we don't touch Python objects or Python state). 990 return self.assert_consumed() 991 992 def assert_nontrivial_match(self): 993 """Raises an exception if currently created objects are unmatched.""" 994 # For name-based checkpoints there's no object information in the 995 # checkpoint, so there's no distinction between 996 # assert_nontrivial_match and assert_consumed (and both are less 997 # useful since we don't touch Python objects or Python state). 998 return self.assert_consumed() 999 1000 def _gather_saveable_objects(self): 1001 """Walk the object graph, using global names for SaveableObjects.""" 1002 objects = self._graph_view.list_objects() 1003 saveable_objects = [] 1004 for trackable in objects: 1005 # pylint: disable=protected-access 1006 trackable._maybe_initialize_trackable() 1007 if trackable._update_uid < self._checkpoint.restore_uid: 1008 trackable._update_uid = self._checkpoint.restore_uid 1009 else: 1010 continue 1011 # pylint: enable=protected-access 1012 saveable_objects.extend( 1013 self._checkpoint.globally_named_object_attributes(trackable)) 1014 return saveable_objects 1015 1016 def run_restore_ops(self, session=None): 1017 """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`.""" 1018 if context.executing_eagerly(): 1019 return # Nothing to do, variables are restored on creation. 1020 if session is None: 1021 session = get_session() 1022 with ops.device("/cpu:0"): 1023 saveables = self._gather_saveable_objects() 1024 v1_saver_lib.Saver(saveables).restore( 1025 sess=session, save_path=self._checkpoint.save_path) 1026 1027 def initialize_or_restore(self, session=None): 1028 """Alias for `run_restore_ops`.""" 1029 self.run_restore_ops(session=session) 1030 1031 1032class _SessionWithFeedDictAdditions(session_lib.SessionInterface): 1033 """Pretends to be a session, inserts extra feeds on run().""" 1034 1035 def __init__(self, session, feed_additions): 1036 self._wrapped_session = session 1037 self._feed_additions = feed_additions 1038 1039 def run(self, fetches, feed_dict=None, **kwargs): 1040 if feed_dict is None: 1041 feed_dict = {} 1042 else: 1043 feed_dict = feed_dict.copy() 1044 feed_dict.update(self._feed_additions) 1045 return self._wrapped_session.run( 1046 fetches=fetches, feed_dict=feed_dict, **kwargs) 1047 1048 1049class TrackableSaver(object): 1050 """Saves and restores a `Trackable` object and its dependencies. 1051 1052 See `Trackable` for details of dependency management. `Saver` wraps 1053 `tf.compat.v1.train.Saver` for saving, including extra information about the 1054 graph of 1055 dependencies between Python objects. When restoring, it uses this information 1056 about the save-time dependency graph to more robustly match objects with their 1057 checkpointed values. When executing eagerly, it supports restoring variables 1058 on object creation (see `Saver.restore`). 1059 1060 Values in a checkpoint are mapped to `Trackable` Python objects 1061 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the 1062 checkpoint was written. To avoid breaking existing checkpoints when modifying 1063 a class, dependency names (the names of attributes to which `Trackable` 1064 objects are assigned) may not change. These names are local to objects, in 1065 contrast to the `Variable.name`-based save/restore from 1066 `tf.compat.v1.train.Saver`, and 1067 so allow additional program transformations. 1068 """ 1069 1070 def __init__(self, graph_view): 1071 """Configure saving. 1072 1073 Args: 1074 graph_view: A `GraphView` object containing a description of the object 1075 graph to save. 1076 """ 1077 # The file prefix placeholder is created lazily when graph building (and not 1078 # at all when executing eagerly) to avoid creating ops in the constructor 1079 # (when they may never be necessary). 1080 self._file_prefix_placeholder = None 1081 1082 # Op caching for save 1083 self._object_graph_feed_tensor = None 1084 self._last_save_object_graph = None 1085 self._file_prefix_feed_tensor = None 1086 self._cached_save_operation = None 1087 1088 # Op caching for restore, shared between _CheckpointRestoreCoordinators 1089 self._restore_op_cache = {} 1090 self._graph_view = graph_view 1091 1092 def _gather_saveables(self, object_graph_tensor=None): 1093 """Wraps _serialize_object_graph to include the object graph proto.""" 1094 (named_saveable_objects, graph_proto, 1095 feed_additions) = self._graph_view.serialize_object_graph() 1096 if object_graph_tensor is None: 1097 with ops.device("/cpu:0"): 1098 object_graph_tensor = constant_op.constant( 1099 graph_proto.SerializeToString(), dtype=dtypes.string) 1100 else: 1101 feed_additions.update( 1102 {object_graph_tensor: graph_proto.SerializeToString()}) 1103 assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects 1104 named_saveable_objects.append( 1105 base.NoRestoreSaveable( 1106 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) 1107 return named_saveable_objects, graph_proto, feed_additions 1108 1109 def _save_cached_when_graph_building(self, 1110 file_prefix, 1111 object_graph_tensor=None): 1112 """Create or retrieve save ops. 1113 1114 Args: 1115 file_prefix: The prefix for saved checkpoint files. 1116 object_graph_tensor: A `Tensor` to which the current object graph will be 1117 fed. 1118 1119 Returns: 1120 A two-element tuple with a filename tensor and a feed_dict of tensors to 1121 feed when running it (if graph building). The feed dict contains the 1122 current object graph and any Python state to be saved in the 1123 checkpoint. When executing eagerly only the first argument is meaningful. 1124 """ 1125 (named_saveable_objects, graph_proto, 1126 feed_additions) = self._gather_saveables( 1127 object_graph_tensor=object_graph_tensor) 1128 if (self._last_save_object_graph != graph_proto 1129 # When executing eagerly, we need to re-create SaveableObjects each time 1130 # save() is called so they pick up new Tensors passed to their 1131 # constructors. That means the Saver needs to be copied with a new 1132 # var_list. 1133 or context.executing_eagerly() or ops.inside_function()): 1134 saver = functional_saver.MultiDeviceSaver(named_saveable_objects) 1135 save_op = saver.save(file_prefix) 1136 with ops.device("/cpu:0"): 1137 with ops.control_dependencies([save_op]): 1138 self._cached_save_operation = array_ops.identity(file_prefix) 1139 self._last_save_object_graph = graph_proto 1140 return self._cached_save_operation, feed_additions 1141 1142 def save(self, file_prefix, checkpoint_number=None, session=None): 1143 """Save a training checkpoint. 1144 1145 The saved checkpoint includes variables created by this object and any 1146 Trackable objects it depends on at the time `Saver.save()` is called. 1147 1148 Args: 1149 file_prefix: A prefix to use for the checkpoint filenames 1150 (/path/to/directory/and_a_prefix). Names are generated based on this 1151 prefix and `checkpoint_number`, if provided. 1152 checkpoint_number: An integer variable or Tensor, used to number 1153 checkpoints. Typically this value is saved along with other variables in 1154 training checkpoints, which will happen automatically if it was created 1155 by `root_trackable` or one of its dependencies (via 1156 `Trackable._add_variable`). 1157 session: The session to evaluate variables in. Ignored when executing 1158 eagerly. If not provided when graph building, the default session is 1159 used. 1160 1161 Returns: 1162 The full path to the checkpoint. 1163 """ 1164 feed_dict = {} 1165 use_session = (not context.executing_eagerly() and 1166 not ops.inside_function()) 1167 if checkpoint_number: 1168 file_prefix = "%s-%d" % (file_prefix, checkpoint_number) 1169 if use_session: 1170 if self._object_graph_feed_tensor is None: 1171 with ops.device("/cpu:0"): 1172 self._object_graph_feed_tensor = constant_op.constant( 1173 "", dtype=dtypes.string) 1174 self._file_prefix_feed_tensor = constant_op.constant( 1175 "", dtype=dtypes.string) 1176 object_graph_tensor = self._object_graph_feed_tensor 1177 file_prefix_tensor = self._file_prefix_feed_tensor 1178 feed_dict[file_prefix_tensor] = file_prefix 1179 else: 1180 with ops.device("/cpu:0"): 1181 file_prefix_tensor = constant_op.constant( 1182 file_prefix, dtype=dtypes.string) 1183 object_graph_tensor = None 1184 1185 file_io.recursive_create_dir(os.path.dirname(file_prefix)) 1186 save_path, new_feed_additions = self._save_cached_when_graph_building( 1187 file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor) 1188 if new_feed_additions: 1189 feed_dict.update(new_feed_additions) 1190 if not use_session: 1191 session = None 1192 elif session is None: 1193 session = get_session() 1194 1195 if session: 1196 return session.run(save_path, feed_dict=feed_dict) 1197 else: 1198 return save_path 1199 1200 def restore(self, save_path): 1201 """Restore a training checkpoint. 1202 1203 Restores `root_trackable` and any objects that it tracks 1204 (transitive). Either assigns values immediately if variables to restore have 1205 been created already, or defers restoration until the variables are 1206 created. Dependencies added to the `root_trackable` passed to the 1207 constructor after this call will be matched if they have a corresponding 1208 object in the checkpoint. 1209 1210 When building a graph, restorations are added to the graph but not run. 1211 1212 To disallow deferred loading, assert immediately that all checkpointed 1213 variables have been matched to variable objects: 1214 1215 ```python 1216 saver = Saver(root) 1217 saver.restore(path).assert_consumed() 1218 ``` 1219 1220 An exception will be raised unless every object was matched and its 1221 variables already exist. 1222 1223 When graph building, `assert_consumed()` indicates that all of the restore 1224 ops which will be created for this checkpoint have been created. They can be 1225 run via the `run_restore_ops()` function of the status object: 1226 1227 ```python 1228 saver.restore(path).assert_consumed().run_restore_ops() 1229 ``` 1230 1231 If the checkpoint has not been consumed completely, then the list of restore 1232 ops will grow as more objects are added to the dependency graph. 1233 1234 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1235 method. There is no deferred loading, and names are used to match 1236 variables. No restore ops are created/run until `run_restore_ops()` or 1237 `initialize_or_restore()` are called on the returned status object, even 1238 when executing eagerly. Re-encode name-based checkpoints using this 1239 object-based `Saver.save` as soon as possible. 1240 1241 Args: 1242 save_path: The path to the checkpoint, as returned by `save` or 1243 `tf.train.latest_checkpoint`. If None (as when there is no latest 1244 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1245 object which may run initializers for objects in the dependency graph. 1246 If the checkpoint was written by the name-based 1247 `tf.compat.v1.train.Saver`, names are used to match variables. 1248 1249 Returns: 1250 A load status object, which can be used to make assertions about the 1251 status of checkpoint restoration and run initialization/restore ops 1252 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if 1253 `save_path` is `None`). 1254 1255 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` 1256 object is returned which runs restore ops from a name-based saver. 1257 """ 1258 if save_path is None: 1259 return InitializationOnlyStatus(self._graph_view, ops.uid()) 1260 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 1261 graph_building = not context.executing_eagerly() 1262 if graph_building: 1263 dtype_map = None 1264 else: 1265 dtype_map = reader.get_variable_to_dtype_map() 1266 try: 1267 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 1268 except errors_impl.NotFoundError: 1269 # The object graph proto does not exist in this checkpoint. Try the 1270 # name-based compatibility mode. 1271 restore_coordinator = _NameBasedRestoreCoordinator( 1272 save_path=save_path, 1273 dtype_map=dtype_map) 1274 if not graph_building: 1275 for existing_trackable in self._graph_view.list_objects(): 1276 # pylint: disable=protected-access 1277 existing_trackable._maybe_initialize_trackable() 1278 existing_trackable._name_based_restores.add(restore_coordinator) 1279 existing_trackable._name_based_attribute_restore(restore_coordinator) 1280 # pylint: enable=protected-access 1281 return NameBasedSaverStatus( 1282 restore_coordinator, 1283 graph_view=self._graph_view) 1284 1285 if graph_building: 1286 if self._file_prefix_placeholder is None: 1287 with ops.device("/cpu:0"): 1288 self._file_prefix_placeholder = constant_op.constant("model") 1289 file_prefix_tensor = self._file_prefix_placeholder 1290 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 1291 else: 1292 with ops.device("/cpu:0"): 1293 file_prefix_tensor = constant_op.constant(save_path) 1294 file_prefix_feed_dict = None 1295 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1296 object_graph_proto.ParseFromString(object_graph_string) 1297 checkpoint = _CheckpointRestoreCoordinator( 1298 object_graph_proto=object_graph_proto, 1299 save_path=save_path, 1300 save_path_tensor=file_prefix_tensor, 1301 restore_op_cache=self._restore_op_cache, 1302 graph_view=self._graph_view) 1303 base.CheckpointPosition( 1304 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) 1305 load_status = CheckpointLoadStatus( 1306 checkpoint, 1307 graph_view=self._graph_view, 1308 feed_dict=file_prefix_feed_dict) 1309 return load_status 1310 1311 1312def frozen_saver(root_trackable): 1313 """Creates a static `tf.compat.v1.train.Saver` from a trackable object. 1314 1315 The returned `Saver` saves object-based checkpoints, but these checkpoints 1316 will no longer reflect structural changes to the object graph, only changes to 1317 the values of `Variable`s added as dependencies of the root object before 1318 `freeze` was called. 1319 1320 `restore` works on the returned `Saver`, but requires that the object graph of 1321 the checkpoint being loaded exactly matches the object graph when `freeze` was 1322 called. This is in contrast the object-based restore performed by 1323 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's 1324 object graph and the current Python object graph. 1325 1326 Args: 1327 root_trackable: A trackable object to save. 1328 1329 Returns: 1330 A saver which saves object-based checkpoints for the object graph frozen at 1331 the time `frozen_saver` was called. 1332 """ 1333 named_saveable_objects = graph_view_lib.ObjectGraphView( 1334 root_trackable).frozen_saveable_objects() 1335 return functional_saver.MultiDeviceSaver(named_saveable_objects) 1336 1337 1338def saver_with_op_caching(obj): 1339 """A TrackableSaver with a SaveableObject cache when graph building.""" 1340 if context.executing_eagerly(): 1341 saveables_cache = None 1342 else: 1343 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 1344 return TrackableSaver( 1345 graph_view_lib.ObjectGraphView( 1346 weakref.ref(obj), saveables_cache=saveables_cache)) 1347 1348 1349# Mentions graph building / Sessions. The v2 version is below. 1350@tf_export(v1=["train.Checkpoint"]) 1351class CheckpointV1(tracking.AutoTrackable): 1352 """Groups trackable objects, saving and restoring them. 1353 1354 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1355 that contain trackable state, such as `tf.compat.v1.train.Optimizer` 1356 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1357 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1358 maintains a `save_counter` for numbering checkpoints. 1359 1360 Example usage when graph building: 1361 1362 ```python 1363 import tensorflow as tf 1364 import os 1365 1366 checkpoint_directory = "/tmp/training_checkpoints" 1367 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1368 1369 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1370 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1371 train_op = optimizer.minimize( ... ) 1372 status.assert_consumed() # Optional sanity checks. 1373 with tf.compat.v1.Session() as session: 1374 # Use the Session to restore variables, or initialize them if 1375 # tf.train.latest_checkpoint returned None. 1376 status.initialize_or_restore(session) 1377 for _ in range(num_training_steps): 1378 session.run(train_op) 1379 checkpoint.save(file_prefix=checkpoint_prefix) 1380 ``` 1381 1382 Example usage with eager execution enabled: 1383 1384 ```python 1385 import tensorflow as tf 1386 import os 1387 1388 tf.compat.v1.enable_eager_execution() 1389 1390 checkpoint_directory = "/tmp/training_checkpoints" 1391 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1392 1393 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1394 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1395 for _ in range(num_training_steps): 1396 optimizer.minimize( ... ) # Variables will be restored on creation. 1397 status.assert_consumed() # Optional sanity checks. 1398 checkpoint.save(file_prefix=checkpoint_prefix) 1399 ``` 1400 1401 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1402 checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads 1403 `variable.name` based checkpoints. Object-based checkpointing saves a graph of 1404 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, 1405 etc.) with named edges, and this graph is used to match variables when 1406 restoring a checkpoint. It can be more robust to changes in the Python 1407 program, and helps to support restore-on-create for variables when executing 1408 eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new 1409 code. 1410 1411 `Checkpoint` objects have dependencies on the objects passed as keyword 1412 arguments to their constructors, and each dependency is given a name that is 1413 identical to the name of the keyword argument for which it was created. 1414 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1415 dependencies on their variables (e.g. "kernel" and "bias" for 1416 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1417 dependencies easy in user-defined classes, since `Model` hooks into attribute 1418 assignment. For example: 1419 1420 ```python 1421 class Regress(tf.keras.Model): 1422 1423 def __init__(self): 1424 super(Regress, self).__init__() 1425 self.input_transform = tf.keras.layers.Dense(10) 1426 # ... 1427 1428 def call(self, inputs): 1429 x = self.input_transform(inputs) 1430 # ... 1431 ``` 1432 1433 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1434 which in turn depends on its variables. As a result, saving an instance of 1435 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1436 by the `Dense` layer. 1437 1438 When variables are assigned to multiple workers, each worker writes its own 1439 section of the checkpoint. These sections are then merged/re-indexed to behave 1440 as a single checkpoint. This avoids copying all variables to one worker, but 1441 does require that all workers see a common filesystem. 1442 1443 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the 1444 same format, note that the root of the resulting checkpoint is the object the 1445 save method is attached to. This means saving a `tf.keras.Model` using 1446 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model` 1447 attached (or vice versa) will not match the `Model`'s variables. See the 1448 [guide to training 1449 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1450 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for 1451 training checkpoints. 1452 1453 Attributes: 1454 save_counter: Incremented when `save()` is called. Used to number 1455 checkpoints. 1456 """ 1457 1458 def __init__(self, **kwargs): 1459 """Group objects into a training checkpoint. 1460 1461 Args: 1462 **kwargs: Keyword arguments are set as attributes of this object, and are 1463 saved with the checkpoint. Values must be trackable objects. 1464 1465 Raises: 1466 ValueError: If objects in `kwargs` are not trackable. 1467 """ 1468 super(CheckpointV1, self).__init__() 1469 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1470 setattr(self, k, v) 1471 if not isinstance( 1472 getattr(self, k), (base.Trackable, def_function.Function)): 1473 raise ValueError( 1474 ("`Checkpoint` was expecting a trackable object (an object " 1475 "derived from `TrackableBase`), got %s. If you believe this " 1476 "object should be trackable (i.e. it is part of the " 1477 "TensorFlow Python API and manages state), please open an issue.") 1478 % (v,)) 1479 self._save_counter = None # Created lazily for restore-on-create. 1480 self._save_assign_op = None 1481 self._saver = saver_with_op_caching(self) 1482 1483 def _maybe_create_save_counter(self): 1484 """Create a save counter if it does not yet exist.""" 1485 if self._save_counter is None: 1486 # Initialized to 0 and incremented before saving. 1487 with ops.device("/cpu:0"): 1488 # add_variable creates a dependency named "save_counter"; NoDependency 1489 # prevents creating a second dependency named "_save_counter". 1490 self._save_counter = data_structures.NoDependency( 1491 add_variable( 1492 self, 1493 name="save_counter", 1494 initializer=0, 1495 dtype=dtypes.int64, 1496 trainable=False)) 1497 1498 def write(self, file_prefix, session=None): 1499 """Writes a training checkpoint. 1500 1501 The checkpoint includes variables created by this object and any 1502 trackable objects it depends on at the time `Checkpoint.write()` is 1503 called. 1504 1505 `write` does not number checkpoints, increment `save_counter`, or update the 1506 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1507 use by higher level checkpoint management utilities. `save` provides a very 1508 basic implementation of these features. 1509 1510 Args: 1511 file_prefix: A prefix to use for the checkpoint filenames 1512 (/path/to/directory/and_a_prefix). 1513 session: The session to evaluate variables in. Ignored when executing 1514 eagerly. If not provided when graph building, the default session is 1515 used. 1516 1517 Returns: 1518 The full path to the checkpoint (i.e. `file_prefix`). 1519 """ 1520 output = self._saver.save(file_prefix=file_prefix, session=session) 1521 if tensor_util.is_tensor(output): 1522 if context.executing_eagerly(): 1523 return compat.as_str(output.numpy()) 1524 else: 1525 # Function building 1526 return output 1527 else: 1528 # Graph + Session, so we already session.ran it. 1529 return compat.as_str(output) 1530 1531 @property 1532 def save_counter(self): 1533 """An integer variable which starts at zero and is incremented on save. 1534 1535 Used to number checkpoints. 1536 1537 Returns: 1538 The save counter variable. 1539 """ 1540 self._maybe_create_save_counter() 1541 return self._save_counter 1542 1543 def save(self, file_prefix, session=None): 1544 """Saves a training checkpoint and provides basic checkpoint management. 1545 1546 The saved checkpoint includes variables created by this object and any 1547 trackable objects it depends on at the time `Checkpoint.save()` is 1548 called. 1549 1550 `save` is a basic convenience wrapper around the `write` method, 1551 sequentially numbering checkpoints using `save_counter` and updating the 1552 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1553 management, for example garbage collection and custom numbering, may be 1554 provided by other utilities which also wrap `write` 1555 (`tf.train.CheckpointManager` for example). 1556 1557 Args: 1558 file_prefix: A prefix to use for the checkpoint filenames 1559 (/path/to/directory/and_a_prefix). Names are generated based on this 1560 prefix and `Checkpoint.save_counter`. 1561 session: The session to evaluate variables in. Ignored when executing 1562 eagerly. If not provided when graph building, the default session is 1563 used. 1564 1565 Returns: 1566 The full path to the checkpoint. 1567 """ 1568 graph_building = not context.executing_eagerly() 1569 if graph_building: 1570 if ops.inside_function(): 1571 raise NotImplementedError( 1572 "Calling tf.train.Checkpoint.save() from a function is not " 1573 "supported, as save() modifies saving metadata in ways not " 1574 "supported by TensorFlow Operations. Consider using " 1575 "tf.train.Checkpoint.write(), a lower-level API which does not " 1576 "update metadata. tf.train.latest_checkpoint and related APIs will " 1577 "not see this checkpoint.") 1578 if session is None: 1579 session = get_session() 1580 if self._save_counter is None: 1581 # When graph building, if this is a new save counter variable then it 1582 # needs to be initialized before assign_add. This is only an issue if 1583 # restore() has not been called first. 1584 session.run(self.save_counter.initializer) 1585 if not graph_building or self._save_assign_op is None: 1586 with ops.colocate_with(self.save_counter): 1587 assign_op = self.save_counter.assign_add(1, read_value=True) 1588 if graph_building: 1589 self._save_assign_op = data_structures.NoDependency(assign_op) 1590 if graph_building: 1591 checkpoint_number = session.run(self._save_assign_op) 1592 else: 1593 checkpoint_number = assign_op.numpy() 1594 file_path = self.write( 1595 "%s-%d" % (file_prefix, checkpoint_number), session=session) 1596 checkpoint_management.update_checkpoint_state_internal( 1597 save_dir=os.path.dirname(file_prefix), 1598 model_checkpoint_path=file_path, 1599 all_model_checkpoint_paths=[file_path], 1600 save_relative_paths=True) 1601 return file_path 1602 1603 def restore(self, save_path): 1604 """Restore a training checkpoint. 1605 1606 Restores this `Checkpoint` and any objects it depends on. 1607 1608 When executing eagerly, either assigns values immediately if variables to 1609 restore have been created already, or defers restoration until the variables 1610 are created. Dependencies added after this call will be matched if they have 1611 a corresponding object in the checkpoint (the restore request will queue in 1612 any trackable object waiting for the expected dependency to be added). 1613 1614 When graph building, restoration ops are added to the graph but not run 1615 immediately. 1616 1617 To ensure that loading is complete and no more assignments will take place, 1618 use the `assert_consumed()` method of the status object returned by 1619 `restore`: 1620 1621 ```python 1622 checkpoint = tf.train.Checkpoint( ... ) 1623 checkpoint.restore(path).assert_consumed() 1624 ``` 1625 1626 An exception will be raised if any Python objects in the dependency graph 1627 were not found in the checkpoint, or if any checkpointed values do not have 1628 a matching Python object. 1629 1630 When graph building, `assert_consumed()` indicates that all of the restore 1631 ops that will be created for this checkpoint have been created. They can be 1632 run via the `run_restore_ops()` method of the status object: 1633 1634 ```python 1635 checkpoint.restore(path).assert_consumed().run_restore_ops() 1636 ``` 1637 1638 If the checkpoint has not been consumed completely, then the list of restore 1639 ops will grow as more objects are added to the dependency graph. 1640 1641 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1642 method. Names are used to match variables. No restore ops are created/run 1643 until `run_restore_ops()` or `initialize_or_restore()` are called on the 1644 returned status object when graph building, but there is restore-on-creation 1645 when executing eagerly. Re-encode name-based checkpoints using 1646 `tf.train.Checkpoint.save` as soon as possible. 1647 1648 Args: 1649 save_path: The path to the checkpoint, as returned by `save` or 1650 `tf.train.latest_checkpoint`. If None (as when there is no latest 1651 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1652 object which may run initializers for objects in the dependency graph. 1653 If the checkpoint was written by the name-based 1654 `tf.compat.v1.train.Saver`, names are used to match variables. 1655 1656 Returns: 1657 A load status object, which can be used to make assertions about the 1658 status of a checkpoint restoration and run initialization/restore ops. 1659 1660 The returned status object has the following methods: 1661 1662 * `assert_consumed()`: 1663 Raises an exception if any variables are unmatched: either 1664 checkpointed values which don't have a matching Python object or 1665 Python objects in the dependency graph with no values in the 1666 checkpoint. This method returns the status object, and so may be 1667 chained with `initialize_or_restore` or `run_restore_ops`. 1668 1669 * `assert_existing_objects_matched()`: 1670 Raises an exception if any existing Python objects in the dependency 1671 graph are unmatched. Unlike `assert_consumed`, this assertion will 1672 pass if values in the checkpoint have no corresponding Python 1673 objects. For example a `tf.keras.Layer` object which has not yet been 1674 built, and so has not created any variables, will pass this assertion 1675 but fail `assert_consumed`. Useful when loading part of a larger 1676 checkpoint into a new Python program, e.g. a training checkpoint with 1677 a `tf.compat.v1.train.Optimizer` was saved but only the state required 1678 for 1679 inference is being loaded. This method returns the status object, and 1680 so may be chained with `initialize_or_restore` or `run_restore_ops`. 1681 1682 * `assert_nontrivial_match()`: Asserts that something aside from the root 1683 object was matched. This is a very weak assertion, but is useful for 1684 sanity checking in library code where objects may exist in the 1685 checkpoint which haven't been created in Python and some Python 1686 objects may not have a checkpointed value. 1687 1688 * `expect_partial()`: Silence warnings about incomplete checkpoint 1689 restores. Warnings are otherwise printed for unused parts of the 1690 checkpoint file or object when the `Checkpoint` object is deleted 1691 (often at program shutdown). 1692 1693 * `initialize_or_restore(session=None)`: 1694 When graph building, runs variable initializers if `save_path` is 1695 `None`, but otherwise runs restore operations. If no `session` is 1696 explicitly specified, the default session is used. No effect when 1697 executing eagerly (variables are initialized or restored eagerly). 1698 1699 * `run_restore_ops(session=None)`: 1700 When graph building, runs restore operations. If no `session` is 1701 explicitly specified, the default session is used. No effect when 1702 executing eagerly (restore operations are run eagerly). May only be 1703 called when `save_path` is not `None`. 1704 """ 1705 status = self._saver.restore(save_path=save_path) 1706 # Create the save counter now so it gets initialized with other variables 1707 # when graph building. Creating it earlier would lead to errors when using, 1708 # say, train.Saver() to save the model before initializing it. 1709 self._maybe_create_save_counter() 1710 if isinstance(status, NameBasedSaverStatus): 1711 status.add_to_optionally_restored(self.save_counter) 1712 return status 1713 1714 1715@tf_export("train.Checkpoint", v1=[]) 1716class Checkpoint(tracking.AutoTrackable): 1717 """Groups trackable objects, saving and restoring them. 1718 1719 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1720 that contain trackable state, such as `tf.keras.optimizers.Optimizer` 1721 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1722 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1723 maintains a `save_counter` for numbering checkpoints. 1724 1725 Example usage: 1726 1727 ```python 1728 import tensorflow as tf 1729 import os 1730 1731 checkpoint_directory = "/tmp/training_checkpoints" 1732 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1733 1734 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1735 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1736 for _ in range(num_training_steps): 1737 optimizer.minimize( ... ) # Variables will be restored on creation. 1738 status.assert_consumed() # Optional sanity checks. 1739 checkpoint.save(file_prefix=checkpoint_prefix) 1740 ``` 1741 1742 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1743 checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which 1744 writes and 1745 reads `variable.name` based checkpoints. Object-based checkpointing saves a 1746 graph of dependencies between Python objects (`Layer`s, `Optimizer`s, 1747 `Variable`s, etc.) with named edges, and this graph is used to match variables 1748 when restoring a checkpoint. It can be more robust to changes in the Python 1749 program, and helps to support restore-on-create for variables. 1750 1751 `Checkpoint` objects have dependencies on the objects passed as keyword 1752 arguments to their constructors, and each dependency is given a name that is 1753 identical to the name of the keyword argument for which it was created. 1754 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1755 dependencies on their variables (e.g. "kernel" and "bias" for 1756 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1757 dependencies easy in user-defined classes, since `Model` hooks into attribute 1758 assignment. For example: 1759 1760 ```python 1761 class Regress(tf.keras.Model): 1762 1763 def __init__(self): 1764 super(Regress, self).__init__() 1765 self.input_transform = tf.keras.layers.Dense(10) 1766 # ... 1767 1768 def call(self, inputs): 1769 x = self.input_transform(inputs) 1770 # ... 1771 ``` 1772 1773 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1774 which in turn depends on its variables. As a result, saving an instance of 1775 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1776 by the `Dense` layer. 1777 1778 When variables are assigned to multiple workers, each worker writes its own 1779 section of the checkpoint. These sections are then merged/re-indexed to behave 1780 as a single checkpoint. This avoids copying all variables to one worker, but 1781 does require that all workers see a common filesystem. 1782 1783 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the 1784 same format, note that the root of the resulting checkpoint is the object the 1785 save method is attached to. This means saving a `tf.keras.Model` using 1786 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model` 1787 attached (or vice versa) will not match the `Model`'s variables. See the 1788 [guide to training 1789 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1790 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for 1791 training checkpoints. 1792 1793 Attributes: 1794 save_counter: Incremented when `save()` is called. Used to number 1795 checkpoints. 1796 """ 1797 1798 def __init__(self, **kwargs): 1799 """Group objects into a training checkpoint. 1800 1801 Args: 1802 **kwargs: Keyword arguments are set as attributes of this object, and are 1803 saved with the checkpoint. Values must be trackable objects. 1804 1805 Raises: 1806 ValueError: If objects in `kwargs` are not trackable. 1807 """ 1808 super(Checkpoint, self).__init__() 1809 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1810 setattr(self, k, v) 1811 if not isinstance( 1812 getattr(self, k), (base.Trackable, def_function.Function)): 1813 raise ValueError( 1814 ("`Checkpoint` was expecting a trackable object (an object " 1815 "derived from `TrackableBase`), got %s. If you believe this " 1816 "object should be trackable (i.e. it is part of the " 1817 "TensorFlow Python API and manages state), please open an issue.") 1818 % (v,)) 1819 self._save_counter = None # Created lazily for restore-on-create. 1820 self._save_assign_op = None 1821 self._saver = saver_with_op_caching(self) 1822 1823 def _maybe_create_save_counter(self): 1824 """Create a save counter if it does not yet exist.""" 1825 if self._save_counter is None: 1826 # Initialized to 0 and incremented before saving. 1827 with ops.device("/cpu:0"): 1828 # add_variable creates a dependency named "save_counter"; NoDependency 1829 # prevents creating a second dependency named "_save_counter". 1830 self._save_counter = data_structures.NoDependency( 1831 add_variable( 1832 self, 1833 name="save_counter", 1834 initializer=0, 1835 dtype=dtypes.int64, 1836 trainable=False)) 1837 1838 def write(self, file_prefix): 1839 """Writes a training checkpoint. 1840 1841 The checkpoint includes variables created by this object and any 1842 trackable objects it depends on at the time `Checkpoint.write()` is 1843 called. 1844 1845 `write` does not number checkpoints, increment `save_counter`, or update the 1846 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1847 use by higher level checkpoint management utilities. `save` provides a very 1848 basic implementation of these features. 1849 1850 Args: 1851 file_prefix: A prefix to use for the checkpoint filenames 1852 (/path/to/directory/and_a_prefix). 1853 1854 Returns: 1855 The full path to the checkpoint (i.e. `file_prefix`). 1856 """ 1857 output = self._saver.save(file_prefix=file_prefix) 1858 if tensor_util.is_tensor(output): 1859 if context.executing_eagerly(): 1860 return compat.as_str(output.numpy()) 1861 else: 1862 # Function building 1863 return output 1864 else: 1865 # Graph + Session, so we already session.ran it. 1866 return compat.as_str(output) 1867 1868 @property 1869 def save_counter(self): 1870 """An integer variable which starts at zero and is incremented on save. 1871 1872 Used to number checkpoints. 1873 1874 Returns: 1875 The save counter variable. 1876 """ 1877 self._maybe_create_save_counter() 1878 return self._save_counter 1879 1880 def save(self, file_prefix): 1881 """Saves a training checkpoint and provides basic checkpoint management. 1882 1883 The saved checkpoint includes variables created by this object and any 1884 trackable objects it depends on at the time `Checkpoint.save()` is 1885 called. 1886 1887 `save` is a basic convenience wrapper around the `write` method, 1888 sequentially numbering checkpoints using `save_counter` and updating the 1889 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1890 management, for example garbage collection and custom numbering, may be 1891 provided by other utilities which also wrap `write` 1892 (`tf.train.CheckpointManager` for example). 1893 1894 Args: 1895 file_prefix: A prefix to use for the checkpoint filenames 1896 (/path/to/directory/and_a_prefix). Names are generated based on this 1897 prefix and `Checkpoint.save_counter`. 1898 1899 Returns: 1900 The full path to the checkpoint. 1901 """ 1902 graph_building = not context.executing_eagerly() 1903 if graph_building: 1904 if ops.inside_function(): 1905 raise NotImplementedError( 1906 "Calling tf.train.Checkpoint.save() from a function is not " 1907 "supported, as save() modifies saving metadata in ways not " 1908 "supported by TensorFlow Operations. Consider using " 1909 "tf.train.Checkpoint.write(), a lower-level API which does not " 1910 "update metadata. tf.train.latest_checkpoint and related APIs will " 1911 "not see this checkpoint.") 1912 session = get_session() 1913 if self._save_counter is None: 1914 # When graph building, if this is a new save counter variable then it 1915 # needs to be initialized before assign_add. This is only an issue if 1916 # restore() has not been called first. 1917 session.run(self.save_counter.initializer) 1918 if not graph_building or self._save_assign_op is None: 1919 with ops.colocate_with(self.save_counter): 1920 assign_op = self.save_counter.assign_add(1, read_value=True) 1921 if graph_building: 1922 self._save_assign_op = data_structures.NoDependency(assign_op) 1923 if graph_building: 1924 checkpoint_number = session.run(self._save_assign_op) 1925 else: 1926 checkpoint_number = assign_op.numpy() 1927 file_path = self.write("%s-%d" % (file_prefix, checkpoint_number)) 1928 checkpoint_management.update_checkpoint_state_internal( 1929 save_dir=os.path.dirname(file_prefix), 1930 model_checkpoint_path=file_path, 1931 all_model_checkpoint_paths=[file_path], 1932 save_relative_paths=True) 1933 return file_path 1934 1935 def restore(self, save_path): 1936 """Restore a training checkpoint. 1937 1938 Restores this `Checkpoint` and any objects it depends on. 1939 1940 Either assigns values immediately if variables to restore have been created 1941 already, or defers restoration until the variables are created. Dependencies 1942 added after this call will be matched if they have a corresponding object in 1943 the checkpoint (the restore request will queue in any trackable object 1944 waiting for the expected dependency to be added). 1945 1946 To ensure that loading is complete and no more assignments will take place, 1947 use the `assert_consumed()` method of the status object returned by 1948 `restore`: 1949 1950 ```python 1951 checkpoint = tf.train.Checkpoint( ... ) 1952 checkpoint.restore(path).assert_consumed() 1953 ``` 1954 1955 An exception will be raised if any Python objects in the dependency graph 1956 were not found in the checkpoint, or if any checkpointed values do not have 1957 a matching Python object. 1958 1959 Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be 1960 loaded 1961 using this method. Names are used to match variables. Re-encode name-based 1962 checkpoints using `tf.train.Checkpoint.save` as soon as possible. 1963 1964 Args: 1965 save_path: The path to the checkpoint, as returned by `save` or 1966 `tf.train.latest_checkpoint`. If None (as when there is no latest 1967 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1968 object which may run initializers for objects in the dependency graph. 1969 If the checkpoint was written by the name-based 1970 `tf.compat.v1.train.Saver`, names are used to match variables. 1971 1972 Returns: 1973 A load status object, which can be used to make assertions about the 1974 status of a checkpoint restoration. 1975 1976 The returned status object has the following methods: 1977 1978 * `assert_consumed()`: 1979 Raises an exception if any variables are unmatched: either 1980 checkpointed values which don't have a matching Python object or 1981 Python objects in the dependency graph with no values in the 1982 checkpoint. This method returns the status object, and so may be 1983 chained with other assertions. 1984 1985 * `assert_existing_objects_matched()`: 1986 Raises an exception if any existing Python objects in the dependency 1987 graph are unmatched. Unlike `assert_consumed`, this assertion will 1988 pass if values in the checkpoint have no corresponding Python 1989 objects. For example a `tf.keras.Layer` object which has not yet been 1990 built, and so has not created any variables, will pass this assertion 1991 but fail `assert_consumed`. Useful when loading part of a larger 1992 checkpoint into a new Python program, e.g. a training checkpoint with 1993 a `tf.compat.v1.train.Optimizer` was saved but only the state required 1994 for 1995 inference is being loaded. This method returns the status object, and 1996 so may be chained with other assertions. 1997 1998 * `assert_nontrivial_match()`: Asserts that something aside from the root 1999 object was matched. This is a very weak assertion, but is useful for 2000 sanity checking in library code where objects may exist in the 2001 checkpoint which haven't been created in Python and some Python 2002 objects may not have a checkpointed value. 2003 2004 * `expect_partial()`: Silence warnings about incomplete checkpoint 2005 restores. Warnings are otherwise printed for unused parts of the 2006 checkpoint file or object when the `Checkpoint` object is deleted 2007 (often at program shutdown). 2008 """ 2009 status = self._saver.restore(save_path=save_path) 2010 # Create the save counter now so it gets initialized with other variables 2011 # when graph building. Creating it earlier would lead to errors when using, 2012 # say, train.Saver() to save the model before initializing it. 2013 self._maybe_create_save_counter() 2014 if isinstance(status, NameBasedSaverStatus): 2015 status.add_to_optionally_restored(self.save_counter) 2016 return status 2017