1"""Utilities for saving/loading Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16 17import abc 18import collections 19import functools 20import glob 21import os 22import threading 23import time 24import weakref 25 26from tensorflow.core.protobuf import trackable_object_graph_pb2 27from tensorflow.python.checkpoint import checkpoint_management 28from tensorflow.python.checkpoint import checkpoint_options 29from tensorflow.python.checkpoint import functional_saver 30from tensorflow.python.checkpoint import graph_view as graph_view_lib 31from tensorflow.python.checkpoint import restore as restore_lib 32from tensorflow.python.checkpoint import save_util_v1 33from tensorflow.python.checkpoint import util 34from tensorflow.python.client import session as session_lib 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.eager import executor 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import errors_impl 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import tensor_shape 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.lib.io import file_io 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import gen_io_ops as io_ops 47from tensorflow.python.ops import init_ops 48from tensorflow.python.ops import variable_scope 49from tensorflow.python.ops import variables 50from tensorflow.python.platform import gfile 51from tensorflow.python.platform import tf_logging as logging 52from tensorflow.python.saved_model import utils_impl 53from tensorflow.python.saved_model.pywrap_saved_model import metrics 54from tensorflow.python.trackable import autotrackable 55from tensorflow.python.trackable import base 56from tensorflow.python.trackable import data_structures 57from tensorflow.python.training import py_checkpoint_reader 58from tensorflow.python.training import saver as v1_saver_lib 59from tensorflow.python.training.saving import saveable_object_util 60from tensorflow.python.util import compat 61from tensorflow.python.util import deprecation 62from tensorflow.python.util import object_identity 63from tensorflow.python.util import tf_contextlib 64from tensorflow.python.util import tf_inspect 65from tensorflow.python.util.tf_export import tf_export 66 67# The callable that provide Keras default session that is needed for saving. 68_SESSION_PROVIDER = None 69 70# Captures the timestamp of the first Checkpoint instantiation or end of a write 71# operation. Can be accessed by multiple Checkpoint instances. 72_END_TIME_OF_LAST_WRITE = None 73_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock() 74 75# API labels for cell names used in checkpoint metrics. 76_CHECKPOINT_V1 = "checkpoint_v1" 77_CHECKPOINT_V2 = "checkpoint_v2" 78 79# Async thread used for asynchronous checkpoint. 80_ASYNC_CHECKPOINT_THREAD = None 81 82 83def _get_duration_microseconds(start_time_seconds, end_time_seconds): 84 if end_time_seconds < start_time_seconds: 85 # Avoid returning negative value in case of clock skew. 86 return 0 87 return round((end_time_seconds - start_time_seconds) * 1000000) 88 89 90@tf_export("__internal__.tracking.register_session_provider", v1=[]) 91def register_session_provider(session_provider): 92 global _SESSION_PROVIDER 93 # TODO(scottzhu): Change it back to only allow one time setting for session 94 # provider once we finished the keras repo split. 95 # if _SESSION_PROVIDER is None: 96 _SESSION_PROVIDER = session_provider 97 98 99def get_session(): 100 # Prefer TF's default session since get_session from Keras has side-effects. 101 session = ops.get_default_session() 102 if session is None: 103 global _SESSION_PROVIDER 104 if _SESSION_PROVIDER is not None: 105 session = _SESSION_PROVIDER() # pylint: disable=not-callable 106 return session 107 108 109def _get_checkpoint_size(prefix): 110 """Calculates filesize of checkpoint based on prefix.""" 111 size = 0 112 # Gather all files beginning with prefix (.index plus sharded data files). 113 files = glob.glob("{}*".format(prefix)) 114 for file in files: 115 # Use TensorFlow's C++ FileSystem API. 116 size += metrics.CalculateFileSize(file) 117 return size 118 119 120class ObjectGraphProtoPrettyPrinter: 121 """Lazily traverses an object graph proto to pretty print names. 122 123 If no calls to `node_names` are made this object has no performance 124 overhead. On the other hand, it will only traverse the object graph once, so 125 repeated naming is cheap after the first. 126 """ 127 128 __slots__ = ["_object_graph_proto", "_node_name_cache"] 129 130 def __init__(self, object_graph_proto): 131 self._object_graph_proto = object_graph_proto 132 self._node_name_cache = None 133 134 @property 135 def node_names(self): 136 """Lazily creates a mapping from node id to ("path", "to", "root").""" 137 if self._node_name_cache is not None: 138 return self._node_name_cache 139 path_to_root = {} 140 path_to_root[0] = ("(root)",) 141 to_visit = collections.deque([0]) 142 while to_visit: 143 node_id = to_visit.popleft() 144 obj = self._object_graph_proto.nodes[node_id] 145 for child in obj.children: 146 if child.node_id not in path_to_root: 147 path_to_root[child.node_id] = ( 148 path_to_root[node_id] + (child.local_name,)) 149 to_visit.append(child.node_id) 150 151 node_names = {} 152 for node_id, path_to_root in path_to_root.items(): 153 node_names[node_id] = ".".join(path_to_root) 154 155 for node_id, node in enumerate(self._object_graph_proto.nodes): 156 for slot_reference in node.slot_variables: 157 node_names[slot_reference.slot_variable_node_id] = ( 158 f"{node_names[node_id]}'s state '{slot_reference.slot_name}' for " 159 f"{node_names[slot_reference.original_variable_node_id]}") 160 self._node_name_cache = node_names 161 return node_names 162 163 164class _CheckpointRestoreCoordinatorDeleter: 165 """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__().""" 166 167 __slots__ = [ 168 "expect_partial", "object_graph_proto", "matched_proto_ids", 169 "unused_attributes" 170 ] 171 172 def __init__(self, expect_partial, object_graph_proto, matched_proto_ids, 173 unused_attributes): 174 self.expect_partial = expect_partial 175 self.object_graph_proto = object_graph_proto 176 self.matched_proto_ids = matched_proto_ids 177 self.unused_attributes = unused_attributes 178 179 def set_expect_partial(self, expect_partial): 180 self.expect_partial = expect_partial 181 182 def __del__(self): 183 if self.expect_partial: 184 return 185 if logging is None: 186 # The logging module may have been unloaded when __del__ is called. 187 log_fn = print 188 else: 189 log_fn = logging.warning 190 unused_nodes_in_checkpoint = [] 191 unrestored_attributes_in_object = [] 192 pretty_printer = ObjectGraphProtoPrettyPrinter(self.object_graph_proto) 193 for node_id, node in enumerate(self.object_graph_proto.nodes): 194 if not node.attributes: 195 continue 196 if node_id not in self.matched_proto_ids: 197 unused_nodes_in_checkpoint.append(pretty_printer.node_names[node_id]) 198 for node_id, attribute_name in self.unused_attributes.items(): 199 unrestored_attributes_in_object.append(( 200 pretty_printer.node_names[node_id], attribute_name)) 201 if unused_nodes_in_checkpoint or unrestored_attributes_in_object: 202 # pylint:disable=line-too-long 203 log_fn("Detecting that an object or model or tf.train.Checkpoint is being" 204 " deleted with unrestored values. See the following logs for the " 205 "specific values in question. To silence these warnings, use " 206 "`status.expect_partial()`. See " 207 "https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restore" 208 "for details about the status object returned by the restore " 209 "function.") 210 # pylint:enable=line-too-long 211 for node_path in unused_nodes_in_checkpoint: 212 log_fn("Value in checkpoint could not be found in the restored object: " 213 f"{node_path}") 214 for node_path, attr in unrestored_attributes_in_object: 215 log_fn("An attribute in the restored object could not be found in the " 216 f"checkpoint. Object: {node_path}, attribute: {attr}") 217 218 219class _CheckpointRestoreCoordinator: 220 """Holds the status of an object-based checkpoint load.""" 221 222 def __init__(self, object_graph_proto, save_path, save_path_tensor, reader, 223 restore_op_cache, graph_view, options, saveables_cache): 224 """Specify the checkpoint being loaded. 225 226 Args: 227 object_graph_proto: The TrackableObjectGraph protocol buffer associated 228 with this checkpoint. 229 save_path: A string, the path to the checkpoint, as returned by 230 `tf.train.latest_checkpoint`. 231 save_path_tensor: A string `Tensor` which contains or will be fed the save 232 path. 233 reader: A `CheckpointReader` for `save_path`. If None, 234 `_CheckpointRestoreCoordinator` will initialize one itself. 235 restore_op_cache: A dictionary shared between 236 `_CheckpointRestoreCoordinator`s for the same Python objects, used to 237 look up restore ops by name to avoid re-creating them across multiple 238 `restore()` calls. 239 graph_view: A graph_view_lib.ObjectGraphView object for the restored 240 objects. 241 options: A CheckpointOptions object. 242 saveables_cache: An optional cache storing previously created 243 SaveableObjects created for each Trackable. Maps Trackables to a 244 dictionary of attribute names to Trackable. 245 """ 246 self.options = options 247 self.object_graph_proto = object_graph_proto 248 self.restore_uid = ops.uid() 249 # Maps from proto ids to lists of attributes which were in the checkpoint 250 # but not loaded into any object, for error checking. 251 self.unused_attributes = {} 252 # Dictionary mapping from an id in the protocol buffer flat array to 253 # Trackable Python objects. This mapping may be deferred if a 254 # checkpoint is restored before all dependencies have been tracked. Uses 255 # weak references so that partial restorations don't create reference cycles 256 # (as objects with deferred dependencies will generally have references to 257 # this object). 258 self.object_by_proto_id = weakref.WeakValueDictionary() 259 self.matched_proto_ids = set() 260 # A set of all Python objects we've seen as dependencies, even if we didn't 261 # use them (for example because of inconsistent references when 262 # loading). Used to make status assertions fail when loading checkpoints 263 # that don't quite match. 264 self.all_python_objects = object_identity.ObjectIdentityWeakSet() 265 self.save_path_tensor = save_path_tensor 266 self.save_path_string = save_path 267 self.reader = reader 268 if self.reader is None: 269 self.reader = py_checkpoint_reader.NewCheckpointReader(save_path) 270 self.dtype_map = reader.get_variable_to_dtype_map() 271 self.shape_map = reader.get_variable_to_shape_map() 272 # A NewCheckpointReader for the most recent checkpoint, for streaming Python 273 # state restoration. 274 # When graph building, contains a list of ops to run to restore objects from 275 # this checkpoint. 276 self.restore_ops = [] 277 self.restore_ops_by_name = restore_op_cache 278 self.graph_view = graph_view 279 self.new_restore_ops_callback = None 280 # A mapping from optimizer proto ids to lists of slot variables to be 281 # restored when the optimizer is tracked. Only includes slot variables whose 282 # regular variables have already been created, and only for optimizer 283 # objects which have not yet been created/tracked. 284 self.deferred_slot_restorations = {} 285 # A mapping from variable proto ids to lists of slot variables to be 286 # restored when the variable is created/tracked. These get shifted over to 287 # deferred_slot_restorations if the optimizer hasn't been created when that 288 # happens. 289 self.slot_restorations = {} 290 # Controls whether errors are printed in __del__ if some objects did not 291 # match. 292 self.expect_partial_attr = False 293 for node_index, node in enumerate(self.object_graph_proto.nodes): 294 for slot_reference in node.slot_variables: 295 # `node` refers to an `Optimizer`, since only these have slot variables. 296 self.slot_restorations.setdefault( 297 slot_reference.original_variable_node_id, []).append( 298 base._SlotVariableRestoration( # pylint: disable=protected-access 299 optimizer_id=node_index, 300 slot_variable_id=slot_reference.slot_variable_node_id, 301 slot_name=slot_reference.slot_name)) 302 303 self._deleter = _CheckpointRestoreCoordinatorDeleter( 304 self.expect_partial_attr, 305 self.object_graph_proto, 306 self.matched_proto_ids, 307 self.unused_attributes) 308 309 self.saveables_cache = saveables_cache 310 311 @property 312 def expect_partial(self): 313 return self.expect_partial_attr 314 315 @expect_partial.setter 316 def expect_partial(self, expect_partial): 317 self.expect_partial_attr = expect_partial 318 self._deleter.set_expect_partial(expect_partial) 319 320 def new_restore_ops(self, new_ops): 321 self.restore_ops.extend(new_ops) 322 if self.new_restore_ops_callback: 323 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable 324 325 def restore_saveables(self, 326 tensor_saveables, 327 python_positions, 328 registered_savers=None): 329 """Run or build restore operations for SaveableObjects. 330 331 Args: 332 tensor_saveables: `SaveableObject`s which correspond to Tensors. 333 python_positions: List of CheckpointPositions bound to `PythonState` 334 objects which must be restored eagerly. 335 registered_savers: a dict mapping saver names-> object name -> Trackable. 336 337 Returns: 338 When graph building, a list of restore operations, either cached or newly 339 created, to restore `tensor_saveables`. 340 """ 341 restore_ops = [] 342 # Eagerly run restorations for Python state. 343 for position in python_positions: 344 key = position.object_proto.attributes[0].checkpoint_key 345 position.trackable.deserialize(self.reader.get_tensor(key)) 346 347 # If we have new SaveableObjects, extract and cache restore ops. 348 if tensor_saveables or registered_savers: 349 flat_saveables = saveable_object_util.validate_and_slice_inputs( 350 tensor_saveables) 351 new_restore_ops = functional_saver.MultiDeviceSaver( 352 flat_saveables, 353 registered_savers).restore(self.save_path_tensor, self.options) 354 if not context.executing_eagerly(): 355 for name, restore_op in sorted(new_restore_ops.items()): 356 restore_ops.append(restore_op) 357 assert name not in self.restore_ops_by_name 358 self.restore_ops_by_name[name] = restore_op 359 return restore_ops 360 361 362class _NameBasedRestoreCoordinator: 363 """Keeps the status of a name-based checkpoint restore.""" 364 365 def __init__(self, save_path, dtype_map=None): 366 self.save_path = save_path 367 self.dtype_map = dtype_map 368 # A map from trackable objects to unused attribute names. We don't have 369 # proto IDs when doing a name-based restore, so the map keys differ from 370 # those in _CheckpointRestoreCoordinator. 371 self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary() 372 self.restore_uid = ops.uid() 373 374 def globally_named_object_attributes(self, trackable): 375 """Create globally named SaveableObjects from attributes. 376 377 If an object's attribute has no global name specified (default construction 378 for the SaveableObject factory), records the failure in 379 `self.unused_attributes` (which can then be used to make status assertions 380 fail; see `NameBasedSaverStatus`). 381 382 Args: 383 trackable: An object to save. 384 385 Yields: 386 SaveableObjects for `trackable`'s attributes. 387 """ 388 for attribute_name, saveable_factory in ( 389 trackable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 390 if callable(saveable_factory): 391 try: 392 # This saveable object factory does not have a default name= argument, 393 # which means there's no way to save/restore it using a name-based 394 # checkpoint. Ignore the error now and make sure assert_consumed() 395 # fails. 396 saveable = saveable_factory() 397 except TypeError: 398 self.unused_attributes.setdefault(trackable, 399 []).append(attribute_name) 400 continue 401 else: 402 saveable = saveable_factory 403 names_to_saveables = saveable_object_util.op_list_to_dict( 404 [saveable], convert_variable_to_tensor=False) 405 for name, op in names_to_saveables.items(): 406 for saveable_object in saveable_object_util.saveable_objects_for_op( 407 op=op, name=name): 408 yield saveable_object 409 410 def eager_restore(self, trackable): 411 """Runs restore ops for `trackable`'s attributes.""" 412 # When graph building, we don't add any restore ops to the graph until 413 # run_restore_ops/initialize_or_restore on the status object for name-based 414 # checkpoints. 415 assert context.executing_eagerly() 416 for saveable in self.globally_named_object_attributes(trackable): 417 restored_tensors = [] 418 tensor_missing = False 419 for spec in saveable.specs: 420 if spec.name in self.dtype_map: 421 with ops.device("cpu:0"): 422 restored, = io_ops.restore_v2( 423 prefix=self.save_path, 424 tensor_names=[spec.name], 425 shape_and_slices=[""], 426 dtypes=[self.dtype_map[spec.name]], 427 name="%s_checkpoint_read" % (spec.name,)) 428 restored_tensors.append(array_ops.identity(restored)) 429 else: 430 tensor_missing = True 431 432 if tensor_missing: 433 # Record that this variable didn't match so assertions will fail. 434 self.unused_attributes.setdefault(trackable, []).append(saveable.name) 435 else: 436 # Ignores values missing from the checkpoint, as with object-based 437 # restore. Status assertions can be used to check exact matches, 438 # although it's unlikely to ever happen for name-based checkpoints. 439 saveable.restore( 440 restored_tensors=restored_tensors, restored_shapes=None) 441 442 443# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change 444# or consolidating the implementation with get_variable. 445def _default_getter(name, 446 shape, 447 dtype, 448 initializer=None, 449 partition_info=None, 450 **kwargs): 451 """A pared-down version of get_variable which does not reuse variables.""" 452 dtype = dtypes.as_dtype(dtype) 453 shape_object = tensor_shape.as_shape(shape) 454 with ops.init_scope(): 455 if initializer is None: 456 initializer, initializing_from_value = ( 457 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access 458 name=name, 459 shape=shape_object, 460 dtype=dtype)) 461 else: 462 initializing_from_value = not callable(initializer) 463 # Same logic as get_variable 464 variable_dtype = dtype.base_dtype 465 if initializing_from_value: 466 if shape is not None: 467 raise ValueError("If initializer is a constant, do not specify shape.") 468 initial_value = initializer 469 else: 470 # Instantiate initializer if provided initializer is a type object. 471 if isinstance(initializer, type(init_ops.Initializer)): 472 initializer = initializer(dtype=dtype) 473 shape_list = None if shape is None else shape_object.as_list() 474 if "partition_info" in tf_inspect.getargspec(initializer).args: 475 initial_value = functools.partial(initializer, 476 shape_list, 477 dtype=dtype, 478 partition_info=partition_info) 479 else: 480 initial_value = functools.partial(initializer, 481 shape_list, 482 dtype=dtype) 483 484 return variables.VariableV1( 485 initial_value=initial_value, 486 name=name, 487 dtype=variable_dtype, 488 use_resource=True, 489 **kwargs) 490 491 492def add_variable(trackable, 493 name, 494 shape=None, 495 dtype=dtypes.float32, 496 initializer=None, 497 trainable=True): 498 """Add a variable to a Trackable with no scope influence.""" 499 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access 500 name=name, 501 shape=shape, 502 dtype=dtype, 503 initializer=initializer, 504 getter=_default_getter, 505 trainable=trainable) 506 507 508def object_metadata(save_path): 509 """Retrieves information about the objects in a checkpoint. 510 511 Example usage: 512 513 ```python 514 object_graph = tf.contrib.checkpoint.object_metadata( 515 tf.train.latest_checkpoint(checkpoint_directory)) 516 ckpt_variable_names = set() 517 for node in object_graph.nodes: 518 for attribute in node.attributes: 519 ckpt_variable_names.add(attribute.full_name) 520 ``` 521 522 Args: 523 save_path: The path to the checkpoint, as returned by `save` or 524 `tf.train.latest_checkpoint`. 525 526 Returns: 527 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. 528 Raises: 529 ValueError: If an object graph was not found in the checkpoint. 530 """ 531 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 532 try: 533 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 534 except errors_impl.NotFoundError: 535 raise ValueError( 536 f"The specified checkpoint \"{save_path}\" does not appear to be " 537 "object-based (saved with TF2) since it is missing the key " 538 f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the " 539 "TF1 name-based saver and does not contain an object dependency graph.") 540 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 541 object_graph_proto.ParseFromString(object_graph_string) 542 return object_graph_proto 543 544 545def list_objects(root_trackable): 546 """Traverse the object graph and list all accessible objects. 547 548 Looks for `Trackable` objects which are dependencies of 549 `root_trackable`. Includes slot variables only if the variable they are 550 slotting for and the optimizer are dependencies of `root_trackable` 551 (i.e. if they would be saved with a checkpoint). 552 553 Args: 554 root_trackable: A `Trackable` object whose dependencies should be flattened. 555 556 Returns: 557 A flat list of objects. 558 """ 559 return util.list_objects(graph_view_lib.ObjectGraphView(root_trackable)) 560 561 562def gather_initializers(root_trackable): 563 """Traverse the object graph and find initialization ops. 564 565 Looks for `Trackable` objects which are dependencies of 566 `root_trackable` and which have an `initializer` property. Includes 567 initializers for slot variables only if the variable they are slotting for and 568 the optimizer are dependencies of `root_trackable` (i.e. if they would be 569 saved with a checkpoint). 570 571 Args: 572 root_trackable: A `Trackable` object to gather initializers for. 573 574 Returns: 575 A list of initialization ops. 576 """ 577 trackable_objects = list_objects(root_trackable) 578 return [ 579 c.initializer 580 for c in trackable_objects 581 if hasattr(c, "initializer") and c.initializer is not None 582 ] 583 584 585@tf_contextlib.contextmanager 586def capture_dependencies(template): 587 """Capture variables created within this scope as `Template` dependencies. 588 589 Requires that `template.variable_scope` is active. 590 591 This scope is intended as a compatibility measure, allowing a trackable 592 object to add dependencies on variables created in a block of code which is 593 not aware of object-based saving (and instead uses variable names 594 heavily). This is how `Template` objects add dependencies on variables and 595 sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly. 596 597 Args: 598 template: The `Template` object to register dependencies with. 599 600 Yields: 601 None (when used as a context manager). 602 """ 603 name_prefix = template.variable_scope.name 604 605 def _trackable_custom_creator(next_creator, 606 name, 607 initial_value, 608 trackable_parent=None, 609 **kwargs): 610 """A variable creation hook which adds Trackable dependencies. 611 612 Set for example during a `Template`'s first wrapped function 613 execution. Ensures that (a) `template` depends on any trackable 614 objects using their own `capture_dependencies` scope inside this scope which 615 create variables, and (b) that any variables not in a more deeply nested 616 scope are added as dependencies directly. 617 618 The `trackable_parent` argument is passed between custom creators but 619 ignored when the variable object itself is created. This argument indicates 620 (if not `None`) that a more deeply nested scope has already added the 621 variable as a dependency, and that parent scopes should add a dependency on 622 that object rather than on the variable directly. 623 624 Args: 625 next_creator: See `variable_scope.variable_creator_scope`; the next 626 creator in the chain. 627 name: The (full, scope-influenced) name of the variable. The `name_prefix` 628 itself is stripped for the purposes of object-based dependency tracking, 629 but scopes opened within this scope are respected. 630 initial_value: See `variable_scope.variable_creator_scope`. Taken 631 explicitly so the argument can be re-named and used with 632 `Trackable._add_variable_with_custom_getter`. 633 trackable_parent: If not None, a more deeply nested trackable object and 634 its name prefix which were passed to `capture_dependencies` to add a 635 dependency on (rather than depending on the variable directly). 636 **kwargs: Passed through to the next creator. 637 638 Returns: 639 The output of `next_creator`: the fetched/created variable object. 640 """ 641 642 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): 643 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which 644 # we don't want to propagate. 645 return next_creator(initial_value=initializer, name=name, **inner_kwargs) 646 647 if name is not None and name.startswith(name_prefix): 648 scope_stripped_name = name[len(name_prefix) + 1:] 649 if not trackable_parent: 650 return template._add_variable_with_custom_getter( # pylint: disable=protected-access 651 initializer=initial_value, 652 name=scope_stripped_name, 653 getter=_call_next_creator_renaming_initializer, 654 # Disable error checking for Trackable. Exceptions are instead 655 # raised if necessary when the object-based saver tries to 656 # save/restore the object. 657 overwrite=True, 658 trackable_parent=(template, name_prefix), 659 **kwargs) 660 else: 661 parent_object, parent_name_prefix = trackable_parent 662 template._track_trackable( # pylint: disable=protected-access 663 parent_object, 664 name=parent_name_prefix[len(name_prefix) + 1:], 665 overwrite=True) 666 return next_creator( 667 name=name, 668 initial_value=initial_value, 669 trackable_parent=(template, name_prefix), 670 **kwargs) 671 672 with variable_scope.variable_creator_scope(_trackable_custom_creator): 673 yield 674 675 676class _LoadStatus: 677 """Abstract base for load status callbacks.""" 678 679 @abc.abstractmethod 680 def assert_consumed(self): 681 """Raises an exception unless a non-trivial restoration has completed.""" 682 pass 683 684 @abc.abstractmethod 685 def assert_existing_objects_matched(self): 686 """Raises an exception unless existing Python objects have been matched.""" 687 pass 688 689 @abc.abstractmethod 690 def assert_nontrivial_match(self): 691 """Raises an exception if only the root object matched.""" 692 pass 693 694 @abc.abstractmethod 695 def run_restore_ops(self, session=None): 696 """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" 697 pass 698 699 @abc.abstractmethod 700 def initialize_or_restore(self, session=None): 701 """Runs restore ops from the checkpoint, or initializes variables.""" 702 pass 703 704 def expect_partial(self): 705 """Silence warnings about incomplete checkpoint restores.""" 706 return self 707 708 709@tf_export("__internal__.tracking.streaming_restore", v1=[]) 710def streaming_restore(status, session=None): 711 """When graph building, runs restore ops as soon as they come in. 712 713 Args: 714 status: A _LoadStatus objects from an object-based saver's restore(). 715 Streaming restore from name-based checkpoints is not currently supported. 716 session: A session to run new restore ops in. 717 """ 718 if context.executing_eagerly(): 719 # Streaming restore is the default/only behavior when executing eagerly. 720 return 721 if session is None: 722 session = get_session() 723 if isinstance(status, NameBasedSaverStatus): 724 raise NotImplementedError( 725 "Streaming restore not supported from name-based checkpoints when " 726 "graph building. File a feature request if this limitation bothers " 727 "you. As a workaround, consider either using tf.train.Checkpoint to " 728 "load name-based checkpoints or enabling eager execution.") 729 status.run_restore_ops(session=session) 730 # pylint: disable=protected-access 731 status._checkpoint.new_restore_ops_callback = ( 732 lambda ops: session.run(ops, feed_dict=status._feed_dict)) 733 # pylint: enable=protected-access 734 735 736def _objects_with_attributes(full_list): 737 """Filters out objects with no direct variable dependencies for assertions.""" 738 return [o for o in full_list if o._gather_saveables_for_checkpoint()] # pylint: disable=protected-access 739 740 741class CheckpointLoadStatus(_LoadStatus): 742 """Checks the status of checkpoint loading and manages restore ops. 743 744 Returned from `Saver.restore`. Since `restore` may defer the loading of values 745 in the checkpoint which don't yet have corresponding Python objects, 746 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading 747 is complete (`assert_consumed`). 748 749 When graph building, `restore` does not run restore ops itself since their 750 creation may be deferred. The `run_restore_ops` method must be called once all 751 Python objects with values to restore have been created and added to the 752 dependency graph (this does not necessarily have to be the whole checkpoint; 753 calling `run_restore_ops` while `assert_consumed` fails is supported and will 754 partially restore the checkpoint). 755 756 See `Saver.restore` for usage examples. 757 """ 758 759 def __init__(self, checkpoint, feed_dict, graph_view): 760 self._checkpoint = checkpoint 761 self._feed_dict = feed_dict 762 self._object_graph_view = graph_view 763 # Keep a reference to the root, since object_graph_view might only have a 764 # weakref. 765 self._root = graph_view.root 766 767 def assert_consumed(self): 768 """Asserts that all objects in the checkpoint have been created/matched. 769 770 Returns: 771 `self` for chaining. 772 Raises: 773 AssertionError: If there are any Python objects in the dependency graph 774 which have not been restored from this checkpoint or a later `restore`, 775 or if there are any checkpointed values which have not been matched to 776 Python objects. 777 """ 778 pretty_printer = ObjectGraphProtoPrettyPrinter( 779 self._checkpoint.object_graph_proto) 780 self.assert_existing_objects_matched() 781 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 782 if not node.attributes: 783 # Only raise exceptions for the nodes with attributes themselves. Either 784 # they're ultimately not important, or they have a child with an 785 # attribute. 786 continue 787 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 788 if trackable is None: 789 raise AssertionError( 790 "Unresolved object in checkpoint " 791 f"{pretty_printer.node_names[node_id]}: {node}") 792 if self._checkpoint.slot_restorations: 793 # Sanity check; this collection should be clear if everything has been 794 # restored. 795 raise AssertionError( 796 f"Unresolved slot restorations: {self._checkpoint.slot_restorations}") 797 if self._checkpoint.unused_attributes: 798 unused_attribute_messages = [] 799 for node_id, attribute in self._checkpoint.unused_attributes.items(): 800 obj = self._checkpoint.object_by_proto_id[node_id] 801 unused_attribute_messages.append( 802 f"{pretty_printer.node_names[node_id]} ({obj}): {attribute}") 803 joined_attribute_messages = "\n".join(unused_attribute_messages) 804 raise AssertionError( 805 "Unused attributes in these objects (the attributes exist in the " 806 f"checkpoint but were not restored):\n{joined_attribute_messages}") 807 return self 808 809 def assert_existing_objects_matched(self): 810 """Asserts that trackable Python objects have been matched. 811 812 Note that this is a weaker assertion than `assert_consumed`. It will only 813 fail for existing Python objects which are (transitive) dependencies of the 814 root object and which do not have an entry in the checkpoint. 815 816 It will not fail, for example, if a `tf.keras.Layer` object has not yet been 817 built and so has not created any `tf.Variable` objects. 818 819 Returns: 820 `self` for chaining. 821 822 Raises: 823 AssertionError: If a Python object exists in the transitive dependencies 824 of the root object but does not have a value in the checkpoint. 825 """ 826 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 827 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 828 if (trackable is not None and 829 trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access 830 raise AssertionError( 831 f"Object {node} not assigned a value from checkpoint.") 832 for trackable_object in util.list_objects(self._object_graph_view): 833 # Remove data structures that do not contain any variables from 834 # restoration checks. 835 if (isinstance(trackable_object, 836 data_structures.TrackableDataStructure) and 837 not trackable_object._trackable_children()): # pylint: disable=protected-access 838 continue 839 self._checkpoint.all_python_objects.add(trackable_object) 840 unused_python_objects = ( 841 object_identity.ObjectIdentitySet( 842 _objects_with_attributes( 843 self._checkpoint.all_python_objects)) - 844 object_identity.ObjectIdentitySet( 845 self._checkpoint.object_by_proto_id.values())) 846 if unused_python_objects: 847 num_unused_python_objects = len(list(unused_python_objects)) 848 # Display max number of 10 variables in error message. 849 num_variables_to_show = min(10, num_unused_python_objects) 850 raise AssertionError( 851 f"Found {num_unused_python_objects} Python objects that were " 852 "not bound to checkpointed values, likely due to changes in the " 853 f"Python program. Showing {num_variables_to_show} of " 854 f"{num_unused_python_objects} unmatched objects: " 855 f"{list(unused_python_objects)[:num_variables_to_show]}") 856 return self 857 858 def assert_nontrivial_match(self): 859 """Raises an exception if only the root object matched.""" 860 for trackable_object in util.list_objects(self._object_graph_view): 861 self._checkpoint.all_python_objects.add(trackable_object) 862 if len(self._checkpoint.object_by_proto_id) <= 1: 863 unused_python_objects = ( 864 object_identity.ObjectIdentitySet( 865 _objects_with_attributes(self._checkpoint.all_python_objects)) - 866 object_identity.ObjectIdentitySet( 867 self._checkpoint.object_by_proto_id.values())) 868 if unused_python_objects: 869 raise AssertionError( 870 "Nothing except the root object matched a checkpointed value. " 871 "Typically this means that the checkpoint does not match the " 872 "Python program. The following objects have no matching " 873 f"checkpointed value: {list(unused_python_objects)}") 874 else: 875 raise AssertionError( 876 "Nothing to load. No dependencies have been added to " 877 f"{self._object_graph_view.root} yet.") 878 return self 879 880 def run_restore_ops(self, session=None): 881 """Run operations to restore objects in the dependency graph.""" 882 if context.executing_eagerly(): 883 return # Run eagerly 884 if session is None: 885 session = get_session() 886 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) 887 888 def initialize_or_restore(self, session=None): 889 """Run operations to initialize or restore objects in the dependency graph. 890 891 Any objects in the dependency graph which have initializers but are not in 892 the checkpoint will have those initializers run, unless those variables are 893 being restored by a later call to `tf.train.Checkpoint.restore()`. 894 895 This method has a sibling in `InitializationOnlyStatus` which instead 896 initializes variables. That type is returned if no checkpoint is specified 897 in `Saver.restore`. 898 899 Args: 900 session: The session to run init/restore ops in. If `None`, uses the 901 default session. 902 """ 903 if context.executing_eagerly(): 904 return # Initialization and restoration ops are run eagerly 905 if session is None: 906 session = get_session() 907 all_objects = util.list_objects(self._object_graph_view) 908 already_initialized_objects = object_identity.ObjectIdentitySet( 909 self._checkpoint.object_by_proto_id.values()) 910 initializers_for_non_restored_variables = [ 911 c.initializer for c in all_objects 912 if hasattr(c, "initializer") 913 and c not in already_initialized_objects 914 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1) 915 < self._checkpoint.restore_uid) 916 ] 917 self.run_restore_ops(session=session) 918 session.run(initializers_for_non_restored_variables) 919 920 def expect_partial(self): 921 """Silence warnings about incomplete checkpoint restores.""" 922 self._checkpoint.expect_partial = True 923 return self 924 925 926class InitializationOnlyStatus(_LoadStatus): 927 """Returned from `Saver.restore` when no checkpoint has been specified. 928 929 Objects of this type have the same `assert_consumed` method as 930 `CheckpointLoadStatus`, but it always fails. However, 931 `initialize_or_restore` works on objects of both types, and will 932 initialize variables in `InitializationOnlyStatus` objects or restore them 933 otherwise. 934 """ 935 936 def __init__(self, object_graph_view, restore_uid): 937 self._restore_uid = restore_uid 938 self._object_graph_view = object_graph_view 939 # Keep a reference to the root, since graph_view might only have a weakref. 940 self._root = object_graph_view.root 941 942 def assert_consumed(self): 943 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 944 raise AssertionError( 945 "No checkpoint specified (save_path=None); nothing is being restored.") 946 947 def assert_existing_objects_matched(self): 948 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 949 raise AssertionError( 950 "No checkpoint specified (save_path=None); nothing is being restored.") 951 952 def assert_nontrivial_match(self): 953 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 954 raise AssertionError( 955 "No checkpoint specified (save_path=None); nothing is being restored.") 956 957 def run_restore_ops(self, session=None): 958 """For consistency with `CheckpointLoadStatus`. 959 960 Use `initialize_or_restore` for initializing if no checkpoint was passed 961 to `Saver.restore` and restoring otherwise. 962 963 Args: 964 session: Not used. 965 """ 966 raise AssertionError( 967 "No checkpoint specified, so no restore ops are available " 968 "(save_path=None to Saver.restore).") 969 970 def initialize_or_restore(self, session=None): 971 """Runs initialization ops for variables. 972 973 Objects which would be saved by `Saver.save` will be initialized, unless 974 those variables are being restored by a later call to 975 `tf.train.Checkpoint.restore()`. 976 977 This method does nothing when executing eagerly (initializers get run 978 eagerly). 979 980 Args: 981 session: The session to run initialization ops in. If `None`, uses the 982 default session. 983 """ 984 if context.executing_eagerly(): 985 return # run eagerly 986 if session is None: 987 session = get_session() 988 trackable_objects = util.list_objects(self._object_graph_view) 989 initializers = [ 990 c.initializer for c in trackable_objects 991 if hasattr(c, "initializer") and c.initializer is not None 992 and (getattr(c, "_update_uid", self._restore_uid - 1) 993 < self._restore_uid) 994 ] 995 session.run(initializers) 996 997 998_DEPRECATED_RESTORE_INSTRUCTIONS = ( 999 "Restoring a name-based tf.train.Saver checkpoint using the object-based " 1000 "restore API. This mode uses global names to match variables, and so is " 1001 "somewhat fragile. It also adds new restore ops to the graph each time it " 1002 "is called when graph building. Prefer re-encoding training checkpoints in " 1003 "the object-based format: run save() on the object-based saver (the same " 1004 "one this message is coming from) and use that checkpoint in the future.") 1005 1006 1007class NameBasedSaverStatus(_LoadStatus): 1008 """Status for loading a name-based training checkpoint.""" 1009 1010 # Ideally this deprecation decorator would be on the class, but that 1011 # interferes with isinstance checks. 1012 @deprecation.deprecated( 1013 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) 1014 def __init__(self, checkpoint, object_graph_view): 1015 self._checkpoint = checkpoint 1016 self._object_graph_view = object_graph_view 1017 self._optionally_restored = [] 1018 # Keep a reference to the root, since graph_view might only have a weakref. 1019 self._root = object_graph_view.root 1020 1021 def add_to_optionally_restored(self, var): 1022 """Add a variable to the list of optionally restored variables. 1023 1024 There are situations where certain variables should be ignored in assertions 1025 such as assert_existing_objects_matched(). One example is that of a 1026 checkpoint saved with train.Saver(), and restored with train.Checkpoint(): 1027 it is possible for the train.Saver() checkpoint to be missing the internal 1028 `save_counter` variable, which we want to ignore on restore. 1029 1030 Args: 1031 var: The variable to treat as optionally restored. 1032 """ 1033 self._optionally_restored.append(var) 1034 1035 def assert_consumed(self): 1036 """Raises an exception if any variables are unmatched.""" 1037 unused_attributes = list(self._checkpoint.unused_attributes.items()) 1038 unused_attributes = [ 1039 a for a in unused_attributes 1040 if all(a[0] is not x for x in self._optionally_restored) 1041 ] 1042 if unused_attributes: 1043 unused_attribute_strings = [ 1044 f"\n {obj}: {attributes}" for obj, attributes in unused_attributes] 1045 raise AssertionError( 1046 "Some objects had attributes which were not restored: " 1047 f"{unused_attribute_strings}") 1048 for trackable in util.list_objects(self._object_graph_view): 1049 # pylint: disable=protected-access 1050 trackable._maybe_initialize_trackable() 1051 if trackable._update_uid < self._checkpoint.restore_uid: 1052 raise AssertionError(f"Object not restored: {trackable}") 1053 # pylint: enable=protected-access 1054 return self 1055 1056 def assert_existing_objects_matched(self): 1057 """Raises an exception if currently created objects are unmatched.""" 1058 # For name-based checkpoints there's no object information in the 1059 # checkpoint, so there's no distinction between 1060 # assert_existing_objects_matched and assert_consumed (and both are less 1061 # useful since we don't touch Python objects or Python state). 1062 return self.assert_consumed() 1063 1064 def assert_nontrivial_match(self): 1065 """Raises an exception if currently created objects are unmatched.""" 1066 # For name-based checkpoints there's no object information in the 1067 # checkpoint, so there's no distinction between 1068 # assert_nontrivial_match and assert_consumed (and both are less 1069 # useful since we don't touch Python objects or Python state). 1070 return self.assert_consumed() 1071 1072 def _gather_saveable_objects(self): 1073 """Walk the object graph, using global names for SaveableObjects.""" 1074 objects = util.list_objects(self._object_graph_view) 1075 saveable_objects = [] 1076 for trackable in objects: 1077 # pylint: disable=protected-access 1078 trackable._maybe_initialize_trackable() 1079 if trackable._update_uid < self._checkpoint.restore_uid: 1080 trackable._update_uid = self._checkpoint.restore_uid 1081 else: 1082 continue 1083 # pylint: enable=protected-access 1084 saveable_objects.extend( 1085 self._checkpoint.globally_named_object_attributes(trackable)) 1086 return saveable_objects 1087 1088 def run_restore_ops(self, session=None): 1089 """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`.""" 1090 if context.executing_eagerly(): 1091 return # Nothing to do, variables are restored on creation. 1092 if session is None: 1093 session = get_session() 1094 with ops.device("/cpu:0"): 1095 saveables = self._gather_saveable_objects() 1096 v1_saver_lib.Saver(saveables).restore( 1097 sess=session, save_path=self._checkpoint.save_path) 1098 1099 def initialize_or_restore(self, session=None): 1100 """Alias for `run_restore_ops`.""" 1101 self.run_restore_ops(session=session) 1102 1103 1104class _SessionWithFeedDictAdditions(session_lib.SessionInterface): 1105 """Pretends to be a session, inserts extra feeds on run().""" 1106 1107 def __init__(self, session, feed_additions): 1108 self._wrapped_session = session 1109 self._feed_additions = feed_additions 1110 1111 def run(self, fetches, feed_dict=None, **kwargs): 1112 if feed_dict is None: 1113 feed_dict = {} 1114 else: 1115 feed_dict = feed_dict.copy() 1116 feed_dict.update(self._feed_additions) 1117 return self._wrapped_session.run( 1118 fetches=fetches, feed_dict=feed_dict, **kwargs) 1119 1120 1121class TrackableSaver: 1122 """Saves and restores a `Trackable` object and its dependencies. 1123 1124 See `Trackable` for details of dependency management. `Saver` wraps 1125 `tf.compat.v1.train.Saver` for saving, including extra information about the 1126 graph of 1127 dependencies between Python objects. When restoring, it uses this information 1128 about the save-time dependency graph to more robustly match objects with their 1129 checkpointed values. When executing eagerly, it supports restoring variables 1130 on object creation (see `Saver.restore`). 1131 1132 Values in a checkpoint are mapped to `Trackable` Python objects 1133 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the 1134 checkpoint was written. To avoid breaking existing checkpoints when modifying 1135 a class, dependency names (the names of attributes to which `Trackable` 1136 objects are assigned) may not change. These names are local to objects, in 1137 contrast to the `Variable.name`-based save/restore from 1138 `tf.compat.v1.train.Saver`, and 1139 so allow additional program transformations. 1140 """ 1141 1142 def __init__(self, graph_view): 1143 """Configure saving. 1144 1145 Args: 1146 graph_view: An `ObjectGraphView` object containing a description of the 1147 object graph to save. 1148 """ 1149 self._graph_view = graph_view 1150 1151 # The following attributes are used when graph building. 1152 1153 # Saveables caching: A dictionary mapping `Trackable` objects -> 1154 # attribute names -> SaveableObjects, used to avoid re-creating 1155 # SaveableObjects when graph building. 1156 if context.executing_eagerly(): 1157 self._saveables_cache = None 1158 else: 1159 self._saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 1160 1161 # The file prefix placeholder is created lazily when graph building (and not 1162 # at all when executing eagerly) to avoid creating ops in the constructor 1163 # (when they may never be necessary). 1164 self._file_prefix_placeholder = None 1165 1166 # Op caching for save 1167 self._object_graph_feed_tensor = None 1168 self._last_save_object_graph = None 1169 self._file_prefix_feed_tensor = None 1170 self._cached_save_operation = None 1171 1172 # Op caching for restore, shared between _CheckpointRestoreCoordinators 1173 self._restore_op_cache = {} 1174 1175 # Object map used for checkpoint. This attribute is to be overridden by a 1176 # Checkpoint subclass, e.g., AsyncCheckpoint, to replace the trackable 1177 # objects for checkpoint saving. 1178 self._object_map = None 1179 1180 def _gather_saveables(self, object_graph_tensor=None): 1181 """Wraps _serialize_object_graph to include the object graph proto.""" 1182 named_saveable_objects, graph_proto, feed_additions, registered_savers = ( 1183 save_util_v1.serialize_gathered_objects( 1184 graph_view=self._graph_view, 1185 object_map=self._object_map, 1186 saveables_cache=self._saveables_cache)) 1187 1188 if object_graph_tensor is None: 1189 with ops.device("/cpu:0"): 1190 object_graph_tensor = constant_op.constant( 1191 graph_proto.SerializeToString(), dtype=dtypes.string) 1192 else: 1193 feed_additions.update( 1194 {object_graph_tensor: graph_proto.SerializeToString()}) 1195 assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects 1196 named_saveable_objects.append( 1197 base.NoRestoreSaveable( 1198 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) 1199 return (named_saveable_objects, graph_proto, feed_additions, 1200 registered_savers) 1201 1202 def _save_cached_when_graph_building(self, 1203 file_prefix, 1204 object_graph_tensor, 1205 options, 1206 write_done_callback=None): 1207 """Create or retrieve save ops. 1208 1209 Args: 1210 file_prefix: The prefix for saved checkpoint files. 1211 object_graph_tensor: A `Tensor` to which the current object graph will be 1212 fed. 1213 options: `CheckpointOptions` object. 1214 write_done_callback: Optional callback function to be executed once 1215 the underlying checkpoint saving is finished. Example usage includes 1216 updating the checkpoint internal state. 1217 1218 Returns: 1219 A two-element tuple with a filename tensor and a feed_dict of tensors to 1220 feed when running it (if graph building). The feed dict contains the 1221 current object graph and any Python state to be saved in the 1222 checkpoint. When executing eagerly only the first argument is meaningful. 1223 """ 1224 (named_saveable_objects, graph_proto, feed_additions, 1225 registered_savers) = self._gather_saveables( 1226 object_graph_tensor=object_graph_tensor) 1227 1228 def _run_save(): 1229 """Create and execute the SaveOp for the checkpoint.""" 1230 if (self._last_save_object_graph != graph_proto 1231 # When executing eagerly, we need to re-create SaveableObjects each 1232 # time save() is called so they pick up new Tensors passed to their 1233 # constructors. That means the Saver needs to be copied with a new 1234 # var_list. 1235 or context.executing_eagerly() or ops.inside_function()): 1236 saver = functional_saver.MultiDeviceSaver(named_saveable_objects, 1237 registered_savers) 1238 save_op = saver.save(file_prefix, options=options) 1239 with ops.device("/cpu:0"): 1240 with ops.control_dependencies([save_op]): 1241 self._cached_save_operation = array_ops.identity(file_prefix) 1242 self._last_save_object_graph = graph_proto 1243 return self._cached_save_operation, feed_additions 1244 1245 def _copy_tensors(): 1246 """Copy the tensors to the host CPU device.""" 1247 for saveable in named_saveable_objects: 1248 # Pin the device according to the SaveableObject's device location to 1249 # avoid unnecessary data copies when reading the variables. This is 1250 # aligned with the behavior in MultiDeviceSaver.save(). 1251 original_device = saveable.device 1252 with ops.device(original_device): 1253 for spec in saveable.specs: 1254 tensor = spec.tensor 1255 device = spec.device 1256 if tensor is not None: 1257 with ops.device(saveable_object_util.set_cpu0(device)): 1258 spec._tensor = array_ops.identity(tensor) # pylint: disable=protected-access 1259 # Modify the device info accordingly now that the tensors are 1260 # copied to the host CPU device. 1261 spec.device = saveable_object_util.set_cpu0(device) 1262 1263 def _async_save_fn(): 1264 """The thread function for executing async checkpoint save.""" 1265 with context.executor_scope( 1266 executor.new_executor( 1267 enable_async=False, enable_streaming_enqueue=False)): 1268 _run_save() 1269 # Update the internal checkpoint state if the checkpoint event is 1270 # triggered from Checkpoint.save(). 1271 if write_done_callback: 1272 write_done_callback(_convert_file_name_tensor_to_string(file_prefix)) 1273 1274 if options.experimental_enable_async_checkpoint: 1275 # Execute async-checkpoint. 1276 1277 # Step-1: Explicitly copy the tensors to their host CPU device. 1278 _copy_tensors() 1279 1280 # Step-2: Execute the rest of the checkpoint operations on the host device 1281 # using an async executor. 1282 global _ASYNC_CHECKPOINT_THREAD 1283 if _ASYNC_CHECKPOINT_THREAD is not None: 1284 _ASYNC_CHECKPOINT_THREAD.join() 1285 _ASYNC_CHECKPOINT_THREAD = threading.Thread(target=_async_save_fn) 1286 _ASYNC_CHECKPOINT_THREAD.start() 1287 1288 # Step-3: Return the expected checkpoint file path though the save op may 1289 # not have finished. 1290 self._cached_save_operation = file_prefix 1291 return self._cached_save_operation, feed_additions 1292 1293 # Execute the normal checkpoint, i.e., synchronous. 1294 return _run_save() 1295 1296 def save(self, file_prefix, checkpoint_number=None, session=None, 1297 options=None, write_done_callback=None): 1298 """Save a training checkpoint. 1299 1300 The saved checkpoint includes variables created by this object and any 1301 Trackable objects it depends on at the time `Saver.save()` is called. 1302 1303 Args: 1304 file_prefix: A prefix to use for the checkpoint filenames 1305 (/path/to/directory/and_a_prefix). Names are generated based on this 1306 prefix and `checkpoint_number`, if provided. 1307 checkpoint_number: An integer variable or Tensor, used to number 1308 checkpoints. Typically this value is saved along with other variables in 1309 training checkpoints, which will happen automatically if it was created 1310 by `root_trackable` or one of its dependencies (via 1311 `Trackable._add_variable`). 1312 session: The session to evaluate variables in. Ignored when executing 1313 eagerly. If not provided when graph building, the default session is 1314 used. 1315 options: Optional `tf.train.CheckpointOptions` object. 1316 write_done_callback: Optional callback function to be executed once 1317 the underlying checkpoint saving is finished. Example usage includes 1318 updating the checkpoint internal state. 1319 1320 Returns: 1321 The full path to the checkpoint. 1322 """ 1323 options = options or checkpoint_options.CheckpointOptions() 1324 feed_dict = {} 1325 use_session = (not context.executing_eagerly() and 1326 not ops.inside_function()) 1327 if checkpoint_number: 1328 file_prefix = "%s-%d" % (file_prefix, checkpoint_number) 1329 if use_session: 1330 if self._object_graph_feed_tensor is None: 1331 with ops.device("/cpu:0"): 1332 self._object_graph_feed_tensor = constant_op.constant( 1333 "", dtype=dtypes.string) 1334 self._file_prefix_feed_tensor = constant_op.constant( 1335 "", dtype=dtypes.string) 1336 object_graph_tensor = self._object_graph_feed_tensor 1337 file_prefix_tensor = self._file_prefix_feed_tensor 1338 feed_dict[file_prefix_tensor] = file_prefix 1339 else: 1340 with ops.device("/cpu:0"): 1341 file_prefix_tensor = ops.convert_to_tensor( 1342 file_prefix, dtype=dtypes.string) 1343 object_graph_tensor = None 1344 1345 if not tensor_util.is_tensor(file_prefix): 1346 file_io.recursive_create_dir(os.path.dirname(file_prefix)) 1347 1348 save_path, new_feed_additions = self._save_cached_when_graph_building( 1349 file_prefix_tensor, object_graph_tensor, options, write_done_callback) 1350 1351 if new_feed_additions: 1352 feed_dict.update(new_feed_additions) 1353 if not use_session: 1354 session = None 1355 elif session is None: 1356 session = get_session() 1357 1358 if session: 1359 return session.run(save_path, feed_dict=feed_dict) 1360 else: 1361 return save_path 1362 1363 def restore(self, save_path, options=None): 1364 """Restore a training checkpoint. 1365 1366 Restores `root_trackable` and any objects that it tracks 1367 (transitive). Either assigns values immediately if variables to restore have 1368 been created already, or defers restoration until the variables are 1369 created. Dependencies added to the `root_trackable` passed to the 1370 constructor after this call will be matched if they have a corresponding 1371 object in the checkpoint. 1372 1373 When building a graph, restorations are added to the graph but not run. 1374 1375 ```python 1376 saver = Saver(root) 1377 saver.restore(path) 1378 ``` 1379 1380 To ensure that loading is complete and no more deferred restorations will 1381 take place, you can use the `assert_consumed()` method of the status object 1382 returned by the `restore` call. 1383 1384 The assert will raise an exception unless every object was matched and all 1385 checkpointed values have a matching variable object. 1386 1387 ```python 1388 saver = Saver(root) 1389 saver.restore(path).assert_consumed() 1390 ``` 1391 1392 When graph building, `assert_consumed()` indicates that all of the restore 1393 ops which will be created for this checkpoint have been created. They can be 1394 run via the `run_restore_ops()` function of the status object: 1395 1396 ```python 1397 saver.restore(path).assert_consumed().run_restore_ops() 1398 ``` 1399 1400 If the checkpoint has not been consumed completely, then the list of restore 1401 ops will grow as more objects are added to the dependency graph. 1402 1403 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1404 method. There is no deferred loading, and names are used to match 1405 variables. No restore ops are created/run until `run_restore_ops()` or 1406 `initialize_or_restore()` are called on the returned status object, even 1407 when executing eagerly. Re-encode name-based checkpoints using this 1408 object-based `Saver.save` as soon as possible. 1409 1410 Args: 1411 save_path: The path to the checkpoint, as returned by `save` or 1412 `tf.train.latest_checkpoint`. If None (as when there is no latest 1413 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1414 object which may run initializers for objects in the dependency graph. 1415 If the checkpoint was written by the name-based 1416 `tf.compat.v1.train.Saver`, names are used to match variables. 1417 options: Optional `tf.train.CheckpointOptions` object. 1418 1419 Returns: 1420 A load status object, which can be used to make assertions about the 1421 status of checkpoint restoration and run initialization/restore ops 1422 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if 1423 `save_path` is `None`). 1424 1425 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` 1426 object is returned which runs restore ops from a name-based saver. 1427 1428 Raises: 1429 RuntimeError: When a checkpoint file saved by async checkpoint is not 1430 available upon restore(). 1431 """ 1432 options = options or checkpoint_options.CheckpointOptions() 1433 if save_path is None: 1434 return InitializationOnlyStatus(self._graph_view, ops.uid()) 1435 1436 # Wait until the ongoing checkpoint to finish. 1437 # TODO(chienchunh): Allow to load the file while other checkpoint events 1438 # are still ongiing. Need to add timeout mechanism along 1439 # with conditional variables to notify when the checkpoint 1440 # file is ready. 1441 global _ASYNC_CHECKPOINT_THREAD 1442 if _ASYNC_CHECKPOINT_THREAD is not None: 1443 _ASYNC_CHECKPOINT_THREAD.join() 1444 1445 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 1446 graph_building = not context.executing_eagerly() 1447 if graph_building: 1448 dtype_map = None 1449 else: 1450 dtype_map = reader.get_variable_to_dtype_map() 1451 try: 1452 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 1453 except errors_impl.NotFoundError: 1454 # The object graph proto does not exist in this checkpoint. Try the 1455 # name-based compatibility mode. 1456 restore_coordinator = _NameBasedRestoreCoordinator( 1457 save_path=save_path, 1458 dtype_map=dtype_map) 1459 if not graph_building: 1460 for existing_trackable in util.list_objects(self._graph_view): 1461 # pylint: disable=protected-access 1462 existing_trackable._maybe_initialize_trackable() 1463 existing_trackable._name_based_restores.add(restore_coordinator) 1464 existing_trackable._name_based_attribute_restore(restore_coordinator) 1465 # pylint: enable=protected-access 1466 return NameBasedSaverStatus( 1467 restore_coordinator, 1468 object_graph_view=self._graph_view) 1469 1470 if graph_building: 1471 if self._file_prefix_placeholder is None: 1472 with ops.device("/cpu:0"): 1473 self._file_prefix_placeholder = constant_op.constant("model") 1474 file_prefix_tensor = self._file_prefix_placeholder 1475 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 1476 else: 1477 with ops.device("/cpu:0"): 1478 file_prefix_tensor = constant_op.constant(save_path) 1479 file_prefix_feed_dict = None 1480 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1481 object_graph_proto.ParseFromString(object_graph_string) 1482 checkpoint = _CheckpointRestoreCoordinator( 1483 object_graph_proto=object_graph_proto, 1484 save_path=save_path, 1485 save_path_tensor=file_prefix_tensor, 1486 reader=reader, 1487 restore_op_cache=self._restore_op_cache, 1488 graph_view=self._graph_view, 1489 options=options, 1490 saveables_cache=self._saveables_cache) 1491 restore_lib.CheckpointPosition( 1492 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) 1493 1494 # Attached dependencies are not attached to the root, so should be restored 1495 # separately. 1496 if self._graph_view.attached_dependencies: 1497 for ref in self._graph_view.attached_dependencies: 1498 if ref.name == "root": 1499 # Root dependency is automatically added to attached dependencies -- 1500 # this can be ignored since it maps back to the root object. 1501 continue 1502 proto_id = None 1503 # Find proto ID of attached dependency (if it is in the proto). 1504 for proto_ref in object_graph_proto.nodes[0].children: 1505 if proto_ref.local_name == ref.name: 1506 proto_id = proto_ref.node_id 1507 break 1508 1509 if proto_id in checkpoint.object_by_proto_id: 1510 # Object has already been restored. This can happen when there's an 1511 # indirect connection from the attached object to the root. 1512 continue 1513 1514 if proto_id is None: 1515 # Could not find attached dependency in proto. 1516 continue 1517 1518 restore_lib.CheckpointPosition( 1519 checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) 1520 1521 load_status = CheckpointLoadStatus( 1522 checkpoint, 1523 graph_view=self._graph_view, 1524 feed_dict=file_prefix_feed_dict) 1525 return load_status 1526 1527 1528def frozen_saver(root_trackable): 1529 """Creates a static `tf.compat.v1.train.Saver` from a trackable object. 1530 1531 The returned `Saver` saves object-based checkpoints, but these checkpoints 1532 will no longer reflect structural changes to the object graph, only changes to 1533 the values of `Variable`s added as dependencies of the root object before 1534 `freeze` was called. 1535 1536 `restore` works on the returned `Saver`, but requires that the object graph of 1537 the checkpoint being loaded exactly matches the object graph when `freeze` was 1538 called. This is in contrast the object-based restore performed by 1539 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's 1540 object graph and the current Python object graph. 1541 1542 Args: 1543 root_trackable: A trackable object to save. 1544 1545 Returns: 1546 A saver which saves object-based checkpoints for the object graph frozen at 1547 the time `frozen_saver` was called. 1548 """ 1549 named_saveable_objects, registered_savers = ( 1550 save_util_v1.frozen_saveables_and_savers( 1551 graph_view_lib.ObjectGraphView(root_trackable))) 1552 return functional_saver.MultiDeviceSaver(named_saveable_objects, 1553 registered_savers) 1554 1555 1556def _assert_trackable(obj, name): 1557 if not isinstance( 1558 obj, (base.Trackable, def_function.Function)): 1559 raise ValueError( 1560 f"`Checkpoint` was expecting {name} to be a trackable object (an " 1561 f"object derived from `Trackable`), got {obj}. If you believe this " 1562 "object should be trackable (i.e. it is part of the " 1563 "TensorFlow Python API and manages state), please open an issue.") 1564 1565 1566def _update_checkpoint_state_internal(file_path): 1567 """Update internal checkpoint state.""" 1568 checkpoint_management.update_checkpoint_state_internal( 1569 save_dir=os.path.dirname(file_path), 1570 model_checkpoint_path=file_path, 1571 all_model_checkpoint_paths=[file_path], 1572 save_relative_paths=True) 1573 1574 1575def _convert_file_name_tensor_to_string(tensor): 1576 """Convert file name tensor to string.""" 1577 output = tensor 1578 if tensor_util.is_tf_type(output): 1579 # Convert to numpy if not `tf.function` building. 1580 if context.executing_eagerly(): 1581 output = compat.as_str(output.numpy()) 1582 else: 1583 # Graph + Session, so we already session.ran it. 1584 output = compat.as_str(output) 1585 return output 1586 1587 1588# Mentions graph building / Sessions. The v2 version is below. 1589@tf_export(v1=["train.Checkpoint"]) 1590class CheckpointV1(autotrackable.AutoTrackable): 1591 """Groups trackable objects, saving and restoring them. 1592 1593 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1594 that contain trackable state, such as `tf.compat.v1.train.Optimizer` 1595 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1596 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1597 maintains a `save_counter` for numbering checkpoints. 1598 1599 Example usage when graph building: 1600 1601 ```python 1602 import tensorflow as tf 1603 import os 1604 1605 checkpoint_directory = "/tmp/training_checkpoints" 1606 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1607 1608 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1609 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1610 train_op = optimizer.minimize( ... ) 1611 status.assert_consumed() # Optional sanity checks. 1612 with tf.compat.v1.Session() as session: 1613 # Use the Session to restore variables, or initialize them if 1614 # tf.train.latest_checkpoint returned None. 1615 status.initialize_or_restore(session) 1616 for _ in range(num_training_steps): 1617 session.run(train_op) 1618 checkpoint.save(file_prefix=checkpoint_prefix) 1619 ``` 1620 1621 Example usage with eager execution enabled: 1622 1623 ```python 1624 import tensorflow as tf 1625 import os 1626 1627 tf.compat.v1.enable_eager_execution() 1628 1629 checkpoint_directory = "/tmp/training_checkpoints" 1630 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1631 1632 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1633 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1634 for _ in range(num_training_steps): 1635 optimizer.minimize( ... ) # Variables will be restored on creation. 1636 status.assert_consumed() # Optional sanity checks. 1637 checkpoint.save(file_prefix=checkpoint_prefix) 1638 ``` 1639 1640 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1641 checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads 1642 `variable.name` based checkpoints. Object-based checkpointing saves a graph of 1643 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, 1644 etc.) with named edges, and this graph is used to match variables when 1645 restoring a checkpoint. It can be more robust to changes in the Python 1646 program, and helps to support restore-on-create for variables when executing 1647 eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new 1648 code. 1649 1650 `Checkpoint` objects have dependencies on the objects passed as keyword 1651 arguments to their constructors, and each dependency is given a name that is 1652 identical to the name of the keyword argument for which it was created. 1653 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1654 dependencies on their variables (e.g. "kernel" and "bias" for 1655 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1656 dependencies easy in user-defined classes, since `Model` hooks into attribute 1657 assignment. For example: 1658 1659 ```python 1660 class Regress(tf.keras.Model): 1661 1662 def __init__(self): 1663 super().__init__() 1664 self.input_transform = tf.keras.layers.Dense(10) 1665 # ... 1666 1667 def call(self, inputs): 1668 x = self.input_transform(inputs) 1669 # ... 1670 ``` 1671 1672 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1673 which in turn depends on its variables. As a result, saving an instance of 1674 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1675 by the `Dense` layer. 1676 1677 When variables are assigned to multiple workers, each worker writes its own 1678 section of the checkpoint. These sections are then merged/re-indexed to behave 1679 as a single checkpoint. This avoids copying all variables to one worker, but 1680 does require that all workers see a common filesystem. 1681 1682 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the 1683 same format, note that the root of the resulting checkpoint is the object the 1684 save method is attached to. This means saving a `tf.keras.Model` using 1685 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model` 1686 attached (or vice versa) will not match the `Model`'s variables. See the 1687 [guide to training 1688 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1689 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for 1690 training checkpoints. 1691 1692 Attributes: 1693 save_counter: Incremented when `save()` is called. Used to number 1694 checkpoints. 1695 """ 1696 1697 def __init__(self, **kwargs): 1698 """Group objects into a training checkpoint. 1699 1700 Args: 1701 **kwargs: Keyword arguments are set as attributes of this object, and are 1702 saved with the checkpoint. Values must be trackable objects. 1703 1704 Raises: 1705 ValueError: If objects in `kwargs` are not trackable. 1706 """ 1707 super().__init__() 1708 global _END_TIME_OF_LAST_WRITE 1709 with _END_TIME_OF_LAST_WRITE_LOCK: 1710 if _END_TIME_OF_LAST_WRITE is None: 1711 _END_TIME_OF_LAST_WRITE = time.time() 1712 1713 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1714 setattr(self, k, v) 1715 if not isinstance( 1716 getattr(self, k), (base.Trackable, def_function.Function)): 1717 raise ValueError( 1718 "`Checkpoint` was expecting a trackable object (an object " 1719 f"derived from `Trackable`), got {v}. If you believe this " 1720 "object should be trackable (i.e. it is part of the " 1721 "TensorFlow Python API and manages state), please open an issue.") 1722 self._save_counter = None # Created lazily for restore-on-create. 1723 self._save_assign_op = None 1724 self._saver = TrackableSaver(graph_view_lib.ObjectGraphView(self)) 1725 1726 def _maybe_create_save_counter(self): 1727 """Create a save counter if it does not yet exist.""" 1728 if self._save_counter is None: 1729 # Initialized to 0 and incremented before saving. 1730 with ops.device("/cpu:0"): 1731 # add_variable creates a dependency named "save_counter"; NoDependency 1732 # prevents creating a second dependency named "_save_counter". 1733 self._save_counter = data_structures.NoDependency( 1734 add_variable( 1735 self, 1736 name="save_counter", 1737 initializer=0, 1738 dtype=dtypes.int64, 1739 trainable=False)) 1740 1741 def write(self, file_prefix, session=None): 1742 """Writes a training checkpoint. 1743 1744 The checkpoint includes variables created by this object and any 1745 trackable objects it depends on at the time `Checkpoint.write()` is 1746 called. 1747 1748 `write` does not number checkpoints, increment `save_counter`, or update the 1749 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1750 use by higher level checkpoint management utilities. `save` provides a very 1751 basic implementation of these features. 1752 1753 Args: 1754 file_prefix: A prefix to use for the checkpoint filenames 1755 (/path/to/directory/and_a_prefix). 1756 session: The session to evaluate variables in. Ignored when executing 1757 eagerly. If not provided when graph building, the default session is 1758 used. 1759 1760 Returns: 1761 The full path to the checkpoint (i.e. `file_prefix`). 1762 """ 1763 return self._write(file_prefix, session) 1764 1765 def _write(self, file_prefix, session=None, write_done_callback=None): 1766 """Writes a training checkpoint. 1767 1768 The checkpoint includes variables created by this object and any 1769 trackable objects it depends on at the time `Checkpoint.write()` is 1770 called. 1771 1772 `write` does not number checkpoints, increment `save_counter`, or update the 1773 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1774 use by higher level checkpoint management utilities. `save` provides a very 1775 basic implementation of these features. 1776 1777 Args: 1778 file_prefix: A prefix to use for the checkpoint filenames 1779 (/path/to/directory/and_a_prefix). 1780 session: The session to evaluate variables in. Ignored when executing 1781 eagerly. If not provided when graph building, the default session is 1782 used. 1783 write_done_callback: Optional callback function to be executed once 1784 the underlying checkpoint saving is finished. Example usage includes 1785 updating the checkpoint internal state. 1786 1787 Returns: 1788 The full path to the checkpoint (i.e. `file_prefix`). 1789 """ 1790 start_time = time.time() 1791 output = self._saver.save(file_prefix=file_prefix, session=session) 1792 end_time = time.time() 1793 1794 metrics.AddCheckpointWriteDuration( 1795 api_label=_CHECKPOINT_V1, 1796 microseconds=_get_duration_microseconds(start_time, end_time)) 1797 1798 global _END_TIME_OF_LAST_WRITE 1799 with _END_TIME_OF_LAST_WRITE_LOCK: 1800 metrics.AddTrainingTimeSaved( 1801 api_label=_CHECKPOINT_V1, 1802 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 1803 end_time)) 1804 _END_TIME_OF_LAST_WRITE = end_time 1805 1806 if tensor_util.is_tf_type(output): 1807 # Convert to numpy if not `tf.function` building. 1808 if context.executing_eagerly(): 1809 output = compat.as_str(output.numpy()) 1810 else: 1811 # Graph + Session, so we already session.ran it. 1812 output = compat.as_str(output) 1813 1814 if write_done_callback: 1815 write_done_callback(output) 1816 1817 metrics.RecordCheckpointSize( 1818 api_label=_CHECKPOINT_V1, filesize=_get_checkpoint_size(output)) 1819 return output 1820 1821 @property 1822 def save_counter(self): 1823 """An integer variable which starts at zero and is incremented on save. 1824 1825 Used to number checkpoints. 1826 1827 Returns: 1828 The save counter variable. 1829 """ 1830 self._maybe_create_save_counter() 1831 return self._save_counter 1832 1833 def save(self, file_prefix, session=None): 1834 """Saves a training checkpoint and provides basic checkpoint management. 1835 1836 The saved checkpoint includes variables created by this object and any 1837 trackable objects it depends on at the time `Checkpoint.save()` is 1838 called. 1839 1840 `save` is a basic convenience wrapper around the `write` method, 1841 sequentially numbering checkpoints using `save_counter` and updating the 1842 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1843 management, for example garbage collection and custom numbering, may be 1844 provided by other utilities which also wrap `write` 1845 (`tf.train.CheckpointManager` for example). 1846 1847 Args: 1848 file_prefix: A prefix to use for the checkpoint filenames 1849 (/path/to/directory/and_a_prefix). Names are generated based on this 1850 prefix and `Checkpoint.save_counter`. 1851 session: The session to evaluate variables in. Ignored when executing 1852 eagerly. If not provided when graph building, the default session is 1853 used. 1854 1855 Returns: 1856 The full path to the checkpoint. 1857 """ 1858 graph_building = not context.executing_eagerly() 1859 if graph_building: 1860 if ops.inside_function(): 1861 raise NotImplementedError( 1862 "Calling tf.train.Checkpoint.save() from a function is not " 1863 "supported, as save() modifies saving metadata in ways not " 1864 "supported by TensorFlow Operations. Consider using " 1865 "tf.train.Checkpoint.write(), a lower-level API which does not " 1866 "update metadata. tf.train.latest_checkpoint and related APIs will " 1867 "not see this checkpoint.") 1868 if session is None: 1869 session = get_session() 1870 if self._save_counter is None: 1871 # When graph building, if this is a new save counter variable then it 1872 # needs to be initialized before assign_add. This is only an issue if 1873 # restore() has not been called first. 1874 session.run(self.save_counter.initializer) 1875 if not graph_building or self._save_assign_op is None: 1876 with ops.colocate_with(self.save_counter): 1877 assign_op = self.save_counter.assign_add(1, read_value=True) 1878 if graph_building: 1879 self._save_assign_op = data_structures.NoDependency(assign_op) 1880 if graph_building: 1881 checkpoint_number = session.run(self._save_assign_op) 1882 else: 1883 checkpoint_number = assign_op.numpy() 1884 file_path = self.write( 1885 "%s-%d" % (file_prefix, checkpoint_number), session=session) 1886 checkpoint_management.update_checkpoint_state_internal( 1887 save_dir=os.path.dirname(file_prefix), 1888 model_checkpoint_path=file_path, 1889 all_model_checkpoint_paths=[file_path], 1890 save_relative_paths=True) 1891 return file_path 1892 1893 def restore(self, save_path): 1894 """Restore a training checkpoint. 1895 1896 Restores this `Checkpoint` and any objects it depends on. 1897 1898 When executing eagerly, either assigns values immediately if variables to 1899 restore have been created already, or defers restoration until the variables 1900 are created. Dependencies added after this call will be matched if they have 1901 a corresponding object in the checkpoint (the restore request will queue in 1902 any trackable object waiting for the expected dependency to be added). 1903 1904 When graph building, restoration ops are added to the graph but not run 1905 immediately. 1906 1907 ```python 1908 checkpoint = tf.train.Checkpoint( ... ) 1909 checkpoint.restore(path) 1910 ``` 1911 1912 To ensure that loading is complete and no more deferred restorations will 1913 take place, you can use the `assert_consumed()` method of the status object 1914 returned by `restore`. 1915 The assert will raise an exception if any Python objects in the dependency 1916 graph were not found in the checkpoint, or if any checkpointed values do not 1917 have a matching Python object: 1918 1919 ```python 1920 checkpoint = tf.train.Checkpoint( ... ) 1921 checkpoint.restore(path).assert_consumed() 1922 ``` 1923 1924 When graph building, `assert_consumed()` indicates that all of the restore 1925 ops that will be created for this checkpoint have been created. They can be 1926 run via the `run_restore_ops()` method of the status object: 1927 1928 ```python 1929 checkpoint.restore(path).assert_consumed().run_restore_ops() 1930 ``` 1931 1932 If the checkpoint has not been consumed completely, then the list of restore 1933 ops will grow as more objects are added to the dependency graph. 1934 1935 To check that all variables in the Python object have restored values from 1936 checkpoint, use `assert_existing_objects_matched()`. This assertion is 1937 useful when called after the variables in your graph have been created. 1938 1939 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1940 method. Names are used to match variables. No restore ops are created/run 1941 until `run_restore_ops()` or `initialize_or_restore()` are called on the 1942 returned status object when graph building, but there is restore-on-creation 1943 when executing eagerly. Re-encode name-based checkpoints using 1944 `tf.train.Checkpoint.save` as soon as possible. 1945 1946 Args: 1947 save_path: The path to the checkpoint, as returned by `save` or 1948 `tf.train.latest_checkpoint`. If None (as when there is no latest 1949 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1950 object which may run initializers for objects in the dependency graph. 1951 If the checkpoint was written by the name-based 1952 `tf.compat.v1.train.Saver`, names are used to match variables. 1953 1954 Returns: 1955 A load status object, which can be used to make assertions about the 1956 status of a checkpoint restoration and run initialization/restore ops. 1957 1958 The returned status object has the following methods: 1959 1960 * `assert_consumed()`: 1961 Raises an exception if any variables are unmatched: either 1962 checkpointed values which don't have a matching Python object or 1963 Python objects in the dependency graph with no values in the 1964 checkpoint. This method returns the status object, and so may be 1965 chained with `initialize_or_restore` or `run_restore_ops`. 1966 1967 * `assert_existing_objects_matched()`: 1968 Raises an exception if any existing Python objects in the dependency 1969 graph are unmatched. Unlike `assert_consumed`, this assertion will 1970 pass if values in the checkpoint have no corresponding Python 1971 objects. For example a `tf.keras.Layer` object which has not yet been 1972 built, and so has not created any variables, will pass this assertion 1973 but will fail `assert_consumed`. Useful when loading part of a larger 1974 checkpoint into a new Python program, e.g. a training checkpoint with 1975 a `tf.compat.v1.train.Optimizer` was saved but only the state required 1976 for inference is being loaded. This method returns the status object, 1977 and so may be chained with `initialize_or_restore` or 1978 `run_restore_ops`. 1979 1980 * `assert_nontrivial_match()`: Asserts that something aside from the root 1981 object was matched. This is a very weak assertion, but is useful for 1982 sanity checking in library code where objects may exist in the 1983 checkpoint which haven't been created in Python and some Python 1984 objects may not have a checkpointed value. 1985 1986 * `expect_partial()`: Silence warnings about incomplete checkpoint 1987 restores. Warnings are otherwise printed for unused parts of the 1988 checkpoint file or object when the `Checkpoint` object is deleted 1989 (often at program shutdown). 1990 1991 * `initialize_or_restore(session=None)`: 1992 When graph building, runs variable initializers if `save_path` is 1993 `None`, but otherwise runs restore operations. If no `session` is 1994 explicitly specified, the default session is used. No effect when 1995 executing eagerly (variables are initialized or restored eagerly). 1996 1997 * `run_restore_ops(session=None)`: 1998 When graph building, runs restore operations. If no `session` is 1999 explicitly specified, the default session is used. No effect when 2000 executing eagerly (restore operations are run eagerly). May only be 2001 called when `save_path` is not `None`. 2002 """ 2003 start_time = time.time() 2004 status = self._saver.restore(save_path=save_path) 2005 # Create the save counter now so it gets initialized with other variables 2006 # when graph building. Creating it earlier would lead to errors when using, 2007 # say, train.Saver() to save the model before initializing it. 2008 self._maybe_create_save_counter() 2009 if isinstance(status, NameBasedSaverStatus): 2010 status.add_to_optionally_restored(self.save_counter) 2011 2012 metrics.AddCheckpointReadDuration( 2013 api_label=_CHECKPOINT_V1, 2014 microseconds=_get_duration_microseconds(start_time, time.time())) 2015 return status 2016 2017 2018@tf_export("train.Checkpoint", v1=[]) 2019class Checkpoint(autotrackable.AutoTrackable): 2020 """Manages saving/restoring trackable values to disk. 2021 2022 TensorFlow objects may contain trackable state, such as `tf.Variable`s, 2023 `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators, 2024 `tf.keras.Layer` implementations, or `tf.keras.Model` implementations. 2025 These are called **trackable objects**. 2026 2027 A `Checkpoint` object can be constructed to save either a single or group of 2028 trackable objects to a checkpoint file. It maintains a `save_counter` for 2029 numbering checkpoints. 2030 2031 Example: 2032 2033 ```python 2034 model = tf.keras.Model(...) 2035 checkpoint = tf.train.Checkpoint(model) 2036 2037 # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time 2038 # checkpoint.save is called, the save counter is increased. 2039 save_path = checkpoint.save('/tmp/training_checkpoints') 2040 2041 # Restore the checkpointed values to the `model` object. 2042 checkpoint.restore(save_path) 2043 ``` 2044 2045 Example 2: 2046 2047 ```python 2048 import tensorflow as tf 2049 import os 2050 2051 checkpoint_directory = "/tmp/training_checkpoints" 2052 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 2053 2054 # Create a Checkpoint that will manage two objects with trackable state, 2055 # one we name "optimizer" and the other we name "model". 2056 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 2057 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 2058 for _ in range(num_training_steps): 2059 optimizer.minimize( ... ) # Variables will be restored on creation. 2060 status.assert_consumed() # Optional sanity checks. 2061 checkpoint.save(file_prefix=checkpoint_prefix) 2062 ``` 2063 2064 `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based 2065 checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which 2066 writes and 2067 reads `variable.name` based checkpoints. Object-based checkpointing saves a 2068 graph of dependencies between Python objects (`Layer`s, `Optimizer`s, 2069 `Variable`s, etc.) with named edges, and this graph is used to match variables 2070 when restoring a checkpoint. It can be more robust to changes in the Python 2071 program, and helps to support restore-on-create for variables. 2072 2073 `Checkpoint` objects have dependencies on the objects passed as keyword 2074 arguments to their constructors, and each dependency is given a name that is 2075 identical to the name of the keyword argument for which it was created. 2076 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 2077 dependencies on their own variables (e.g. "kernel" and "bias" for 2078 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 2079 dependencies easy in user-defined classes, since `Model` hooks into attribute 2080 assignment. For example: 2081 2082 ```python 2083 class Regress(tf.keras.Model): 2084 2085 def __init__(self): 2086 super().__init__() 2087 self.input_transform = tf.keras.layers.Dense(10) 2088 # ... 2089 2090 def call(self, inputs): 2091 x = self.input_transform(inputs) 2092 # ... 2093 ``` 2094 2095 This `Model` has a dependency named "input_transform" on its `Dense` layer, 2096 which in turn depends on its variables. As a result, saving an instance of 2097 `Regress` using `tf.train.Checkpoint` will also save all the variables created 2098 by the `Dense` layer. 2099 2100 When variables are assigned to multiple workers, each worker writes its own 2101 section of the checkpoint. These sections are then merged/re-indexed to behave 2102 as a single checkpoint. This avoids copying all variables to one worker, but 2103 does require that all workers see a common filesystem. 2104 2105 This function differs slightly from the Keras Model `save_weights` function. 2106 `tf.keras.Model.save_weights` creates a checkpoint file with the name 2107 specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints, 2108 using `filepath` as the prefix for the checkpoint file names. Aside from this, 2109 `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent. 2110 2111 See the [guide to training 2112 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 2113 details. 2114 2115 Attributes: 2116 save_counter: Incremented when `save()` is called. Used to number 2117 checkpoints. 2118 """ 2119 2120 def __init__(self, root=None, **kwargs): 2121 """Creates a training checkpoint for a single or group of objects. 2122 2123 Args: 2124 root: The root object to checkpoint. `root` may be a trackable object or 2125 `WeakRef` of a trackable object. 2126 **kwargs: Keyword arguments are set as attributes of this object, and are 2127 saved with the checkpoint. All `kwargs` must be trackable objects, or a 2128 nested structure of trackable objects (`list`, `dict`, or `tuple`). 2129 2130 Raises: 2131 ValueError: If `root` or the objects in `kwargs` are not trackable. A 2132 `ValueError` is also raised if the `root` object tracks different 2133 objects from the ones listed in attributes in kwargs (e.g. 2134 `root.child = A` and `tf.train.Checkpoint(root, child=B)` are 2135 incompatible). 2136 2137 """ 2138 super().__init__() 2139 global _END_TIME_OF_LAST_WRITE 2140 with _END_TIME_OF_LAST_WRITE_LOCK: 2141 if _END_TIME_OF_LAST_WRITE is None: 2142 _END_TIME_OF_LAST_WRITE = time.time() 2143 2144 attached_dependencies = None 2145 self._save_counter = None # Created lazily for restore-on-create. 2146 self._save_assign_op = None 2147 2148 if root: 2149 trackable_root = root() if isinstance(root, weakref.ref) else root 2150 _assert_trackable(trackable_root, "root") 2151 attached_dependencies = [] 2152 2153 # All keyword arguments (including root itself) are set as children 2154 # of root. 2155 kwargs["root"] = root 2156 trackable_root._maybe_initialize_trackable() 2157 2158 self._save_counter = data_structures.NoDependency( 2159 trackable_root._lookup_dependency("save_counter")) 2160 2161 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 2162 setattr(self, k, v) 2163 2164 # Call getattr instead of directly using v because setattr converts 2165 # v to a Trackable data structure when v is a list/dict/tuple. 2166 converted_v = getattr(self, k) 2167 if isinstance(converted_v, weakref.ref): 2168 converted_v = converted_v() 2169 _assert_trackable(converted_v, k) 2170 2171 if root: 2172 # Make sure that root doesn't already have dependencies with these names 2173 child = trackable_root._lookup_dependency(k) 2174 if child is None: 2175 attached_dependencies.append( 2176 base.WeakTrackableReference(k, converted_v)) 2177 elif child != converted_v: 2178 raise ValueError( 2179 f"Cannot create a Checkpoint with keyword argument {k} if " 2180 f"root.{k} already exists.") 2181 2182 self._saver = TrackableSaver( 2183 graph_view_lib.ObjectGraphView( 2184 root if root else self, 2185 attached_dependencies=attached_dependencies)) 2186 self._attached_dependencies = data_structures.NoDependency( 2187 attached_dependencies) 2188 2189 def _maybe_create_save_counter(self): 2190 """Create a save counter if it does not yet exist.""" 2191 if self._save_counter is None: 2192 # Initialized to 0 and incremented before saving. 2193 with ops.device("/cpu:0"): 2194 # add_variable creates a dependency named "save_counter"; NoDependency 2195 # prevents creating a second dependency named "_save_counter". 2196 self._save_counter = data_structures.NoDependency( 2197 add_variable( 2198 self, 2199 name="save_counter", 2200 initializer=0, 2201 dtype=dtypes.int64, 2202 trainable=False)) 2203 if self._attached_dependencies is not None: 2204 self._attached_dependencies.append( 2205 # Store a stronge reference to the `save_counter`, so that if the 2206 # `Checkpoint` object is deleted, the `save_counter` does not get 2207 # deleted immediately. (The LoadStatus object needs to indirectly 2208 # reference the counter through the ObjectGraphView). 2209 base.TrackableReference("save_counter", self._save_counter)) 2210 # When loading a checkpoint, the save counter is created after 2211 # the checkpoint has been loaded, so it must be handled in a deferred 2212 # manner. 2213 if isinstance(self.root, weakref.ref): 2214 root = self.root() 2215 else: 2216 root = self.root 2217 restore = root._deferred_dependencies.pop("save_counter", ()) # pylint: disable=protected-access 2218 if restore: 2219 restore[0].restore(self._save_counter) 2220 2221 def write(self, file_prefix, options=None): 2222 """Writes a training checkpoint. 2223 2224 The checkpoint includes variables created by this object and any 2225 trackable objects it depends on at the time `Checkpoint.write()` is 2226 called. 2227 2228 `write` does not number checkpoints, increment `save_counter`, or update the 2229 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 2230 use by higher level checkpoint management utilities. `save` provides a very 2231 basic implementation of these features. 2232 2233 Checkpoints written with `write` must be read with `read`. 2234 2235 Example usage: 2236 2237 ``` 2238 step = tf.Variable(0, name="step") 2239 checkpoint = tf.Checkpoint(step=step) 2240 checkpoint.write("/tmp/ckpt") 2241 2242 # Later, read the checkpoint with read() 2243 checkpoint.read("/tmp/ckpt") 2244 2245 # You can also pass options to write() and read(). For example this 2246 # runs the IO ops on the localhost: 2247 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2248 checkpoint.write("/tmp/ckpt", options=options) 2249 2250 # Later, read the checkpoint with read() 2251 checkpoint.read("/tmp/ckpt", options=options) 2252 ``` 2253 2254 Args: 2255 file_prefix: A prefix to use for the checkpoint filenames 2256 (/path/to/directory/and_a_prefix). 2257 options: Optional `tf.train.CheckpointOptions` object. 2258 2259 Returns: 2260 The full path to the checkpoint (i.e. `file_prefix`). 2261 """ 2262 if isinstance(file_prefix, os.PathLike): 2263 file_prefix = os.fspath(file_prefix) 2264 return self._write(file_prefix, options) 2265 2266 def _write(self, file_prefix, options=None, write_done_callback=None): 2267 """Internal method that implements Checkpoint.write(). 2268 2269 Args: 2270 file_prefix: A prefix to use for the checkpoint filenames 2271 (/path/to/directory/and_a_prefix). 2272 options: Optional `tf.train.CheckpointOptions` object. 2273 write_done_callback: Optional callback function to be executed once 2274 the underlying checkpoint saving is finished. Example usage includes 2275 updating the checkpoint internal state. 2276 2277 Returns: 2278 The full path to the checkpoint (i.e. `file_prefix`). 2279 """ 2280 start_time = time.time() 2281 options = options or checkpoint_options.CheckpointOptions() 2282 output = self._saver.save( 2283 file_prefix=file_prefix, 2284 options=options, 2285 write_done_callback=write_done_callback) 2286 output = _convert_file_name_tensor_to_string(output) 2287 # For async-checkpoint-v1 scenario, the callback function is triggered in 2288 # TrackableSaver. For others, it is executed here. 2289 if not options.experimental_enable_async_checkpoint and write_done_callback: 2290 write_done_callback(output) 2291 end_time = time.time() 2292 2293 # This records the time checkpoint._write() blocks on the main thread. 2294 metrics.AddCheckpointWriteDuration( 2295 api_label=_CHECKPOINT_V2, 2296 microseconds=_get_duration_microseconds(start_time, end_time)) 2297 2298 global _END_TIME_OF_LAST_WRITE 2299 with _END_TIME_OF_LAST_WRITE_LOCK: 2300 metrics.AddTrainingTimeSaved( 2301 api_label=_CHECKPOINT_V2, 2302 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 2303 end_time)) 2304 _END_TIME_OF_LAST_WRITE = end_time 2305 2306 # Async-checkpoint-v1 may not have finished yet, so we can't measure its 2307 # checkpoint size now. For other scenarios (sync-checkpoint and async-v2), 2308 # this metric recording works fine here. 2309 if not options.experimental_enable_async_checkpoint: 2310 metrics.RecordCheckpointSize( 2311 api_label=_CHECKPOINT_V2, filesize=_get_checkpoint_size(output)) 2312 return output 2313 2314 @property 2315 def save_counter(self): 2316 """An integer variable which starts at zero and is incremented on save. 2317 2318 Used to number checkpoints. 2319 2320 Returns: 2321 The save counter variable. 2322 """ 2323 self._maybe_create_save_counter() 2324 return self._save_counter 2325 2326 def save(self, file_prefix, options=None): 2327 # pylint:disable=line-too-long 2328 """Saves a training checkpoint and provides basic checkpoint management. 2329 2330 The saved checkpoint includes variables created by this object and any 2331 trackable objects it depends on at the time `Checkpoint.save()` is 2332 called. 2333 2334 `save` is a basic convenience wrapper around the `write` method, 2335 sequentially numbering checkpoints using `save_counter` and updating the 2336 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 2337 management, for example garbage collection and custom numbering, may be 2338 provided by other utilities which also wrap `write` and `read`. 2339 (`tf.train.CheckpointManager` for example). 2340 2341 ``` 2342 step = tf.Variable(0, name="step") 2343 checkpoint = tf.train.Checkpoint(step=step) 2344 checkpoint.save("/tmp/ckpt") 2345 2346 # Later, read the checkpoint with restore() 2347 checkpoint.restore("/tmp/ckpt-1") 2348 2349 # You can also pass options to save() and restore(). For example this 2350 # runs the IO ops on the localhost: 2351 options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 2352 checkpoint.save("/tmp/ckpt", options=options) 2353 2354 # Later, read the checkpoint with restore() 2355 checkpoint.restore("/tmp/ckpt-1", options=options) 2356 ``` 2357 2358 Args: 2359 file_prefix: A prefix to use for the checkpoint filenames 2360 (/path/to/directory/and_a_prefix). Names are generated based on this 2361 prefix and `Checkpoint.save_counter`. 2362 options: Optional `tf.train.CheckpointOptions` object. 2363 2364 Returns: 2365 The full path to the checkpoint. 2366 """ 2367 if isinstance(file_prefix, os.PathLike): 2368 file_prefix = os.fspath(file_prefix) 2369 # pylint:enable=line-too-long 2370 options = options or checkpoint_options.CheckpointOptions() 2371 graph_building = not context.executing_eagerly() 2372 if graph_building: 2373 # Assert that async checkpoint is not used for non-eager mode. 2374 if options.experimental_enable_async_checkpoint: 2375 raise NotImplementedError( 2376 "Async checkpoint is not supported for non-eager mode. ") 2377 2378 if ops.inside_function(): 2379 raise NotImplementedError( 2380 "Calling tf.train.Checkpoint.save() from a function is not " 2381 "supported, as save() modifies saving metadata in ways not " 2382 "supported by TensorFlow Operations. Consider using " 2383 "tf.train.Checkpoint.write(), a lower-level API which does not " 2384 "update metadata. tf.train.latest_checkpoint and related APIs will " 2385 "not see this checkpoint.") 2386 session = get_session() 2387 if self._save_counter is None: 2388 # When graph building, if this is a new save counter variable then it 2389 # needs to be initialized before assign_add. This is only an issue if 2390 # restore() has not been called first. 2391 session.run(self.save_counter.initializer) 2392 2393 if not graph_building or self._save_assign_op is None: 2394 with ops.colocate_with(self.save_counter): 2395 assign_op = self.save_counter.assign_add(1, read_value=True) 2396 if graph_building: 2397 self._save_assign_op = data_structures.NoDependency(assign_op) 2398 2399 if graph_building: 2400 checkpoint_number = session.run(self._save_assign_op) 2401 else: 2402 checkpoint_number = assign_op.numpy() 2403 2404 file_path = self._write( 2405 "%s-%d" % (file_prefix, checkpoint_number), 2406 options=options, 2407 write_done_callback=_update_checkpoint_state_internal) 2408 2409 # Ensure save operations have completed, only when running in eager runtime 2410 # and non-async checkpoint configuration. 2411 if not graph_building and not options.experimental_enable_async_checkpoint: 2412 context.async_wait() 2413 2414 return file_path 2415 2416 def read(self, save_path, options=None): 2417 """Reads a training checkpoint written with `write`. 2418 2419 Reads this `Checkpoint` and any objects it depends on. 2420 2421 This method is just like `restore()` but does not expect the `save_counter` 2422 variable in the checkpoint. It only restores the objects that the checkpoint 2423 already depends on. 2424 2425 The method is primarily intended for use by higher level checkpoint 2426 management utilities that use `write()` instead of `save()` and have their 2427 own mechanisms to number and track checkpoints. 2428 2429 Example usage: 2430 2431 ```python 2432 # Create a checkpoint with write() 2433 ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) 2434 path = ckpt.write('/tmp/my_checkpoint') 2435 2436 # Later, load the checkpoint with read() 2437 # With restore() assert_consumed() would have failed. 2438 checkpoint.read(path).assert_consumed() 2439 2440 # You can also pass options to read(). For example this 2441 # runs the IO ops on the localhost: 2442 options = tf.train.CheckpointOptions( 2443 experimental_io_device="/job:localhost") 2444 checkpoint.read(path, options=options) 2445 ``` 2446 2447 Args: 2448 save_path: The path to the checkpoint as returned by `write`. 2449 options: Optional `tf.train.CheckpointOptions` object. 2450 2451 Returns: 2452 A load status object, which can be used to make assertions about the 2453 status of a checkpoint restoration. See `restore` for details. 2454 """ 2455 start_time = time.time() 2456 if isinstance(save_path, os.PathLike): 2457 save_path = os.fspath(save_path) 2458 options = options or checkpoint_options.CheckpointOptions() 2459 result = self._saver.restore(save_path=save_path, options=options) 2460 metrics.AddCheckpointReadDuration( 2461 api_label=_CHECKPOINT_V2, 2462 microseconds=_get_duration_microseconds(start_time, time.time())) 2463 return result 2464 2465 def restore(self, save_path, options=None): 2466 """Restores a training checkpoint. 2467 2468 Restores this `Checkpoint` and any objects it depends on. 2469 2470 This method is intended to be used to load checkpoints created by `save()`. 2471 For checkpoints created by `write()` use the `read()` method which does not 2472 expect the `save_counter` variable added by `save()`. 2473 2474 `restore()` either assigns values immediately if variables to restore have 2475 been created already, or defers restoration until the variables are 2476 created. Dependencies added after this call will be matched if they have a 2477 corresponding object in the checkpoint (the restore request will queue in 2478 any trackable object waiting for the expected dependency to be added). 2479 2480 ```python 2481 checkpoint = tf.train.Checkpoint( ... ) 2482 checkpoint.restore(path) 2483 2484 # You can additionally pass options to restore(): 2485 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2486 checkpoint.restore(path, options=options) 2487 ``` 2488 2489 To ensure that loading is complete and no more deferred restorations will 2490 take place, use the `assert_consumed()` method of the status object returned 2491 by `restore()`: 2492 2493 ```python 2494 checkpoint.restore(path, options=options).assert_consumed() 2495 ``` 2496 2497 The assert will raise an error if any Python objects in the dependency graph 2498 were not found in the checkpoint, or if any checkpointed values do not have 2499 a matching Python object. 2500 2501 Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be 2502 loaded using this method. Names are used to match variables. Re-encode 2503 name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible. 2504 2505 **Loading from SavedModel checkpoints** 2506 2507 To load values from a SavedModel, just pass the SavedModel directory 2508 to checkpoint.restore: 2509 2510 ```python 2511 model = tf.keras.Model(...) 2512 tf.saved_model.save(model, path) # or model.save(path, save_format='tf') 2513 2514 checkpoint = tf.train.Checkpoint(model) 2515 checkpoint.restore(path).expect_partial() 2516 ``` 2517 2518 This example calls `expect_partial()` on the loaded status, since 2519 SavedModels saved from Keras often generates extra keys in the checkpoint. 2520 Otherwise, the program prints a lot of warnings about unused keys at exit 2521 time. 2522 2523 Args: 2524 save_path: The path to the checkpoint, as returned by `save` or 2525 `tf.train.latest_checkpoint`. If the checkpoint was written by the 2526 name-based `tf.compat.v1.train.Saver`, names are used to match 2527 variables. This path may also be a SavedModel directory. 2528 options: Optional `tf.train.CheckpointOptions` object. 2529 2530 Returns: 2531 A load status object, which can be used to make assertions about the 2532 status of a checkpoint restoration. 2533 2534 The returned status object has the following methods: 2535 2536 * `assert_consumed()`: 2537 Raises an exception if any variables are unmatched: either 2538 checkpointed values which don't have a matching Python object or 2539 Python objects in the dependency graph with no values in the 2540 checkpoint. This method returns the status object, and so may be 2541 chained with other assertions. 2542 2543 * `assert_existing_objects_matched()`: 2544 Raises an exception if any existing Python objects in the dependency 2545 graph are unmatched. Unlike `assert_consumed`, this assertion will 2546 pass if values in the checkpoint have no corresponding Python 2547 objects. For example a `tf.keras.Layer` object which has not yet been 2548 built, and so has not created any variables, will pass this assertion 2549 but fail `assert_consumed`. Useful when loading part of a larger 2550 checkpoint into a new Python program, e.g. a training checkpoint with 2551 a `tf.compat.v1.train.Optimizer` was saved but only the state required 2552 for 2553 inference is being loaded. This method returns the status object, and 2554 so may be chained with other assertions. 2555 2556 * `assert_nontrivial_match()`: Asserts that something aside from the root 2557 object was matched. This is a very weak assertion, but is useful for 2558 sanity checking in library code where objects may exist in the 2559 checkpoint which haven't been created in Python and some Python 2560 objects may not have a checkpointed value. 2561 2562 * `expect_partial()`: Silence warnings about incomplete checkpoint 2563 restores. Warnings are otherwise printed for unused parts of the 2564 checkpoint file or object when the `Checkpoint` object is deleted 2565 (often at program shutdown). 2566 2567 Raises: 2568 NotFoundError: if the a checkpoint or SavedModel cannot be found at 2569 `save_path`. 2570 """ 2571 orig_save_path = save_path 2572 if isinstance(save_path, os.PathLike): 2573 save_path = os.fspath(save_path) 2574 2575 if save_path is not None and gfile.IsDirectory(save_path) and ( 2576 (gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or 2577 gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))): 2578 save_path = utils_impl.get_variables_path(save_path) 2579 2580 try: 2581 status = self.read(save_path, options=options) 2582 if context.executing_eagerly(): 2583 context.async_wait() # Ensure restore operations have completed. 2584 except errors_impl.NotFoundError as e: 2585 raise errors_impl.NotFoundError( 2586 None, None, 2587 f"Error when restoring from checkpoint or SavedModel at " 2588 f"{orig_save_path}: {e.message}" 2589 f"\nPlease double-check that the path is correct. You may be missing " 2590 "the checkpoint suffix (e.g. the '-1' in 'path/to/ckpt-1').") 2591 # Create the save counter now so it gets initialized with other variables 2592 # when graph building. Creating it earlier would lead to errors when using, 2593 # say, train.Saver() to save the model before initializing it. 2594 self._maybe_create_save_counter() 2595 if isinstance(status, NameBasedSaverStatus): 2596 status.add_to_optionally_restored(self.save_counter) 2597 return status 2598