1"""Manages a graph of Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import weakref 22 23from tensorflow.core.protobuf import trackable_object_graph_pb2 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.training import optimizer as optimizer_v1 28from tensorflow.python.training.saving import saveable_object as saveable_object_lib 29from tensorflow.python.training.saving import saveable_object_util 30from tensorflow.python.training.tracking import base 31from tensorflow.python.training.tracking import tracking 32from tensorflow.python.util import object_identity 33 34 35_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. 36 37# Keyword for identifying that the next bit of a checkpoint variable name is a 38# slot name. Checkpoint names for slot variables look like: 39# 40# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name> 41# 42# Where <path to variable> is a full path from the checkpoint root to the 43# variable being slotted for. 44_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" 45# Keyword for separating the path to an object from the name of an 46# attribute in checkpoint names. Used like: 47# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute> 48_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" 49 50 51def _escape_local_name(name): 52 # We need to support slashes in local names for compatibility, since this 53 # naming scheme is being patched in to things like Layer.add_variable where 54 # slashes were previously accepted. We also want to use slashes to indicate 55 # edges traversed to reach the variable, so we escape forward slashes in 56 # names. 57 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) 58 .replace(r"/", _ESCAPE_CHAR + "S")) 59 60 61def _object_prefix_from_path(path_to_root): 62 return "/".join( 63 (_escape_local_name(trackable.name) 64 for trackable in path_to_root)) 65 66 67def _slot_variable_naming_for_optimizer(optimizer_path): 68 """Make a function for naming slot variables in an optimizer.""" 69 # Name slot variables: 70 # 71 # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name> 72 # 73 # where <variable name> is exactly the checkpoint name used for the original 74 # variable, including the path from the checkpoint root and the local name in 75 # the object which owns it. Note that we only save slot variables if the 76 # variable it's slotting for is also being saved. 77 78 optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) 79 80 def _name_slot_variable(variable_path, slot_name): 81 """With an optimizer specified, name a slot variable.""" 82 return (variable_path 83 + optimizer_identifier 84 + _escape_local_name(slot_name)) 85 86 return _name_slot_variable 87 88 89def _serialize_slot_variables(trackable_objects, node_ids, object_names): 90 """Gather and name slot variables.""" 91 non_slot_objects = list(trackable_objects) 92 slot_variables = object_identity.ObjectIdentityDictionary() 93 for trackable in non_slot_objects: 94 if (isinstance(trackable, optimizer_v1.Optimizer) 95 # TODO(b/110718070): Fix Keras imports. 96 # Note: dir() is used rather than hasattr() here to avoid triggering 97 # custom __getattr__ code, see b/152031870 for context. 98 or "_create_or_restore_slot_variable" in dir(trackable)): 99 naming_scheme = _slot_variable_naming_for_optimizer( 100 optimizer_path=object_names[trackable]) 101 slot_names = trackable.get_slot_names() 102 for slot_name in slot_names: 103 for original_variable_node_id, original_variable in enumerate( 104 non_slot_objects): 105 try: 106 slot_variable = trackable.get_slot( 107 original_variable, slot_name) 108 except (AttributeError, KeyError): 109 slot_variable = None 110 if slot_variable is None: 111 continue 112 slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access 113 if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access 114 # TODO(allenl): Gather dependencies of slot variables. 115 raise NotImplementedError( 116 "Currently only variables with no dependencies can be saved as " 117 "slot variables. File a feature request if this limitation " 118 "bothers you.") 119 if slot_variable in node_ids: 120 raise NotImplementedError( 121 ("A slot variable was re-used as a dependency of a " 122 "Trackable object: %s. This is not currently " 123 "allowed. File a feature request if this limitation bothers " 124 "you.") % slot_variable) 125 checkpoint_name = naming_scheme( 126 variable_path=object_names[original_variable], 127 slot_name=slot_name) 128 object_names[slot_variable] = checkpoint_name 129 slot_variable_node_id = len(trackable_objects) 130 node_ids[slot_variable] = slot_variable_node_id 131 trackable_objects.append(slot_variable) 132 slot_variable_proto = ( 133 trackable_object_graph_pb2.TrackableObjectGraph 134 .TrackableObject.SlotVariableReference( 135 slot_name=slot_name, 136 original_variable_node_id=original_variable_node_id, 137 slot_variable_node_id=slot_variable_node_id)) 138 slot_variables.setdefault(trackable, []).append( 139 slot_variable_proto) 140 return slot_variables 141 142 143class ObjectGraphView(object): 144 """Gathers and serializes an object graph.""" 145 146 def __init__(self, root, saveables_cache=None, attached_dependencies=None): 147 """Configure the graph view. 148 149 Args: 150 root: A `Trackable` object whose variables (including the variables 151 of dependencies, recursively) should be saved. May be a weak reference. 152 saveables_cache: A dictionary mapping `Trackable` objects -> 153 attribute names -> SaveableObjects, used to avoid re-creating 154 SaveableObjects when graph building. 155 attached_dependencies: Dependencies to attach to the root object. Used 156 when saving a Checkpoint with a defined root object. 157 """ 158 self._root_ref = root 159 self._saveables_cache = saveables_cache 160 self._attached_dependencies = attached_dependencies 161 162 def list_dependencies(self, obj): 163 # pylint: disable=protected-access 164 obj._maybe_initialize_trackable() 165 dependencies = obj._checkpoint_dependencies 166 # pylint: enable=protected-access 167 168 if obj is self.root and self._attached_dependencies: 169 dependencies = dependencies.copy() 170 dependencies.extend(self._attached_dependencies) 171 return dependencies 172 173 @property 174 def saveables_cache(self): 175 """Maps Trackable objects -> attribute names -> list(SaveableObjects). 176 177 Used to avoid re-creating SaveableObjects when graph building. None when 178 executing eagerly. 179 180 Returns: 181 The cache (an object-identity dictionary), or None if caching is disabled. 182 """ 183 return self._saveables_cache 184 185 @property 186 def attached_dependencies(self): 187 """Returns list of dependencies that should be saved in the checkpoint. 188 189 These dependencies are not tracked by root, but are in the checkpoint. 190 This is defined when the user creates a Checkpoint with both root and kwargs 191 set. 192 193 Returns: 194 A list of TrackableReferences. 195 """ 196 return self._attached_dependencies 197 198 @property 199 def root(self): 200 if isinstance(self._root_ref, weakref.ref): 201 derefed = self._root_ref() 202 assert derefed is not None 203 return derefed 204 else: 205 return self._root_ref 206 207 def _breadth_first_traversal(self): 208 """Find shortest paths to all dependencies of self.root.""" 209 bfs_sorted = [] 210 to_visit = collections.deque([self.root]) 211 path_to_root = object_identity.ObjectIdentityDictionary() 212 path_to_root[self.root] = () 213 while to_visit: 214 current_trackable = to_visit.popleft() 215 if isinstance(current_trackable, tracking.NotTrackable): 216 raise NotImplementedError( 217 ("The object %s does not support object-based saving. File a " 218 "feature request if this limitation bothers you. In the meantime, " 219 "you can remove the dependency on this object and save everything " 220 "else.") 221 % (current_trackable,)) 222 bfs_sorted.append(current_trackable) 223 for name, dependency in self.list_dependencies(current_trackable): 224 if dependency not in path_to_root: 225 path_to_root[dependency] = ( 226 path_to_root[current_trackable] + ( 227 base.TrackableReference(name, dependency),)) 228 to_visit.append(dependency) 229 return bfs_sorted, path_to_root 230 231 def _add_attributes_to_object_graph( 232 self, trackable_objects, object_graph_proto, node_ids, object_names, 233 object_map, call_with_mapped_captures): 234 """Create SaveableObjects and corresponding SerializedTensor protos.""" 235 named_saveable_objects = [] 236 if self._saveables_cache is None: 237 # No SaveableObject caching. Either we're executing eagerly, or building a 238 # static save which is specialized to the current Python state. 239 feed_additions = None 240 else: 241 # If we are caching SaveableObjects, we need to build up a feed_dict with 242 # functions computing volatile Python state to be saved with the 243 # checkpoint. 244 feed_additions = {} 245 for checkpoint_id, (trackable, object_proto) in enumerate( 246 zip(trackable_objects, object_graph_proto.nodes)): 247 assert node_ids[trackable] == checkpoint_id 248 object_name = object_names[trackable] 249 if object_map is None: 250 object_to_save = trackable 251 else: 252 object_to_save = object_map.get(trackable, trackable) 253 if self._saveables_cache is not None: 254 cached_attributes = self._saveables_cache.setdefault(object_to_save, {}) 255 else: 256 cached_attributes = None 257 258 for name, saveable_factory in ( 259 object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 260 attribute = object_proto.attributes.add() 261 attribute.name = name 262 attribute.checkpoint_key = "%s/%s/%s" % ( 263 object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) 264 if cached_attributes is None: 265 saveables = None 266 else: 267 saveables = cached_attributes.get(name, None) 268 if saveables is not None: 269 for saveable in saveables: 270 if attribute.checkpoint_key not in saveable.name: 271 # The checkpoint key for this SaveableObject is different. We 272 # need to re-create it. 273 saveables = None 274 del cached_attributes[name] 275 break 276 if saveables is None: 277 if callable(saveable_factory): 278 maybe_saveable = saveable_object_util.create_saveable_object( 279 saveable_factory, attribute.checkpoint_key, 280 call_with_mapped_captures) 281 else: 282 maybe_saveable = saveable_factory 283 if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): 284 saveables = (maybe_saveable,) 285 else: 286 # Figure out the name-based Saver's name for this variable. If it's 287 # already a SaveableObject we'd just get the checkpoint key back, so 288 # we leave full_name blank. 289 saver_dict = saveable_object_util.op_list_to_dict( 290 [maybe_saveable], convert_variable_to_tensor=False) 291 full_name, = saver_dict.keys() 292 saveables = tuple(saveable_object_util.saveable_objects_for_op( 293 op=maybe_saveable, name=attribute.checkpoint_key)) 294 for saveable in saveables: 295 saveable.full_name = full_name 296 for saveable in saveables: 297 if attribute.checkpoint_key not in saveable.name: 298 raise AssertionError( 299 ("The object %s produced a SaveableObject with name '%s' for " 300 "attribute '%s'. Expected a name containing '%s'.") 301 % (trackable, name, saveable.name, 302 attribute.checkpoint_key)) 303 if cached_attributes is not None: 304 cached_attributes[name] = saveables 305 306 optional_restore = None 307 for saveable in saveables: 308 if optional_restore is None: 309 optional_restore = saveable.optional_restore 310 else: 311 optional_restore = optional_restore and saveable.optional_restore 312 313 if hasattr(saveable, "full_name"): 314 attribute.full_name = saveable.full_name 315 if isinstance(saveable, base.PythonStateSaveable): 316 if feed_additions is None: 317 assert self._saveables_cache is None 318 # If we're not caching saveables, then we're either executing 319 # eagerly or building a static save/restore (e.g. for a 320 # SavedModel). In either case, we should embed the current Python 321 # state in the graph rather than relying on a feed dict. 322 saveable = saveable.freeze() 323 else: 324 saveable_feed_dict = saveable.feed_dict_additions() 325 for new_feed_key in saveable_feed_dict.keys(): 326 if new_feed_key in feed_additions: 327 raise AssertionError( 328 ("The object %s tried to feed a value for the Tensor %s " 329 "when saving, but another object is already feeding a " 330 "value.") 331 % (trackable, new_feed_key)) 332 feed_additions.update(saveable_feed_dict) 333 named_saveable_objects.append(saveable) 334 if optional_restore is None: 335 optional_restore = False 336 attribute.optional_restore = optional_restore 337 338 return named_saveable_objects, feed_additions 339 340 def _fill_object_graph_proto(self, trackable_objects, 341 node_ids, 342 slot_variables, 343 object_graph_proto=None): 344 """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" 345 if object_graph_proto is None: 346 object_graph_proto = ( 347 trackable_object_graph_pb2.TrackableObjectGraph()) 348 for checkpoint_id, trackable in enumerate(trackable_objects): 349 assert node_ids[trackable] == checkpoint_id 350 object_proto = object_graph_proto.nodes.add() 351 object_proto.slot_variables.extend(slot_variables.get(trackable, ())) 352 for child in self.list_dependencies(trackable): 353 child_proto = object_proto.children.add() 354 child_proto.node_id = node_ids[child.ref] 355 child_proto.local_name = child.name 356 return object_graph_proto 357 358 def _serialize_gathered_objects(self, trackable_objects, path_to_root, 359 object_map=None, 360 call_with_mapped_captures=None): 361 """Create SaveableObjects and protos for gathered objects.""" 362 object_names = object_identity.ObjectIdentityDictionary() 363 for obj, path in path_to_root.items(): 364 object_names[obj] = _object_prefix_from_path(path) 365 node_ids = object_identity.ObjectIdentityDictionary() 366 for node_id, node in enumerate(trackable_objects): 367 node_ids[node] = node_id 368 slot_variables = _serialize_slot_variables( 369 trackable_objects=trackable_objects, 370 node_ids=node_ids, 371 object_names=object_names) 372 object_graph_proto = self._fill_object_graph_proto( 373 trackable_objects=trackable_objects, 374 node_ids=node_ids, 375 slot_variables=slot_variables) 376 named_saveable_objects, feed_additions = ( 377 self._add_attributes_to_object_graph( 378 trackable_objects=trackable_objects, 379 object_graph_proto=object_graph_proto, 380 node_ids=node_ids, 381 object_names=object_names, 382 object_map=object_map, 383 call_with_mapped_captures=call_with_mapped_captures)) 384 return named_saveable_objects, object_graph_proto, feed_additions 385 386 def serialize_object_graph(self): 387 """Determine checkpoint keys for variables and build a serialized graph. 388 389 Non-slot variables are keyed based on a shortest path from the root saveable 390 to the object which owns the variable (i.e. the one which called 391 `Trackable._add_variable` to create it). 392 393 Slot variables are keyed based on a shortest path to the variable being 394 slotted for, a shortest path to their optimizer, and the slot name. 395 396 Returns: 397 A tuple of (named_variables, object_graph_proto, feed_additions): 398 named_variables: A dictionary mapping names to variable objects. 399 object_graph_proto: A TrackableObjectGraph protocol buffer 400 containing the serialized object graph and variable references. 401 feed_additions: A dictionary mapping from Tensors to values which should 402 be fed when saving. 403 404 Raises: 405 ValueError: If there are invalid characters in an optimizer's slot names. 406 """ 407 trackable_objects, path_to_root = self._breadth_first_traversal() 408 return self._serialize_gathered_objects( 409 trackable_objects, path_to_root) 410 411 def frozen_saveable_objects(self, object_map=None, to_graph=None, 412 call_with_mapped_captures=None): 413 """Creates SaveableObjects with the current object graph frozen.""" 414 trackable_objects, path_to_root = self._breadth_first_traversal() 415 if to_graph: 416 target_context = to_graph.as_default 417 else: 418 target_context = ops.NullContextmanager 419 with target_context(): 420 named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects( 421 trackable_objects, 422 path_to_root, 423 object_map, 424 call_with_mapped_captures) 425 with ops.device("/cpu:0"): 426 object_graph_tensor = constant_op.constant( 427 graph_proto.SerializeToString(), dtype=dtypes.string) 428 named_saveable_objects.append( 429 base.NoRestoreSaveable( 430 tensor=object_graph_tensor, 431 name=base.OBJECT_GRAPH_PROTO_KEY)) 432 return named_saveable_objects 433 434 def objects_ids_and_slot_variables_and_paths(self): 435 """Traverse the object graph and list all accessible objects. 436 437 Looks for `Trackable` objects which are dependencies of 438 `root_trackable`. Includes slot variables only if the variable they are 439 slotting for and the optimizer are dependencies of `root_trackable` 440 (i.e. if they would be saved with a checkpoint). 441 442 Returns: 443 A tuple of (trackable objects, paths from root for each object, 444 object -> node id, slot variables) 445 """ 446 trackable_objects, path_to_root = self._breadth_first_traversal() 447 object_names = object_identity.ObjectIdentityDictionary() 448 for obj, path in path_to_root.items(): 449 object_names[obj] = _object_prefix_from_path(path) 450 node_ids = object_identity.ObjectIdentityDictionary() 451 for node_id, node in enumerate(trackable_objects): 452 node_ids[node] = node_id 453 slot_variables = _serialize_slot_variables( 454 trackable_objects=trackable_objects, 455 node_ids=node_ids, 456 object_names=object_names) 457 return trackable_objects, path_to_root, node_ids, slot_variables 458 459 def objects_ids_and_slot_variables(self): 460 trackable_objects, _, node_ids, slot_variables = ( 461 self.objects_ids_and_slot_variables_and_paths()) 462 return trackable_objects, node_ids, slot_variables 463 464 def list_objects(self): 465 """Traverse the object graph and list all accessible objects.""" 466 trackable_objects, _, _ = self.objects_ids_and_slot_variables() 467 return trackable_objects 468