1"""Manages a graph of Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import weakref 22 23from tensorflow.core.protobuf import trackable_object_graph_pb2 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.training import optimizer as optimizer_v1 28from tensorflow.python.training.saving import saveable_object as saveable_object_lib 29from tensorflow.python.training.saving import saveable_object_util 30from tensorflow.python.training.tracking import base 31from tensorflow.python.training.tracking import object_identity 32from tensorflow.python.training.tracking import tracking 33 34 35_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. 36 37# Keyword for identifying that the next bit of a checkpoint variable name is a 38# slot name. Checkpoint names for slot variables look like: 39# 40# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name> 41# 42# Where <path to variable> is a full path from the checkpoint root to the 43# variable being slotted for. 44_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" 45# Keyword for separating the path to an object from the name of an 46# attribute in checkpoint names. Used like: 47# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute> 48_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" 49 50 51def _escape_local_name(name): 52 # We need to support slashes in local names for compatibility, since this 53 # naming scheme is being patched in to things like Layer.add_variable where 54 # slashes were previously accepted. We also want to use slashes to indicate 55 # edges traversed to reach the variable, so we escape forward slashes in 56 # names. 57 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) 58 .replace(r"/", _ESCAPE_CHAR + "S")) 59 60 61def _object_prefix_from_path(path_to_root): 62 return "/".join( 63 (_escape_local_name(trackable.name) 64 for trackable in path_to_root)) 65 66 67def _slot_variable_naming_for_optimizer(optimizer_path): 68 """Make a function for naming slot variables in an optimizer.""" 69 # Name slot variables: 70 # 71 # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name> 72 # 73 # where <variable name> is exactly the checkpoint name used for the original 74 # variable, including the path from the checkpoint root and the local name in 75 # the object which owns it. Note that we only save slot variables if the 76 # variable it's slotting for is also being saved. 77 78 optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) 79 80 def _name_slot_variable(variable_path, slot_name): 81 """With an optimizer specified, name a slot variable.""" 82 return (variable_path 83 + optimizer_identifier 84 + _escape_local_name(slot_name)) 85 86 return _name_slot_variable 87 88 89def _serialize_slot_variables(trackable_objects, node_ids, object_names): 90 """Gather and name slot variables.""" 91 non_slot_objects = list(trackable_objects) 92 slot_variables = object_identity.ObjectIdentityDictionary() 93 for trackable in non_slot_objects: 94 if (isinstance(trackable, optimizer_v1.Optimizer) 95 # TODO(b/110718070): Fix Keras imports. 96 or hasattr(trackable, "_create_or_restore_slot_variable")): 97 naming_scheme = _slot_variable_naming_for_optimizer( 98 optimizer_path=object_names[trackable]) 99 slot_names = trackable.get_slot_names() 100 for slot_name in slot_names: 101 for original_variable_node_id, original_variable in enumerate( 102 non_slot_objects): 103 try: 104 slot_variable = trackable.get_slot( 105 original_variable, slot_name) 106 except (AttributeError, KeyError): 107 slot_variable = None 108 if slot_variable is None: 109 continue 110 slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access 111 if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access 112 # TODO(allenl): Gather dependencies of slot variables. 113 raise NotImplementedError( 114 "Currently only variables with no dependencies can be saved as " 115 "slot variables. File a feature request if this limitation " 116 "bothers you.") 117 if slot_variable in node_ids: 118 raise NotImplementedError( 119 "A slot variable was re-used as a dependency of a " 120 "Trackable object. This is not currently allowed. File a " 121 "feature request if this limitation bothers you.") 122 checkpoint_name = naming_scheme( 123 variable_path=object_names[original_variable], 124 slot_name=slot_name) 125 object_names[slot_variable] = checkpoint_name 126 slot_variable_node_id = len(trackable_objects) 127 node_ids[slot_variable] = slot_variable_node_id 128 trackable_objects.append(slot_variable) 129 slot_variable_proto = ( 130 trackable_object_graph_pb2.TrackableObjectGraph 131 .TrackableObject.SlotVariableReference( 132 slot_name=slot_name, 133 original_variable_node_id=original_variable_node_id, 134 slot_variable_node_id=slot_variable_node_id)) 135 slot_variables.setdefault(trackable, []).append( 136 slot_variable_proto) 137 return slot_variables 138 139 140class ObjectGraphView(object): 141 """Gathers and serializes an object graph.""" 142 143 def __init__(self, root, saveables_cache=None): 144 """Configure the graph view. 145 146 Args: 147 root: A `Trackable` object whose variables (including the variables 148 of dependencies, recursively) should be saved. May be a weak reference. 149 saveables_cache: A dictionary mapping `Trackable` objects -> 150 attribute names -> SaveableObjects, used to avoid re-creating 151 SaveableObjects when graph building. 152 """ 153 self._root_ref = root 154 self._saveables_cache = saveables_cache 155 156 def list_dependencies(self, obj): 157 # pylint: disable=protected-access 158 obj._maybe_initialize_trackable() 159 return obj._checkpoint_dependencies 160 # pylint: enable=protected-access 161 162 @property 163 def saveables_cache(self): 164 """Maps Trackable objects -> attribute names -> list(SaveableObjects). 165 166 Used to avoid re-creating SaveableObjects when graph building. None when 167 executing eagerly. 168 169 Returns: 170 The cache (an object-identity dictionary), or None if caching is disabled. 171 """ 172 return self._saveables_cache 173 174 @property 175 def root(self): 176 if isinstance(self._root_ref, weakref.ref): 177 derefed = self._root_ref() 178 assert derefed is not None 179 return derefed 180 else: 181 return self._root_ref 182 183 def _breadth_first_traversal(self): 184 """Find shortest paths to all dependencies of self.root.""" 185 bfs_sorted = [] 186 to_visit = collections.deque([self.root]) 187 path_to_root = object_identity.ObjectIdentityDictionary() 188 path_to_root[self.root] = () 189 while to_visit: 190 current_trackable = to_visit.popleft() 191 if isinstance(current_trackable, tracking.NotTrackable): 192 raise NotImplementedError( 193 ("The object %s does not support object-based saving. File a " 194 "feature request if this limitation bothers you. In the meantime, " 195 "you can remove the dependency on this object and save everything " 196 "else.") 197 % (current_trackable,)) 198 bfs_sorted.append(current_trackable) 199 for name, dependency in self.list_dependencies(current_trackable): 200 if dependency not in path_to_root: 201 path_to_root[dependency] = ( 202 path_to_root[current_trackable] + ( 203 base.TrackableReference(name, dependency),)) 204 to_visit.append(dependency) 205 return bfs_sorted, path_to_root 206 207 def _add_attributes_to_object_graph( 208 self, trackable_objects, object_graph_proto, node_ids, object_names, 209 object_map): 210 """Create SaveableObjects and corresponding SerializedTensor protos.""" 211 named_saveable_objects = [] 212 if self._saveables_cache is None: 213 # No SaveableObject caching. Either we're executing eagerly, or building a 214 # static save which is specialized to the current Python state. 215 feed_additions = None 216 else: 217 # If we are caching SaveableObjects, we need to build up a feed_dict with 218 # functions computing volatile Python state to be saved with the 219 # checkpoint. 220 feed_additions = {} 221 for checkpoint_id, (trackable, object_proto) in enumerate( 222 zip(trackable_objects, object_graph_proto.nodes)): 223 assert node_ids[trackable] == checkpoint_id 224 object_name = object_names[trackable] 225 if object_map is None: 226 object_to_save = trackable 227 else: 228 object_to_save = object_map.get(trackable, trackable) 229 if self._saveables_cache is not None: 230 cached_attributes = self._saveables_cache.setdefault(object_to_save, {}) 231 else: 232 cached_attributes = None 233 234 for name, saveable_factory in ( 235 object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 236 attribute = object_proto.attributes.add() 237 attribute.name = name 238 attribute.checkpoint_key = "%s/%s/%s" % ( 239 object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) 240 if cached_attributes is None: 241 saveables = None 242 else: 243 saveables = cached_attributes.get(name, None) 244 if saveables is not None: 245 for saveable in saveables: 246 if attribute.checkpoint_key not in saveable.name: 247 # The checkpoint key for this SaveableObject is different. We 248 # need to re-create it. 249 saveables = None 250 del cached_attributes[name] 251 break 252 if saveables is None: 253 if callable(saveable_factory): 254 maybe_saveable = saveable_factory(name=attribute.checkpoint_key) 255 else: 256 maybe_saveable = saveable_factory 257 if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): 258 saveables = (maybe_saveable,) 259 else: 260 # Figure out the name-based Saver's name for this variable. If it's 261 # already a SaveableObject we'd just get the checkpoint key back, so 262 # we leave full_name blank. 263 saver_dict = saveable_object_util.op_list_to_dict( 264 [maybe_saveable], convert_variable_to_tensor=False) 265 full_name, = saver_dict.keys() 266 saveables = tuple(saveable_object_util.saveable_objects_for_op( 267 op=maybe_saveable, name=attribute.checkpoint_key)) 268 for saveable in saveables: 269 saveable.full_name = full_name 270 for saveable in saveables: 271 if attribute.checkpoint_key not in saveable.name: 272 raise AssertionError( 273 ("The object %s produced a SaveableObject with name '%s' for " 274 "attribute '%s'. Expected a name containing '%s'.") 275 % (trackable, name, saveable.name, 276 attribute.checkpoint_key)) 277 if cached_attributes is not None: 278 cached_attributes[name] = saveables 279 280 optional_restore = None 281 for saveable in saveables: 282 if optional_restore is None: 283 optional_restore = saveable.optional_restore 284 else: 285 optional_restore = optional_restore and saveable.optional_restore 286 287 if hasattr(saveable, "full_name"): 288 attribute.full_name = saveable.full_name 289 if isinstance(saveable, base.PythonStateSaveable): 290 if feed_additions is None: 291 assert self._saveables_cache is None 292 # If we're not caching saveables, then we're either executing 293 # eagerly or building a static save/restore (e.g. for a 294 # SavedModel). In either case, we should embed the current Python 295 # state in the graph rather than relying on a feed dict. 296 saveable = saveable.freeze() 297 else: 298 saveable_feed_dict = saveable.feed_dict_additions() 299 for new_feed_key in saveable_feed_dict.keys(): 300 if new_feed_key in feed_additions: 301 raise AssertionError( 302 ("The object %s tried to feed a value for the Tensor %s " 303 "when saving, but another object is already feeding a " 304 "value.") 305 % (trackable, new_feed_key)) 306 feed_additions.update(saveable_feed_dict) 307 named_saveable_objects.append(saveable) 308 if optional_restore is None: 309 optional_restore = False 310 attribute.optional_restore = optional_restore 311 312 return named_saveable_objects, feed_additions 313 314 def _fill_object_graph_proto(self, trackable_objects, 315 node_ids, 316 slot_variables, 317 object_graph_proto=None): 318 """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" 319 if object_graph_proto is None: 320 object_graph_proto = ( 321 trackable_object_graph_pb2.TrackableObjectGraph()) 322 for checkpoint_id, trackable in enumerate(trackable_objects): 323 assert node_ids[trackable] == checkpoint_id 324 object_proto = object_graph_proto.nodes.add() 325 object_proto.slot_variables.extend(slot_variables.get(trackable, ())) 326 for child in self.list_dependencies(trackable): 327 child_proto = object_proto.children.add() 328 child_proto.node_id = node_ids[child.ref] 329 child_proto.local_name = child.name 330 return object_graph_proto 331 332 def _serialize_gathered_objects(self, trackable_objects, path_to_root, 333 object_map=None): 334 """Create SaveableObjects and protos for gathered objects.""" 335 object_names = object_identity.ObjectIdentityDictionary() 336 for obj, path in path_to_root.items(): 337 object_names[obj] = _object_prefix_from_path(path) 338 node_ids = object_identity.ObjectIdentityDictionary() 339 for node_id, node in enumerate(trackable_objects): 340 node_ids[node] = node_id 341 slot_variables = _serialize_slot_variables( 342 trackable_objects=trackable_objects, 343 node_ids=node_ids, 344 object_names=object_names) 345 object_graph_proto = self._fill_object_graph_proto( 346 trackable_objects=trackable_objects, 347 node_ids=node_ids, 348 slot_variables=slot_variables) 349 named_saveable_objects, feed_additions = ( 350 self._add_attributes_to_object_graph( 351 trackable_objects=trackable_objects, 352 object_graph_proto=object_graph_proto, 353 node_ids=node_ids, 354 object_names=object_names, 355 object_map=object_map)) 356 return named_saveable_objects, object_graph_proto, feed_additions 357 358 def serialize_object_graph(self): 359 """Determine checkpoint keys for variables and build a serialized graph. 360 361 Non-slot variables are keyed based on a shortest path from the root saveable 362 to the object which owns the variable (i.e. the one which called 363 `Trackable._add_variable` to create it). 364 365 Slot variables are keyed based on a shortest path to the variable being 366 slotted for, a shortest path to their optimizer, and the slot name. 367 368 Returns: 369 A tuple of (named_variables, object_graph_proto, feed_additions): 370 named_variables: A dictionary mapping names to variable objects. 371 object_graph_proto: A TrackableObjectGraph protocol buffer 372 containing the serialized object graph and variable references. 373 feed_additions: A dictionary mapping from Tensors to values which should 374 be fed when saving. 375 376 Raises: 377 ValueError: If there are invalid characters in an optimizer's slot names. 378 """ 379 trackable_objects, path_to_root = self._breadth_first_traversal() 380 return self._serialize_gathered_objects( 381 trackable_objects, path_to_root) 382 383 def frozen_saveable_objects(self, object_map=None, to_graph=None): 384 """Creates SaveableObjects with the current object graph frozen.""" 385 trackable_objects, path_to_root = self._breadth_first_traversal() 386 if to_graph: 387 target_context = to_graph.as_default 388 else: 389 target_context = ops.NullContextmanager 390 with target_context(): 391 named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects( 392 trackable_objects, 393 path_to_root, 394 object_map) 395 with ops.device("/cpu:0"): 396 object_graph_tensor = constant_op.constant( 397 graph_proto.SerializeToString(), dtype=dtypes.string) 398 named_saveable_objects.append( 399 base.NoRestoreSaveable( 400 tensor=object_graph_tensor, 401 name=base.OBJECT_GRAPH_PROTO_KEY)) 402 return named_saveable_objects 403 404 def objects_ids_and_slot_variables(self): 405 """Traverse the object graph and list all accessible objects. 406 407 Looks for `Trackable` objects which are dependencies of 408 `root_trackable`. Includes slot variables only if the variable they are 409 slotting for and the optimizer are dependencies of `root_trackable` 410 (i.e. if they would be saved with a checkpoint). 411 412 Returns: 413 A tuple of (trackable objects, object -> node id, slot variables) 414 """ 415 trackable_objects, path_to_root = self._breadth_first_traversal() 416 object_names = object_identity.ObjectIdentityDictionary() 417 for obj, path in path_to_root.items(): 418 object_names[obj] = _object_prefix_from_path(path) 419 node_ids = object_identity.ObjectIdentityDictionary() 420 for node_id, node in enumerate(trackable_objects): 421 node_ids[node] = node_id 422 slot_variables = _serialize_slot_variables( 423 trackable_objects=trackable_objects, 424 node_ids=node_ids, 425 object_names=object_names) 426 return trackable_objects, node_ids, slot_variables 427 428 def list_objects(self): 429 """Traverse the object graph and list all accessible objects.""" 430 trackable_objects, _, _ = self.objects_ids_and_slot_variables() 431 return trackable_objects 432