1"""Utilities for saving/loading Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22import functools 23import os 24import weakref 25 26import six 27 28from tensorflow.core.protobuf import trackable_object_graph_pb2 29from tensorflow.python.client import session as session_lib 30from tensorflow.python.eager import context 31from tensorflow.python.eager import def_function 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors_impl 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_shape 37from tensorflow.python.framework import tensor_util 38from tensorflow.python.lib.io import file_io 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import gen_io_ops as io_ops 41from tensorflow.python.ops import init_ops 42from tensorflow.python.ops import variable_scope 43from tensorflow.python.ops import variables 44from tensorflow.python.platform import gfile 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.saved_model import utils_impl 47from tensorflow.python.training import checkpoint_management 48from tensorflow.python.training import py_checkpoint_reader 49from tensorflow.python.training import saver as v1_saver_lib 50from tensorflow.python.training.saving import checkpoint_options 51from tensorflow.python.training.saving import functional_saver 52from tensorflow.python.training.saving import saveable_object_util 53from tensorflow.python.training.tracking import base 54from tensorflow.python.training.tracking import data_structures 55from tensorflow.python.training.tracking import graph_view as graph_view_lib 56from tensorflow.python.training.tracking import tracking 57from tensorflow.python.util import compat 58from tensorflow.python.util import deprecation 59from tensorflow.python.util import object_identity 60from tensorflow.python.util import tf_contextlib 61from tensorflow.python.util import tf_inspect 62from tensorflow.python.util.tf_export import tf_export 63 64 65# The callable that provide Keras default session that is needed for saving. 66_SESSION_PROVIDER = None 67 68 69def register_session_provider(session_provider): 70 global _SESSION_PROVIDER 71 if _SESSION_PROVIDER is None: 72 _SESSION_PROVIDER = session_provider 73 74 75def get_session(): 76 # Prefer TF's default session since get_session from Keras has side-effects. 77 session = ops.get_default_session() 78 if session is None: 79 global _SESSION_PROVIDER 80 if _SESSION_PROVIDER is not None: 81 session = _SESSION_PROVIDER() # pylint: disable=not-callable 82 return session 83 84 85class _ObjectGraphProtoPrettyPrinter(object): 86 """Lazily traverses an object graph proto to pretty print names. 87 88 If no calls to `node_names` are made this object has no performance 89 overhead. On the other hand, it will only traverse the object graph once, so 90 repeated naming is cheap after the first. 91 """ 92 93 __slots__ = ["_object_graph_proto", "_node_name_cache"] 94 95 def __init__(self, object_graph_proto): 96 self._object_graph_proto = object_graph_proto 97 self._node_name_cache = None 98 99 @property 100 def node_names(self): 101 """Lazily creates a mapping from node id to ("path", "to", "root").""" 102 if self._node_name_cache is not None: 103 return self._node_name_cache 104 path_to_root = {} 105 path_to_root[0] = ("(root)",) 106 to_visit = collections.deque([0]) 107 while to_visit: 108 node_id = to_visit.popleft() 109 obj = self._object_graph_proto.nodes[node_id] 110 for child in obj.children: 111 if child.node_id not in path_to_root: 112 path_to_root[child.node_id] = ( 113 path_to_root[node_id] + (child.local_name,)) 114 to_visit.append(child.node_id) 115 116 node_names = {} 117 for node_id, path_to_root in path_to_root.items(): 118 node_names[node_id] = ".".join(path_to_root) 119 120 for node_id, node in enumerate(self._object_graph_proto.nodes): 121 for slot_reference in node.slot_variables: 122 node_names[slot_reference.slot_variable_node_id] = ( 123 "{}'s state '{}' for {}".format( 124 node_names[node_id], slot_reference.slot_name, 125 node_names[slot_reference.original_variable_node_id])) 126 self._node_name_cache = node_names 127 return node_names 128 129 130class _CheckpointRestoreCoordinatorDeleter(object): 131 """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__().""" 132 133 __slots__ = [ 134 "expect_partial", "object_graph_proto", "matched_proto_ids", 135 "unused_attributes" 136 ] 137 138 def __init__(self, expect_partial, object_graph_proto, matched_proto_ids, 139 unused_attributes): 140 self.expect_partial = expect_partial 141 self.object_graph_proto = object_graph_proto 142 self.matched_proto_ids = matched_proto_ids 143 self.unused_attributes = unused_attributes 144 145 def set_expect_partial(self, expect_partial): 146 self.expect_partial = expect_partial 147 148 def __del__(self): 149 if self.expect_partial: 150 return 151 if logging is None: 152 # The logging module may have been unloaded when __del__ is called. 153 log_fn = print 154 else: 155 log_fn = logging.warning 156 printed_warning = False 157 pretty_printer = _ObjectGraphProtoPrettyPrinter(self.object_graph_proto) 158 for node_id in range(len(self.object_graph_proto.nodes)): 159 if node_id not in self.matched_proto_ids: 160 log_fn("Unresolved object in checkpoint: {}" 161 .format(pretty_printer.node_names[node_id])) 162 printed_warning = True 163 for node_id, attribute_name in self.unused_attributes.items(): 164 log_fn(("Unused attribute in object {}: {}" 165 .format(pretty_printer.node_names[node_id], attribute_name))) 166 printed_warning = True 167 if printed_warning: 168 log_fn( 169 "A checkpoint was restored (e.g. tf.train.Checkpoint.restore or " 170 "tf.keras.Model.load_weights) but not all checkpointed values were " 171 "used. See above for specific issues. Use expect_partial() on the " 172 "load status object, e.g. " 173 "tf.train.Checkpoint.restore(...).expect_partial(), to silence these " 174 "warnings, or use assert_consumed() to make the check explicit. See " 175 "https://www.tensorflow.org/guide/checkpoint#loading_mechanics" 176 " for details.") 177 178 179class _CheckpointRestoreCoordinator(object): 180 """Holds the status of an object-based checkpoint load.""" 181 182 def __init__(self, object_graph_proto, save_path, save_path_tensor, 183 restore_op_cache, graph_view, options): 184 """Specify the checkpoint being loaded. 185 186 Args: 187 object_graph_proto: The TrackableObjectGraph protocol buffer associated 188 with this checkpoint. 189 save_path: A string, the path to the checkpoint, as returned by 190 `tf.train.latest_checkpoint`. 191 save_path_tensor: A string `Tensor` which contains or will be fed the save 192 path. 193 restore_op_cache: A dictionary shared between 194 `_CheckpointRestoreCoordinator`s for the same Python objects, used to 195 look up restore ops by name to avoid re-creating them across multiple 196 `restore()` calls. 197 graph_view: A graph_view_lib.ObjectGraphView object for the restored 198 objects. 199 options: A CheckpointOptions object. 200 """ 201 self.options = options 202 self.object_graph_proto = object_graph_proto 203 self.restore_uid = ops.uid() 204 # Maps from proto ids to lists of attributes which were in the checkpoint 205 # but not loaded into any object, for error checking. 206 self.unused_attributes = {} 207 # Dictionary mapping from an id in the protocol buffer flat array to 208 # Trackable Python objects. This mapping may be deferred if a 209 # checkpoint is restored before all dependencies have been tracked. Uses 210 # weak references so that partial restorations don't create reference cycles 211 # (as objects with deferred dependencies will generally have references to 212 # this object). 213 self.object_by_proto_id = weakref.WeakValueDictionary() 214 self.matched_proto_ids = set() 215 # A set of all Python objects we've seen as dependencies, even if we didn't 216 # use them (for example because of inconsistent references when 217 # loading). Used to make status assertions fail when loading checkpoints 218 # that don't quite match. 219 self.all_python_objects = object_identity.ObjectIdentityWeakSet() 220 self.save_path_tensor = save_path_tensor 221 self.save_path_string = save_path 222 self.dtype_map = py_checkpoint_reader.NewCheckpointReader( 223 save_path).get_variable_to_dtype_map() 224 # A NewCheckpointReader for the most recent checkpoint, for streaming Python 225 # state restoration. 226 # When graph building, contains a list of ops to run to restore objects from 227 # this checkpoint. 228 self.restore_ops = [] 229 self.restore_ops_by_name = restore_op_cache 230 self.graph_view = graph_view 231 self.new_restore_ops_callback = None 232 # A mapping from optimizer proto ids to lists of slot variables to be 233 # restored when the optimizer is tracked. Only includes slot variables whose 234 # regular variables have already been created, and only for optimizer 235 # objects which have not yet been created/tracked. 236 self.deferred_slot_restorations = {} 237 # A mapping from variable proto ids to lists of slot variables to be 238 # restored when the variable is created/tracked. These get shifted over to 239 # deferred_slot_restorations if the optimizer hasn't been created when that 240 # happens. 241 self.slot_restorations = {} 242 # Controls whether errors are printed in __del__ if some objects did not 243 # match. 244 self.expect_partial_attr = False 245 for node_index, node in enumerate(self.object_graph_proto.nodes): 246 for slot_reference in node.slot_variables: 247 # `node` refers to an `Optimizer`, since only these have slot variables. 248 self.slot_restorations.setdefault( 249 slot_reference.original_variable_node_id, []).append( 250 base._SlotVariableRestoration( # pylint: disable=protected-access 251 optimizer_id=node_index, 252 slot_variable_id=slot_reference.slot_variable_node_id, 253 slot_name=slot_reference.slot_name)) 254 255 self._deleter = _CheckpointRestoreCoordinatorDeleter( 256 self.expect_partial_attr, 257 self.object_graph_proto, 258 self.matched_proto_ids, 259 self.unused_attributes) 260 261 @property 262 def expect_partial(self): 263 return self.expect_partial_attr 264 265 @expect_partial.setter 266 def expect_partial(self, expect_partial): 267 self.expect_partial_attr = expect_partial 268 self._deleter.set_expect_partial(expect_partial) 269 270 def new_restore_ops(self, new_ops): 271 self.restore_ops.extend(new_ops) 272 if self.new_restore_ops_callback: 273 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable 274 275 def restore_saveables(self, tensor_saveables, python_saveables): 276 """Run or build restore operations for SaveableObjects. 277 278 Args: 279 tensor_saveables: `SaveableObject`s which correspond to Tensors. 280 python_saveables: `PythonStateSaveable`s which correspond to Python 281 values. 282 283 Returns: 284 When graph building, a list of restore operations, either cached or newly 285 created, to restore `tensor_saveables`. 286 """ 287 restore_ops = [] 288 # Eagerly run restorations for Python state. 289 reader = None 290 for saveable in python_saveables: 291 if reader is None: 292 # Lazily create the NewCheckpointReader, since this requires file access 293 # and we may not have any Python saveables. 294 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string) 295 spec_names = [spec.name for spec in saveable.specs] 296 saveable.python_restore([reader.get_tensor(name) for name in spec_names]) 297 298 # If we have new SaveableObjects, extract and cache restore ops. 299 if tensor_saveables: 300 validated_saveables = saveable_object_util.validate_and_slice_inputs( 301 tensor_saveables) 302 validated_names = set(saveable.name for saveable in validated_saveables) 303 if set(tensor_saveables.keys()) != validated_names: 304 raise AssertionError( 305 ("Saveable keys changed when validating. Got back %s, was " 306 "expecting %s") % (tensor_saveables.keys(), validated_names)) 307 new_restore_ops = functional_saver.MultiDeviceSaver( 308 validated_saveables).restore(self.save_path_tensor, self.options) 309 if not context.executing_eagerly(): 310 for name, restore_op in sorted(new_restore_ops.items()): 311 restore_ops.append(restore_op) 312 assert name not in self.restore_ops_by_name 313 self.restore_ops_by_name[name] = restore_op 314 return restore_ops 315 316 317class _NameBasedRestoreCoordinator(object): 318 """Keeps the status of a name-based checkpoint restore.""" 319 320 def __init__(self, save_path, dtype_map=None): 321 self.save_path = save_path 322 self.dtype_map = dtype_map 323 # A map from trackable objects to unused attribute names. We don't have 324 # proto IDs when doing a name-based restore, so the map keys differ from 325 # those in _CheckpointRestoreCoordinator. 326 self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary() 327 self.restore_uid = ops.uid() 328 329 def globally_named_object_attributes(self, trackable): 330 """Create globally named SaveableObjects from attributes. 331 332 If an object's attribute has no global name specified (default construction 333 for the SaveableObject factory), records the failure in 334 `self.unused_attributes` (which can then be used to make status assertions 335 fail; see `NameBasedSaverStatus`). 336 337 Args: 338 trackable: An object to save. 339 340 Yields: 341 SaveableObjects for `trackable`'s attributes. 342 """ 343 for attribute_name, saveable_factory in ( 344 trackable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 345 if callable(saveable_factory): 346 try: 347 # This saveable object factory does not have a default name= argument, 348 # which means there's no way to save/restore it using a name-based 349 # checkpoint. Ignore the error now and make sure assert_consumed() 350 # fails. 351 saveable = saveable_factory() 352 except TypeError: 353 # Even if we can't name this object, we should construct it and check 354 # whether it's optional to restore it. If it's optional we don't need 355 # to make assertions fail. 356 if not saveable_factory("").optional_restore: 357 self.unused_attributes.setdefault(trackable, 358 []).append(attribute_name) 359 continue 360 else: 361 saveable = saveable_factory 362 names_to_saveables = saveable_object_util.op_list_to_dict( 363 [saveable], convert_variable_to_tensor=False) 364 for name, op in names_to_saveables.items(): 365 for saveable_object in saveable_object_util.saveable_objects_for_op( 366 op=op, name=name): 367 yield saveable_object 368 369 def eager_restore(self, trackable): 370 """Runs restore ops for `trackable`'s attributes.""" 371 # When graph building, we don't add any restore ops to the graph until 372 # run_restore_ops/initialize_or_restore on the status object for name-based 373 # checkpoints. 374 assert context.executing_eagerly() 375 for saveable in self.globally_named_object_attributes(trackable): 376 restored_tensors = [] 377 tensor_missing = False 378 for spec in saveable.specs: 379 if spec.name in self.dtype_map: 380 with ops.device("cpu:0"): 381 restored, = io_ops.restore_v2( 382 prefix=self.save_path, 383 tensor_names=[spec.name], 384 shape_and_slices=[""], 385 dtypes=[self.dtype_map[spec.name]], 386 name="%s_checkpoint_read" % (spec.name,)) 387 restored_tensors.append(array_ops.identity(restored)) 388 else: 389 tensor_missing = True 390 391 if tensor_missing: 392 # Record that this variable didn't match so assertions will fail. 393 self.unused_attributes.setdefault(trackable, []).append(saveable.name) 394 else: 395 # Ignores values missing from the checkpoint, as with object-based 396 # restore. Status assertions can be used to check exact matches, 397 # although it's unlikely to ever happen for name-based checkpoints. 398 saveable.restore( 399 restored_tensors=restored_tensors, restored_shapes=None) 400 401 402# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change 403# or consolidating the implementation with get_variable. 404def _default_getter(name, 405 shape, 406 dtype, 407 initializer=None, 408 partition_info=None, 409 **kwargs): 410 """A pared-down version of get_variable which does not reuse variables.""" 411 dtype = dtypes.as_dtype(dtype) 412 shape_object = tensor_shape.as_shape(shape) 413 with ops.init_scope(): 414 if initializer is None: 415 initializer, initializing_from_value = ( 416 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access 417 name=name, 418 shape=shape_object, 419 dtype=dtype)) 420 else: 421 initializing_from_value = not callable(initializer) 422 # Same logic as get_variable 423 variable_dtype = dtype.base_dtype 424 if initializing_from_value: 425 if shape is not None: 426 raise ValueError("If initializer is a constant, do not specify shape.") 427 initial_value = initializer 428 else: 429 # Instantiate initializer if provided initializer is a type object. 430 if isinstance(initializer, type(init_ops.Initializer)): 431 initializer = initializer(dtype=dtype) 432 shape_list = None if shape is None else shape_object.as_list() 433 if "partition_info" in tf_inspect.getargspec(initializer).args: 434 initial_value = functools.partial(initializer, 435 shape_list, 436 dtype=dtype, 437 partition_info=partition_info) 438 else: 439 initial_value = functools.partial(initializer, 440 shape_list, 441 dtype=dtype) 442 443 return variables.VariableV1( 444 initial_value=initial_value, 445 name=name, 446 dtype=variable_dtype, 447 use_resource=True, 448 **kwargs) 449 450 451def add_variable(trackable, 452 name, 453 shape=None, 454 dtype=dtypes.float32, 455 initializer=None, 456 trainable=True): 457 """Add a variable to a Trackable with no scope influence.""" 458 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access 459 name=name, 460 shape=shape, 461 dtype=dtype, 462 initializer=initializer, 463 getter=_default_getter, 464 trainable=trainable) 465 466 467def object_metadata(save_path): 468 """Retrieves information about the objects in a checkpoint. 469 470 Example usage: 471 472 ```python 473 object_graph = tf.contrib.checkpoint.object_metadata( 474 tf.train.latest_checkpoint(checkpoint_directory)) 475 ckpt_variable_names = set() 476 for node in object_graph.nodes: 477 for attribute in node.attributes: 478 ckpt_variable_names.add(attribute.full_name) 479 ``` 480 481 Args: 482 save_path: The path to the checkpoint, as returned by `save` or 483 `tf.train.latest_checkpoint`. 484 485 Returns: 486 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. 487 Raises: 488 ValueError: If an object graph was not found in the checkpoint. 489 """ 490 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 491 try: 492 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 493 except errors_impl.NotFoundError: 494 raise ValueError( 495 ('The specified checkpoint "%s" does not appear to be object-based (it ' 496 'is missing the key "%s"). Likely it was created with a name-based ' 497 "saver and does not contain an object dependency graph.") % 498 (save_path, base.OBJECT_GRAPH_PROTO_KEY)) 499 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 500 object_graph_proto.ParseFromString(object_graph_string) 501 return object_graph_proto 502 503 504def list_objects(root_trackable): 505 """Traverse the object graph and list all accessible objects. 506 507 Looks for `Trackable` objects which are dependencies of 508 `root_trackable`. Includes slot variables only if the variable they are 509 slotting for and the optimizer are dependencies of `root_trackable` 510 (i.e. if they would be saved with a checkpoint). 511 512 Args: 513 root_trackable: A `Trackable` object whose dependencies should be flattened. 514 515 Returns: 516 A flat list of objects. 517 """ 518 return graph_view_lib.ObjectGraphView(root_trackable).list_objects() 519 520 521def gather_initializers(root_trackable): 522 """Traverse the object graph and find initialization ops. 523 524 Looks for `Trackable` objects which are dependencies of 525 `root_trackable` and which have an `initializer` property. Includes 526 initializers for slot variables only if the variable they are slotting for and 527 the optimizer are dependencies of `root_trackable` (i.e. if they would be 528 saved with a checkpoint). 529 530 Args: 531 root_trackable: A `Trackable` object to gather initializers for. 532 533 Returns: 534 A list of initialization ops. 535 """ 536 trackable_objects = list_objects(root_trackable) 537 return [ 538 c.initializer 539 for c in trackable_objects 540 if hasattr(c, "initializer") and c.initializer is not None 541 ] 542 543 544@tf_contextlib.contextmanager 545def capture_dependencies(template): 546 """Capture variables created within this scope as `Template` dependencies. 547 548 Requires that `template.variable_scope` is active. 549 550 This scope is intended as a compatibility measure, allowing a trackable 551 object to add dependencies on variables created in a block of code which is 552 not aware of object-based saving (and instead uses variable names 553 heavily). This is how `Template` objects add dependencies on variables and 554 sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly. 555 556 Args: 557 template: The `Template` object to register dependencies with. 558 559 Yields: 560 None (when used as a context manager). 561 """ 562 name_prefix = template.variable_scope.name 563 564 def _trackable_custom_creator(next_creator, 565 name, 566 initial_value, 567 trackable_parent=None, 568 **kwargs): 569 """A variable creation hook which adds Trackable dependencies. 570 571 Set for example during a `Template`'s first wrapped function 572 execution. Ensures that (a) `template` depends on any trackable 573 objects using their own `capture_dependencies` scope inside this scope which 574 create variables, and (b) that any variables not in a more deeply nested 575 scope are added as dependencies directly. 576 577 The `trackable_parent` argument is passed between custom creators but 578 ignored when the variable object itself is created. This argument indicates 579 (if not `None`) that a more deeply nested scope has already added the 580 variable as a dependency, and that parent scopes should add a dependency on 581 that object rather than on the variable directly. 582 583 Args: 584 next_creator: See `variable_scope.variable_creator_scope`; the next 585 creator in the chain. 586 name: The (full, scope-influenced) name of the variable. The `name_prefix` 587 itself is stripped for the purposes of object-based dependency tracking, 588 but scopes opened within this scope are respected. 589 initial_value: See `variable_scope.variable_creator_scope`. Taken 590 explicitly so the argument can be re-named and used with 591 `Trackable._add_variable_with_custom_getter`. 592 trackable_parent: If not None, a more deeply nested trackable object and 593 its name prefix which were passed to `capture_dependencies` to add a 594 dependency on (rather than depending on the variable directly). 595 **kwargs: Passed through to the next creator. 596 597 Returns: 598 The output of `next_creator`: the fetched/created variable object. 599 """ 600 601 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): 602 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which 603 # we don't want to propagate. 604 return next_creator(initial_value=initializer, name=name, **inner_kwargs) 605 606 if name is not None and name.startswith(name_prefix): 607 scope_stripped_name = name[len(name_prefix) + 1:] 608 if not trackable_parent: 609 return template._add_variable_with_custom_getter( # pylint: disable=protected-access 610 initializer=initial_value, 611 name=scope_stripped_name, 612 getter=_call_next_creator_renaming_initializer, 613 # Disable error checking for Trackable. Exceptions are instead 614 # raised if necessary when the object-based saver tries to 615 # save/restore the object. 616 overwrite=True, 617 trackable_parent=(template, name_prefix), 618 **kwargs) 619 else: 620 parent_object, parent_name_prefix = trackable_parent 621 template._track_trackable( # pylint: disable=protected-access 622 parent_object, 623 name=parent_name_prefix[len(name_prefix) + 1:], 624 overwrite=True) 625 return next_creator( 626 name=name, 627 initial_value=initial_value, 628 trackable_parent=(template, name_prefix), 629 **kwargs) 630 631 with variable_scope.variable_creator_scope(_trackable_custom_creator): 632 yield 633 634 635class _LoadStatus(object): 636 """Abstract base for load status callbacks.""" 637 638 @abc.abstractmethod 639 def assert_consumed(self): 640 """Raises an exception unless a non-trivial restoration has completed.""" 641 pass 642 643 @abc.abstractmethod 644 def assert_existing_objects_matched(self): 645 """Raises an exception unless existing Python objects have been matched.""" 646 pass 647 648 @abc.abstractmethod 649 def assert_nontrivial_match(self): 650 """Raises an exception if only the root object matched.""" 651 pass 652 653 @abc.abstractmethod 654 def run_restore_ops(self, session=None): 655 """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" 656 pass 657 658 @abc.abstractmethod 659 def initialize_or_restore(self, session=None): 660 """Runs restore ops from the checkpoint, or initializes variables.""" 661 pass 662 663 def expect_partial(self): 664 """Silence warnings about incomplete checkpoint restores.""" 665 return self 666 667 668def streaming_restore(status, session=None): 669 """When graph building, runs restore ops as soon as they come in. 670 671 Args: 672 status: A _LoadStatus objects from an object-based saver's restore(). 673 Streaming restore from name-based checkpoints is not currently supported. 674 session: A session to run new restore ops in. 675 """ 676 if context.executing_eagerly(): 677 # Streaming restore is the default/only behavior when executing eagerly. 678 return 679 if session is None: 680 session = get_session() 681 if isinstance(status, NameBasedSaverStatus): 682 raise NotImplementedError( 683 "Streaming restore not supported from name-based checkpoints when " 684 "graph building. File a feature request if this limitation bothers " 685 "you. As a workaround, consider either using tf.train.Checkpoint to " 686 "load name-based checkpoints or enabling eager execution.") 687 status.run_restore_ops(session=session) 688 # pylint: disable=protected-access 689 status._checkpoint.new_restore_ops_callback = ( 690 lambda ops: session.run(ops, feed_dict=status._feed_dict)) 691 # pylint: enable=protected-access 692 693 694def _objects_with_attributes(full_list): 695 """Filters out objects with no direct variable dependencies for assertions.""" 696 return [o for o in full_list if o._gather_saveables_for_checkpoint()] # pylint: disable=protected-access 697 698 699class CheckpointLoadStatus(_LoadStatus): 700 """Checks the status of checkpoint loading and manages restore ops. 701 702 Returned from `Saver.restore`. Since `restore` may defer the loading of values 703 in the checkpoint which don't yet have corresponding Python objects, 704 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading 705 is complete (`assert_consumed`). 706 707 When graph building, `restore` does not run restore ops itself since their 708 creation may be deferred. The `run_restore_ops` method must be called once all 709 Python objects with values to restore have been created and added to the 710 dependency graph (this does not necessarily have to be the whole checkpoint; 711 calling `run_restore_ops` while `assert_consumed` fails is supported and will 712 partially restore the checkpoint). 713 714 See `Saver.restore` for usage examples. 715 """ 716 717 def __init__(self, checkpoint, feed_dict, graph_view): 718 self._checkpoint = checkpoint 719 self._feed_dict = feed_dict 720 self._graph_view = graph_view 721 # Keep a reference to the root, since graph_view might only have a weakref. 722 self._root = graph_view.root 723 724 def assert_consumed(self): 725 """Asserts that all objects in the checkpoint have been created/matched. 726 727 Returns: 728 `self` for chaining. 729 Raises: 730 AssertionError: If there are any Python objects in the dependency graph 731 which have not been restored from this checkpoint or a later `restore`, 732 or if there are any checkpointed values which have not been matched to 733 Python objects. 734 """ 735 pretty_printer = _ObjectGraphProtoPrettyPrinter( 736 self._checkpoint.object_graph_proto) 737 self.assert_existing_objects_matched() 738 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 739 if not node.attributes: 740 # Only raise exceptions for the nodes with attributes themselves. Either 741 # they're ultimately not important, or they have a child with an 742 # attribute. 743 continue 744 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 745 if trackable is None: 746 raise AssertionError("Unresolved object in checkpoint {}: {}" 747 .format(pretty_printer.node_names[node_id], node)) 748 if self._checkpoint.slot_restorations: 749 # Sanity check; this collection should be clear if everything has been 750 # restored. 751 raise AssertionError("Unresolved slot restorations: %s" % 752 (self._checkpoint.slot_restorations,)) 753 if self._checkpoint.unused_attributes: 754 unused_attribute_messages = [] 755 for node_id, attribute in six.iteritems( 756 self._checkpoint.unused_attributes): 757 obj = self._checkpoint.object_by_proto_id[node_id] 758 unused_attribute_messages.append( 759 "{} ({}): {}" 760 .format(pretty_printer.node_names[node_id], obj, attribute)) 761 raise AssertionError( 762 ("Unused attributes in these objects (the attributes exist in the " 763 "checkpoint but were not restored):\n{}") 764 .format("\n".join(unused_attribute_messages))) 765 return self 766 767 def assert_existing_objects_matched(self): 768 """Asserts that trackable Python objects have been matched. 769 770 Note that this is a weaker assertion than `assert_consumed`. It will only 771 fail for existing Python objects which are (transitive) dependencies of the 772 root object and which do not have an entry in the checkpoint. 773 774 It will not fail, for example, if a `tf.keras.Layer` object has not yet been 775 built and so has not created any `tf.Variable` objects. 776 777 Returns: 778 `self` for chaining. 779 780 Raises: 781 AssertionError: If a Python object exists in the transitive dependencies 782 of the root object but does not have a value in the checkpoint. 783 """ 784 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 785 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 786 if (trackable is not None and 787 trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access 788 raise AssertionError("Object not assigned a value from checkpoint: %s" % 789 (node,)) 790 for trackable_object in self._graph_view.list_objects(): 791 # Remove data structures that do not contain any variables from 792 # restoration checks. 793 if (isinstance(trackable_object, 794 data_structures.TrackableDataStructure) and 795 not trackable_object._checkpoint_dependencies): 796 continue 797 self._checkpoint.all_python_objects.add(trackable_object) 798 unused_python_objects = ( 799 object_identity.ObjectIdentitySet( 800 _objects_with_attributes( 801 self._checkpoint.all_python_objects)) - 802 object_identity.ObjectIdentitySet( 803 self._checkpoint.object_by_proto_id.values())) 804 if unused_python_objects: 805 raise AssertionError( 806 ("Some Python objects were not bound to checkpointed values, likely " 807 "due to changes in the Python program: %s") % 808 (list(unused_python_objects),)) 809 return self 810 811 def assert_nontrivial_match(self): 812 """Raises an exception if only the root object matched.""" 813 for trackable_object in self._graph_view.list_objects(): 814 self._checkpoint.all_python_objects.add(trackable_object) 815 if len(self._checkpoint.object_by_proto_id) <= 1: 816 unused_python_objects = ( 817 object_identity.ObjectIdentitySet( 818 _objects_with_attributes(self._checkpoint.all_python_objects)) 819 - object_identity.ObjectIdentitySet( 820 self._checkpoint.object_by_proto_id.values())) 821 if unused_python_objects: 822 raise AssertionError( 823 ("Nothing except the root object matched a checkpointed value. " 824 "Typically this means that the checkpoint does not match the " 825 "Python program. The following objects have no matching " 826 "checkpointed value: %s") % (list(unused_python_objects),)) 827 else: 828 raise AssertionError( 829 "Nothing to load. No dependencies have been added to %s yet." % 830 (self._graph_view.root,)) 831 return self 832 833 def run_restore_ops(self, session=None): 834 """Run operations to restore objects in the dependency graph.""" 835 if context.executing_eagerly(): 836 return # Run eagerly 837 if session is None: 838 session = get_session() 839 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) 840 841 def initialize_or_restore(self, session=None): 842 """Run operations to initialize or restore objects in the dependency graph. 843 844 Any objects in the dependency graph which have initializers but are not in 845 the checkpoint will have those initializers run, unless those variables are 846 being restored by a later call to `tf.train.Checkpoint.restore()`. 847 848 This method has a sibling in `InitializationOnlyStatus` which instead 849 initializes variables. That type is returned if no checkpoint is specified 850 in `Saver.restore`. 851 852 Args: 853 session: The session to run init/restore ops in. If `None`, uses the 854 default session. 855 """ 856 if context.executing_eagerly(): 857 return # Initialization and restoration ops are run eagerly 858 if session is None: 859 session = get_session() 860 all_objects = self._graph_view.list_objects() 861 already_initialized_objects = object_identity.ObjectIdentitySet( 862 self._checkpoint.object_by_proto_id.values()) 863 initializers_for_non_restored_variables = [ 864 c.initializer for c in all_objects 865 if hasattr(c, "initializer") 866 and c not in already_initialized_objects 867 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1) 868 < self._checkpoint.restore_uid)] 869 self.run_restore_ops(session=session) 870 session.run(initializers_for_non_restored_variables) 871 872 def expect_partial(self): 873 """Silence warnings about incomplete checkpoint restores.""" 874 self._checkpoint.expect_partial = True 875 return self 876 877 878class InitializationOnlyStatus(_LoadStatus): 879 """Returned from `Saver.restore` when no checkpoint has been specified. 880 881 Objects of this type have the same `assert_consumed` method as 882 `CheckpointLoadStatus`, but it always fails. However, 883 `initialize_or_restore` works on objects of both types, and will 884 initialize variables in `InitializationOnlyStatus` objects or restore them 885 otherwise. 886 """ 887 888 def __init__(self, graph_view, restore_uid): 889 self._restore_uid = restore_uid 890 self._graph_view = graph_view 891 # Keep a reference to the root, since graph_view might only have a weakref. 892 self._root = graph_view.root 893 894 def assert_consumed(self): 895 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 896 raise AssertionError( 897 "No checkpoint specified (save_path=None); nothing is being restored.") 898 899 def assert_existing_objects_matched(self): 900 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 901 raise AssertionError( 902 "No checkpoint specified (save_path=None); nothing is being restored.") 903 904 def assert_nontrivial_match(self): 905 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 906 raise AssertionError( 907 "No checkpoint specified (save_path=None); nothing is being restored.") 908 909 def run_restore_ops(self, session=None): 910 """For consistency with `CheckpointLoadStatus`. 911 912 Use `initialize_or_restore` for initializing if no checkpoint was passed 913 to `Saver.restore` and restoring otherwise. 914 915 Args: 916 session: Not used. 917 """ 918 raise AssertionError( 919 "No checkpoint specified, so no restore ops are available " 920 "(save_path=None to Saver.restore).") 921 922 def initialize_or_restore(self, session=None): 923 """Runs initialization ops for variables. 924 925 Objects which would be saved by `Saver.save` will be initialized, unless 926 those variables are being restored by a later call to 927 `tf.train.Checkpoint.restore()`. 928 929 This method does nothing when executing eagerly (initializers get run 930 eagerly). 931 932 Args: 933 session: The session to run initialization ops in. If `None`, uses the 934 default session. 935 """ 936 if context.executing_eagerly(): 937 return # run eagerly 938 if session is None: 939 session = get_session() 940 trackable_objects = self._graph_view.list_objects() 941 initializers = [ 942 c.initializer for c in trackable_objects 943 if hasattr(c, "initializer") and c.initializer is not None 944 and (getattr(c, "_update_uid", self._restore_uid - 1) 945 < self._restore_uid)] 946 session.run(initializers) 947 948 949_DEPRECATED_RESTORE_INSTRUCTIONS = ( 950 "Restoring a name-based tf.train.Saver checkpoint using the object-based " 951 "restore API. This mode uses global names to match variables, and so is " 952 "somewhat fragile. It also adds new restore ops to the graph each time it " 953 "is called when graph building. Prefer re-encoding training checkpoints in " 954 "the object-based format: run save() on the object-based saver (the same " 955 "one this message is coming from) and use that checkpoint in the future.") 956 957 958class NameBasedSaverStatus(_LoadStatus): 959 """Status for loading a name-based training checkpoint.""" 960 961 # Ideally this deprecation decorator would be on the class, but that 962 # interferes with isinstance checks. 963 @deprecation.deprecated( 964 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) 965 def __init__(self, checkpoint, graph_view): 966 self._checkpoint = checkpoint 967 self._graph_view = graph_view 968 self._optionally_restored = [] 969 # Keep a reference to the root, since graph_view might only have a weakref. 970 self._root = graph_view.root 971 972 def add_to_optionally_restored(self, var): 973 """Add a variable to the list of optionally restored variables. 974 975 There are situations where certain variables should be ignored in assertions 976 such as assert_existing_objects_matched(). One example is that of a 977 checkpoint saved with train.Saver(), and restored with train.Checkpoint(): 978 it is possible for the train.Saver() checkpoint to be missing the internal 979 `save_counter` variable, which we want to ignore on restore. 980 981 Args: 982 var: The variable to treat as optionally restored. 983 """ 984 self._optionally_restored.append(var) 985 986 def assert_consumed(self): 987 """Raises an exception if any variables are unmatched.""" 988 unused_attributes = list(self._checkpoint.unused_attributes.items()) 989 unused_attributes = [ 990 a for a in unused_attributes 991 if all(a[0] is not x for x in self._optionally_restored) 992 ] 993 if unused_attributes: 994 unused_attribute_strings = [ 995 "\n {}: {}".format(obj, attributes) 996 for obj, attributes in unused_attributes 997 ] 998 raise AssertionError( 999 "Some objects had attributes which were not restored:{}".format( 1000 "".join(unused_attribute_strings))) 1001 for trackable in self._graph_view.list_objects(): 1002 # pylint: disable=protected-access 1003 trackable._maybe_initialize_trackable() 1004 if trackable._update_uid < self._checkpoint.restore_uid: 1005 raise AssertionError("Object not restored: %s" % (trackable,)) 1006 # pylint: enable=protected-access 1007 return self 1008 1009 def assert_existing_objects_matched(self): 1010 """Raises an exception if currently created objects are unmatched.""" 1011 # For name-based checkpoints there's no object information in the 1012 # checkpoint, so there's no distinction between 1013 # assert_existing_objects_matched and assert_consumed (and both are less 1014 # useful since we don't touch Python objects or Python state). 1015 return self.assert_consumed() 1016 1017 def assert_nontrivial_match(self): 1018 """Raises an exception if currently created objects are unmatched.""" 1019 # For name-based checkpoints there's no object information in the 1020 # checkpoint, so there's no distinction between 1021 # assert_nontrivial_match and assert_consumed (and both are less 1022 # useful since we don't touch Python objects or Python state). 1023 return self.assert_consumed() 1024 1025 def _gather_saveable_objects(self): 1026 """Walk the object graph, using global names for SaveableObjects.""" 1027 objects = self._graph_view.list_objects() 1028 saveable_objects = [] 1029 for trackable in objects: 1030 # pylint: disable=protected-access 1031 trackable._maybe_initialize_trackable() 1032 if trackable._update_uid < self._checkpoint.restore_uid: 1033 trackable._update_uid = self._checkpoint.restore_uid 1034 else: 1035 continue 1036 # pylint: enable=protected-access 1037 saveable_objects.extend( 1038 self._checkpoint.globally_named_object_attributes(trackable)) 1039 return saveable_objects 1040 1041 def run_restore_ops(self, session=None): 1042 """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`.""" 1043 if context.executing_eagerly(): 1044 return # Nothing to do, variables are restored on creation. 1045 if session is None: 1046 session = get_session() 1047 with ops.device("/cpu:0"): 1048 saveables = self._gather_saveable_objects() 1049 v1_saver_lib.Saver(saveables).restore( 1050 sess=session, save_path=self._checkpoint.save_path) 1051 1052 def initialize_or_restore(self, session=None): 1053 """Alias for `run_restore_ops`.""" 1054 self.run_restore_ops(session=session) 1055 1056 1057class _SessionWithFeedDictAdditions(session_lib.SessionInterface): 1058 """Pretends to be a session, inserts extra feeds on run().""" 1059 1060 def __init__(self, session, feed_additions): 1061 self._wrapped_session = session 1062 self._feed_additions = feed_additions 1063 1064 def run(self, fetches, feed_dict=None, **kwargs): 1065 if feed_dict is None: 1066 feed_dict = {} 1067 else: 1068 feed_dict = feed_dict.copy() 1069 feed_dict.update(self._feed_additions) 1070 return self._wrapped_session.run( 1071 fetches=fetches, feed_dict=feed_dict, **kwargs) 1072 1073 1074class TrackableSaver(object): 1075 """Saves and restores a `Trackable` object and its dependencies. 1076 1077 See `Trackable` for details of dependency management. `Saver` wraps 1078 `tf.compat.v1.train.Saver` for saving, including extra information about the 1079 graph of 1080 dependencies between Python objects. When restoring, it uses this information 1081 about the save-time dependency graph to more robustly match objects with their 1082 checkpointed values. When executing eagerly, it supports restoring variables 1083 on object creation (see `Saver.restore`). 1084 1085 Values in a checkpoint are mapped to `Trackable` Python objects 1086 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the 1087 checkpoint was written. To avoid breaking existing checkpoints when modifying 1088 a class, dependency names (the names of attributes to which `Trackable` 1089 objects are assigned) may not change. These names are local to objects, in 1090 contrast to the `Variable.name`-based save/restore from 1091 `tf.compat.v1.train.Saver`, and 1092 so allow additional program transformations. 1093 """ 1094 1095 def __init__(self, graph_view): 1096 """Configure saving. 1097 1098 Args: 1099 graph_view: A `GraphView` object containing a description of the object 1100 graph to save. 1101 """ 1102 # The file prefix placeholder is created lazily when graph building (and not 1103 # at all when executing eagerly) to avoid creating ops in the constructor 1104 # (when they may never be necessary). 1105 self._file_prefix_placeholder = None 1106 1107 # Op caching for save 1108 self._object_graph_feed_tensor = None 1109 self._last_save_object_graph = None 1110 self._file_prefix_feed_tensor = None 1111 self._cached_save_operation = None 1112 1113 # Op caching for restore, shared between _CheckpointRestoreCoordinators 1114 self._restore_op_cache = {} 1115 self._graph_view = graph_view 1116 1117 def _gather_saveables(self, object_graph_tensor=None): 1118 """Wraps _serialize_object_graph to include the object graph proto.""" 1119 (named_saveable_objects, graph_proto, 1120 feed_additions) = self._graph_view.serialize_object_graph() 1121 if object_graph_tensor is None: 1122 with ops.device("/cpu:0"): 1123 object_graph_tensor = constant_op.constant( 1124 graph_proto.SerializeToString(), dtype=dtypes.string) 1125 else: 1126 feed_additions.update( 1127 {object_graph_tensor: graph_proto.SerializeToString()}) 1128 assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects 1129 named_saveable_objects.append( 1130 base.NoRestoreSaveable( 1131 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) 1132 return named_saveable_objects, graph_proto, feed_additions 1133 1134 def _save_cached_when_graph_building(self, 1135 file_prefix, 1136 object_graph_tensor, 1137 options): 1138 """Create or retrieve save ops. 1139 1140 Args: 1141 file_prefix: The prefix for saved checkpoint files. 1142 object_graph_tensor: A `Tensor` to which the current object graph will be 1143 fed. 1144 options: `CheckpointOptions` object. 1145 1146 Returns: 1147 A two-element tuple with a filename tensor and a feed_dict of tensors to 1148 feed when running it (if graph building). The feed dict contains the 1149 current object graph and any Python state to be saved in the 1150 checkpoint. When executing eagerly only the first argument is meaningful. 1151 """ 1152 (named_saveable_objects, graph_proto, 1153 feed_additions) = self._gather_saveables( 1154 object_graph_tensor=object_graph_tensor) 1155 if (self._last_save_object_graph != graph_proto 1156 # When executing eagerly, we need to re-create SaveableObjects each time 1157 # save() is called so they pick up new Tensors passed to their 1158 # constructors. That means the Saver needs to be copied with a new 1159 # var_list. 1160 or context.executing_eagerly() or ops.inside_function()): 1161 saver = functional_saver.MultiDeviceSaver(named_saveable_objects) 1162 save_op = saver.save(file_prefix, options=options) 1163 with ops.device("/cpu:0"): 1164 with ops.control_dependencies([save_op]): 1165 self._cached_save_operation = array_ops.identity(file_prefix) 1166 self._last_save_object_graph = graph_proto 1167 return self._cached_save_operation, feed_additions 1168 1169 def save(self, file_prefix, checkpoint_number=None, session=None, 1170 options=None): 1171 """Save a training checkpoint. 1172 1173 The saved checkpoint includes variables created by this object and any 1174 Trackable objects it depends on at the time `Saver.save()` is called. 1175 1176 Args: 1177 file_prefix: A prefix to use for the checkpoint filenames 1178 (/path/to/directory/and_a_prefix). Names are generated based on this 1179 prefix and `checkpoint_number`, if provided. 1180 checkpoint_number: An integer variable or Tensor, used to number 1181 checkpoints. Typically this value is saved along with other variables in 1182 training checkpoints, which will happen automatically if it was created 1183 by `root_trackable` or one of its dependencies (via 1184 `Trackable._add_variable`). 1185 session: The session to evaluate variables in. Ignored when executing 1186 eagerly. If not provided when graph building, the default session is 1187 used. 1188 options: Optional `tf.train.CheckpointOptions` object. 1189 1190 Returns: 1191 The full path to the checkpoint. 1192 """ 1193 options = options or checkpoint_options.CheckpointOptions() 1194 feed_dict = {} 1195 use_session = (not context.executing_eagerly() and 1196 not ops.inside_function()) 1197 if checkpoint_number: 1198 file_prefix = "%s-%d" % (file_prefix, checkpoint_number) 1199 if use_session: 1200 if self._object_graph_feed_tensor is None: 1201 with ops.device("/cpu:0"): 1202 self._object_graph_feed_tensor = constant_op.constant( 1203 "", dtype=dtypes.string) 1204 self._file_prefix_feed_tensor = constant_op.constant( 1205 "", dtype=dtypes.string) 1206 object_graph_tensor = self._object_graph_feed_tensor 1207 file_prefix_tensor = self._file_prefix_feed_tensor 1208 feed_dict[file_prefix_tensor] = file_prefix 1209 else: 1210 with ops.device("/cpu:0"): 1211 file_prefix_tensor = constant_op.constant( 1212 file_prefix, dtype=dtypes.string) 1213 object_graph_tensor = None 1214 1215 file_io.recursive_create_dir(os.path.dirname(file_prefix)) 1216 save_path, new_feed_additions = self._save_cached_when_graph_building( 1217 file_prefix_tensor, object_graph_tensor, options) 1218 if new_feed_additions: 1219 feed_dict.update(new_feed_additions) 1220 if not use_session: 1221 session = None 1222 elif session is None: 1223 session = get_session() 1224 1225 if session: 1226 return session.run(save_path, feed_dict=feed_dict) 1227 else: 1228 return save_path 1229 1230 def restore(self, save_path, options=None): 1231 """Restore a training checkpoint. 1232 1233 Restores `root_trackable` and any objects that it tracks 1234 (transitive). Either assigns values immediately if variables to restore have 1235 been created already, or defers restoration until the variables are 1236 created. Dependencies added to the `root_trackable` passed to the 1237 constructor after this call will be matched if they have a corresponding 1238 object in the checkpoint. 1239 1240 When building a graph, restorations are added to the graph but not run. 1241 1242 To disallow deferred loading, assert immediately that all checkpointed 1243 variables have been matched to variable objects: 1244 1245 ```python 1246 saver = Saver(root) 1247 saver.restore(path).assert_consumed() 1248 ``` 1249 1250 An exception will be raised unless every object was matched and its 1251 variables already exist. 1252 1253 When graph building, `assert_consumed()` indicates that all of the restore 1254 ops which will be created for this checkpoint have been created. They can be 1255 run via the `run_restore_ops()` function of the status object: 1256 1257 ```python 1258 saver.restore(path).assert_consumed().run_restore_ops() 1259 ``` 1260 1261 If the checkpoint has not been consumed completely, then the list of restore 1262 ops will grow as more objects are added to the dependency graph. 1263 1264 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1265 method. There is no deferred loading, and names are used to match 1266 variables. No restore ops are created/run until `run_restore_ops()` or 1267 `initialize_or_restore()` are called on the returned status object, even 1268 when executing eagerly. Re-encode name-based checkpoints using this 1269 object-based `Saver.save` as soon as possible. 1270 1271 Args: 1272 save_path: The path to the checkpoint, as returned by `save` or 1273 `tf.train.latest_checkpoint`. If None (as when there is no latest 1274 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1275 object which may run initializers for objects in the dependency graph. 1276 If the checkpoint was written by the name-based 1277 `tf.compat.v1.train.Saver`, names are used to match variables. 1278 options: Optional `tf.train.CheckpointOptions` object. 1279 1280 Returns: 1281 A load status object, which can be used to make assertions about the 1282 status of checkpoint restoration and run initialization/restore ops 1283 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if 1284 `save_path` is `None`). 1285 1286 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` 1287 object is returned which runs restore ops from a name-based saver. 1288 """ 1289 options = options or checkpoint_options.CheckpointOptions() 1290 if save_path is None: 1291 return InitializationOnlyStatus(self._graph_view, ops.uid()) 1292 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 1293 graph_building = not context.executing_eagerly() 1294 if graph_building: 1295 dtype_map = None 1296 else: 1297 dtype_map = reader.get_variable_to_dtype_map() 1298 try: 1299 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 1300 except errors_impl.NotFoundError: 1301 # The object graph proto does not exist in this checkpoint. Try the 1302 # name-based compatibility mode. 1303 restore_coordinator = _NameBasedRestoreCoordinator( 1304 save_path=save_path, 1305 dtype_map=dtype_map) 1306 if not graph_building: 1307 for existing_trackable in self._graph_view.list_objects(): 1308 # pylint: disable=protected-access 1309 existing_trackable._maybe_initialize_trackable() 1310 existing_trackable._name_based_restores.add(restore_coordinator) 1311 existing_trackable._name_based_attribute_restore(restore_coordinator) 1312 # pylint: enable=protected-access 1313 return NameBasedSaverStatus( 1314 restore_coordinator, 1315 graph_view=self._graph_view) 1316 1317 if graph_building: 1318 if self._file_prefix_placeholder is None: 1319 with ops.device("/cpu:0"): 1320 self._file_prefix_placeholder = constant_op.constant("model") 1321 file_prefix_tensor = self._file_prefix_placeholder 1322 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 1323 else: 1324 with ops.device("/cpu:0"): 1325 file_prefix_tensor = constant_op.constant(save_path) 1326 file_prefix_feed_dict = None 1327 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1328 object_graph_proto.ParseFromString(object_graph_string) 1329 checkpoint = _CheckpointRestoreCoordinator( 1330 object_graph_proto=object_graph_proto, 1331 save_path=save_path, 1332 save_path_tensor=file_prefix_tensor, 1333 restore_op_cache=self._restore_op_cache, 1334 graph_view=self._graph_view, 1335 options=options) 1336 base.CheckpointPosition( 1337 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) 1338 1339 # Attached dependencies are not attached to the root, so should be restored 1340 # separately. 1341 if self._graph_view.attached_dependencies: 1342 for ref in self._graph_view.attached_dependencies: 1343 if ref.name == "root": 1344 # Root dependency is automatically added to attached dependencies -- 1345 # this can be ignored since it maps back to the root object. 1346 continue 1347 proto_id = None 1348 # Find proto ID of attached dependency (if it is in the proto). 1349 for proto_ref in object_graph_proto.nodes[0].children: 1350 if proto_ref.local_name == ref.name: 1351 proto_id = proto_ref.node_id 1352 break 1353 1354 if proto_id in checkpoint.object_by_proto_id: 1355 # Object has already been restored. This can happen when there's an 1356 # indirect connection from the attached object to the root. 1357 continue 1358 1359 base.CheckpointPosition( 1360 checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) 1361 1362 load_status = CheckpointLoadStatus( 1363 checkpoint, 1364 graph_view=self._graph_view, 1365 feed_dict=file_prefix_feed_dict) 1366 return load_status 1367 1368 1369def frozen_saver(root_trackable): 1370 """Creates a static `tf.compat.v1.train.Saver` from a trackable object. 1371 1372 The returned `Saver` saves object-based checkpoints, but these checkpoints 1373 will no longer reflect structural changes to the object graph, only changes to 1374 the values of `Variable`s added as dependencies of the root object before 1375 `freeze` was called. 1376 1377 `restore` works on the returned `Saver`, but requires that the object graph of 1378 the checkpoint being loaded exactly matches the object graph when `freeze` was 1379 called. This is in contrast the object-based restore performed by 1380 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's 1381 object graph and the current Python object graph. 1382 1383 Args: 1384 root_trackable: A trackable object to save. 1385 1386 Returns: 1387 A saver which saves object-based checkpoints for the object graph frozen at 1388 the time `frozen_saver` was called. 1389 """ 1390 named_saveable_objects = graph_view_lib.ObjectGraphView( 1391 root_trackable).frozen_saveable_objects() 1392 return functional_saver.MultiDeviceSaver(named_saveable_objects) 1393 1394 1395def saver_with_op_caching(obj, attached_dependencies=None): 1396 """A TrackableSaver with a SaveableObject cache when graph building.""" 1397 if context.executing_eagerly(): 1398 saveables_cache = None 1399 else: 1400 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 1401 return TrackableSaver( 1402 graph_view_lib.ObjectGraphView( 1403 weakref.ref(obj), saveables_cache=saveables_cache, 1404 attached_dependencies=attached_dependencies)) 1405 1406 1407def _assert_trackable(obj): 1408 if not isinstance( 1409 obj, (base.Trackable, def_function.Function)): 1410 raise ValueError( 1411 "`Checkpoint` was expecting a trackable object (an object " 1412 "derived from `TrackableBase`), got {}. If you believe this " 1413 "object should be trackable (i.e. it is part of the " 1414 "TensorFlow Python API and manages state), please open an issue." 1415 .format(obj)) 1416 1417 1418# Mentions graph building / Sessions. The v2 version is below. 1419@tf_export(v1=["train.Checkpoint"]) 1420class CheckpointV1(tracking.AutoTrackable): 1421 """Groups trackable objects, saving and restoring them. 1422 1423 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1424 that contain trackable state, such as `tf.compat.v1.train.Optimizer` 1425 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1426 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1427 maintains a `save_counter` for numbering checkpoints. 1428 1429 Example usage when graph building: 1430 1431 ```python 1432 import tensorflow as tf 1433 import os 1434 1435 checkpoint_directory = "/tmp/training_checkpoints" 1436 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1437 1438 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1439 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1440 train_op = optimizer.minimize( ... ) 1441 status.assert_consumed() # Optional sanity checks. 1442 with tf.compat.v1.Session() as session: 1443 # Use the Session to restore variables, or initialize them if 1444 # tf.train.latest_checkpoint returned None. 1445 status.initialize_or_restore(session) 1446 for _ in range(num_training_steps): 1447 session.run(train_op) 1448 checkpoint.save(file_prefix=checkpoint_prefix) 1449 ``` 1450 1451 Example usage with eager execution enabled: 1452 1453 ```python 1454 import tensorflow as tf 1455 import os 1456 1457 tf.compat.v1.enable_eager_execution() 1458 1459 checkpoint_directory = "/tmp/training_checkpoints" 1460 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1461 1462 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1463 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1464 for _ in range(num_training_steps): 1465 optimizer.minimize( ... ) # Variables will be restored on creation. 1466 status.assert_consumed() # Optional sanity checks. 1467 checkpoint.save(file_prefix=checkpoint_prefix) 1468 ``` 1469 1470 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1471 checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads 1472 `variable.name` based checkpoints. Object-based checkpointing saves a graph of 1473 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, 1474 etc.) with named edges, and this graph is used to match variables when 1475 restoring a checkpoint. It can be more robust to changes in the Python 1476 program, and helps to support restore-on-create for variables when executing 1477 eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new 1478 code. 1479 1480 `Checkpoint` objects have dependencies on the objects passed as keyword 1481 arguments to their constructors, and each dependency is given a name that is 1482 identical to the name of the keyword argument for which it was created. 1483 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1484 dependencies on their variables (e.g. "kernel" and "bias" for 1485 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1486 dependencies easy in user-defined classes, since `Model` hooks into attribute 1487 assignment. For example: 1488 1489 ```python 1490 class Regress(tf.keras.Model): 1491 1492 def __init__(self): 1493 super(Regress, self).__init__() 1494 self.input_transform = tf.keras.layers.Dense(10) 1495 # ... 1496 1497 def call(self, inputs): 1498 x = self.input_transform(inputs) 1499 # ... 1500 ``` 1501 1502 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1503 which in turn depends on its variables. As a result, saving an instance of 1504 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1505 by the `Dense` layer. 1506 1507 When variables are assigned to multiple workers, each worker writes its own 1508 section of the checkpoint. These sections are then merged/re-indexed to behave 1509 as a single checkpoint. This avoids copying all variables to one worker, but 1510 does require that all workers see a common filesystem. 1511 1512 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the 1513 same format, note that the root of the resulting checkpoint is the object the 1514 save method is attached to. This means saving a `tf.keras.Model` using 1515 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model` 1516 attached (or vice versa) will not match the `Model`'s variables. See the 1517 [guide to training 1518 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1519 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for 1520 training checkpoints. 1521 1522 Attributes: 1523 save_counter: Incremented when `save()` is called. Used to number 1524 checkpoints. 1525 """ 1526 1527 def __init__(self, **kwargs): 1528 """Group objects into a training checkpoint. 1529 1530 Args: 1531 **kwargs: Keyword arguments are set as attributes of this object, and are 1532 saved with the checkpoint. Values must be trackable objects. 1533 1534 Raises: 1535 ValueError: If objects in `kwargs` are not trackable. 1536 """ 1537 super(CheckpointV1, self).__init__() 1538 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1539 setattr(self, k, v) 1540 if not isinstance( 1541 getattr(self, k), (base.Trackable, def_function.Function)): 1542 raise ValueError( 1543 ("`Checkpoint` was expecting a trackable object (an object " 1544 "derived from `TrackableBase`), got %s. If you believe this " 1545 "object should be trackable (i.e. it is part of the " 1546 "TensorFlow Python API and manages state), please open an issue.") 1547 % (v,)) 1548 self._save_counter = None # Created lazily for restore-on-create. 1549 self._save_assign_op = None 1550 self._saver = saver_with_op_caching(self) 1551 1552 def _maybe_create_save_counter(self): 1553 """Create a save counter if it does not yet exist.""" 1554 if self._save_counter is None: 1555 # Initialized to 0 and incremented before saving. 1556 with ops.device("/cpu:0"): 1557 # add_variable creates a dependency named "save_counter"; NoDependency 1558 # prevents creating a second dependency named "_save_counter". 1559 self._save_counter = data_structures.NoDependency( 1560 add_variable( 1561 self, 1562 name="save_counter", 1563 initializer=0, 1564 dtype=dtypes.int64, 1565 trainable=False)) 1566 1567 def write(self, file_prefix, session=None): 1568 """Writes a training checkpoint. 1569 1570 The checkpoint includes variables created by this object and any 1571 trackable objects it depends on at the time `Checkpoint.write()` is 1572 called. 1573 1574 `write` does not number checkpoints, increment `save_counter`, or update the 1575 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1576 use by higher level checkpoint management utilities. `save` provides a very 1577 basic implementation of these features. 1578 1579 Args: 1580 file_prefix: A prefix to use for the checkpoint filenames 1581 (/path/to/directory/and_a_prefix). 1582 session: The session to evaluate variables in. Ignored when executing 1583 eagerly. If not provided when graph building, the default session is 1584 used. 1585 1586 Returns: 1587 The full path to the checkpoint (i.e. `file_prefix`). 1588 """ 1589 output = self._saver.save(file_prefix=file_prefix, session=session) 1590 if tensor_util.is_tf_type(output): 1591 if context.executing_eagerly(): 1592 return compat.as_str(output.numpy()) 1593 else: 1594 # Function building 1595 return output 1596 else: 1597 # Graph + Session, so we already session.ran it. 1598 return compat.as_str(output) 1599 1600 @property 1601 def save_counter(self): 1602 """An integer variable which starts at zero and is incremented on save. 1603 1604 Used to number checkpoints. 1605 1606 Returns: 1607 The save counter variable. 1608 """ 1609 self._maybe_create_save_counter() 1610 return self._save_counter 1611 1612 def save(self, file_prefix, session=None): 1613 """Saves a training checkpoint and provides basic checkpoint management. 1614 1615 The saved checkpoint includes variables created by this object and any 1616 trackable objects it depends on at the time `Checkpoint.save()` is 1617 called. 1618 1619 `save` is a basic convenience wrapper around the `write` method, 1620 sequentially numbering checkpoints using `save_counter` and updating the 1621 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1622 management, for example garbage collection and custom numbering, may be 1623 provided by other utilities which also wrap `write` 1624 (`tf.train.CheckpointManager` for example). 1625 1626 Args: 1627 file_prefix: A prefix to use for the checkpoint filenames 1628 (/path/to/directory/and_a_prefix). Names are generated based on this 1629 prefix and `Checkpoint.save_counter`. 1630 session: The session to evaluate variables in. Ignored when executing 1631 eagerly. If not provided when graph building, the default session is 1632 used. 1633 1634 Returns: 1635 The full path to the checkpoint. 1636 """ 1637 graph_building = not context.executing_eagerly() 1638 if graph_building: 1639 if ops.inside_function(): 1640 raise NotImplementedError( 1641 "Calling tf.train.Checkpoint.save() from a function is not " 1642 "supported, as save() modifies saving metadata in ways not " 1643 "supported by TensorFlow Operations. Consider using " 1644 "tf.train.Checkpoint.write(), a lower-level API which does not " 1645 "update metadata. tf.train.latest_checkpoint and related APIs will " 1646 "not see this checkpoint.") 1647 if session is None: 1648 session = get_session() 1649 if self._save_counter is None: 1650 # When graph building, if this is a new save counter variable then it 1651 # needs to be initialized before assign_add. This is only an issue if 1652 # restore() has not been called first. 1653 session.run(self.save_counter.initializer) 1654 if not graph_building or self._save_assign_op is None: 1655 with ops.colocate_with(self.save_counter): 1656 assign_op = self.save_counter.assign_add(1, read_value=True) 1657 if graph_building: 1658 self._save_assign_op = data_structures.NoDependency(assign_op) 1659 if graph_building: 1660 checkpoint_number = session.run(self._save_assign_op) 1661 else: 1662 checkpoint_number = assign_op.numpy() 1663 file_path = self.write( 1664 "%s-%d" % (file_prefix, checkpoint_number), session=session) 1665 checkpoint_management.update_checkpoint_state_internal( 1666 save_dir=os.path.dirname(file_prefix), 1667 model_checkpoint_path=file_path, 1668 all_model_checkpoint_paths=[file_path], 1669 save_relative_paths=True) 1670 return file_path 1671 1672 def restore(self, save_path): 1673 """Restore a training checkpoint. 1674 1675 Restores this `Checkpoint` and any objects it depends on. 1676 1677 When executing eagerly, either assigns values immediately if variables to 1678 restore have been created already, or defers restoration until the variables 1679 are created. Dependencies added after this call will be matched if they have 1680 a corresponding object in the checkpoint (the restore request will queue in 1681 any trackable object waiting for the expected dependency to be added). 1682 1683 When graph building, restoration ops are added to the graph but not run 1684 immediately. 1685 1686 To ensure that loading is complete and no more assignments will take place, 1687 use the `assert_consumed()` method of the status object returned by 1688 `restore`: 1689 1690 ```python 1691 checkpoint = tf.train.Checkpoint( ... ) 1692 checkpoint.restore(path).assert_consumed() 1693 ``` 1694 1695 An exception will be raised if any Python objects in the dependency graph 1696 were not found in the checkpoint, or if any checkpointed values do not have 1697 a matching Python object. 1698 1699 When graph building, `assert_consumed()` indicates that all of the restore 1700 ops that will be created for this checkpoint have been created. They can be 1701 run via the `run_restore_ops()` method of the status object: 1702 1703 ```python 1704 checkpoint.restore(path).assert_consumed().run_restore_ops() 1705 ``` 1706 1707 If the checkpoint has not been consumed completely, then the list of restore 1708 ops will grow as more objects are added to the dependency graph. 1709 1710 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1711 method. Names are used to match variables. No restore ops are created/run 1712 until `run_restore_ops()` or `initialize_or_restore()` are called on the 1713 returned status object when graph building, but there is restore-on-creation 1714 when executing eagerly. Re-encode name-based checkpoints using 1715 `tf.train.Checkpoint.save` as soon as possible. 1716 1717 Args: 1718 save_path: The path to the checkpoint, as returned by `save` or 1719 `tf.train.latest_checkpoint`. If None (as when there is no latest 1720 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1721 object which may run initializers for objects in the dependency graph. 1722 If the checkpoint was written by the name-based 1723 `tf.compat.v1.train.Saver`, names are used to match variables. 1724 1725 Returns: 1726 A load status object, which can be used to make assertions about the 1727 status of a checkpoint restoration and run initialization/restore ops. 1728 1729 The returned status object has the following methods: 1730 1731 * `assert_consumed()`: 1732 Raises an exception if any variables are unmatched: either 1733 checkpointed values which don't have a matching Python object or 1734 Python objects in the dependency graph with no values in the 1735 checkpoint. This method returns the status object, and so may be 1736 chained with `initialize_or_restore` or `run_restore_ops`. 1737 1738 * `assert_existing_objects_matched()`: 1739 Raises an exception if any existing Python objects in the dependency 1740 graph are unmatched. Unlike `assert_consumed`, this assertion will 1741 pass if values in the checkpoint have no corresponding Python 1742 objects. For example a `tf.keras.Layer` object which has not yet been 1743 built, and so has not created any variables, will pass this assertion 1744 but fail `assert_consumed`. Useful when loading part of a larger 1745 checkpoint into a new Python program, e.g. a training checkpoint with 1746 a `tf.compat.v1.train.Optimizer` was saved but only the state required 1747 for 1748 inference is being loaded. This method returns the status object, and 1749 so may be chained with `initialize_or_restore` or `run_restore_ops`. 1750 1751 * `assert_nontrivial_match()`: Asserts that something aside from the root 1752 object was matched. This is a very weak assertion, but is useful for 1753 sanity checking in library code where objects may exist in the 1754 checkpoint which haven't been created in Python and some Python 1755 objects may not have a checkpointed value. 1756 1757 * `expect_partial()`: Silence warnings about incomplete checkpoint 1758 restores. Warnings are otherwise printed for unused parts of the 1759 checkpoint file or object when the `Checkpoint` object is deleted 1760 (often at program shutdown). 1761 1762 * `initialize_or_restore(session=None)`: 1763 When graph building, runs variable initializers if `save_path` is 1764 `None`, but otherwise runs restore operations. If no `session` is 1765 explicitly specified, the default session is used. No effect when 1766 executing eagerly (variables are initialized or restored eagerly). 1767 1768 * `run_restore_ops(session=None)`: 1769 When graph building, runs restore operations. If no `session` is 1770 explicitly specified, the default session is used. No effect when 1771 executing eagerly (restore operations are run eagerly). May only be 1772 called when `save_path` is not `None`. 1773 """ 1774 status = self._saver.restore(save_path=save_path) 1775 # Create the save counter now so it gets initialized with other variables 1776 # when graph building. Creating it earlier would lead to errors when using, 1777 # say, train.Saver() to save the model before initializing it. 1778 self._maybe_create_save_counter() 1779 if isinstance(status, NameBasedSaverStatus): 1780 status.add_to_optionally_restored(self.save_counter) 1781 return status 1782 1783 1784@tf_export("train.Checkpoint", v1=[]) 1785class Checkpoint(tracking.AutoTrackable): 1786 """Manages saving/restoring trackable values to disk. 1787 1788 TensorFlow objects may contain trackable state, such as `tf.Variable`s, 1789 `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators, 1790 `tf.keras.Layer` implementations, or `tf.keras.Model` implementations. 1791 These are called **trackable objects**. 1792 1793 A `Checkpoint` object can be constructed to save either a single or group of 1794 trackable objects to a checkpoint file. It maintains a `save_counter` for 1795 numbering checkpoints. 1796 1797 Example: 1798 1799 ```python 1800 model = tf.keras.Model(...) 1801 checkpoint = tf.train.Checkpoint(model) 1802 1803 # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time 1804 # checkpoint.save is called, the save counter is increased. 1805 save_path = checkpoint.save('/tmp/training_checkpoints') 1806 1807 # Restore the checkpointed values to the `model` object. 1808 checkpoint.restore(save_path) 1809 ``` 1810 1811 Example 2: 1812 1813 ```python 1814 import tensorflow as tf 1815 import os 1816 1817 checkpoint_directory = "/tmp/training_checkpoints" 1818 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1819 1820 # Create a Checkpoint that will manage two objects with trackable state, 1821 # one we name "optimizer" and the other we name "model". 1822 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1823 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1824 for _ in range(num_training_steps): 1825 optimizer.minimize( ... ) # Variables will be restored on creation. 1826 status.assert_consumed() # Optional sanity checks. 1827 checkpoint.save(file_prefix=checkpoint_prefix) 1828 ``` 1829 1830 `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based 1831 checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which 1832 writes and 1833 reads `variable.name` based checkpoints. Object-based checkpointing saves a 1834 graph of dependencies between Python objects (`Layer`s, `Optimizer`s, 1835 `Variable`s, etc.) with named edges, and this graph is used to match variables 1836 when restoring a checkpoint. It can be more robust to changes in the Python 1837 program, and helps to support restore-on-create for variables. 1838 1839 `Checkpoint` objects have dependencies on the objects passed as keyword 1840 arguments to their constructors, and each dependency is given a name that is 1841 identical to the name of the keyword argument for which it was created. 1842 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1843 dependencies on their own variables (e.g. "kernel" and "bias" for 1844 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1845 dependencies easy in user-defined classes, since `Model` hooks into attribute 1846 assignment. For example: 1847 1848 ```python 1849 class Regress(tf.keras.Model): 1850 1851 def __init__(self): 1852 super(Regress, self).__init__() 1853 self.input_transform = tf.keras.layers.Dense(10) 1854 # ... 1855 1856 def call(self, inputs): 1857 x = self.input_transform(inputs) 1858 # ... 1859 ``` 1860 1861 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1862 which in turn depends on its variables. As a result, saving an instance of 1863 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1864 by the `Dense` layer. 1865 1866 When variables are assigned to multiple workers, each worker writes its own 1867 section of the checkpoint. These sections are then merged/re-indexed to behave 1868 as a single checkpoint. This avoids copying all variables to one worker, but 1869 does require that all workers see a common filesystem. 1870 1871 This function differs slightly from the Keras Model `save_weights` function. 1872 `tf.keras.Model.save_weights` creates a checkpoint file with the name 1873 specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints, 1874 using `filepath` as the prefix for the checkpoint file names. Aside from this, 1875 `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent. 1876 1877 See the [guide to training 1878 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1879 details. 1880 1881 Attributes: 1882 save_counter: Incremented when `save()` is called. Used to number 1883 checkpoints. 1884 """ 1885 1886 def __init__(self, root=None, **kwargs): 1887 """Creates a training checkpoint for a single or group of objects. 1888 1889 Args: 1890 root: The root object to checkpoint. 1891 **kwargs: Keyword arguments are set as attributes of this object, and are 1892 saved with the checkpoint. Values must be trackable objects. 1893 1894 Raises: 1895 ValueError: If `root` or the objects in `kwargs` are not trackable. A 1896 `ValueError` is also raised if the `root` object tracks different 1897 objects from the ones listed in attributes in kwargs (e.g. 1898 `root.child = A` and `tf.train.Checkpoint(root, child=B)` are 1899 incompatible). 1900 1901 """ 1902 super(Checkpoint, self).__init__() 1903 1904 saver_root = self 1905 attached_dependencies = None 1906 self._save_counter = None # Created lazily for restore-on-create. 1907 self._save_assign_op = None 1908 1909 if root: 1910 _assert_trackable(root) 1911 saver_root = root 1912 attached_dependencies = [] 1913 1914 # All keyword arguments (including root itself) are set as children 1915 # of root. 1916 kwargs["root"] = root 1917 root._maybe_initialize_trackable() 1918 1919 self._save_counter = data_structures.NoDependency( 1920 root._lookup_dependency("save_counter")) 1921 self._root = data_structures.NoDependency(root) 1922 1923 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1924 setattr(self, k, v) 1925 1926 # Call getattr instead of directly using v because setattr converts 1927 # v to a Trackable data structure when v is a list/dict/tuple. 1928 converted_v = getattr(self, k) 1929 _assert_trackable(converted_v) 1930 1931 if root: 1932 # Make sure that root doesn't already have dependencies with these names 1933 child = root._lookup_dependency(k) 1934 if child is None: 1935 attached_dependencies.append(base.TrackableReference(k, converted_v)) 1936 elif child != converted_v: 1937 raise ValueError( 1938 "Cannot create a Checkpoint with keyword argument {name} if " 1939 "root.{name} already exists.".format(name=k)) 1940 1941 self._saver = saver_with_op_caching(saver_root, attached_dependencies) 1942 self._attached_dependencies = data_structures.NoDependency( 1943 attached_dependencies) 1944 1945 def _maybe_create_save_counter(self): 1946 """Create a save counter if it does not yet exist.""" 1947 if self._save_counter is None: 1948 # Initialized to 0 and incremented before saving. 1949 with ops.device("/cpu:0"): 1950 # add_variable creates a dependency named "save_counter"; NoDependency 1951 # prevents creating a second dependency named "_save_counter". 1952 self._save_counter = data_structures.NoDependency( 1953 add_variable( 1954 self, 1955 name="save_counter", 1956 initializer=0, 1957 dtype=dtypes.int64, 1958 trainable=False)) 1959 if self._attached_dependencies is not None: 1960 self._attached_dependencies.append( 1961 base.TrackableReference("save_counter", self._save_counter)) 1962 # When loading a checkpoint, the save counter is created after 1963 # the checkpoint has been loaded, so it must be handled in a deferred 1964 # manner. 1965 restore = self.root._deferred_dependencies.pop("save_counter", ()) # pylint: disable=protected-access 1966 if restore: 1967 restore[0].restore(self._save_counter) 1968 1969 def write(self, file_prefix, options=None): 1970 """Writes a training checkpoint. 1971 1972 The checkpoint includes variables created by this object and any 1973 trackable objects it depends on at the time `Checkpoint.write()` is 1974 called. 1975 1976 `write` does not number checkpoints, increment `save_counter`, or update the 1977 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1978 use by higher level checkpoint management utilities. `save` provides a very 1979 basic implementation of these features. 1980 1981 Checkpoints written with `write` must be read with `read`. 1982 1983 Example usage: 1984 1985 ``` 1986 step = tf.Variable(0, name="step") 1987 checkpoint = tf.Checkpoint(step=step) 1988 checkpoint.write("/tmp/ckpt") 1989 1990 # Later, read the checkpoint with read() 1991 checkpoint.read("/tmp/ckpt").assert_consumed() 1992 1993 # You can also pass options to write() and read(). For example this 1994 # runs the IO ops on the localhost: 1995 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 1996 checkpoint.write("/tmp/ckpt", options=options) 1997 1998 # Later, read the checkpoint with read() 1999 checkpoint.read("/tmp/ckpt", options=options).assert_consumed() 2000 ``` 2001 2002 Args: 2003 file_prefix: A prefix to use for the checkpoint filenames 2004 (/path/to/directory/and_a_prefix). 2005 options: Optional `tf.train.CheckpointOptions` object. 2006 2007 Returns: 2008 The full path to the checkpoint (i.e. `file_prefix`). 2009 """ 2010 options = options or checkpoint_options.CheckpointOptions() 2011 output = self._saver.save(file_prefix=file_prefix, options=options) 2012 if tensor_util.is_tf_type(output): 2013 if context.executing_eagerly(): 2014 return compat.as_str(output.numpy()) 2015 else: 2016 # Function building 2017 return output 2018 else: 2019 # Graph + Session, so we already session.ran it. 2020 return compat.as_str(output) 2021 2022 @property 2023 def save_counter(self): 2024 """An integer variable which starts at zero and is incremented on save. 2025 2026 Used to number checkpoints. 2027 2028 Returns: 2029 The save counter variable. 2030 """ 2031 self._maybe_create_save_counter() 2032 return self._save_counter 2033 2034 def save(self, file_prefix, options=None): 2035 """Saves a training checkpoint and provides basic checkpoint management. 2036 2037 The saved checkpoint includes variables created by this object and any 2038 trackable objects it depends on at the time `Checkpoint.save()` is 2039 called. 2040 2041 `save` is a basic convenience wrapper around the `write` method, 2042 sequentially numbering checkpoints using `save_counter` and updating the 2043 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 2044 management, for example garbage collection and custom numbering, may be 2045 provided by other utilities which also wrap `write` and `read`. 2046 (`tf.train.CheckpointManager` for example). 2047 2048 ``` 2049 step = tf.Variable(0, name="step") 2050 checkpoint = tf.Checkpoint(step=step) 2051 checkpoint.save("/tmp/ckpt") 2052 2053 # Later, read the checkpoint with restore() 2054 checkpoint.restore("/tmp/ckpt").assert_consumed() 2055 2056 # You can also pass options to save() and restore(). For example this 2057 # runs the IO ops on the localhost: 2058 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2059 checkpoint.save("/tmp/ckpt", options=options) 2060 2061 # Later, read the checkpoint with restore() 2062 checkpoint.restore("/tmp/ckpt", options=options).assert_consumed() 2063 ``` 2064 2065 Args: 2066 file_prefix: A prefix to use for the checkpoint filenames 2067 (/path/to/directory/and_a_prefix). Names are generated based on this 2068 prefix and `Checkpoint.save_counter`. 2069 options: Optional `tf.train.CheckpointOptions` object. 2070 2071 Returns: 2072 The full path to the checkpoint. 2073 """ 2074 options = options or checkpoint_options.CheckpointOptions() 2075 graph_building = not context.executing_eagerly() 2076 if graph_building: 2077 if ops.inside_function(): 2078 raise NotImplementedError( 2079 "Calling tf.train.Checkpoint.save() from a function is not " 2080 "supported, as save() modifies saving metadata in ways not " 2081 "supported by TensorFlow Operations. Consider using " 2082 "tf.train.Checkpoint.write(), a lower-level API which does not " 2083 "update metadata. tf.train.latest_checkpoint and related APIs will " 2084 "not see this checkpoint.") 2085 session = get_session() 2086 if self._save_counter is None: 2087 # When graph building, if this is a new save counter variable then it 2088 # needs to be initialized before assign_add. This is only an issue if 2089 # restore() has not been called first. 2090 session.run(self.save_counter.initializer) 2091 if not graph_building or self._save_assign_op is None: 2092 with ops.colocate_with(self.save_counter): 2093 assign_op = self.save_counter.assign_add(1, read_value=True) 2094 if graph_building: 2095 self._save_assign_op = data_structures.NoDependency(assign_op) 2096 if graph_building: 2097 checkpoint_number = session.run(self._save_assign_op) 2098 else: 2099 checkpoint_number = assign_op.numpy() 2100 file_path = self.write("%s-%d" % (file_prefix, checkpoint_number), 2101 options=options) 2102 checkpoint_management.update_checkpoint_state_internal( 2103 save_dir=os.path.dirname(file_prefix), 2104 model_checkpoint_path=file_path, 2105 all_model_checkpoint_paths=[file_path], 2106 save_relative_paths=True) 2107 return file_path 2108 2109 def read(self, save_path, options=None): 2110 """Reads a training checkpoint written with `write`. 2111 2112 Reads this `Checkpoint` and any objects it depends on. 2113 2114 This method is just like `restore()` but does not expect the `save_counter` 2115 variable in the checkpoint. It only restores the objects that the checkpoint 2116 already depends on. 2117 2118 The method is primarily intended for use by higher level checkpoint 2119 management utilities that use `write()` instead of `save()` and have their 2120 own mechanisms to number and track checkpoints. 2121 2122 Example usage: 2123 2124 ```python 2125 # Create a checkpoint with write() 2126 ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) 2127 path = ckpt.write('/tmp/my_checkpoint') 2128 2129 # Later, load the checkpoint with read() 2130 # With restore() assert_consumed() would have failed. 2131 checkpoint.read(path).assert_consumed() 2132 2133 # You can also pass options to read(). For example this 2134 # runs the IO ops on the localhost: 2135 options = tf.train.CheckpointOptions( 2136 experimental_io_device="/job:localhost") 2137 checkpoint.read(path, options=options) 2138 ``` 2139 2140 Args: 2141 save_path: The path to the checkpoint as returned by `write`. 2142 options: Optional `tf.train.CheckpointOptions` object. 2143 2144 Returns: 2145 A load status object, which can be used to make assertions about the 2146 status of a checkpoint restoration. See `restore` for details. 2147 """ 2148 options = options or checkpoint_options.CheckpointOptions() 2149 return self._saver.restore(save_path=save_path, options=options) 2150 2151 def restore(self, save_path, options=None): 2152 """Restores a training checkpoint. 2153 2154 Restores this `Checkpoint` and any objects it depends on. 2155 2156 This method is intended to be used to load checkpoints created by `save()`. 2157 For checkpoints created by `write()` use the `read()` method which does not 2158 expect the `save_counter` variable added by `save()`. 2159 2160 `restore()` either assigns values immediately if variables to restore have 2161 been created already, or defers restoration until the variables are 2162 created. Dependencies added after this call will be matched if they have a 2163 corresponding object in the checkpoint (the restore request will queue in 2164 any trackable object waiting for the expected dependency to be added). 2165 2166 To ensure that loading is complete and no more assignments will take place, 2167 use the `assert_consumed()` method of the status object returned by 2168 `restore()`: 2169 2170 ```python 2171 checkpoint = tf.train.Checkpoint( ... ) 2172 checkpoint.restore(path).assert_consumed() 2173 2174 # You can additionally pass options to restore(): 2175 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2176 checkpoint.restore(path, options=options).assert_consumed() 2177 ``` 2178 2179 An exception will be raised if any Python objects in the dependency graph 2180 were not found in the checkpoint, or if any checkpointed values do not have 2181 a matching Python object. 2182 2183 Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be 2184 loaded using this method. Names are used to match variables. Re-encode 2185 name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible. 2186 2187 **Loading from SavedModel checkpoints** 2188 2189 To load values from a SavedModel, just pass the SavedModel directory 2190 to checkpoint.restore: 2191 2192 ```python 2193 model = tf.keras.Model(...) 2194 tf.saved_model.save(model, path) # or model.save(path, save_format='tf') 2195 2196 checkpoint = tf.train.Checkpoint(model) 2197 checkpoint.restore(path).expect_partial() 2198 ``` 2199 2200 This example calls `expect_partial()` on the loaded status, since 2201 SavedModels saved from Keras often generates extra keys in the checkpoint. 2202 Otherwise, the program prints a lot of warnings about unused keys at exit 2203 time. 2204 2205 Args: 2206 save_path: The path to the checkpoint, as returned by `save` or 2207 `tf.train.latest_checkpoint`. If the checkpoint was written by the 2208 name-based `tf.compat.v1.train.Saver`, names are used to match 2209 variables. This path may also be a SavedModel directory. 2210 options: Optional `tf.train.CheckpointOptions` object. 2211 2212 Returns: 2213 A load status object, which can be used to make assertions about the 2214 status of a checkpoint restoration. 2215 2216 The returned status object has the following methods: 2217 2218 * `assert_consumed()`: 2219 Raises an exception if any variables are unmatched: either 2220 checkpointed values which don't have a matching Python object or 2221 Python objects in the dependency graph with no values in the 2222 checkpoint. This method returns the status object, and so may be 2223 chained with other assertions. 2224 2225 * `assert_existing_objects_matched()`: 2226 Raises an exception if any existing Python objects in the dependency 2227 graph are unmatched. Unlike `assert_consumed`, this assertion will 2228 pass if values in the checkpoint have no corresponding Python 2229 objects. For example a `tf.keras.Layer` object which has not yet been 2230 built, and so has not created any variables, will pass this assertion 2231 but fail `assert_consumed`. Useful when loading part of a larger 2232 checkpoint into a new Python program, e.g. a training checkpoint with 2233 a `tf.compat.v1.train.Optimizer` was saved but only the state required 2234 for 2235 inference is being loaded. This method returns the status object, and 2236 so may be chained with other assertions. 2237 2238 * `assert_nontrivial_match()`: Asserts that something aside from the root 2239 object was matched. This is a very weak assertion, but is useful for 2240 sanity checking in library code where objects may exist in the 2241 checkpoint which haven't been created in Python and some Python 2242 objects may not have a checkpointed value. 2243 2244 * `expect_partial()`: Silence warnings about incomplete checkpoint 2245 restores. Warnings are otherwise printed for unused parts of the 2246 checkpoint file or object when the `Checkpoint` object is deleted 2247 (often at program shutdown). 2248 2249 Raises: 2250 NotFoundError: if the a checkpoint or SavedModel cannot be found at 2251 `save_path`. 2252 """ 2253 orig_save_path = save_path 2254 2255 if save_path is not None and gfile.IsDirectory(save_path) and ( 2256 (gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or 2257 gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))): 2258 save_path = utils_impl.get_variables_path(save_path) 2259 2260 try: 2261 status = self.read(save_path, options=options) 2262 except errors_impl.NotFoundError: 2263 raise errors_impl.NotFoundError( 2264 None, None, 2265 "Could not find checkpoint or SavedModel at {}." 2266 .format(orig_save_path)) 2267 # Create the save counter now so it gets initialized with other variables 2268 # when graph building. Creating it earlier would lead to errors when using, 2269 # say, train.Saver() to save the model before initializing it. 2270 self._maybe_create_save_counter() 2271 if isinstance(status, NameBasedSaverStatus): 2272 status.add_to_optionally_restored(self.save_counter) 2273 return status 2274