"""Manages a graph of Trackable objects.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import weakref from tensorflow.core.protobuf import trackable_object_graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training.saving import saveable_object as saveable_object_lib from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import tracking from tensorflow.python.util import object_identity _ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. # Keyword for identifying that the next bit of a checkpoint variable name is a # slot name. Checkpoint names for slot variables look like: # # /<_OPTIMIZER_SLOTS_NAME>// # # Where is a full path from the checkpoint root to the # variable being slotted for. _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" # Keyword for separating the path to an object from the name of an # attribute in checkpoint names. Used like: # /<_OBJECT_ATTRIBUTES_NAME>/ _OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" def _escape_local_name(name): # We need to support slashes in local names for compatibility, since this # naming scheme is being patched in to things like Layer.add_variable where # slashes were previously accepted. We also want to use slashes to indicate # edges traversed to reach the variable, so we escape forward slashes in # names. return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) .replace(r"/", _ESCAPE_CHAR + "S")) def _object_prefix_from_path(path_to_root): return "/".join( (_escape_local_name(trackable.name) for trackable in path_to_root)) def _slot_variable_naming_for_optimizer(optimizer_path): """Make a function for naming slot variables in an optimizer.""" # Name slot variables: # # /<_OPTIMIZER_SLOTS_NAME>// # # where is exactly the checkpoint name used for the original # variable, including the path from the checkpoint root and the local name in # the object which owns it. Note that we only save slot variables if the # variable it's slotting for is also being saved. optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) def _name_slot_variable(variable_path, slot_name): """With an optimizer specified, name a slot variable.""" return (variable_path + optimizer_identifier + _escape_local_name(slot_name)) return _name_slot_variable def _serialize_slot_variables(trackable_objects, node_ids, object_names): """Gather and name slot variables.""" non_slot_objects = list(trackable_objects) slot_variables = object_identity.ObjectIdentityDictionary() for trackable in non_slot_objects: if (isinstance(trackable, optimizer_v1.Optimizer) # TODO(b/110718070): Fix Keras imports. # Note: dir() is used rather than hasattr() here to avoid triggering # custom __getattr__ code, see b/152031870 for context. or "_create_or_restore_slot_variable" in dir(trackable)): naming_scheme = _slot_variable_naming_for_optimizer( optimizer_path=object_names[trackable]) slot_names = trackable.get_slot_names() for slot_name in slot_names: for original_variable_node_id, original_variable in enumerate( non_slot_objects): try: slot_variable = trackable.get_slot( original_variable, slot_name) except (AttributeError, KeyError): slot_variable = None if slot_variable is None: continue slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access # TODO(allenl): Gather dependencies of slot variables. raise NotImplementedError( "Currently only variables with no dependencies can be saved as " "slot variables. File a feature request if this limitation " "bothers you.") if slot_variable in node_ids: raise NotImplementedError( ("A slot variable was re-used as a dependency of a " "Trackable object: %s. This is not currently " "allowed. File a feature request if this limitation bothers " "you.") % slot_variable) checkpoint_name = naming_scheme( variable_path=object_names[original_variable], slot_name=slot_name) object_names[slot_variable] = checkpoint_name slot_variable_node_id = len(trackable_objects) node_ids[slot_variable] = slot_variable_node_id trackable_objects.append(slot_variable) slot_variable_proto = ( trackable_object_graph_pb2.TrackableObjectGraph .TrackableObject.SlotVariableReference( slot_name=slot_name, original_variable_node_id=original_variable_node_id, slot_variable_node_id=slot_variable_node_id)) slot_variables.setdefault(trackable, []).append( slot_variable_proto) return slot_variables class ObjectGraphView(object): """Gathers and serializes an object graph.""" def __init__(self, root, saveables_cache=None, attached_dependencies=None): """Configure the graph view. Args: root: A `Trackable` object whose variables (including the variables of dependencies, recursively) should be saved. May be a weak reference. saveables_cache: A dictionary mapping `Trackable` objects -> attribute names -> SaveableObjects, used to avoid re-creating SaveableObjects when graph building. attached_dependencies: Dependencies to attach to the root object. Used when saving a Checkpoint with a defined root object. """ self._root_ref = root self._saveables_cache = saveables_cache self._attached_dependencies = attached_dependencies def list_dependencies(self, obj): # pylint: disable=protected-access obj._maybe_initialize_trackable() dependencies = obj._checkpoint_dependencies # pylint: enable=protected-access if obj is self.root and self._attached_dependencies: dependencies = dependencies.copy() dependencies.extend(self._attached_dependencies) return dependencies @property def saveables_cache(self): """Maps Trackable objects -> attribute names -> list(SaveableObjects). Used to avoid re-creating SaveableObjects when graph building. None when executing eagerly. Returns: The cache (an object-identity dictionary), or None if caching is disabled. """ return self._saveables_cache @property def attached_dependencies(self): """Returns list of dependencies that should be saved in the checkpoint. These dependencies are not tracked by root, but are in the checkpoint. This is defined when the user creates a Checkpoint with both root and kwargs set. Returns: A list of TrackableReferences. """ return self._attached_dependencies @property def root(self): if isinstance(self._root_ref, weakref.ref): derefed = self._root_ref() assert derefed is not None return derefed else: return self._root_ref def _breadth_first_traversal(self): """Find shortest paths to all dependencies of self.root.""" bfs_sorted = [] to_visit = collections.deque([self.root]) path_to_root = object_identity.ObjectIdentityDictionary() path_to_root[self.root] = () while to_visit: current_trackable = to_visit.popleft() if isinstance(current_trackable, tracking.NotTrackable): raise NotImplementedError( ("The object %s does not support object-based saving. File a " "feature request if this limitation bothers you. In the meantime, " "you can remove the dependency on this object and save everything " "else.") % (current_trackable,)) bfs_sorted.append(current_trackable) for name, dependency in self.list_dependencies(current_trackable): if dependency not in path_to_root: path_to_root[dependency] = ( path_to_root[current_trackable] + ( base.TrackableReference(name, dependency),)) to_visit.append(dependency) return bfs_sorted, path_to_root def _add_attributes_to_object_graph( self, trackable_objects, object_graph_proto, node_ids, object_names, object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. feed_additions = None else: # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the # checkpoint. feed_additions = {} for checkpoint_id, (trackable, object_proto) in enumerate( zip(trackable_objects, object_graph_proto.nodes)): assert node_ids[trackable] == checkpoint_id object_name = object_names[trackable] if object_map is None: object_to_save = trackable else: object_to_save = object_map.get(trackable, trackable) if self._saveables_cache is not None: cached_attributes = self._saveables_cache.setdefault(object_to_save, {}) else: cached_attributes = None for name, saveable_factory in ( object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) if cached_attributes is None: saveables = None else: saveables = cached_attributes.get(name, None) if saveables is not None: for saveable in saveables: if attribute.checkpoint_key not in saveable.name: # The checkpoint key for this SaveableObject is different. We # need to re-create it. saveables = None del cached_attributes[name] break if saveables is None: if callable(saveable_factory): maybe_saveable = saveable_object_util.create_saveable_object( saveable_factory, attribute.checkpoint_key, call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): saveables = (maybe_saveable,) else: # Figure out the name-based Saver's name for this variable. If it's # already a SaveableObject we'd just get the checkpoint key back, so # we leave full_name blank. saver_dict = saveable_object_util.op_list_to_dict( [maybe_saveable], convert_variable_to_tensor=False) full_name, = saver_dict.keys() saveables = tuple(saveable_object_util.saveable_objects_for_op( op=maybe_saveable, name=attribute.checkpoint_key)) for saveable in saveables: saveable.full_name = full_name for saveable in saveables: if attribute.checkpoint_key not in saveable.name: raise AssertionError( ("The object %s produced a SaveableObject with name '%s' for " "attribute '%s'. Expected a name containing '%s'.") % (trackable, name, saveable.name, attribute.checkpoint_key)) if cached_attributes is not None: cached_attributes[name] = saveables optional_restore = None for saveable in saveables: if optional_restore is None: optional_restore = saveable.optional_restore else: optional_restore = optional_restore and saveable.optional_restore if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name if isinstance(saveable, base.PythonStateSaveable): if feed_additions is None: assert self._saveables_cache is None # If we're not caching saveables, then we're either executing # eagerly or building a static save/restore (e.g. for a # SavedModel). In either case, we should embed the current Python # state in the graph rather than relying on a feed dict. saveable = saveable.freeze() else: saveable_feed_dict = saveable.feed_dict_additions() for new_feed_key in saveable_feed_dict.keys(): if new_feed_key in feed_additions: raise AssertionError( ("The object %s tried to feed a value for the Tensor %s " "when saving, but another object is already feeding a " "value.") % (trackable, new_feed_key)) feed_additions.update(saveable_feed_dict) named_saveable_objects.append(saveable) if optional_restore is None: optional_restore = False attribute.optional_restore = optional_restore return named_saveable_objects, feed_additions def _fill_object_graph_proto(self, trackable_objects, node_ids, slot_variables, object_graph_proto=None): """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" if object_graph_proto is None: object_graph_proto = ( trackable_object_graph_pb2.TrackableObjectGraph()) for checkpoint_id, trackable in enumerate(trackable_objects): assert node_ids[trackable] == checkpoint_id object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(trackable, ())) for child in self.list_dependencies(trackable): child_proto = object_proto.children.add() child_proto.node_id = node_ids[child.ref] child_proto.local_name = child.name return object_graph_proto def _serialize_gathered_objects(self, trackable_objects, path_to_root, object_map=None, call_with_mapped_captures=None): """Create SaveableObjects and protos for gathered objects.""" object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): object_names[obj] = _object_prefix_from_path(path) node_ids = object_identity.ObjectIdentityDictionary() for node_id, node in enumerate(trackable_objects): node_ids[node] = node_id slot_variables = _serialize_slot_variables( trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) object_graph_proto = self._fill_object_graph_proto( trackable_objects=trackable_objects, node_ids=node_ids, slot_variables=slot_variables) named_saveable_objects, feed_additions = ( self._add_attributes_to_object_graph( trackable_objects=trackable_objects, object_graph_proto=object_graph_proto, node_ids=node_ids, object_names=object_names, object_map=object_map, call_with_mapped_captures=call_with_mapped_captures)) return named_saveable_objects, object_graph_proto, feed_additions def serialize_object_graph(self): """Determine checkpoint keys for variables and build a serialized graph. Non-slot variables are keyed based on a shortest path from the root saveable to the object which owns the variable (i.e. the one which called `Trackable._add_variable` to create it). Slot variables are keyed based on a shortest path to the variable being slotted for, a shortest path to their optimizer, and the slot name. Returns: A tuple of (named_variables, object_graph_proto, feed_additions): named_variables: A dictionary mapping names to variable objects. object_graph_proto: A TrackableObjectGraph protocol buffer containing the serialized object graph and variable references. feed_additions: A dictionary mapping from Tensors to values which should be fed when saving. Raises: ValueError: If there are invalid characters in an optimizer's slot names. """ trackable_objects, path_to_root = self._breadth_first_traversal() return self._serialize_gathered_objects( trackable_objects, path_to_root) def frozen_saveable_objects(self, object_map=None, to_graph=None, call_with_mapped_captures=None): """Creates SaveableObjects with the current object graph frozen.""" trackable_objects, path_to_root = self._breadth_first_traversal() if to_graph: target_context = to_graph.as_default else: target_context = ops.NullContextmanager with target_context(): named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects( trackable_objects, path_to_root, object_map, call_with_mapped_captures) with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) named_saveable_objects.append( base.NoRestoreSaveable( tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY)) return named_saveable_objects def objects_ids_and_slot_variables_and_paths(self): """Traverse the object graph and list all accessible objects. Looks for `Trackable` objects which are dependencies of `root_trackable`. Includes slot variables only if the variable they are slotting for and the optimizer are dependencies of `root_trackable` (i.e. if they would be saved with a checkpoint). Returns: A tuple of (trackable objects, paths from root for each object, object -> node id, slot variables) """ trackable_objects, path_to_root = self._breadth_first_traversal() object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): object_names[obj] = _object_prefix_from_path(path) node_ids = object_identity.ObjectIdentityDictionary() for node_id, node in enumerate(trackable_objects): node_ids[node] = node_id slot_variables = _serialize_slot_variables( trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) return trackable_objects, path_to_root, node_ids, slot_variables def objects_ids_and_slot_variables(self): trackable_objects, _, node_ids, slot_variables = ( self.objects_ids_and_slot_variables_and_paths()) return trackable_objects, node_ids, slot_variables def list_objects(self): """Traverse the object graph and list all accessible objects.""" trackable_objects, _, _ = self.objects_ids_and_slot_variables() return trackable_objects