• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Manages a graph of Trackable objects."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import weakref
22
23from tensorflow.core.protobuf import trackable_object_graph_pb2
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.training import optimizer as optimizer_v1
28from tensorflow.python.training.saving import saveable_object as saveable_object_lib
29from tensorflow.python.training.saving import saveable_object_util
30from tensorflow.python.training.tracking import base
31from tensorflow.python.util import object_identity
32from tensorflow.python.util.tf_export import tf_export
33
34
35_ESCAPE_CHAR = "."  # For avoiding conflicts with user-specified names.
36
37# Keyword for identifying that the next bit of a checkpoint variable name is a
38# slot name. Checkpoint names for slot variables look like:
39#
40#   <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
41#
42# Where <path to variable> is a full path from the checkpoint root to the
43# variable being slotted for.
44_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
45# Keyword for separating the path to an object from the name of an
46# attribute in checkpoint names. Used like:
47#   <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
48_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
49
50
51def _escape_local_name(name):
52  # We need to support slashes in local names for compatibility, since this
53  # naming scheme is being patched in to things like Layer.add_variable where
54  # slashes were previously accepted. We also want to use slashes to indicate
55  # edges traversed to reach the variable, so we escape forward slashes in
56  # names.
57  return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
58          .replace(r"/", _ESCAPE_CHAR + "S"))
59
60
61def _object_prefix_from_path(path_to_root):
62  return "/".join(
63      (_escape_local_name(trackable.name)
64       for trackable in path_to_root))
65
66
67def _slot_variable_naming_for_optimizer(optimizer_path):
68  """Make a function for naming slot variables in an optimizer."""
69  # Name slot variables:
70  #
71  #   <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
72  #
73  # where <variable name> is exactly the checkpoint name used for the original
74  # variable, including the path from the checkpoint root and the local name in
75  # the object which owns it. Note that we only save slot variables if the
76  # variable it's slotting for is also being saved.
77
78  optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path)
79
80  def _name_slot_variable(variable_path, slot_name):
81    """With an optimizer specified, name a slot variable."""
82    return (variable_path
83            + optimizer_identifier
84            + _escape_local_name(slot_name))
85
86  return _name_slot_variable
87
88
89def _serialize_slot_variables(trackable_objects, node_ids, object_names):
90  """Gather and name slot variables."""
91  non_slot_objects = list(trackable_objects)
92  slot_variables = object_identity.ObjectIdentityDictionary()
93  for trackable in non_slot_objects:
94    if (isinstance(trackable, optimizer_v1.Optimizer)
95        # TODO(b/110718070): Fix Keras imports.
96        # Note: dir() is used rather than hasattr() here to avoid triggering
97        # custom __getattr__ code, see b/152031870 for context.
98        or "_create_or_restore_slot_variable" in dir(trackable)):
99      naming_scheme = _slot_variable_naming_for_optimizer(
100          optimizer_path=object_names[trackable])
101      slot_names = trackable.get_slot_names()
102      for slot_name in slot_names:
103        for original_variable_node_id, original_variable in enumerate(
104            non_slot_objects):
105          try:
106            slot_variable = trackable.get_slot(
107                original_variable, slot_name)
108          except (AttributeError, KeyError):
109            slot_variable = None
110          if slot_variable is None:
111            continue
112          slot_variable._maybe_initialize_trackable()  # pylint: disable=protected-access
113          if slot_variable._checkpoint_dependencies:  # pylint: disable=protected-access
114            # TODO(allenl): Gather dependencies of slot variables.
115            raise NotImplementedError(
116                "Currently only variables with no dependencies can be saved as "
117                "slot variables. File a feature request if this limitation "
118                "bothers you.")
119          if slot_variable in node_ids:
120            raise NotImplementedError(
121                ("A slot variable was re-used as a dependency of a "
122                 "Trackable object: %s. This is not currently "
123                 "allowed. File a feature request if this limitation bothers "
124                 "you.") % slot_variable)
125          checkpoint_name = naming_scheme(
126              variable_path=object_names[original_variable],
127              slot_name=slot_name)
128          object_names[slot_variable] = checkpoint_name
129          slot_variable_node_id = len(trackable_objects)
130          node_ids[slot_variable] = slot_variable_node_id
131          trackable_objects.append(slot_variable)
132          slot_variable_proto = (
133              trackable_object_graph_pb2.TrackableObjectGraph
134              .TrackableObject.SlotVariableReference(
135                  slot_name=slot_name,
136                  original_variable_node_id=original_variable_node_id,
137                  slot_variable_node_id=slot_variable_node_id))
138          slot_variables.setdefault(trackable, []).append(
139              slot_variable_proto)
140  return slot_variables
141
142
143@tf_export("__internal__.tracking.ObjectGraphView", v1=[])
144class ObjectGraphView(object):
145  """Gathers and serializes an object graph."""
146
147  def __init__(self, root, saveables_cache=None, attached_dependencies=None):
148    """Configure the graph view.
149
150    Args:
151      root: A `Trackable` object whose variables (including the variables
152        of dependencies, recursively) should be saved. May be a weak reference.
153      saveables_cache: A dictionary mapping `Trackable` objects ->
154        attribute names -> SaveableObjects, used to avoid re-creating
155        SaveableObjects when graph building.
156      attached_dependencies: Dependencies to attach to the root object. Used
157        when saving a Checkpoint with a defined root object.
158    """
159    self._root_ref = root
160    self._saveables_cache = saveables_cache
161    self._attached_dependencies = attached_dependencies
162
163  def list_dependencies(self, obj):
164    # pylint: disable=protected-access
165    obj._maybe_initialize_trackable()
166    dependencies = obj._checkpoint_dependencies
167    # pylint: enable=protected-access
168
169    if obj is self.root and self._attached_dependencies:
170      dependencies = dependencies.copy()
171      dependencies.extend(self._attached_dependencies)
172    return dependencies
173
174  @property
175  def saveables_cache(self):
176    """Maps Trackable objects -> attribute names -> list(SaveableObjects).
177
178    Used to avoid re-creating SaveableObjects when graph building. None when
179    executing eagerly.
180
181    Returns:
182      The cache (an object-identity dictionary), or None if caching is disabled.
183    """
184    return self._saveables_cache
185
186  @property
187  def attached_dependencies(self):
188    """Returns list of dependencies that should be saved in the checkpoint.
189
190    These dependencies are not tracked by root, but are in the checkpoint.
191    This is defined when the user creates a Checkpoint with both root and kwargs
192    set.
193
194    Returns:
195      A list of TrackableReferences.
196    """
197    return self._attached_dependencies
198
199  @property
200  def root(self):
201    if isinstance(self._root_ref, weakref.ref):
202      derefed = self._root_ref()
203      assert derefed is not None
204      return derefed
205    else:
206      return self._root_ref
207
208  def _breadth_first_traversal(self):
209    """Find shortest paths to all dependencies of self.root."""
210    bfs_sorted = []
211    to_visit = collections.deque([self.root])
212    path_to_root = object_identity.ObjectIdentityDictionary()
213    path_to_root[self.root] = ()
214    while to_visit:
215      current_trackable = to_visit.popleft()
216      bfs_sorted.append(current_trackable)
217      for name, dependency in self.list_dependencies(current_trackable):
218        if dependency not in path_to_root:
219          path_to_root[dependency] = (
220              path_to_root[current_trackable] + (
221                  base.TrackableReference(name, dependency),))
222          to_visit.append(dependency)
223    return bfs_sorted, path_to_root
224
225  def _add_attributes_to_object_graph(
226      self, trackable_objects, object_graph_proto, node_ids, object_names,
227      object_map, call_with_mapped_captures):
228    """Create SaveableObjects and corresponding SerializedTensor protos."""
229    named_saveable_objects = []
230    if self._saveables_cache is None:
231      # No SaveableObject caching. Either we're executing eagerly, or building a
232      # static save which is specialized to the current Python state.
233      feed_additions = None
234    else:
235      # If we are caching SaveableObjects, we need to build up a feed_dict with
236      # functions computing volatile Python state to be saved with the
237      # checkpoint.
238      feed_additions = {}
239    for checkpoint_id, (trackable, object_proto) in enumerate(
240        zip(trackable_objects, object_graph_proto.nodes)):
241      assert node_ids[trackable] == checkpoint_id
242      object_name = object_names[trackable]
243      if object_map is None:
244        object_to_save = trackable
245      else:
246        object_to_save = object_map.get(trackable, trackable)
247      if self._saveables_cache is not None:
248        cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
249      else:
250        cached_attributes = None
251
252      for name, saveable_factory in (
253          object_to_save._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
254        attribute = object_proto.attributes.add()
255        attribute.name = name
256        attribute.checkpoint_key = "%s/%s/%s" % (
257            object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
258        if cached_attributes is None:
259          saveables = None
260        else:
261          saveables = cached_attributes.get(name, None)
262          if saveables is not None:
263            for saveable in saveables:
264              if attribute.checkpoint_key not in saveable.name:
265                # The checkpoint key for this SaveableObject is different. We
266                # need to re-create it.
267                saveables = None
268                del cached_attributes[name]
269                break
270        if saveables is None:
271          if callable(saveable_factory):
272            maybe_saveable = saveable_object_util.create_saveable_object(
273                saveable_factory, attribute.checkpoint_key,
274                call_with_mapped_captures)
275          else:
276            maybe_saveable = saveable_factory
277          if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
278            saveables = (maybe_saveable,)
279          else:
280            # Figure out the name-based Saver's name for this variable. If it's
281            # already a SaveableObject we'd just get the checkpoint key back, so
282            # we leave full_name blank.
283            saver_dict = saveable_object_util.op_list_to_dict(
284                [maybe_saveable], convert_variable_to_tensor=False)
285            full_name, = saver_dict.keys()
286            saveables = tuple(saveable_object_util.saveable_objects_for_op(
287                op=maybe_saveable, name=attribute.checkpoint_key))
288            for saveable in saveables:
289              saveable.full_name = full_name
290          for saveable in saveables:
291            if attribute.checkpoint_key not in saveable.name:
292              raise AssertionError(
293                  ("The object %s produced a SaveableObject with name '%s' for "
294                   "attribute '%s'. Expected a name containing '%s'.")
295                  % (trackable, name, saveable.name,
296                     attribute.checkpoint_key))
297          if cached_attributes is not None:
298            cached_attributes[name] = saveables
299
300        optional_restore = None
301        for saveable in saveables:
302          if optional_restore is None:
303            optional_restore = saveable.optional_restore
304          else:
305            optional_restore = optional_restore and saveable.optional_restore
306
307          if hasattr(saveable, "full_name"):
308            attribute.full_name = saveable.full_name
309          if isinstance(saveable, base.PythonStateSaveable):
310            if feed_additions is None:
311              assert self._saveables_cache is None
312              # If we're not caching saveables, then we're either executing
313              # eagerly or building a static save/restore (e.g. for a
314              # SavedModel). In either case, we should embed the current Python
315              # state in the graph rather than relying on a feed dict.
316              saveable = saveable.freeze()
317            else:
318              saveable_feed_dict = saveable.feed_dict_additions()
319              for new_feed_key in saveable_feed_dict.keys():
320                if new_feed_key in feed_additions:
321                  raise AssertionError(
322                      ("The object %s tried to feed a value for the Tensor %s "
323                       "when saving, but another object is already feeding a "
324                       "value.")
325                      % (trackable, new_feed_key))
326              feed_additions.update(saveable_feed_dict)
327          named_saveable_objects.append(saveable)
328        if optional_restore is None:
329          optional_restore = False
330        attribute.optional_restore = optional_restore
331
332    return named_saveable_objects, feed_additions
333
334  def _fill_object_graph_proto(self, trackable_objects,
335                               node_ids,
336                               slot_variables,
337                               object_graph_proto=None):
338    """Name non-slot `Trackable`s and add them to `object_graph_proto`."""
339    if object_graph_proto is None:
340      object_graph_proto = (
341          trackable_object_graph_pb2.TrackableObjectGraph())
342    for checkpoint_id, trackable in enumerate(trackable_objects):
343      assert node_ids[trackable] == checkpoint_id
344      object_proto = object_graph_proto.nodes.add()
345      object_proto.slot_variables.extend(slot_variables.get(trackable, ()))
346      for child in self.list_dependencies(trackable):
347        child_proto = object_proto.children.add()
348        child_proto.node_id = node_ids[child.ref]
349        child_proto.local_name = child.name
350    return object_graph_proto
351
352  def _serialize_gathered_objects(self, trackable_objects, path_to_root,
353                                  object_map=None,
354                                  call_with_mapped_captures=None):
355    """Create SaveableObjects and protos for gathered objects."""
356    object_names = object_identity.ObjectIdentityDictionary()
357    for obj, path in path_to_root.items():
358      object_names[obj] = _object_prefix_from_path(path)
359    node_ids = object_identity.ObjectIdentityDictionary()
360    for node_id, node in enumerate(trackable_objects):
361      node_ids[node] = node_id
362    slot_variables = _serialize_slot_variables(
363        trackable_objects=trackable_objects,
364        node_ids=node_ids,
365        object_names=object_names)
366    object_graph_proto = self._fill_object_graph_proto(
367        trackable_objects=trackable_objects,
368        node_ids=node_ids,
369        slot_variables=slot_variables)
370    named_saveable_objects, feed_additions = (
371        self._add_attributes_to_object_graph(
372            trackable_objects=trackable_objects,
373            object_graph_proto=object_graph_proto,
374            node_ids=node_ids,
375            object_names=object_names,
376            object_map=object_map,
377            call_with_mapped_captures=call_with_mapped_captures))
378    return named_saveable_objects, object_graph_proto, feed_additions
379
380  def serialize_object_graph(self):
381    """Determine checkpoint keys for variables and build a serialized graph.
382
383    Non-slot variables are keyed based on a shortest path from the root saveable
384    to the object which owns the variable (i.e. the one which called
385    `Trackable._add_variable` to create it).
386
387    Slot variables are keyed based on a shortest path to the variable being
388    slotted for, a shortest path to their optimizer, and the slot name.
389
390    Returns:
391      A tuple of (named_variables, object_graph_proto, feed_additions):
392        named_variables: A dictionary mapping names to variable objects.
393        object_graph_proto: A TrackableObjectGraph protocol buffer
394          containing the serialized object graph and variable references.
395        feed_additions: A dictionary mapping from Tensors to values which should
396          be fed when saving.
397
398    Raises:
399      ValueError: If there are invalid characters in an optimizer's slot names.
400    """
401    trackable_objects, path_to_root = self._breadth_first_traversal()
402    return self._serialize_gathered_objects(
403        trackable_objects, path_to_root)
404
405  def frozen_saveable_objects(self, object_map=None, to_graph=None,
406                              call_with_mapped_captures=None):
407    """Creates SaveableObjects with the current object graph frozen."""
408    trackable_objects, path_to_root = self._breadth_first_traversal()
409    if to_graph:
410      target_context = to_graph.as_default
411    else:
412      target_context = ops.NullContextmanager
413    with target_context():
414      named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
415          trackable_objects,
416          path_to_root,
417          object_map,
418          call_with_mapped_captures)
419      with ops.device("/cpu:0"):
420        object_graph_tensor = constant_op.constant(
421            graph_proto.SerializeToString(), dtype=dtypes.string)
422      named_saveable_objects.append(
423          base.NoRestoreSaveable(
424              tensor=object_graph_tensor,
425              name=base.OBJECT_GRAPH_PROTO_KEY))
426    return named_saveable_objects
427
428  def objects_ids_and_slot_variables_and_paths(self):
429    """Traverse the object graph and list all accessible objects.
430
431    Looks for `Trackable` objects which are dependencies of
432    `root_trackable`. Includes slot variables only if the variable they are
433    slotting for and the optimizer are dependencies of `root_trackable`
434    (i.e. if they would be saved with a checkpoint).
435
436    Returns:
437      A tuple of (trackable objects, paths from root for each object,
438                  object -> node id, slot variables)
439    """
440    trackable_objects, path_to_root = self._breadth_first_traversal()
441    object_names = object_identity.ObjectIdentityDictionary()
442    for obj, path in path_to_root.items():
443      object_names[obj] = _object_prefix_from_path(path)
444    node_ids = object_identity.ObjectIdentityDictionary()
445    for node_id, node in enumerate(trackable_objects):
446      node_ids[node] = node_id
447    slot_variables = _serialize_slot_variables(
448        trackable_objects=trackable_objects,
449        node_ids=node_ids,
450        object_names=object_names)
451    return trackable_objects, path_to_root, node_ids, slot_variables
452
453  def objects_ids_and_slot_variables(self):
454    trackable_objects, _, node_ids, slot_variables = (
455        self.objects_ids_and_slot_variables_and_paths())
456    return trackable_objects, node_ids, slot_variables
457
458  def list_objects(self):
459    """Traverse the object graph and list all accessible objects."""
460    trackable_objects, _, _ = self.objects_ids_and_slot_variables()
461    return trackable_objects
462