1"""Manages a graph of Trackable objects.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import weakref 22 23from tensorflow.core.protobuf import trackable_object_graph_pb2 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.training import optimizer as optimizer_v1 28from tensorflow.python.training.saving import saveable_object as saveable_object_lib 29from tensorflow.python.training.saving import saveable_object_util 30from tensorflow.python.training.tracking import base 31from tensorflow.python.util import object_identity 32from tensorflow.python.util.tf_export import tf_export 33 34 35_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. 36 37# Keyword for identifying that the next bit of a checkpoint variable name is a 38# slot name. Checkpoint names for slot variables look like: 39# 40# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name> 41# 42# Where <path to variable> is a full path from the checkpoint root to the 43# variable being slotted for. 44_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" 45# Keyword for separating the path to an object from the name of an 46# attribute in checkpoint names. Used like: 47# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute> 48_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" 49 50 51def _escape_local_name(name): 52 # We need to support slashes in local names for compatibility, since this 53 # naming scheme is being patched in to things like Layer.add_variable where 54 # slashes were previously accepted. We also want to use slashes to indicate 55 # edges traversed to reach the variable, so we escape forward slashes in 56 # names. 57 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) 58 .replace(r"/", _ESCAPE_CHAR + "S")) 59 60 61def _object_prefix_from_path(path_to_root): 62 return "/".join( 63 (_escape_local_name(trackable.name) 64 for trackable in path_to_root)) 65 66 67def _slot_variable_naming_for_optimizer(optimizer_path): 68 """Make a function for naming slot variables in an optimizer.""" 69 # Name slot variables: 70 # 71 # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name> 72 # 73 # where <variable name> is exactly the checkpoint name used for the original 74 # variable, including the path from the checkpoint root and the local name in 75 # the object which owns it. Note that we only save slot variables if the 76 # variable it's slotting for is also being saved. 77 78 optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) 79 80 def _name_slot_variable(variable_path, slot_name): 81 """With an optimizer specified, name a slot variable.""" 82 return (variable_path 83 + optimizer_identifier 84 + _escape_local_name(slot_name)) 85 86 return _name_slot_variable 87 88 89def _serialize_slot_variables(trackable_objects, node_ids, object_names): 90 """Gather and name slot variables.""" 91 non_slot_objects = list(trackable_objects) 92 slot_variables = object_identity.ObjectIdentityDictionary() 93 for trackable in non_slot_objects: 94 if (isinstance(trackable, optimizer_v1.Optimizer) 95 # TODO(b/110718070): Fix Keras imports. 96 # Note: dir() is used rather than hasattr() here to avoid triggering 97 # custom __getattr__ code, see b/152031870 for context. 98 or "_create_or_restore_slot_variable" in dir(trackable)): 99 naming_scheme = _slot_variable_naming_for_optimizer( 100 optimizer_path=object_names[trackable]) 101 slot_names = trackable.get_slot_names() 102 for slot_name in slot_names: 103 for original_variable_node_id, original_variable in enumerate( 104 non_slot_objects): 105 try: 106 slot_variable = trackable.get_slot( 107 original_variable, slot_name) 108 except (AttributeError, KeyError): 109 slot_variable = None 110 if slot_variable is None: 111 continue 112 slot_variable._maybe_initialize_trackable() # pylint: disable=protected-access 113 if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access 114 # TODO(allenl): Gather dependencies of slot variables. 115 raise NotImplementedError( 116 "Currently only variables with no dependencies can be saved as " 117 "slot variables. File a feature request if this limitation " 118 "bothers you.") 119 if slot_variable in node_ids: 120 raise NotImplementedError( 121 ("A slot variable was re-used as a dependency of a " 122 "Trackable object: %s. This is not currently " 123 "allowed. File a feature request if this limitation bothers " 124 "you.") % slot_variable) 125 checkpoint_name = naming_scheme( 126 variable_path=object_names[original_variable], 127 slot_name=slot_name) 128 object_names[slot_variable] = checkpoint_name 129 slot_variable_node_id = len(trackable_objects) 130 node_ids[slot_variable] = slot_variable_node_id 131 trackable_objects.append(slot_variable) 132 slot_variable_proto = ( 133 trackable_object_graph_pb2.TrackableObjectGraph 134 .TrackableObject.SlotVariableReference( 135 slot_name=slot_name, 136 original_variable_node_id=original_variable_node_id, 137 slot_variable_node_id=slot_variable_node_id)) 138 slot_variables.setdefault(trackable, []).append( 139 slot_variable_proto) 140 return slot_variables 141 142 143@tf_export("__internal__.tracking.ObjectGraphView", v1=[]) 144class ObjectGraphView(object): 145 """Gathers and serializes an object graph.""" 146 147 def __init__(self, root, saveables_cache=None, attached_dependencies=None): 148 """Configure the graph view. 149 150 Args: 151 root: A `Trackable` object whose variables (including the variables 152 of dependencies, recursively) should be saved. May be a weak reference. 153 saveables_cache: A dictionary mapping `Trackable` objects -> 154 attribute names -> SaveableObjects, used to avoid re-creating 155 SaveableObjects when graph building. 156 attached_dependencies: Dependencies to attach to the root object. Used 157 when saving a Checkpoint with a defined root object. 158 """ 159 self._root_ref = root 160 self._saveables_cache = saveables_cache 161 self._attached_dependencies = attached_dependencies 162 163 def list_dependencies(self, obj): 164 # pylint: disable=protected-access 165 obj._maybe_initialize_trackable() 166 dependencies = obj._checkpoint_dependencies 167 # pylint: enable=protected-access 168 169 if obj is self.root and self._attached_dependencies: 170 dependencies = dependencies.copy() 171 dependencies.extend(self._attached_dependencies) 172 return dependencies 173 174 @property 175 def saveables_cache(self): 176 """Maps Trackable objects -> attribute names -> list(SaveableObjects). 177 178 Used to avoid re-creating SaveableObjects when graph building. None when 179 executing eagerly. 180 181 Returns: 182 The cache (an object-identity dictionary), or None if caching is disabled. 183 """ 184 return self._saveables_cache 185 186 @property 187 def attached_dependencies(self): 188 """Returns list of dependencies that should be saved in the checkpoint. 189 190 These dependencies are not tracked by root, but are in the checkpoint. 191 This is defined when the user creates a Checkpoint with both root and kwargs 192 set. 193 194 Returns: 195 A list of TrackableReferences. 196 """ 197 return self._attached_dependencies 198 199 @property 200 def root(self): 201 if isinstance(self._root_ref, weakref.ref): 202 derefed = self._root_ref() 203 assert derefed is not None 204 return derefed 205 else: 206 return self._root_ref 207 208 def _breadth_first_traversal(self): 209 """Find shortest paths to all dependencies of self.root.""" 210 bfs_sorted = [] 211 to_visit = collections.deque([self.root]) 212 path_to_root = object_identity.ObjectIdentityDictionary() 213 path_to_root[self.root] = () 214 while to_visit: 215 current_trackable = to_visit.popleft() 216 bfs_sorted.append(current_trackable) 217 for name, dependency in self.list_dependencies(current_trackable): 218 if dependency not in path_to_root: 219 path_to_root[dependency] = ( 220 path_to_root[current_trackable] + ( 221 base.TrackableReference(name, dependency),)) 222 to_visit.append(dependency) 223 return bfs_sorted, path_to_root 224 225 def _add_attributes_to_object_graph( 226 self, trackable_objects, object_graph_proto, node_ids, object_names, 227 object_map, call_with_mapped_captures): 228 """Create SaveableObjects and corresponding SerializedTensor protos.""" 229 named_saveable_objects = [] 230 if self._saveables_cache is None: 231 # No SaveableObject caching. Either we're executing eagerly, or building a 232 # static save which is specialized to the current Python state. 233 feed_additions = None 234 else: 235 # If we are caching SaveableObjects, we need to build up a feed_dict with 236 # functions computing volatile Python state to be saved with the 237 # checkpoint. 238 feed_additions = {} 239 for checkpoint_id, (trackable, object_proto) in enumerate( 240 zip(trackable_objects, object_graph_proto.nodes)): 241 assert node_ids[trackable] == checkpoint_id 242 object_name = object_names[trackable] 243 if object_map is None: 244 object_to_save = trackable 245 else: 246 object_to_save = object_map.get(trackable, trackable) 247 if self._saveables_cache is not None: 248 cached_attributes = self._saveables_cache.setdefault(object_to_save, {}) 249 else: 250 cached_attributes = None 251 252 for name, saveable_factory in ( 253 object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access 254 attribute = object_proto.attributes.add() 255 attribute.name = name 256 attribute.checkpoint_key = "%s/%s/%s" % ( 257 object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) 258 if cached_attributes is None: 259 saveables = None 260 else: 261 saveables = cached_attributes.get(name, None) 262 if saveables is not None: 263 for saveable in saveables: 264 if attribute.checkpoint_key not in saveable.name: 265 # The checkpoint key for this SaveableObject is different. We 266 # need to re-create it. 267 saveables = None 268 del cached_attributes[name] 269 break 270 if saveables is None: 271 if callable(saveable_factory): 272 maybe_saveable = saveable_object_util.create_saveable_object( 273 saveable_factory, attribute.checkpoint_key, 274 call_with_mapped_captures) 275 else: 276 maybe_saveable = saveable_factory 277 if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): 278 saveables = (maybe_saveable,) 279 else: 280 # Figure out the name-based Saver's name for this variable. If it's 281 # already a SaveableObject we'd just get the checkpoint key back, so 282 # we leave full_name blank. 283 saver_dict = saveable_object_util.op_list_to_dict( 284 [maybe_saveable], convert_variable_to_tensor=False) 285 full_name, = saver_dict.keys() 286 saveables = tuple(saveable_object_util.saveable_objects_for_op( 287 op=maybe_saveable, name=attribute.checkpoint_key)) 288 for saveable in saveables: 289 saveable.full_name = full_name 290 for saveable in saveables: 291 if attribute.checkpoint_key not in saveable.name: 292 raise AssertionError( 293 ("The object %s produced a SaveableObject with name '%s' for " 294 "attribute '%s'. Expected a name containing '%s'.") 295 % (trackable, name, saveable.name, 296 attribute.checkpoint_key)) 297 if cached_attributes is not None: 298 cached_attributes[name] = saveables 299 300 optional_restore = None 301 for saveable in saveables: 302 if optional_restore is None: 303 optional_restore = saveable.optional_restore 304 else: 305 optional_restore = optional_restore and saveable.optional_restore 306 307 if hasattr(saveable, "full_name"): 308 attribute.full_name = saveable.full_name 309 if isinstance(saveable, base.PythonStateSaveable): 310 if feed_additions is None: 311 assert self._saveables_cache is None 312 # If we're not caching saveables, then we're either executing 313 # eagerly or building a static save/restore (e.g. for a 314 # SavedModel). In either case, we should embed the current Python 315 # state in the graph rather than relying on a feed dict. 316 saveable = saveable.freeze() 317 else: 318 saveable_feed_dict = saveable.feed_dict_additions() 319 for new_feed_key in saveable_feed_dict.keys(): 320 if new_feed_key in feed_additions: 321 raise AssertionError( 322 ("The object %s tried to feed a value for the Tensor %s " 323 "when saving, but another object is already feeding a " 324 "value.") 325 % (trackable, new_feed_key)) 326 feed_additions.update(saveable_feed_dict) 327 named_saveable_objects.append(saveable) 328 if optional_restore is None: 329 optional_restore = False 330 attribute.optional_restore = optional_restore 331 332 return named_saveable_objects, feed_additions 333 334 def _fill_object_graph_proto(self, trackable_objects, 335 node_ids, 336 slot_variables, 337 object_graph_proto=None): 338 """Name non-slot `Trackable`s and add them to `object_graph_proto`.""" 339 if object_graph_proto is None: 340 object_graph_proto = ( 341 trackable_object_graph_pb2.TrackableObjectGraph()) 342 for checkpoint_id, trackable in enumerate(trackable_objects): 343 assert node_ids[trackable] == checkpoint_id 344 object_proto = object_graph_proto.nodes.add() 345 object_proto.slot_variables.extend(slot_variables.get(trackable, ())) 346 for child in self.list_dependencies(trackable): 347 child_proto = object_proto.children.add() 348 child_proto.node_id = node_ids[child.ref] 349 child_proto.local_name = child.name 350 return object_graph_proto 351 352 def _serialize_gathered_objects(self, trackable_objects, path_to_root, 353 object_map=None, 354 call_with_mapped_captures=None): 355 """Create SaveableObjects and protos for gathered objects.""" 356 object_names = object_identity.ObjectIdentityDictionary() 357 for obj, path in path_to_root.items(): 358 object_names[obj] = _object_prefix_from_path(path) 359 node_ids = object_identity.ObjectIdentityDictionary() 360 for node_id, node in enumerate(trackable_objects): 361 node_ids[node] = node_id 362 slot_variables = _serialize_slot_variables( 363 trackable_objects=trackable_objects, 364 node_ids=node_ids, 365 object_names=object_names) 366 object_graph_proto = self._fill_object_graph_proto( 367 trackable_objects=trackable_objects, 368 node_ids=node_ids, 369 slot_variables=slot_variables) 370 named_saveable_objects, feed_additions = ( 371 self._add_attributes_to_object_graph( 372 trackable_objects=trackable_objects, 373 object_graph_proto=object_graph_proto, 374 node_ids=node_ids, 375 object_names=object_names, 376 object_map=object_map, 377 call_with_mapped_captures=call_with_mapped_captures)) 378 return named_saveable_objects, object_graph_proto, feed_additions 379 380 def serialize_object_graph(self): 381 """Determine checkpoint keys for variables and build a serialized graph. 382 383 Non-slot variables are keyed based on a shortest path from the root saveable 384 to the object which owns the variable (i.e. the one which called 385 `Trackable._add_variable` to create it). 386 387 Slot variables are keyed based on a shortest path to the variable being 388 slotted for, a shortest path to their optimizer, and the slot name. 389 390 Returns: 391 A tuple of (named_variables, object_graph_proto, feed_additions): 392 named_variables: A dictionary mapping names to variable objects. 393 object_graph_proto: A TrackableObjectGraph protocol buffer 394 containing the serialized object graph and variable references. 395 feed_additions: A dictionary mapping from Tensors to values which should 396 be fed when saving. 397 398 Raises: 399 ValueError: If there are invalid characters in an optimizer's slot names. 400 """ 401 trackable_objects, path_to_root = self._breadth_first_traversal() 402 return self._serialize_gathered_objects( 403 trackable_objects, path_to_root) 404 405 def frozen_saveable_objects(self, object_map=None, to_graph=None, 406 call_with_mapped_captures=None): 407 """Creates SaveableObjects with the current object graph frozen.""" 408 trackable_objects, path_to_root = self._breadth_first_traversal() 409 if to_graph: 410 target_context = to_graph.as_default 411 else: 412 target_context = ops.NullContextmanager 413 with target_context(): 414 named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects( 415 trackable_objects, 416 path_to_root, 417 object_map, 418 call_with_mapped_captures) 419 with ops.device("/cpu:0"): 420 object_graph_tensor = constant_op.constant( 421 graph_proto.SerializeToString(), dtype=dtypes.string) 422 named_saveable_objects.append( 423 base.NoRestoreSaveable( 424 tensor=object_graph_tensor, 425 name=base.OBJECT_GRAPH_PROTO_KEY)) 426 return named_saveable_objects 427 428 def objects_ids_and_slot_variables_and_paths(self): 429 """Traverse the object graph and list all accessible objects. 430 431 Looks for `Trackable` objects which are dependencies of 432 `root_trackable`. Includes slot variables only if the variable they are 433 slotting for and the optimizer are dependencies of `root_trackable` 434 (i.e. if they would be saved with a checkpoint). 435 436 Returns: 437 A tuple of (trackable objects, paths from root for each object, 438 object -> node id, slot variables) 439 """ 440 trackable_objects, path_to_root = self._breadth_first_traversal() 441 object_names = object_identity.ObjectIdentityDictionary() 442 for obj, path in path_to_root.items(): 443 object_names[obj] = _object_prefix_from_path(path) 444 node_ids = object_identity.ObjectIdentityDictionary() 445 for node_id, node in enumerate(trackable_objects): 446 node_ids[node] = node_id 447 slot_variables = _serialize_slot_variables( 448 trackable_objects=trackable_objects, 449 node_ids=node_ids, 450 object_names=object_names) 451 return trackable_objects, path_to_root, node_ids, slot_variables 452 453 def objects_ids_and_slot_variables(self): 454 trackable_objects, _, node_ids, slot_variables = ( 455 self.objects_ids_and_slot_variables_and_paths()) 456 return trackable_objects, node_ids, slot_variables 457 458 def list_objects(self): 459 """Traverse the object graph and list all accessible objects.""" 460 trackable_objects, _, _ = self.objects_ids_and_slot_variables() 461 return trackable_objects 462