1"""Utilities for saving/loading Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22import functools 23import os 24import threading 25import time 26import weakref 27 28import six 29 30from tensorflow.core.protobuf import trackable_object_graph_pb2 31from tensorflow.python.client import session as session_lib 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors_impl 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.lib.io import file_io 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import gen_io_ops as io_ops 43from tensorflow.python.ops import init_ops 44from tensorflow.python.ops import variable_scope 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import gfile 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.saved_model import utils_impl 49from tensorflow.python.saved_model.pywrap_saved_model import metrics 50from tensorflow.python.training import checkpoint_management 51from tensorflow.python.training import py_checkpoint_reader 52from tensorflow.python.training import saver as v1_saver_lib 53from tensorflow.python.training.saving import checkpoint_options 54from tensorflow.python.training.saving import functional_saver 55from tensorflow.python.training.saving import saveable_object_util 56from tensorflow.python.training.tracking import base 57from tensorflow.python.training.tracking import data_structures 58from tensorflow.python.training.tracking import graph_view as graph_view_lib 59from tensorflow.python.training.tracking import tracking 60from tensorflow.python.util import compat 61from tensorflow.python.util import deprecation 62from tensorflow.python.util import object_identity 63from tensorflow.python.util import tf_contextlib 64from tensorflow.python.util import tf_inspect 65from tensorflow.python.util.tf_export import tf_export 66 67# The callable that provide Keras default session that is needed for saving. 68_SESSION_PROVIDER = None 69 70# Captures the timestamp of the first Checkpoint instantiation or end of a write 71# operation. Can be accessed by multiple Checkpoint instances. 72_END_TIME_OF_LAST_WRITE = None 73_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock() 74 75# API labels for cell names used in checkpoint metrics. 76_CHECKPOINT_V1 = "checkpoint_v1" 77_CHECKPOINT_V2 = "checkpoint_v2" 78 79 80def _get_duration_microseconds(start_time_seconds, end_time_seconds): 81 if end_time_seconds < start_time_seconds: 82 # Avoid returning negative value in case of clock skew. 83 return 0 84 return round((end_time_seconds - start_time_seconds) * 1000000) 85 86 87@tf_export("__internal__.tracking.register_session_provider", v1=[]) 88def register_session_provider(session_provider): 89 global _SESSION_PROVIDER 90 # TODO(scottzhu): Change it back to only allow one time setting for session 91 # provider once we finished the keras repo split. 92 # if _SESSION_PROVIDER is None: 93 _SESSION_PROVIDER = session_provider 94 95 96def get_session(): 97 # Prefer TF's default session since get_session from Keras has side-effects. 98 session = ops.get_default_session() 99 if session is None: 100 global _SESSION_PROVIDER 101 if _SESSION_PROVIDER is not None: 102 session = _SESSION_PROVIDER() # pylint: disable=not-callable 103 return session 104 105 106class _ObjectGraphProtoPrettyPrinter(object): 107 """Lazily traverses an object graph proto to pretty print names. 108 109 If no calls to `node_names` are made this object has no performance 110 overhead. On the other hand, it will only traverse the object graph once, so 111 repeated naming is cheap after the first. 112 """ 113 114 __slots__ = ["_object_graph_proto", "_node_name_cache"] 115 116 def __init__(self, object_graph_proto): 117 self._object_graph_proto = object_graph_proto 118 self._node_name_cache = None 119 120 @property 121 def node_names(self): 122 """Lazily creates a mapping from node id to ("path", "to", "root").""" 123 if self._node_name_cache is not None: 124 return self._node_name_cache 125 path_to_root = {} 126 path_to_root[0] = ("(root)",) 127 to_visit = collections.deque([0]) 128 while to_visit: 129 node_id = to_visit.popleft() 130 obj = self._object_graph_proto.nodes[node_id] 131 for child in obj.children: 132 if child.node_id not in path_to_root: 133 path_to_root[child.node_id] = ( 134 path_to_root[node_id] + (child.local_name,)) 135 to_visit.append(child.node_id) 136 137 node_names = {} 138 for node_id, path_to_root in path_to_root.items(): 139 node_names[node_id] = ".".join(path_to_root) 140 141 for node_id, node in enumerate(self._object_graph_proto.nodes): 142 for slot_reference in node.slot_variables: 143 node_names[slot_reference.slot_variable_node_id] = ( 144 "{}'s state '{}' for {}".format( 145 node_names[node_id], slot_reference.slot_name, 146 node_names[slot_reference.original_variable_node_id])) 147 self._node_name_cache = node_names 148 return node_names 149 150 151class _CheckpointRestoreCoordinatorDeleter(object): 152 """Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__().""" 153 154 __slots__ = [ 155 "expect_partial", "object_graph_proto", "matched_proto_ids", 156 "unused_attributes" 157 ] 158 159 def __init__(self, expect_partial, object_graph_proto, matched_proto_ids, 160 unused_attributes): 161 self.expect_partial = expect_partial 162 self.object_graph_proto = object_graph_proto 163 self.matched_proto_ids = matched_proto_ids 164 self.unused_attributes = unused_attributes 165 166 def set_expect_partial(self, expect_partial): 167 self.expect_partial = expect_partial 168 169 def __del__(self): 170 if self.expect_partial: 171 return 172 if logging is None: 173 # The logging module may have been unloaded when __del__ is called. 174 log_fn = print 175 else: 176 log_fn = logging.warning 177 printed_warning = False 178 pretty_printer = _ObjectGraphProtoPrettyPrinter(self.object_graph_proto) 179 for node_id in range(len(self.object_graph_proto.nodes)): 180 if node_id not in self.matched_proto_ids: 181 log_fn("Unresolved object in checkpoint: {}" 182 .format(pretty_printer.node_names[node_id])) 183 printed_warning = True 184 for node_id, attribute_name in self.unused_attributes.items(): 185 log_fn(("Unused attribute in object {}: {}" 186 .format(pretty_printer.node_names[node_id], attribute_name))) 187 printed_warning = True 188 if printed_warning: 189 log_fn( 190 "A checkpoint was restored (e.g. tf.train.Checkpoint.restore or " 191 "tf.keras.Model.load_weights) but not all checkpointed values were " 192 "used. See above for specific issues. Use expect_partial() on the " 193 "load status object, e.g. " 194 "tf.train.Checkpoint.restore(...).expect_partial(), to silence these " 195 "warnings, or use assert_consumed() to make the check explicit. See " 196 "https://www.tensorflow.org/guide/checkpoint#loading_mechanics" 197 " for details.") 198 199 200class _CheckpointRestoreCoordinator(object): 201 """Holds the status of an object-based checkpoint load.""" 202 203 def __init__(self, object_graph_proto, save_path, save_path_tensor, 204 restore_op_cache, graph_view, options): 205 """Specify the checkpoint being loaded. 206 207 Args: 208 object_graph_proto: The TrackableObjectGraph protocol buffer associated 209 with this checkpoint. 210 save_path: A string, the path to the checkpoint, as returned by 211 `tf.train.latest_checkpoint`. 212 save_path_tensor: A string `Tensor` which contains or will be fed the save 213 path. 214 restore_op_cache: A dictionary shared between 215 `_CheckpointRestoreCoordinator`s for the same Python objects, used to 216 look up restore ops by name to avoid re-creating them across multiple 217 `restore()` calls. 218 graph_view: A graph_view_lib.ObjectGraphView object for the restored 219 objects. 220 options: A CheckpointOptions object. 221 """ 222 self.options = options 223 self.object_graph_proto = object_graph_proto 224 self.restore_uid = ops.uid() 225 # Maps from proto ids to lists of attributes which were in the checkpoint 226 # but not loaded into any object, for error checking. 227 self.unused_attributes = {} 228 # Dictionary mapping from an id in the protocol buffer flat array to 229 # Trackable Python objects. This mapping may be deferred if a 230 # checkpoint is restored before all dependencies have been tracked. Uses 231 # weak references so that partial restorations don't create reference cycles 232 # (as objects with deferred dependencies will generally have references to 233 # this object). 234 self.object_by_proto_id = weakref.WeakValueDictionary() 235 self.matched_proto_ids = set() 236 # A set of all Python objects we've seen as dependencies, even if we didn't 237 # use them (for example because of inconsistent references when 238 # loading). Used to make status assertions fail when loading checkpoints 239 # that don't quite match. 240 self.all_python_objects = object_identity.ObjectIdentityWeakSet() 241 self.save_path_tensor = save_path_tensor 242 self.save_path_string = save_path 243 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 244 self.dtype_map = reader.get_variable_to_dtype_map() 245 self.shape_map = reader.get_variable_to_shape_map() 246 # A NewCheckpointReader for the most recent checkpoint, for streaming Python 247 # state restoration. 248 # When graph building, contains a list of ops to run to restore objects from 249 # this checkpoint. 250 self.restore_ops = [] 251 self.restore_ops_by_name = restore_op_cache 252 self.graph_view = graph_view 253 self.new_restore_ops_callback = None 254 # A mapping from optimizer proto ids to lists of slot variables to be 255 # restored when the optimizer is tracked. Only includes slot variables whose 256 # regular variables have already been created, and only for optimizer 257 # objects which have not yet been created/tracked. 258 self.deferred_slot_restorations = {} 259 # A mapping from variable proto ids to lists of slot variables to be 260 # restored when the variable is created/tracked. These get shifted over to 261 # deferred_slot_restorations if the optimizer hasn't been created when that 262 # happens. 263 self.slot_restorations = {} 264 # Controls whether errors are printed in __del__ if some objects did not 265 # match. 266 self.expect_partial_attr = False 267 for node_index, node in enumerate(self.object_graph_proto.nodes): 268 for slot_reference in node.slot_variables: 269 # `node` refers to an `Optimizer`, since only these have slot variables. 270 self.slot_restorations.setdefault( 271 slot_reference.original_variable_node_id, []).append( 272 base._SlotVariableRestoration( # pylint: disable=protected-access 273 optimizer_id=node_index, 274 slot_variable_id=slot_reference.slot_variable_node_id, 275 slot_name=slot_reference.slot_name)) 276 277 self._deleter = _CheckpointRestoreCoordinatorDeleter( 278 self.expect_partial_attr, 279 self.object_graph_proto, 280 self.matched_proto_ids, 281 self.unused_attributes) 282 283 @property 284 def expect_partial(self): 285 return self.expect_partial_attr 286 287 @expect_partial.setter 288 def expect_partial(self, expect_partial): 289 self.expect_partial_attr = expect_partial 290 self._deleter.set_expect_partial(expect_partial) 291 292 def new_restore_ops(self, new_ops): 293 self.restore_ops.extend(new_ops) 294 if self.new_restore_ops_callback: 295 self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable 296 297 def restore_saveables(self, tensor_saveables, python_saveables): 298 """Run or build restore operations for SaveableObjects. 299 300 Args: 301 tensor_saveables: `SaveableObject`s which correspond to Tensors. 302 python_saveables: `PythonStateSaveable`s which correspond to Python 303 values. 304 305 Returns: 306 When graph building, a list of restore operations, either cached or newly 307 created, to restore `tensor_saveables`. 308 """ 309 restore_ops = [] 310 # Eagerly run restorations for Python state. 311 reader = None 312 for saveable in python_saveables: 313 if reader is None: 314 # Lazily create the NewCheckpointReader, since this requires file access 315 # and we may not have any Python saveables. 316 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string) 317 spec_names = [spec.name for spec in saveable.specs] 318 saveable.python_restore([reader.get_tensor(name) for name in spec_names]) 319 320 # If we have new SaveableObjects, extract and cache restore ops. 321 if tensor_saveables: 322 validated_saveables = saveable_object_util.validate_and_slice_inputs( 323 tensor_saveables) 324 validated_names = set(saveable.name for saveable in validated_saveables) 325 if set(tensor_saveables.keys()) != validated_names: 326 raise AssertionError( 327 ("Saveable keys changed when validating. Got back %s, was " 328 "expecting %s") % (tensor_saveables.keys(), validated_names)) 329 new_restore_ops = functional_saver.MultiDeviceSaver( 330 validated_saveables).restore(self.save_path_tensor, self.options) 331 if not context.executing_eagerly(): 332 for name, restore_op in sorted(new_restore_ops.items()): 333 restore_ops.append(restore_op) 334 assert name not in self.restore_ops_by_name 335 self.restore_ops_by_name[name] = restore_op 336 return restore_ops 337 338 339class _NameBasedRestoreCoordinator(object): 340 """Keeps the status of a name-based checkpoint restore.""" 341 342 def __init__(self, save_path, dtype_map=None): 343 self.save_path = save_path 344 self.dtype_map = dtype_map 345 # A map from trackable objects to unused attribute names. We don't have 346 # proto IDs when doing a name-based restore, so the map keys differ from 347 # those in _CheckpointRestoreCoordinator. 348 self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary() 349 self.restore_uid = ops.uid() 350 351 def globally_named_object_attributes(self, trackable): 352 """Create globally named SaveableObjects from attributes. 353 354 If an object's attribute has no global name specified (default construction 355 for the SaveableObject factory), records the failure in 356 `self.unused_attributes` (which can then be used to make status assertions 357 fail; see `NameBasedSaverStatus`). 358 359 Args: 360 trackable: An object to save. 361 362 Yields: 363 SaveableObjects for `trackable`'s attributes. 364 """ 365 for attribute_name, saveable_factory in ( 366 trackable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 367 if callable(saveable_factory): 368 try: 369 # This saveable object factory does not have a default name= argument, 370 # which means there's no way to save/restore it using a name-based 371 # checkpoint. Ignore the error now and make sure assert_consumed() 372 # fails. 373 saveable = saveable_factory() 374 except TypeError: 375 # Even if we can't name this object, we should construct it and check 376 # whether it's optional to restore it. If it's optional we don't need 377 # to make assertions fail. 378 if not saveable_factory("").optional_restore: 379 self.unused_attributes.setdefault(trackable, 380 []).append(attribute_name) 381 continue 382 else: 383 saveable = saveable_factory 384 names_to_saveables = saveable_object_util.op_list_to_dict( 385 [saveable], convert_variable_to_tensor=False) 386 for name, op in names_to_saveables.items(): 387 for saveable_object in saveable_object_util.saveable_objects_for_op( 388 op=op, name=name): 389 yield saveable_object 390 391 def eager_restore(self, trackable): 392 """Runs restore ops for `trackable`'s attributes.""" 393 # When graph building, we don't add any restore ops to the graph until 394 # run_restore_ops/initialize_or_restore on the status object for name-based 395 # checkpoints. 396 assert context.executing_eagerly() 397 for saveable in self.globally_named_object_attributes(trackable): 398 restored_tensors = [] 399 tensor_missing = False 400 for spec in saveable.specs: 401 if spec.name in self.dtype_map: 402 with ops.device("cpu:0"): 403 restored, = io_ops.restore_v2( 404 prefix=self.save_path, 405 tensor_names=[spec.name], 406 shape_and_slices=[""], 407 dtypes=[self.dtype_map[spec.name]], 408 name="%s_checkpoint_read" % (spec.name,)) 409 restored_tensors.append(array_ops.identity(restored)) 410 else: 411 tensor_missing = True 412 413 if tensor_missing: 414 # Record that this variable didn't match so assertions will fail. 415 self.unused_attributes.setdefault(trackable, []).append(saveable.name) 416 else: 417 # Ignores values missing from the checkpoint, as with object-based 418 # restore. Status assertions can be used to check exact matches, 419 # although it's unlikely to ever happen for name-based checkpoints. 420 saveable.restore( 421 restored_tensors=restored_tensors, restored_shapes=None) 422 423 424# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change 425# or consolidating the implementation with get_variable. 426def _default_getter(name, 427 shape, 428 dtype, 429 initializer=None, 430 partition_info=None, 431 **kwargs): 432 """A pared-down version of get_variable which does not reuse variables.""" 433 dtype = dtypes.as_dtype(dtype) 434 shape_object = tensor_shape.as_shape(shape) 435 with ops.init_scope(): 436 if initializer is None: 437 initializer, initializing_from_value = ( 438 variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access 439 name=name, 440 shape=shape_object, 441 dtype=dtype)) 442 else: 443 initializing_from_value = not callable(initializer) 444 # Same logic as get_variable 445 variable_dtype = dtype.base_dtype 446 if initializing_from_value: 447 if shape is not None: 448 raise ValueError("If initializer is a constant, do not specify shape.") 449 initial_value = initializer 450 else: 451 # Instantiate initializer if provided initializer is a type object. 452 if isinstance(initializer, type(init_ops.Initializer)): 453 initializer = initializer(dtype=dtype) 454 shape_list = None if shape is None else shape_object.as_list() 455 if "partition_info" in tf_inspect.getargspec(initializer).args: 456 initial_value = functools.partial(initializer, 457 shape_list, 458 dtype=dtype, 459 partition_info=partition_info) 460 else: 461 initial_value = functools.partial(initializer, 462 shape_list, 463 dtype=dtype) 464 465 return variables.VariableV1( 466 initial_value=initial_value, 467 name=name, 468 dtype=variable_dtype, 469 use_resource=True, 470 **kwargs) 471 472 473def add_variable(trackable, 474 name, 475 shape=None, 476 dtype=dtypes.float32, 477 initializer=None, 478 trainable=True): 479 """Add a variable to a Trackable with no scope influence.""" 480 return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access 481 name=name, 482 shape=shape, 483 dtype=dtype, 484 initializer=initializer, 485 getter=_default_getter, 486 trainable=trainable) 487 488 489def object_metadata(save_path): 490 """Retrieves information about the objects in a checkpoint. 491 492 Example usage: 493 494 ```python 495 object_graph = tf.contrib.checkpoint.object_metadata( 496 tf.train.latest_checkpoint(checkpoint_directory)) 497 ckpt_variable_names = set() 498 for node in object_graph.nodes: 499 for attribute in node.attributes: 500 ckpt_variable_names.add(attribute.full_name) 501 ``` 502 503 Args: 504 save_path: The path to the checkpoint, as returned by `save` or 505 `tf.train.latest_checkpoint`. 506 507 Returns: 508 A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. 509 Raises: 510 ValueError: If an object graph was not found in the checkpoint. 511 """ 512 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 513 try: 514 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 515 except errors_impl.NotFoundError: 516 raise ValueError( 517 ('The specified checkpoint "%s" does not appear to be object-based (it ' 518 'is missing the key "%s"). Likely it was created with a name-based ' 519 "saver and does not contain an object dependency graph.") % 520 (save_path, base.OBJECT_GRAPH_PROTO_KEY)) 521 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 522 object_graph_proto.ParseFromString(object_graph_string) 523 return object_graph_proto 524 525 526def list_objects(root_trackable): 527 """Traverse the object graph and list all accessible objects. 528 529 Looks for `Trackable` objects which are dependencies of 530 `root_trackable`. Includes slot variables only if the variable they are 531 slotting for and the optimizer are dependencies of `root_trackable` 532 (i.e. if they would be saved with a checkpoint). 533 534 Args: 535 root_trackable: A `Trackable` object whose dependencies should be flattened. 536 537 Returns: 538 A flat list of objects. 539 """ 540 return graph_view_lib.ObjectGraphView(root_trackable).list_objects() 541 542 543def gather_initializers(root_trackable): 544 """Traverse the object graph and find initialization ops. 545 546 Looks for `Trackable` objects which are dependencies of 547 `root_trackable` and which have an `initializer` property. Includes 548 initializers for slot variables only if the variable they are slotting for and 549 the optimizer are dependencies of `root_trackable` (i.e. if they would be 550 saved with a checkpoint). 551 552 Args: 553 root_trackable: A `Trackable` object to gather initializers for. 554 555 Returns: 556 A list of initialization ops. 557 """ 558 trackable_objects = list_objects(root_trackable) 559 return [ 560 c.initializer 561 for c in trackable_objects 562 if hasattr(c, "initializer") and c.initializer is not None 563 ] 564 565 566@tf_contextlib.contextmanager 567def capture_dependencies(template): 568 """Capture variables created within this scope as `Template` dependencies. 569 570 Requires that `template.variable_scope` is active. 571 572 This scope is intended as a compatibility measure, allowing a trackable 573 object to add dependencies on variables created in a block of code which is 574 not aware of object-based saving (and instead uses variable names 575 heavily). This is how `Template` objects add dependencies on variables and 576 sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly. 577 578 Args: 579 template: The `Template` object to register dependencies with. 580 581 Yields: 582 None (when used as a context manager). 583 """ 584 name_prefix = template.variable_scope.name 585 586 def _trackable_custom_creator(next_creator, 587 name, 588 initial_value, 589 trackable_parent=None, 590 **kwargs): 591 """A variable creation hook which adds Trackable dependencies. 592 593 Set for example during a `Template`'s first wrapped function 594 execution. Ensures that (a) `template` depends on any trackable 595 objects using their own `capture_dependencies` scope inside this scope which 596 create variables, and (b) that any variables not in a more deeply nested 597 scope are added as dependencies directly. 598 599 The `trackable_parent` argument is passed between custom creators but 600 ignored when the variable object itself is created. This argument indicates 601 (if not `None`) that a more deeply nested scope has already added the 602 variable as a dependency, and that parent scopes should add a dependency on 603 that object rather than on the variable directly. 604 605 Args: 606 next_creator: See `variable_scope.variable_creator_scope`; the next 607 creator in the chain. 608 name: The (full, scope-influenced) name of the variable. The `name_prefix` 609 itself is stripped for the purposes of object-based dependency tracking, 610 but scopes opened within this scope are respected. 611 initial_value: See `variable_scope.variable_creator_scope`. Taken 612 explicitly so the argument can be re-named and used with 613 `Trackable._add_variable_with_custom_getter`. 614 trackable_parent: If not None, a more deeply nested trackable object and 615 its name prefix which were passed to `capture_dependencies` to add a 616 dependency on (rather than depending on the variable directly). 617 **kwargs: Passed through to the next creator. 618 619 Returns: 620 The output of `next_creator`: the fetched/created variable object. 621 """ 622 623 def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): 624 inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which 625 # we don't want to propagate. 626 return next_creator(initial_value=initializer, name=name, **inner_kwargs) 627 628 if name is not None and name.startswith(name_prefix): 629 scope_stripped_name = name[len(name_prefix) + 1:] 630 if not trackable_parent: 631 return template._add_variable_with_custom_getter( # pylint: disable=protected-access 632 initializer=initial_value, 633 name=scope_stripped_name, 634 getter=_call_next_creator_renaming_initializer, 635 # Disable error checking for Trackable. Exceptions are instead 636 # raised if necessary when the object-based saver tries to 637 # save/restore the object. 638 overwrite=True, 639 trackable_parent=(template, name_prefix), 640 **kwargs) 641 else: 642 parent_object, parent_name_prefix = trackable_parent 643 template._track_trackable( # pylint: disable=protected-access 644 parent_object, 645 name=parent_name_prefix[len(name_prefix) + 1:], 646 overwrite=True) 647 return next_creator( 648 name=name, 649 initial_value=initial_value, 650 trackable_parent=(template, name_prefix), 651 **kwargs) 652 653 with variable_scope.variable_creator_scope(_trackable_custom_creator): 654 yield 655 656 657class _LoadStatus(object): 658 """Abstract base for load status callbacks.""" 659 660 @abc.abstractmethod 661 def assert_consumed(self): 662 """Raises an exception unless a non-trivial restoration has completed.""" 663 pass 664 665 @abc.abstractmethod 666 def assert_existing_objects_matched(self): 667 """Raises an exception unless existing Python objects have been matched.""" 668 pass 669 670 @abc.abstractmethod 671 def assert_nontrivial_match(self): 672 """Raises an exception if only the root object matched.""" 673 pass 674 675 @abc.abstractmethod 676 def run_restore_ops(self, session=None): 677 """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" 678 pass 679 680 @abc.abstractmethod 681 def initialize_or_restore(self, session=None): 682 """Runs restore ops from the checkpoint, or initializes variables.""" 683 pass 684 685 def expect_partial(self): 686 """Silence warnings about incomplete checkpoint restores.""" 687 return self 688 689 690@tf_export("__internal__.tracking.streaming_restore", v1=[]) 691def streaming_restore(status, session=None): 692 """When graph building, runs restore ops as soon as they come in. 693 694 Args: 695 status: A _LoadStatus objects from an object-based saver's restore(). 696 Streaming restore from name-based checkpoints is not currently supported. 697 session: A session to run new restore ops in. 698 """ 699 if context.executing_eagerly(): 700 # Streaming restore is the default/only behavior when executing eagerly. 701 return 702 if session is None: 703 session = get_session() 704 if isinstance(status, NameBasedSaverStatus): 705 raise NotImplementedError( 706 "Streaming restore not supported from name-based checkpoints when " 707 "graph building. File a feature request if this limitation bothers " 708 "you. As a workaround, consider either using tf.train.Checkpoint to " 709 "load name-based checkpoints or enabling eager execution.") 710 status.run_restore_ops(session=session) 711 # pylint: disable=protected-access 712 status._checkpoint.new_restore_ops_callback = ( 713 lambda ops: session.run(ops, feed_dict=status._feed_dict)) 714 # pylint: enable=protected-access 715 716 717def _objects_with_attributes(full_list): 718 """Filters out objects with no direct variable dependencies for assertions.""" 719 return [o for o in full_list if o._gather_saveables_for_checkpoint()] # pylint: disable=protected-access 720 721 722class CheckpointLoadStatus(_LoadStatus): 723 """Checks the status of checkpoint loading and manages restore ops. 724 725 Returned from `Saver.restore`. Since `restore` may defer the loading of values 726 in the checkpoint which don't yet have corresponding Python objects, 727 `CheckpointLoadStatus` provides a callback to verify that checkpoint loading 728 is complete (`assert_consumed`). 729 730 When graph building, `restore` does not run restore ops itself since their 731 creation may be deferred. The `run_restore_ops` method must be called once all 732 Python objects with values to restore have been created and added to the 733 dependency graph (this does not necessarily have to be the whole checkpoint; 734 calling `run_restore_ops` while `assert_consumed` fails is supported and will 735 partially restore the checkpoint). 736 737 See `Saver.restore` for usage examples. 738 """ 739 740 def __init__(self, checkpoint, feed_dict, graph_view): 741 self._checkpoint = checkpoint 742 self._feed_dict = feed_dict 743 self._graph_view = graph_view 744 # Keep a reference to the root, since graph_view might only have a weakref. 745 self._root = graph_view.root 746 747 def assert_consumed(self): 748 """Asserts that all objects in the checkpoint have been created/matched. 749 750 Returns: 751 `self` for chaining. 752 Raises: 753 AssertionError: If there are any Python objects in the dependency graph 754 which have not been restored from this checkpoint or a later `restore`, 755 or if there are any checkpointed values which have not been matched to 756 Python objects. 757 """ 758 pretty_printer = _ObjectGraphProtoPrettyPrinter( 759 self._checkpoint.object_graph_proto) 760 self.assert_existing_objects_matched() 761 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 762 if not node.attributes: 763 # Only raise exceptions for the nodes with attributes themselves. Either 764 # they're ultimately not important, or they have a child with an 765 # attribute. 766 continue 767 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 768 if trackable is None: 769 raise AssertionError("Unresolved object in checkpoint {}: {}" 770 .format(pretty_printer.node_names[node_id], node)) 771 if self._checkpoint.slot_restorations: 772 # Sanity check; this collection should be clear if everything has been 773 # restored. 774 raise AssertionError("Unresolved slot restorations: %s" % 775 (self._checkpoint.slot_restorations,)) 776 if self._checkpoint.unused_attributes: 777 unused_attribute_messages = [] 778 for node_id, attribute in six.iteritems( 779 self._checkpoint.unused_attributes): 780 obj = self._checkpoint.object_by_proto_id[node_id] 781 unused_attribute_messages.append( 782 "{} ({}): {}" 783 .format(pretty_printer.node_names[node_id], obj, attribute)) 784 raise AssertionError( 785 ("Unused attributes in these objects (the attributes exist in the " 786 "checkpoint but were not restored):\n{}") 787 .format("\n".join(unused_attribute_messages))) 788 return self 789 790 def assert_existing_objects_matched(self): 791 """Asserts that trackable Python objects have been matched. 792 793 Note that this is a weaker assertion than `assert_consumed`. It will only 794 fail for existing Python objects which are (transitive) dependencies of the 795 root object and which do not have an entry in the checkpoint. 796 797 It will not fail, for example, if a `tf.keras.Layer` object has not yet been 798 built and so has not created any `tf.Variable` objects. 799 800 Returns: 801 `self` for chaining. 802 803 Raises: 804 AssertionError: If a Python object exists in the transitive dependencies 805 of the root object but does not have a value in the checkpoint. 806 """ 807 for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): 808 trackable = self._checkpoint.object_by_proto_id.get(node_id, None) 809 if (trackable is not None and 810 trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access 811 raise AssertionError("Object not assigned a value from checkpoint: %s" % 812 (node,)) 813 for trackable_object in self._graph_view.list_objects(): 814 # Remove data structures that do not contain any variables from 815 # restoration checks. 816 if (isinstance(trackable_object, 817 data_structures.TrackableDataStructure) and 818 not trackable_object._checkpoint_dependencies): 819 continue 820 self._checkpoint.all_python_objects.add(trackable_object) 821 unused_python_objects = ( 822 object_identity.ObjectIdentitySet( 823 _objects_with_attributes( 824 self._checkpoint.all_python_objects)) - 825 object_identity.ObjectIdentitySet( 826 self._checkpoint.object_by_proto_id.values())) 827 if unused_python_objects: 828 raise AssertionError( 829 ("Some Python objects were not bound to checkpointed values, likely " 830 "due to changes in the Python program: %s") % 831 (list(unused_python_objects),)) 832 return self 833 834 def assert_nontrivial_match(self): 835 """Raises an exception if only the root object matched.""" 836 for trackable_object in self._graph_view.list_objects(): 837 self._checkpoint.all_python_objects.add(trackable_object) 838 if len(self._checkpoint.object_by_proto_id) <= 1: 839 unused_python_objects = ( 840 object_identity.ObjectIdentitySet( 841 _objects_with_attributes(self._checkpoint.all_python_objects)) 842 - object_identity.ObjectIdentitySet( 843 self._checkpoint.object_by_proto_id.values())) 844 if unused_python_objects: 845 raise AssertionError( 846 ("Nothing except the root object matched a checkpointed value. " 847 "Typically this means that the checkpoint does not match the " 848 "Python program. The following objects have no matching " 849 "checkpointed value: %s") % (list(unused_python_objects),)) 850 else: 851 raise AssertionError( 852 "Nothing to load. No dependencies have been added to %s yet." % 853 (self._graph_view.root,)) 854 return self 855 856 def run_restore_ops(self, session=None): 857 """Run operations to restore objects in the dependency graph.""" 858 if context.executing_eagerly(): 859 return # Run eagerly 860 if session is None: 861 session = get_session() 862 session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) 863 864 def initialize_or_restore(self, session=None): 865 """Run operations to initialize or restore objects in the dependency graph. 866 867 Any objects in the dependency graph which have initializers but are not in 868 the checkpoint will have those initializers run, unless those variables are 869 being restored by a later call to `tf.train.Checkpoint.restore()`. 870 871 This method has a sibling in `InitializationOnlyStatus` which instead 872 initializes variables. That type is returned if no checkpoint is specified 873 in `Saver.restore`. 874 875 Args: 876 session: The session to run init/restore ops in. If `None`, uses the 877 default session. 878 """ 879 if context.executing_eagerly(): 880 return # Initialization and restoration ops are run eagerly 881 if session is None: 882 session = get_session() 883 all_objects = self._graph_view.list_objects() 884 already_initialized_objects = object_identity.ObjectIdentitySet( 885 self._checkpoint.object_by_proto_id.values()) 886 initializers_for_non_restored_variables = [ 887 c.initializer for c in all_objects 888 if hasattr(c, "initializer") 889 and c not in already_initialized_objects 890 and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1) 891 < self._checkpoint.restore_uid)] 892 self.run_restore_ops(session=session) 893 session.run(initializers_for_non_restored_variables) 894 895 def expect_partial(self): 896 """Silence warnings about incomplete checkpoint restores.""" 897 self._checkpoint.expect_partial = True 898 return self 899 900 901class InitializationOnlyStatus(_LoadStatus): 902 """Returned from `Saver.restore` when no checkpoint has been specified. 903 904 Objects of this type have the same `assert_consumed` method as 905 `CheckpointLoadStatus`, but it always fails. However, 906 `initialize_or_restore` works on objects of both types, and will 907 initialize variables in `InitializationOnlyStatus` objects or restore them 908 otherwise. 909 """ 910 911 def __init__(self, graph_view, restore_uid): 912 self._restore_uid = restore_uid 913 self._graph_view = graph_view 914 # Keep a reference to the root, since graph_view might only have a weakref. 915 self._root = graph_view.root 916 917 def assert_consumed(self): 918 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 919 raise AssertionError( 920 "No checkpoint specified (save_path=None); nothing is being restored.") 921 922 def assert_existing_objects_matched(self): 923 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 924 raise AssertionError( 925 "No checkpoint specified (save_path=None); nothing is being restored.") 926 927 def assert_nontrivial_match(self): 928 """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" 929 raise AssertionError( 930 "No checkpoint specified (save_path=None); nothing is being restored.") 931 932 def run_restore_ops(self, session=None): 933 """For consistency with `CheckpointLoadStatus`. 934 935 Use `initialize_or_restore` for initializing if no checkpoint was passed 936 to `Saver.restore` and restoring otherwise. 937 938 Args: 939 session: Not used. 940 """ 941 raise AssertionError( 942 "No checkpoint specified, so no restore ops are available " 943 "(save_path=None to Saver.restore).") 944 945 def initialize_or_restore(self, session=None): 946 """Runs initialization ops for variables. 947 948 Objects which would be saved by `Saver.save` will be initialized, unless 949 those variables are being restored by a later call to 950 `tf.train.Checkpoint.restore()`. 951 952 This method does nothing when executing eagerly (initializers get run 953 eagerly). 954 955 Args: 956 session: The session to run initialization ops in. If `None`, uses the 957 default session. 958 """ 959 if context.executing_eagerly(): 960 return # run eagerly 961 if session is None: 962 session = get_session() 963 trackable_objects = self._graph_view.list_objects() 964 initializers = [ 965 c.initializer for c in trackable_objects 966 if hasattr(c, "initializer") and c.initializer is not None 967 and (getattr(c, "_update_uid", self._restore_uid - 1) 968 < self._restore_uid)] 969 session.run(initializers) 970 971 972_DEPRECATED_RESTORE_INSTRUCTIONS = ( 973 "Restoring a name-based tf.train.Saver checkpoint using the object-based " 974 "restore API. This mode uses global names to match variables, and so is " 975 "somewhat fragile. It also adds new restore ops to the graph each time it " 976 "is called when graph building. Prefer re-encoding training checkpoints in " 977 "the object-based format: run save() on the object-based saver (the same " 978 "one this message is coming from) and use that checkpoint in the future.") 979 980 981class NameBasedSaverStatus(_LoadStatus): 982 """Status for loading a name-based training checkpoint.""" 983 984 # Ideally this deprecation decorator would be on the class, but that 985 # interferes with isinstance checks. 986 @deprecation.deprecated( 987 date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) 988 def __init__(self, checkpoint, graph_view): 989 self._checkpoint = checkpoint 990 self._graph_view = graph_view 991 self._optionally_restored = [] 992 # Keep a reference to the root, since graph_view might only have a weakref. 993 self._root = graph_view.root 994 995 def add_to_optionally_restored(self, var): 996 """Add a variable to the list of optionally restored variables. 997 998 There are situations where certain variables should be ignored in assertions 999 such as assert_existing_objects_matched(). One example is that of a 1000 checkpoint saved with train.Saver(), and restored with train.Checkpoint(): 1001 it is possible for the train.Saver() checkpoint to be missing the internal 1002 `save_counter` variable, which we want to ignore on restore. 1003 1004 Args: 1005 var: The variable to treat as optionally restored. 1006 """ 1007 self._optionally_restored.append(var) 1008 1009 def assert_consumed(self): 1010 """Raises an exception if any variables are unmatched.""" 1011 unused_attributes = list(self._checkpoint.unused_attributes.items()) 1012 unused_attributes = [ 1013 a for a in unused_attributes 1014 if all(a[0] is not x for x in self._optionally_restored) 1015 ] 1016 if unused_attributes: 1017 unused_attribute_strings = [ 1018 "\n {}: {}".format(obj, attributes) 1019 for obj, attributes in unused_attributes 1020 ] 1021 raise AssertionError( 1022 "Some objects had attributes which were not restored:{}".format( 1023 "".join(unused_attribute_strings))) 1024 for trackable in self._graph_view.list_objects(): 1025 # pylint: disable=protected-access 1026 trackable._maybe_initialize_trackable() 1027 if trackable._update_uid < self._checkpoint.restore_uid: 1028 raise AssertionError("Object not restored: %s" % (trackable,)) 1029 # pylint: enable=protected-access 1030 return self 1031 1032 def assert_existing_objects_matched(self): 1033 """Raises an exception if currently created objects are unmatched.""" 1034 # For name-based checkpoints there's no object information in the 1035 # checkpoint, so there's no distinction between 1036 # assert_existing_objects_matched and assert_consumed (and both are less 1037 # useful since we don't touch Python objects or Python state). 1038 return self.assert_consumed() 1039 1040 def assert_nontrivial_match(self): 1041 """Raises an exception if currently created objects are unmatched.""" 1042 # For name-based checkpoints there's no object information in the 1043 # checkpoint, so there's no distinction between 1044 # assert_nontrivial_match and assert_consumed (and both are less 1045 # useful since we don't touch Python objects or Python state). 1046 return self.assert_consumed() 1047 1048 def _gather_saveable_objects(self): 1049 """Walk the object graph, using global names for SaveableObjects.""" 1050 objects = self._graph_view.list_objects() 1051 saveable_objects = [] 1052 for trackable in objects: 1053 # pylint: disable=protected-access 1054 trackable._maybe_initialize_trackable() 1055 if trackable._update_uid < self._checkpoint.restore_uid: 1056 trackable._update_uid = self._checkpoint.restore_uid 1057 else: 1058 continue 1059 # pylint: enable=protected-access 1060 saveable_objects.extend( 1061 self._checkpoint.globally_named_object_attributes(trackable)) 1062 return saveable_objects 1063 1064 def run_restore_ops(self, session=None): 1065 """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`.""" 1066 if context.executing_eagerly(): 1067 return # Nothing to do, variables are restored on creation. 1068 if session is None: 1069 session = get_session() 1070 with ops.device("/cpu:0"): 1071 saveables = self._gather_saveable_objects() 1072 v1_saver_lib.Saver(saveables).restore( 1073 sess=session, save_path=self._checkpoint.save_path) 1074 1075 def initialize_or_restore(self, session=None): 1076 """Alias for `run_restore_ops`.""" 1077 self.run_restore_ops(session=session) 1078 1079 1080class _SessionWithFeedDictAdditions(session_lib.SessionInterface): 1081 """Pretends to be a session, inserts extra feeds on run().""" 1082 1083 def __init__(self, session, feed_additions): 1084 self._wrapped_session = session 1085 self._feed_additions = feed_additions 1086 1087 def run(self, fetches, feed_dict=None, **kwargs): 1088 if feed_dict is None: 1089 feed_dict = {} 1090 else: 1091 feed_dict = feed_dict.copy() 1092 feed_dict.update(self._feed_additions) 1093 return self._wrapped_session.run( 1094 fetches=fetches, feed_dict=feed_dict, **kwargs) 1095 1096 1097@tf_export("__internal__.tracking.TrackableSaver", v1=[]) 1098class TrackableSaver(object): 1099 """Saves and restores a `Trackable` object and its dependencies. 1100 1101 See `Trackable` for details of dependency management. `Saver` wraps 1102 `tf.compat.v1.train.Saver` for saving, including extra information about the 1103 graph of 1104 dependencies between Python objects. When restoring, it uses this information 1105 about the save-time dependency graph to more robustly match objects with their 1106 checkpointed values. When executing eagerly, it supports restoring variables 1107 on object creation (see `Saver.restore`). 1108 1109 Values in a checkpoint are mapped to `Trackable` Python objects 1110 (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the 1111 checkpoint was written. To avoid breaking existing checkpoints when modifying 1112 a class, dependency names (the names of attributes to which `Trackable` 1113 objects are assigned) may not change. These names are local to objects, in 1114 contrast to the `Variable.name`-based save/restore from 1115 `tf.compat.v1.train.Saver`, and 1116 so allow additional program transformations. 1117 """ 1118 1119 def __init__(self, graph_view): 1120 """Configure saving. 1121 1122 Args: 1123 graph_view: A `GraphView` object containing a description of the object 1124 graph to save. 1125 """ 1126 # The file prefix placeholder is created lazily when graph building (and not 1127 # at all when executing eagerly) to avoid creating ops in the constructor 1128 # (when they may never be necessary). 1129 self._file_prefix_placeholder = None 1130 1131 # Op caching for save 1132 self._object_graph_feed_tensor = None 1133 self._last_save_object_graph = None 1134 self._file_prefix_feed_tensor = None 1135 self._cached_save_operation = None 1136 1137 # Op caching for restore, shared between _CheckpointRestoreCoordinators 1138 self._restore_op_cache = {} 1139 self._graph_view = graph_view 1140 1141 def _gather_saveables(self, object_graph_tensor=None): 1142 """Wraps _serialize_object_graph to include the object graph proto.""" 1143 (named_saveable_objects, graph_proto, 1144 feed_additions) = self._graph_view.serialize_object_graph() 1145 if object_graph_tensor is None: 1146 with ops.device("/cpu:0"): 1147 object_graph_tensor = constant_op.constant( 1148 graph_proto.SerializeToString(), dtype=dtypes.string) 1149 else: 1150 feed_additions.update( 1151 {object_graph_tensor: graph_proto.SerializeToString()}) 1152 assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects 1153 named_saveable_objects.append( 1154 base.NoRestoreSaveable( 1155 tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) 1156 return named_saveable_objects, graph_proto, feed_additions 1157 1158 def _save_cached_when_graph_building(self, 1159 file_prefix, 1160 object_graph_tensor, 1161 options): 1162 """Create or retrieve save ops. 1163 1164 Args: 1165 file_prefix: The prefix for saved checkpoint files. 1166 object_graph_tensor: A `Tensor` to which the current object graph will be 1167 fed. 1168 options: `CheckpointOptions` object. 1169 1170 Returns: 1171 A two-element tuple with a filename tensor and a feed_dict of tensors to 1172 feed when running it (if graph building). The feed dict contains the 1173 current object graph and any Python state to be saved in the 1174 checkpoint. When executing eagerly only the first argument is meaningful. 1175 """ 1176 (named_saveable_objects, graph_proto, 1177 feed_additions) = self._gather_saveables( 1178 object_graph_tensor=object_graph_tensor) 1179 if (self._last_save_object_graph != graph_proto 1180 # When executing eagerly, we need to re-create SaveableObjects each time 1181 # save() is called so they pick up new Tensors passed to their 1182 # constructors. That means the Saver needs to be copied with a new 1183 # var_list. 1184 or context.executing_eagerly() or ops.inside_function()): 1185 saver = functional_saver.MultiDeviceSaver(named_saveable_objects) 1186 save_op = saver.save(file_prefix, options=options) 1187 with ops.device("/cpu:0"): 1188 with ops.control_dependencies([save_op]): 1189 self._cached_save_operation = array_ops.identity(file_prefix) 1190 self._last_save_object_graph = graph_proto 1191 return self._cached_save_operation, feed_additions 1192 1193 def save(self, file_prefix, checkpoint_number=None, session=None, 1194 options=None): 1195 """Save a training checkpoint. 1196 1197 The saved checkpoint includes variables created by this object and any 1198 Trackable objects it depends on at the time `Saver.save()` is called. 1199 1200 Args: 1201 file_prefix: A prefix to use for the checkpoint filenames 1202 (/path/to/directory/and_a_prefix). Names are generated based on this 1203 prefix and `checkpoint_number`, if provided. 1204 checkpoint_number: An integer variable or Tensor, used to number 1205 checkpoints. Typically this value is saved along with other variables in 1206 training checkpoints, which will happen automatically if it was created 1207 by `root_trackable` or one of its dependencies (via 1208 `Trackable._add_variable`). 1209 session: The session to evaluate variables in. Ignored when executing 1210 eagerly. If not provided when graph building, the default session is 1211 used. 1212 options: Optional `tf.train.CheckpointOptions` object. 1213 1214 Returns: 1215 The full path to the checkpoint. 1216 """ 1217 options = options or checkpoint_options.CheckpointOptions() 1218 feed_dict = {} 1219 use_session = (not context.executing_eagerly() and 1220 not ops.inside_function()) 1221 if checkpoint_number: 1222 file_prefix = "%s-%d" % (file_prefix, checkpoint_number) 1223 if use_session: 1224 if self._object_graph_feed_tensor is None: 1225 with ops.device("/cpu:0"): 1226 self._object_graph_feed_tensor = constant_op.constant( 1227 "", dtype=dtypes.string) 1228 self._file_prefix_feed_tensor = constant_op.constant( 1229 "", dtype=dtypes.string) 1230 object_graph_tensor = self._object_graph_feed_tensor 1231 file_prefix_tensor = self._file_prefix_feed_tensor 1232 feed_dict[file_prefix_tensor] = file_prefix 1233 else: 1234 with ops.device("/cpu:0"): 1235 file_prefix_tensor = constant_op.constant( 1236 file_prefix, dtype=dtypes.string) 1237 object_graph_tensor = None 1238 1239 file_io.recursive_create_dir(os.path.dirname(file_prefix)) 1240 save_path, new_feed_additions = self._save_cached_when_graph_building( 1241 file_prefix_tensor, object_graph_tensor, options) 1242 if new_feed_additions: 1243 feed_dict.update(new_feed_additions) 1244 if not use_session: 1245 session = None 1246 elif session is None: 1247 session = get_session() 1248 1249 if session: 1250 return session.run(save_path, feed_dict=feed_dict) 1251 else: 1252 return save_path 1253 1254 def restore(self, save_path, options=None): 1255 """Restore a training checkpoint. 1256 1257 Restores `root_trackable` and any objects that it tracks 1258 (transitive). Either assigns values immediately if variables to restore have 1259 been created already, or defers restoration until the variables are 1260 created. Dependencies added to the `root_trackable` passed to the 1261 constructor after this call will be matched if they have a corresponding 1262 object in the checkpoint. 1263 1264 When building a graph, restorations are added to the graph but not run. 1265 1266 ```python 1267 saver = Saver(root) 1268 saver.restore(path) 1269 ``` 1270 1271 To ensure that loading is complete and no more assignments will take place 1272 you can use the `assert_consumed()` method of the status object returned 1273 by the `restore` call. 1274 1275 The assert will raise an exception unless every object was matched and all 1276 checkpointed values have a matching variable object. 1277 1278 ```python 1279 saver = Saver(root) 1280 saver.restore(path).assert_consumed() 1281 ``` 1282 1283 When graph building, `assert_consumed()` indicates that all of the restore 1284 ops which will be created for this checkpoint have been created. They can be 1285 run via the `run_restore_ops()` function of the status object: 1286 1287 ```python 1288 saver.restore(path).assert_consumed().run_restore_ops() 1289 ``` 1290 1291 If the checkpoint has not been consumed completely, then the list of restore 1292 ops will grow as more objects are added to the dependency graph. 1293 1294 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1295 method. There is no deferred loading, and names are used to match 1296 variables. No restore ops are created/run until `run_restore_ops()` or 1297 `initialize_or_restore()` are called on the returned status object, even 1298 when executing eagerly. Re-encode name-based checkpoints using this 1299 object-based `Saver.save` as soon as possible. 1300 1301 Args: 1302 save_path: The path to the checkpoint, as returned by `save` or 1303 `tf.train.latest_checkpoint`. If None (as when there is no latest 1304 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1305 object which may run initializers for objects in the dependency graph. 1306 If the checkpoint was written by the name-based 1307 `tf.compat.v1.train.Saver`, names are used to match variables. 1308 options: Optional `tf.train.CheckpointOptions` object. 1309 1310 Returns: 1311 A load status object, which can be used to make assertions about the 1312 status of checkpoint restoration and run initialization/restore ops 1313 (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if 1314 `save_path` is `None`). 1315 1316 If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` 1317 object is returned which runs restore ops from a name-based saver. 1318 """ 1319 options = options or checkpoint_options.CheckpointOptions() 1320 if save_path is None: 1321 return InitializationOnlyStatus(self._graph_view, ops.uid()) 1322 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 1323 graph_building = not context.executing_eagerly() 1324 if graph_building: 1325 dtype_map = None 1326 else: 1327 dtype_map = reader.get_variable_to_dtype_map() 1328 try: 1329 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 1330 except errors_impl.NotFoundError: 1331 # The object graph proto does not exist in this checkpoint. Try the 1332 # name-based compatibility mode. 1333 restore_coordinator = _NameBasedRestoreCoordinator( 1334 save_path=save_path, 1335 dtype_map=dtype_map) 1336 if not graph_building: 1337 for existing_trackable in self._graph_view.list_objects(): 1338 # pylint: disable=protected-access 1339 existing_trackable._maybe_initialize_trackable() 1340 existing_trackable._name_based_restores.add(restore_coordinator) 1341 existing_trackable._name_based_attribute_restore(restore_coordinator) 1342 # pylint: enable=protected-access 1343 return NameBasedSaverStatus( 1344 restore_coordinator, 1345 graph_view=self._graph_view) 1346 1347 if graph_building: 1348 if self._file_prefix_placeholder is None: 1349 with ops.device("/cpu:0"): 1350 self._file_prefix_placeholder = constant_op.constant("model") 1351 file_prefix_tensor = self._file_prefix_placeholder 1352 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 1353 else: 1354 with ops.device("/cpu:0"): 1355 file_prefix_tensor = constant_op.constant(save_path) 1356 file_prefix_feed_dict = None 1357 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 1358 object_graph_proto.ParseFromString(object_graph_string) 1359 checkpoint = _CheckpointRestoreCoordinator( 1360 object_graph_proto=object_graph_proto, 1361 save_path=save_path, 1362 save_path_tensor=file_prefix_tensor, 1363 restore_op_cache=self._restore_op_cache, 1364 graph_view=self._graph_view, 1365 options=options) 1366 base.CheckpointPosition( 1367 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) 1368 1369 # Attached dependencies are not attached to the root, so should be restored 1370 # separately. 1371 if self._graph_view.attached_dependencies: 1372 for ref in self._graph_view.attached_dependencies: 1373 if ref.name == "root": 1374 # Root dependency is automatically added to attached dependencies -- 1375 # this can be ignored since it maps back to the root object. 1376 continue 1377 proto_id = None 1378 # Find proto ID of attached dependency (if it is in the proto). 1379 for proto_ref in object_graph_proto.nodes[0].children: 1380 if proto_ref.local_name == ref.name: 1381 proto_id = proto_ref.node_id 1382 break 1383 1384 if proto_id in checkpoint.object_by_proto_id: 1385 # Object has already been restored. This can happen when there's an 1386 # indirect connection from the attached object to the root. 1387 continue 1388 1389 base.CheckpointPosition( 1390 checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) 1391 1392 load_status = CheckpointLoadStatus( 1393 checkpoint, 1394 graph_view=self._graph_view, 1395 feed_dict=file_prefix_feed_dict) 1396 return load_status 1397 1398 1399def frozen_saver(root_trackable): 1400 """Creates a static `tf.compat.v1.train.Saver` from a trackable object. 1401 1402 The returned `Saver` saves object-based checkpoints, but these checkpoints 1403 will no longer reflect structural changes to the object graph, only changes to 1404 the values of `Variable`s added as dependencies of the root object before 1405 `freeze` was called. 1406 1407 `restore` works on the returned `Saver`, but requires that the object graph of 1408 the checkpoint being loaded exactly matches the object graph when `freeze` was 1409 called. This is in contrast the object-based restore performed by 1410 `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's 1411 object graph and the current Python object graph. 1412 1413 Args: 1414 root_trackable: A trackable object to save. 1415 1416 Returns: 1417 A saver which saves object-based checkpoints for the object graph frozen at 1418 the time `frozen_saver` was called. 1419 """ 1420 named_saveable_objects = graph_view_lib.ObjectGraphView( 1421 root_trackable).frozen_saveable_objects() 1422 return functional_saver.MultiDeviceSaver(named_saveable_objects) 1423 1424 1425def saver_with_op_caching(obj, attached_dependencies=None): 1426 if context.executing_eagerly(): 1427 saveables_cache = None 1428 else: 1429 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 1430 return TrackableSaver( 1431 graph_view_lib.ObjectGraphView( 1432 weakref.ref(obj), saveables_cache=saveables_cache, 1433 attached_dependencies=attached_dependencies)) 1434 1435 1436def _assert_trackable(obj): 1437 if not isinstance( 1438 obj, (base.Trackable, def_function.Function)): 1439 raise ValueError( 1440 "`Checkpoint` was expecting a trackable object (an object " 1441 "derived from `TrackableBase`), got {}. If you believe this " 1442 "object should be trackable (i.e. it is part of the " 1443 "TensorFlow Python API and manages state), please open an issue." 1444 .format(obj)) 1445 1446 1447# Mentions graph building / Sessions. The v2 version is below. 1448@tf_export(v1=["train.Checkpoint"]) 1449class CheckpointV1(tracking.AutoTrackable): 1450 """Groups trackable objects, saving and restoring them. 1451 1452 `Checkpoint`'s constructor accepts keyword arguments whose values are types 1453 that contain trackable state, such as `tf.compat.v1.train.Optimizer` 1454 implementations, `tf.Variable`, `tf.keras.Layer` implementations, or 1455 `tf.keras.Model` implementations. It saves these values with a checkpoint, and 1456 maintains a `save_counter` for numbering checkpoints. 1457 1458 Example usage when graph building: 1459 1460 ```python 1461 import tensorflow as tf 1462 import os 1463 1464 checkpoint_directory = "/tmp/training_checkpoints" 1465 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1466 1467 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1468 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1469 train_op = optimizer.minimize( ... ) 1470 status.assert_consumed() # Optional sanity checks. 1471 with tf.compat.v1.Session() as session: 1472 # Use the Session to restore variables, or initialize them if 1473 # tf.train.latest_checkpoint returned None. 1474 status.initialize_or_restore(session) 1475 for _ in range(num_training_steps): 1476 session.run(train_op) 1477 checkpoint.save(file_prefix=checkpoint_prefix) 1478 ``` 1479 1480 Example usage with eager execution enabled: 1481 1482 ```python 1483 import tensorflow as tf 1484 import os 1485 1486 tf.compat.v1.enable_eager_execution() 1487 1488 checkpoint_directory = "/tmp/training_checkpoints" 1489 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1490 1491 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1492 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1493 for _ in range(num_training_steps): 1494 optimizer.minimize( ... ) # Variables will be restored on creation. 1495 status.assert_consumed() # Optional sanity checks. 1496 checkpoint.save(file_prefix=checkpoint_prefix) 1497 ``` 1498 1499 `Checkpoint.save` and `Checkpoint.restore` write and read object-based 1500 checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads 1501 `variable.name` based checkpoints. Object-based checkpointing saves a graph of 1502 dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, 1503 etc.) with named edges, and this graph is used to match variables when 1504 restoring a checkpoint. It can be more robust to changes in the Python 1505 program, and helps to support restore-on-create for variables when executing 1506 eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new 1507 code. 1508 1509 `Checkpoint` objects have dependencies on the objects passed as keyword 1510 arguments to their constructors, and each dependency is given a name that is 1511 identical to the name of the keyword argument for which it was created. 1512 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1513 dependencies on their variables (e.g. "kernel" and "bias" for 1514 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1515 dependencies easy in user-defined classes, since `Model` hooks into attribute 1516 assignment. For example: 1517 1518 ```python 1519 class Regress(tf.keras.Model): 1520 1521 def __init__(self): 1522 super(Regress, self).__init__() 1523 self.input_transform = tf.keras.layers.Dense(10) 1524 # ... 1525 1526 def call(self, inputs): 1527 x = self.input_transform(inputs) 1528 # ... 1529 ``` 1530 1531 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1532 which in turn depends on its variables. As a result, saving an instance of 1533 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1534 by the `Dense` layer. 1535 1536 When variables are assigned to multiple workers, each worker writes its own 1537 section of the checkpoint. These sections are then merged/re-indexed to behave 1538 as a single checkpoint. This avoids copying all variables to one worker, but 1539 does require that all workers see a common filesystem. 1540 1541 While `tf.keras.Model.save_weights` and `tf.train.Checkpoint.save` save in the 1542 same format, note that the root of the resulting checkpoint is the object the 1543 save method is attached to. This means saving a `tf.keras.Model` using 1544 `save_weights` and loading into a `tf.train.Checkpoint` with a `Model` 1545 attached (or vice versa) will not match the `Model`'s variables. See the 1546 [guide to training 1547 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1548 details. Prefer `tf.train.Checkpoint` over `tf.keras.Model.save_weights` for 1549 training checkpoints. 1550 1551 Attributes: 1552 save_counter: Incremented when `save()` is called. Used to number 1553 checkpoints. 1554 """ 1555 1556 def __init__(self, **kwargs): 1557 """Group objects into a training checkpoint. 1558 1559 Args: 1560 **kwargs: Keyword arguments are set as attributes of this object, and are 1561 saved with the checkpoint. Values must be trackable objects. 1562 1563 Raises: 1564 ValueError: If objects in `kwargs` are not trackable. 1565 """ 1566 super(CheckpointV1, self).__init__() 1567 global _END_TIME_OF_LAST_WRITE 1568 with _END_TIME_OF_LAST_WRITE_LOCK: 1569 if _END_TIME_OF_LAST_WRITE is None: 1570 _END_TIME_OF_LAST_WRITE = time.time() 1571 1572 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1573 setattr(self, k, v) 1574 if not isinstance( 1575 getattr(self, k), (base.Trackable, def_function.Function)): 1576 raise ValueError( 1577 ("`Checkpoint` was expecting a trackable object (an object " 1578 "derived from `TrackableBase`), got %s. If you believe this " 1579 "object should be trackable (i.e. it is part of the " 1580 "TensorFlow Python API and manages state), please open an issue.") 1581 % (v,)) 1582 self._save_counter = None # Created lazily for restore-on-create. 1583 self._save_assign_op = None 1584 self._saver = saver_with_op_caching(self) 1585 1586 def _maybe_create_save_counter(self): 1587 """Create a save counter if it does not yet exist.""" 1588 if self._save_counter is None: 1589 # Initialized to 0 and incremented before saving. 1590 with ops.device("/cpu:0"): 1591 # add_variable creates a dependency named "save_counter"; NoDependency 1592 # prevents creating a second dependency named "_save_counter". 1593 self._save_counter = data_structures.NoDependency( 1594 add_variable( 1595 self, 1596 name="save_counter", 1597 initializer=0, 1598 dtype=dtypes.int64, 1599 trainable=False)) 1600 1601 def write(self, file_prefix, session=None): 1602 """Writes a training checkpoint. 1603 1604 The checkpoint includes variables created by this object and any 1605 trackable objects it depends on at the time `Checkpoint.write()` is 1606 called. 1607 1608 `write` does not number checkpoints, increment `save_counter`, or update the 1609 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 1610 use by higher level checkpoint management utilities. `save` provides a very 1611 basic implementation of these features. 1612 1613 Args: 1614 file_prefix: A prefix to use for the checkpoint filenames 1615 (/path/to/directory/and_a_prefix). 1616 session: The session to evaluate variables in. Ignored when executing 1617 eagerly. If not provided when graph building, the default session is 1618 used. 1619 1620 Returns: 1621 The full path to the checkpoint (i.e. `file_prefix`). 1622 """ 1623 start_time = time.time() 1624 output = self._saver.save(file_prefix=file_prefix, session=session) 1625 end_time = time.time() 1626 1627 metrics.AddCheckpointWriteDuration( 1628 api_label=_CHECKPOINT_V1, 1629 microseconds=_get_duration_microseconds(start_time, end_time)) 1630 1631 global _END_TIME_OF_LAST_WRITE 1632 with _END_TIME_OF_LAST_WRITE_LOCK: 1633 metrics.AddTrainingTimeSaved( 1634 api_label=_CHECKPOINT_V1, 1635 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 1636 end_time)) 1637 _END_TIME_OF_LAST_WRITE = end_time 1638 1639 if tensor_util.is_tf_type(output): 1640 if context.executing_eagerly(): 1641 return compat.as_str(output.numpy()) 1642 else: 1643 # Function building 1644 return output 1645 else: 1646 # Graph + Session, so we already session.ran it. 1647 return compat.as_str(output) 1648 1649 @property 1650 def save_counter(self): 1651 """An integer variable which starts at zero and is incremented on save. 1652 1653 Used to number checkpoints. 1654 1655 Returns: 1656 The save counter variable. 1657 """ 1658 self._maybe_create_save_counter() 1659 return self._save_counter 1660 1661 def save(self, file_prefix, session=None): 1662 """Saves a training checkpoint and provides basic checkpoint management. 1663 1664 The saved checkpoint includes variables created by this object and any 1665 trackable objects it depends on at the time `Checkpoint.save()` is 1666 called. 1667 1668 `save` is a basic convenience wrapper around the `write` method, 1669 sequentially numbering checkpoints using `save_counter` and updating the 1670 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 1671 management, for example garbage collection and custom numbering, may be 1672 provided by other utilities which also wrap `write` 1673 (`tf.train.CheckpointManager` for example). 1674 1675 Args: 1676 file_prefix: A prefix to use for the checkpoint filenames 1677 (/path/to/directory/and_a_prefix). Names are generated based on this 1678 prefix and `Checkpoint.save_counter`. 1679 session: The session to evaluate variables in. Ignored when executing 1680 eagerly. If not provided when graph building, the default session is 1681 used. 1682 1683 Returns: 1684 The full path to the checkpoint. 1685 """ 1686 graph_building = not context.executing_eagerly() 1687 if graph_building: 1688 if ops.inside_function(): 1689 raise NotImplementedError( 1690 "Calling tf.train.Checkpoint.save() from a function is not " 1691 "supported, as save() modifies saving metadata in ways not " 1692 "supported by TensorFlow Operations. Consider using " 1693 "tf.train.Checkpoint.write(), a lower-level API which does not " 1694 "update metadata. tf.train.latest_checkpoint and related APIs will " 1695 "not see this checkpoint.") 1696 if session is None: 1697 session = get_session() 1698 if self._save_counter is None: 1699 # When graph building, if this is a new save counter variable then it 1700 # needs to be initialized before assign_add. This is only an issue if 1701 # restore() has not been called first. 1702 session.run(self.save_counter.initializer) 1703 if not graph_building or self._save_assign_op is None: 1704 with ops.colocate_with(self.save_counter): 1705 assign_op = self.save_counter.assign_add(1, read_value=True) 1706 if graph_building: 1707 self._save_assign_op = data_structures.NoDependency(assign_op) 1708 if graph_building: 1709 checkpoint_number = session.run(self._save_assign_op) 1710 else: 1711 checkpoint_number = assign_op.numpy() 1712 file_path = self.write( 1713 "%s-%d" % (file_prefix, checkpoint_number), session=session) 1714 checkpoint_management.update_checkpoint_state_internal( 1715 save_dir=os.path.dirname(file_prefix), 1716 model_checkpoint_path=file_path, 1717 all_model_checkpoint_paths=[file_path], 1718 save_relative_paths=True) 1719 return file_path 1720 1721 def restore(self, save_path): 1722 """Restore a training checkpoint. 1723 1724 Restores this `Checkpoint` and any objects it depends on. 1725 1726 When executing eagerly, either assigns values immediately if variables to 1727 restore have been created already, or defers restoration until the variables 1728 are created. Dependencies added after this call will be matched if they have 1729 a corresponding object in the checkpoint (the restore request will queue in 1730 any trackable object waiting for the expected dependency to be added). 1731 1732 When graph building, restoration ops are added to the graph but not run 1733 immediately. 1734 1735 ```python 1736 checkpoint = tf.train.Checkpoint( ... ) 1737 checkpoint.restore(path) 1738 ``` 1739 1740 To ensure that loading is complete and no more assignments will take place, 1741 you can use the `assert_consumed()` method of the status object returned by 1742 `restore`. 1743 The assert will raise an exception if any Python objects in the dependency 1744 graph were not found in the checkpoint, or if any checkpointed values do not 1745 have a matching Python object: 1746 1747 ```python 1748 checkpoint = tf.train.Checkpoint( ... ) 1749 checkpoint.restore(path).assert_consumed() 1750 ``` 1751 1752 When graph building, `assert_consumed()` indicates that all of the restore 1753 ops that will be created for this checkpoint have been created. They can be 1754 run via the `run_restore_ops()` method of the status object: 1755 1756 ```python 1757 checkpoint.restore(path).assert_consumed().run_restore_ops() 1758 ``` 1759 1760 If the checkpoint has not been consumed completely, then the list of restore 1761 ops will grow as more objects are added to the dependency graph. 1762 1763 Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this 1764 method. Names are used to match variables. No restore ops are created/run 1765 until `run_restore_ops()` or `initialize_or_restore()` are called on the 1766 returned status object when graph building, but there is restore-on-creation 1767 when executing eagerly. Re-encode name-based checkpoints using 1768 `tf.train.Checkpoint.save` as soon as possible. 1769 1770 Args: 1771 save_path: The path to the checkpoint, as returned by `save` or 1772 `tf.train.latest_checkpoint`. If None (as when there is no latest 1773 checkpoint for `tf.train.latest_checkpoint` to return), returns an 1774 object which may run initializers for objects in the dependency graph. 1775 If the checkpoint was written by the name-based 1776 `tf.compat.v1.train.Saver`, names are used to match variables. 1777 1778 Returns: 1779 A load status object, which can be used to make assertions about the 1780 status of a checkpoint restoration and run initialization/restore ops. 1781 1782 The returned status object has the following methods: 1783 1784 * `assert_consumed()`: 1785 Raises an exception if any variables are unmatched: either 1786 checkpointed values which don't have a matching Python object or 1787 Python objects in the dependency graph with no values in the 1788 checkpoint. This method returns the status object, and so may be 1789 chained with `initialize_or_restore` or `run_restore_ops`. 1790 1791 * `assert_existing_objects_matched()`: 1792 Raises an exception if any existing Python objects in the dependency 1793 graph are unmatched. Unlike `assert_consumed`, this assertion will 1794 pass if values in the checkpoint have no corresponding Python 1795 objects. For example a `tf.keras.Layer` object which has not yet been 1796 built, and so has not created any variables, will pass this assertion 1797 but fail `assert_consumed`. Useful when loading part of a larger 1798 checkpoint into a new Python program, e.g. a training checkpoint with 1799 a `tf.compat.v1.train.Optimizer` was saved but only the state required 1800 for 1801 inference is being loaded. This method returns the status object, and 1802 so may be chained with `initialize_or_restore` or `run_restore_ops`. 1803 1804 * `assert_nontrivial_match()`: Asserts that something aside from the root 1805 object was matched. This is a very weak assertion, but is useful for 1806 sanity checking in library code where objects may exist in the 1807 checkpoint which haven't been created in Python and some Python 1808 objects may not have a checkpointed value. 1809 1810 * `expect_partial()`: Silence warnings about incomplete checkpoint 1811 restores. Warnings are otherwise printed for unused parts of the 1812 checkpoint file or object when the `Checkpoint` object is deleted 1813 (often at program shutdown). 1814 1815 * `initialize_or_restore(session=None)`: 1816 When graph building, runs variable initializers if `save_path` is 1817 `None`, but otherwise runs restore operations. If no `session` is 1818 explicitly specified, the default session is used. No effect when 1819 executing eagerly (variables are initialized or restored eagerly). 1820 1821 * `run_restore_ops(session=None)`: 1822 When graph building, runs restore operations. If no `session` is 1823 explicitly specified, the default session is used. No effect when 1824 executing eagerly (restore operations are run eagerly). May only be 1825 called when `save_path` is not `None`. 1826 """ 1827 start_time = time.time() 1828 status = self._saver.restore(save_path=save_path) 1829 # Create the save counter now so it gets initialized with other variables 1830 # when graph building. Creating it earlier would lead to errors when using, 1831 # say, train.Saver() to save the model before initializing it. 1832 self._maybe_create_save_counter() 1833 if isinstance(status, NameBasedSaverStatus): 1834 status.add_to_optionally_restored(self.save_counter) 1835 1836 metrics.AddCheckpointReadDuration( 1837 api_label=_CHECKPOINT_V1, 1838 microseconds=_get_duration_microseconds(start_time, time.time())) 1839 return status 1840 1841 1842@tf_export("train.Checkpoint", v1=[]) 1843class Checkpoint(tracking.AutoTrackable): 1844 """Manages saving/restoring trackable values to disk. 1845 1846 TensorFlow objects may contain trackable state, such as `tf.Variable`s, 1847 `tf.keras.optimizers.Optimizer` implementations, `tf.data.Dataset` iterators, 1848 `tf.keras.Layer` implementations, or `tf.keras.Model` implementations. 1849 These are called **trackable objects**. 1850 1851 A `Checkpoint` object can be constructed to save either a single or group of 1852 trackable objects to a checkpoint file. It maintains a `save_counter` for 1853 numbering checkpoints. 1854 1855 Example: 1856 1857 ```python 1858 model = tf.keras.Model(...) 1859 checkpoint = tf.train.Checkpoint(model) 1860 1861 # Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time 1862 # checkpoint.save is called, the save counter is increased. 1863 save_path = checkpoint.save('/tmp/training_checkpoints') 1864 1865 # Restore the checkpointed values to the `model` object. 1866 checkpoint.restore(save_path) 1867 ``` 1868 1869 Example 2: 1870 1871 ```python 1872 import tensorflow as tf 1873 import os 1874 1875 checkpoint_directory = "/tmp/training_checkpoints" 1876 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 1877 1878 # Create a Checkpoint that will manage two objects with trackable state, 1879 # one we name "optimizer" and the other we name "model". 1880 checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 1881 status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 1882 for _ in range(num_training_steps): 1883 optimizer.minimize( ... ) # Variables will be restored on creation. 1884 status.assert_consumed() # Optional sanity checks. 1885 checkpoint.save(file_prefix=checkpoint_prefix) 1886 ``` 1887 1888 `Checkpoint.save()` and `Checkpoint.restore()` write and read object-based 1889 checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which 1890 writes and 1891 reads `variable.name` based checkpoints. Object-based checkpointing saves a 1892 graph of dependencies between Python objects (`Layer`s, `Optimizer`s, 1893 `Variable`s, etc.) with named edges, and this graph is used to match variables 1894 when restoring a checkpoint. It can be more robust to changes in the Python 1895 program, and helps to support restore-on-create for variables. 1896 1897 `Checkpoint` objects have dependencies on the objects passed as keyword 1898 arguments to their constructors, and each dependency is given a name that is 1899 identical to the name of the keyword argument for which it was created. 1900 TensorFlow classes like `Layer`s and `Optimizer`s will automatically add 1901 dependencies on their own variables (e.g. "kernel" and "bias" for 1902 `tf.keras.layers.Dense`). Inheriting from `tf.keras.Model` makes managing 1903 dependencies easy in user-defined classes, since `Model` hooks into attribute 1904 assignment. For example: 1905 1906 ```python 1907 class Regress(tf.keras.Model): 1908 1909 def __init__(self): 1910 super(Regress, self).__init__() 1911 self.input_transform = tf.keras.layers.Dense(10) 1912 # ... 1913 1914 def call(self, inputs): 1915 x = self.input_transform(inputs) 1916 # ... 1917 ``` 1918 1919 This `Model` has a dependency named "input_transform" on its `Dense` layer, 1920 which in turn depends on its variables. As a result, saving an instance of 1921 `Regress` using `tf.train.Checkpoint` will also save all the variables created 1922 by the `Dense` layer. 1923 1924 When variables are assigned to multiple workers, each worker writes its own 1925 section of the checkpoint. These sections are then merged/re-indexed to behave 1926 as a single checkpoint. This avoids copying all variables to one worker, but 1927 does require that all workers see a common filesystem. 1928 1929 This function differs slightly from the Keras Model `save_weights` function. 1930 `tf.keras.Model.save_weights` creates a checkpoint file with the name 1931 specified in `filepath`, while `tf.train.Checkpoint` numbers the checkpoints, 1932 using `filepath` as the prefix for the checkpoint file names. Aside from this, 1933 `model.save_weights()` and `tf.train.Checkpoint(model).save()` are equivalent. 1934 1935 See the [guide to training 1936 checkpoints](https://www.tensorflow.org/guide/checkpoint) for 1937 details. 1938 1939 Attributes: 1940 save_counter: Incremented when `save()` is called. Used to number 1941 checkpoints. 1942 """ 1943 1944 def __init__(self, root=None, **kwargs): 1945 """Creates a training checkpoint for a single or group of objects. 1946 1947 Args: 1948 root: The root object to checkpoint. 1949 **kwargs: Keyword arguments are set as attributes of this object, and are 1950 saved with the checkpoint. Values must be trackable objects. 1951 1952 Raises: 1953 ValueError: If `root` or the objects in `kwargs` are not trackable. A 1954 `ValueError` is also raised if the `root` object tracks different 1955 objects from the ones listed in attributes in kwargs (e.g. 1956 `root.child = A` and `tf.train.Checkpoint(root, child=B)` are 1957 incompatible). 1958 1959 """ 1960 super(Checkpoint, self).__init__() 1961 global _END_TIME_OF_LAST_WRITE 1962 with _END_TIME_OF_LAST_WRITE_LOCK: 1963 if _END_TIME_OF_LAST_WRITE is None: 1964 _END_TIME_OF_LAST_WRITE = time.time() 1965 1966 saver_root = self 1967 attached_dependencies = None 1968 self._save_counter = None # Created lazily for restore-on-create. 1969 self._save_assign_op = None 1970 1971 if root: 1972 _assert_trackable(root) 1973 saver_root = root 1974 attached_dependencies = [] 1975 1976 # All keyword arguments (including root itself) are set as children 1977 # of root. 1978 kwargs["root"] = root 1979 root._maybe_initialize_trackable() 1980 1981 self._save_counter = data_structures.NoDependency( 1982 root._lookup_dependency("save_counter")) 1983 self._root = data_structures.NoDependency(root) 1984 1985 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 1986 setattr(self, k, v) 1987 1988 # Call getattr instead of directly using v because setattr converts 1989 # v to a Trackable data structure when v is a list/dict/tuple. 1990 converted_v = getattr(self, k) 1991 _assert_trackable(converted_v) 1992 1993 if root: 1994 # Make sure that root doesn't already have dependencies with these names 1995 child = root._lookup_dependency(k) 1996 if child is None: 1997 attached_dependencies.append(base.TrackableReference(k, converted_v)) 1998 elif child != converted_v: 1999 raise ValueError( 2000 "Cannot create a Checkpoint with keyword argument {name} if " 2001 "root.{name} already exists.".format(name=k)) 2002 2003 self._saver = saver_with_op_caching(saver_root, attached_dependencies) 2004 self._attached_dependencies = data_structures.NoDependency( 2005 attached_dependencies) 2006 2007 def _maybe_create_save_counter(self): 2008 """Create a save counter if it does not yet exist.""" 2009 if self._save_counter is None: 2010 # Initialized to 0 and incremented before saving. 2011 with ops.device("/cpu:0"): 2012 # add_variable creates a dependency named "save_counter"; NoDependency 2013 # prevents creating a second dependency named "_save_counter". 2014 self._save_counter = data_structures.NoDependency( 2015 add_variable( 2016 self, 2017 name="save_counter", 2018 initializer=0, 2019 dtype=dtypes.int64, 2020 trainable=False)) 2021 if self._attached_dependencies is not None: 2022 self._attached_dependencies.append( 2023 base.TrackableReference("save_counter", self._save_counter)) 2024 # When loading a checkpoint, the save counter is created after 2025 # the checkpoint has been loaded, so it must be handled in a deferred 2026 # manner. 2027 restore = self.root._deferred_dependencies.pop("save_counter", ()) # pylint: disable=protected-access 2028 if restore: 2029 restore[0].restore(self._save_counter) 2030 2031 def write(self, file_prefix, options=None): 2032 """Writes a training checkpoint. 2033 2034 The checkpoint includes variables created by this object and any 2035 trackable objects it depends on at the time `Checkpoint.write()` is 2036 called. 2037 2038 `write` does not number checkpoints, increment `save_counter`, or update the 2039 metadata used by `tf.train.latest_checkpoint`. It is primarily intended for 2040 use by higher level checkpoint management utilities. `save` provides a very 2041 basic implementation of these features. 2042 2043 Checkpoints written with `write` must be read with `read`. 2044 2045 Example usage: 2046 2047 ``` 2048 step = tf.Variable(0, name="step") 2049 checkpoint = tf.Checkpoint(step=step) 2050 checkpoint.write("/tmp/ckpt") 2051 2052 # Later, read the checkpoint with read() 2053 checkpoint.read("/tmp/ckpt").assert_consumed() 2054 2055 # You can also pass options to write() and read(). For example this 2056 # runs the IO ops on the localhost: 2057 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2058 checkpoint.write("/tmp/ckpt", options=options) 2059 2060 # Later, read the checkpoint with read() 2061 checkpoint.read("/tmp/ckpt", options=options).assert_consumed() 2062 ``` 2063 2064 Args: 2065 file_prefix: A prefix to use for the checkpoint filenames 2066 (/path/to/directory/and_a_prefix). 2067 options: Optional `tf.train.CheckpointOptions` object. 2068 2069 Returns: 2070 The full path to the checkpoint (i.e. `file_prefix`). 2071 """ 2072 start_time = time.time() 2073 options = options or checkpoint_options.CheckpointOptions() 2074 output = self._saver.save(file_prefix=file_prefix, options=options) 2075 end_time = time.time() 2076 2077 metrics.AddCheckpointWriteDuration( 2078 api_label=_CHECKPOINT_V2, 2079 microseconds=_get_duration_microseconds(start_time, end_time)) 2080 2081 global _END_TIME_OF_LAST_WRITE 2082 with _END_TIME_OF_LAST_WRITE_LOCK: 2083 metrics.AddTrainingTimeSaved( 2084 api_label=_CHECKPOINT_V2, 2085 microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE, 2086 end_time)) 2087 _END_TIME_OF_LAST_WRITE = end_time 2088 2089 if tensor_util.is_tf_type(output): 2090 if context.executing_eagerly(): 2091 return compat.as_str(output.numpy()) 2092 else: 2093 # Function building 2094 return output 2095 else: 2096 # Graph + Session, so we already session.ran it. 2097 return compat.as_str(output) 2098 2099 @property 2100 def save_counter(self): 2101 """An integer variable which starts at zero and is incremented on save. 2102 2103 Used to number checkpoints. 2104 2105 Returns: 2106 The save counter variable. 2107 """ 2108 self._maybe_create_save_counter() 2109 return self._save_counter 2110 2111 def save(self, file_prefix, options=None): 2112 """Saves a training checkpoint and provides basic checkpoint management. 2113 2114 The saved checkpoint includes variables created by this object and any 2115 trackable objects it depends on at the time `Checkpoint.save()` is 2116 called. 2117 2118 `save` is a basic convenience wrapper around the `write` method, 2119 sequentially numbering checkpoints using `save_counter` and updating the 2120 metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint 2121 management, for example garbage collection and custom numbering, may be 2122 provided by other utilities which also wrap `write` and `read`. 2123 (`tf.train.CheckpointManager` for example). 2124 2125 ``` 2126 step = tf.Variable(0, name="step") 2127 checkpoint = tf.Checkpoint(step=step) 2128 checkpoint.save("/tmp/ckpt") 2129 2130 # Later, read the checkpoint with restore() 2131 checkpoint.restore("/tmp/ckpt").assert_consumed() 2132 2133 # You can also pass options to save() and restore(). For example this 2134 # runs the IO ops on the localhost: 2135 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2136 checkpoint.save("/tmp/ckpt", options=options) 2137 2138 # Later, read the checkpoint with restore() 2139 checkpoint.restore("/tmp/ckpt", options=options).assert_consumed() 2140 ``` 2141 2142 Args: 2143 file_prefix: A prefix to use for the checkpoint filenames 2144 (/path/to/directory/and_a_prefix). Names are generated based on this 2145 prefix and `Checkpoint.save_counter`. 2146 options: Optional `tf.train.CheckpointOptions` object. 2147 2148 Returns: 2149 The full path to the checkpoint. 2150 """ 2151 options = options or checkpoint_options.CheckpointOptions() 2152 graph_building = not context.executing_eagerly() 2153 if graph_building: 2154 if ops.inside_function(): 2155 raise NotImplementedError( 2156 "Calling tf.train.Checkpoint.save() from a function is not " 2157 "supported, as save() modifies saving metadata in ways not " 2158 "supported by TensorFlow Operations. Consider using " 2159 "tf.train.Checkpoint.write(), a lower-level API which does not " 2160 "update metadata. tf.train.latest_checkpoint and related APIs will " 2161 "not see this checkpoint.") 2162 session = get_session() 2163 if self._save_counter is None: 2164 # When graph building, if this is a new save counter variable then it 2165 # needs to be initialized before assign_add. This is only an issue if 2166 # restore() has not been called first. 2167 session.run(self.save_counter.initializer) 2168 if not graph_building or self._save_assign_op is None: 2169 with ops.colocate_with(self.save_counter): 2170 assign_op = self.save_counter.assign_add(1, read_value=True) 2171 if graph_building: 2172 self._save_assign_op = data_structures.NoDependency(assign_op) 2173 if graph_building: 2174 checkpoint_number = session.run(self._save_assign_op) 2175 else: 2176 checkpoint_number = assign_op.numpy() 2177 file_path = self.write("%s-%d" % (file_prefix, checkpoint_number), 2178 options=options) 2179 checkpoint_management.update_checkpoint_state_internal( 2180 save_dir=os.path.dirname(file_prefix), 2181 model_checkpoint_path=file_path, 2182 all_model_checkpoint_paths=[file_path], 2183 save_relative_paths=True) 2184 if not graph_building: 2185 context.async_wait() # Ensure save operations have completed. 2186 return file_path 2187 2188 def read(self, save_path, options=None): 2189 """Reads a training checkpoint written with `write`. 2190 2191 Reads this `Checkpoint` and any objects it depends on. 2192 2193 This method is just like `restore()` but does not expect the `save_counter` 2194 variable in the checkpoint. It only restores the objects that the checkpoint 2195 already depends on. 2196 2197 The method is primarily intended for use by higher level checkpoint 2198 management utilities that use `write()` instead of `save()` and have their 2199 own mechanisms to number and track checkpoints. 2200 2201 Example usage: 2202 2203 ```python 2204 # Create a checkpoint with write() 2205 ckpt = tf.train.Checkpoint(v=tf.Variable(1.)) 2206 path = ckpt.write('/tmp/my_checkpoint') 2207 2208 # Later, load the checkpoint with read() 2209 # With restore() assert_consumed() would have failed. 2210 checkpoint.read(path).assert_consumed() 2211 2212 # You can also pass options to read(). For example this 2213 # runs the IO ops on the localhost: 2214 options = tf.train.CheckpointOptions( 2215 experimental_io_device="/job:localhost") 2216 checkpoint.read(path, options=options) 2217 ``` 2218 2219 Args: 2220 save_path: The path to the checkpoint as returned by `write`. 2221 options: Optional `tf.train.CheckpointOptions` object. 2222 2223 Returns: 2224 A load status object, which can be used to make assertions about the 2225 status of a checkpoint restoration. See `restore` for details. 2226 """ 2227 start_time = time.time() 2228 options = options or checkpoint_options.CheckpointOptions() 2229 result = self._saver.restore(save_path=save_path, options=options) 2230 metrics.AddCheckpointReadDuration( 2231 api_label=_CHECKPOINT_V2, 2232 microseconds=_get_duration_microseconds(start_time, time.time())) 2233 return result 2234 2235 def restore(self, save_path, options=None): 2236 """Restores a training checkpoint. 2237 2238 Restores this `Checkpoint` and any objects it depends on. 2239 2240 This method is intended to be used to load checkpoints created by `save()`. 2241 For checkpoints created by `write()` use the `read()` method which does not 2242 expect the `save_counter` variable added by `save()`. 2243 2244 `restore()` either assigns values immediately if variables to restore have 2245 been created already, or defers restoration until the variables are 2246 created. Dependencies added after this call will be matched if they have a 2247 corresponding object in the checkpoint (the restore request will queue in 2248 any trackable object waiting for the expected dependency to be added). 2249 2250 ```python 2251 checkpoint = tf.train.Checkpoint( ... ) 2252 checkpoint.restore(path) 2253 2254 # You can additionally pass options to restore(): 2255 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2256 checkpoint.restore(path, options=options) 2257 ``` 2258 2259 To ensure that loading is complete and no more assignments will take place, 2260 use the `assert_consumed()` method of the status object returned by 2261 `restore()`: 2262 2263 ```python 2264 checkpoint = tf.train.Checkpoint( ... ) 2265 checkpoint.restore(path).assert_consumed() 2266 2267 # You can additionally pass options to restore(): 2268 options = tf.CheckpointOptions(experimental_io_device="/job:localhost") 2269 checkpoint.restore(path, options=options).assert_consumed() 2270 ``` 2271 2272 The assert will raise an error if any Python objects in the dependency graph 2273 were not found in the checkpoint, or if any checkpointed values do not have 2274 a matching Python object. 2275 2276 Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be 2277 loaded using this method. Names are used to match variables. Re-encode 2278 name-based checkpoints using `tf.train.Checkpoint.save` as soon as possible. 2279 2280 **Loading from SavedModel checkpoints** 2281 2282 To load values from a SavedModel, just pass the SavedModel directory 2283 to checkpoint.restore: 2284 2285 ```python 2286 model = tf.keras.Model(...) 2287 tf.saved_model.save(model, path) # or model.save(path, save_format='tf') 2288 2289 checkpoint = tf.train.Checkpoint(model) 2290 checkpoint.restore(path).expect_partial() 2291 ``` 2292 2293 This example calls `expect_partial()` on the loaded status, since 2294 SavedModels saved from Keras often generates extra keys in the checkpoint. 2295 Otherwise, the program prints a lot of warnings about unused keys at exit 2296 time. 2297 2298 Args: 2299 save_path: The path to the checkpoint, as returned by `save` or 2300 `tf.train.latest_checkpoint`. If the checkpoint was written by the 2301 name-based `tf.compat.v1.train.Saver`, names are used to match 2302 variables. This path may also be a SavedModel directory. 2303 options: Optional `tf.train.CheckpointOptions` object. 2304 2305 Returns: 2306 A load status object, which can be used to make assertions about the 2307 status of a checkpoint restoration. 2308 2309 The returned status object has the following methods: 2310 2311 * `assert_consumed()`: 2312 Raises an exception if any variables are unmatched: either 2313 checkpointed values which don't have a matching Python object or 2314 Python objects in the dependency graph with no values in the 2315 checkpoint. This method returns the status object, and so may be 2316 chained with other assertions. 2317 2318 * `assert_existing_objects_matched()`: 2319 Raises an exception if any existing Python objects in the dependency 2320 graph are unmatched. Unlike `assert_consumed`, this assertion will 2321 pass if values in the checkpoint have no corresponding Python 2322 objects. For example a `tf.keras.Layer` object which has not yet been 2323 built, and so has not created any variables, will pass this assertion 2324 but fail `assert_consumed`. Useful when loading part of a larger 2325 checkpoint into a new Python program, e.g. a training checkpoint with 2326 a `tf.compat.v1.train.Optimizer` was saved but only the state required 2327 for 2328 inference is being loaded. This method returns the status object, and 2329 so may be chained with other assertions. 2330 2331 * `assert_nontrivial_match()`: Asserts that something aside from the root 2332 object was matched. This is a very weak assertion, but is useful for 2333 sanity checking in library code where objects may exist in the 2334 checkpoint which haven't been created in Python and some Python 2335 objects may not have a checkpointed value. 2336 2337 * `expect_partial()`: Silence warnings about incomplete checkpoint 2338 restores. Warnings are otherwise printed for unused parts of the 2339 checkpoint file or object when the `Checkpoint` object is deleted 2340 (often at program shutdown). 2341 2342 Raises: 2343 NotFoundError: if the a checkpoint or SavedModel cannot be found at 2344 `save_path`. 2345 """ 2346 orig_save_path = save_path 2347 2348 if save_path is not None and gfile.IsDirectory(save_path) and ( 2349 (gfile.Exists(utils_impl.get_saved_model_pb_path(save_path)) or 2350 gfile.Exists(utils_impl.get_saved_model_pbtxt_path(save_path)))): 2351 save_path = utils_impl.get_variables_path(save_path) 2352 2353 try: 2354 status = self.read(save_path, options=options) 2355 if context.executing_eagerly(): 2356 context.async_wait() # Ensure restore operations have completed. 2357 except errors_impl.NotFoundError as e: 2358 raise errors_impl.NotFoundError( 2359 None, None, 2360 "Failed to restore from checkpoint or SavedModel at {}: {}".format( 2361 orig_save_path, e.message)) 2362 # Create the save counter now so it gets initialized with other variables 2363 # when graph building. Creating it earlier would lead to errors when using, 2364 # say, train.Saver() to save the model before initializing it. 2365 self._maybe_create_save_counter() 2366 if isinstance(status, NameBasedSaverStatus): 2367 status.add_to_optionally_restored(self.save_counter) 2368 return status 2369