• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""An object-local variable management scheme."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import weakref
22
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import gen_io_ops as io_ops
28from tensorflow.python.util import nest
29
30# A key indicating a variable's value in an object's checkpointed Tensors
31# (Checkpointable._gather_tensors_for_checkpoint). If this is the only key and
32# the object has no dependencies, then its value may be restored on object
33# creation (avoiding double assignment when executing eagerly).
34VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
35
36_CheckpointableReference = collections.namedtuple(
37    "_CheckpointableReference",
38    [
39        # The local name for this dependency.
40        "name",
41        # The Checkpointable object being referenced.
42        "ref"
43    ])
44
45
46class CheckpointInitialValue(ops.Tensor):
47  """Tensor wrapper for managing update UIDs in `Variables`.
48
49  When supplied as an initial value, objects of this type let a `Variable`
50  (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial
51  value came from. This allows deferred restorations to be sequenced in the
52  order the user specified them, and lets us fall back on assignment if an
53  initial value is not set (e.g. due to a custom getter interfering).
54
55  See comments in _add_variable_with_custom_getter for more information about
56  how `CheckpointInitialValue` is used.
57  """
58
59  def __init__(self, checkpoint_position, shape=None):
60    self.wrapped_value = checkpoint_position.restore_ops()[
61        VARIABLE_VALUE_KEY]
62    if shape:
63      # We need to set the static shape information on the initializer if
64      # possible so we don't get a variable with an unknown shape.
65      self.wrapped_value.set_shape(shape)
66    self._checkpoint_position = checkpoint_position
67
68  @property
69  def __class__(self):
70    return (self.wrapped_value.__class__, CheckpointInitialValue)
71
72  def __getattr__(self, attr):
73    try:
74      return getattr(self.wrapped_value, attr)
75    except AttributeError:
76      return self.__getattribute__(attr)
77
78  @property
79  def checkpoint_position(self):
80    return self._checkpoint_position
81
82
83class _CheckpointPosition(object):
84  """Indicates a position within a `_Checkpoint`."""
85
86  def __init__(self, checkpoint, proto_id):
87    """Specify an object within a checkpoint.
88
89    Args:
90      checkpoint: A _Checkpoint object.
91      proto_id: The index of this object in CheckpointableObjectGraph.nodes.
92    """
93    self._checkpoint = checkpoint
94    self._proto_id = proto_id
95
96  def restore(self, checkpointable):
97    """Restore this value into `checkpointable`."""
98    if self.bind_object(checkpointable):
99      # This object's correspondence with a checkpointed object is new, so
100      # process deferred restorations for it and its dependencies.
101      restore_ops = checkpointable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
102      if restore_ops:
103        self._checkpoint.restore_ops.extend(restore_ops)
104
105  def bind_object(self, checkpointable):
106    """Set a checkpoint<->object correspondence and process slot variables.
107
108    Args:
109      checkpointable: The object to record a correspondence for.
110    Returns:
111      True if this is a new assignment, False if this object has already been
112      mapped to a checkpointed `Object` proto.
113    Raises:
114      AssertionError: If another object is already bound to the `Object` proto.
115    """
116    checkpoint = self.checkpoint
117    current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
118    if current_assignment is None:
119      checkpoint.object_by_proto_id[self._proto_id] = checkpointable
120      for deferred_slot_restoration in (
121          checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
122        checkpointable._create_or_restore_slot_variable(  # pylint: disable=protected-access
123            slot_variable_position=_CheckpointPosition(
124                checkpoint=checkpoint,
125                proto_id=deferred_slot_restoration.slot_variable_id),
126            variable=deferred_slot_restoration.original_variable,
127            slot_name=deferred_slot_restoration.slot_name)
128      for slot_restoration in checkpoint.slot_restorations.pop(
129          self._proto_id, ()):
130        optimizer_object = checkpoint.object_by_proto_id.get(
131            slot_restoration.optimizer_id, None)
132        if optimizer_object is None:
133          # The optimizer has not yet been created or tracked. Record in the
134          # checkpoint that the slot variables need to be restored when it is.
135          checkpoint.deferred_slot_restorations.setdefault(
136              slot_restoration.optimizer_id, []).append(
137                  _DeferredSlotVariableRestoration(
138                      original_variable=checkpointable,
139                      slot_variable_id=slot_restoration.slot_variable_id,
140                      slot_name=slot_restoration.slot_name))
141        else:
142          optimizer_object._create_or_restore_slot_variable(  # pylint: disable=protected-access
143              slot_variable_position=_CheckpointPosition(
144                  checkpoint=checkpoint,
145                  proto_id=slot_restoration.slot_variable_id),
146              variable=checkpointable,
147              slot_name=slot_restoration.slot_name)
148      return True  # New assignment
149    else:
150      # The object was already mapped for this checkpoint load, which means
151      # we don't need to do anything besides check that the mapping is
152      # consistent (if the dependency DAG is not a tree then there are
153      # multiple paths to the same object).
154      if current_assignment is not checkpointable:
155        raise AssertionError(
156            ("Unable to load the checkpoint into this object graph. Either "
157             "the Checkpointable object references in the Python program "
158             "have changed in an incompatible way, or the checkpoint was "
159             "generated in an incompatible program.\n\nTwo checkpoint "
160             "references resolved to different objects (%s and %s).")
161            % (current_assignment, checkpointable))
162      return False  # Not a new assignment
163
164  def is_simple_variable(self):
165    """Determine whether this value is restorable with a Tensor initializer."""
166    attributes = self.object_proto.attributes
167    return (len(attributes) == 1
168            and attributes[0].name == VARIABLE_VALUE_KEY
169            and not self.object_proto.children)
170
171  def restore_ops(self):
172    """Create restore ops for this object's attributes."""
173    restore_tensors = {}
174    for serialized_tensor in self.object_proto.attributes:
175      checkpoint_key = serialized_tensor.checkpoint_key
176      dtype = self._checkpoint.dtype_map[checkpoint_key]
177      base_type = dtype.base_dtype
178      with ops.init_scope():
179        restore, = io_ops.restore_v2(
180            prefix=self._checkpoint.save_path,
181            tensor_names=[checkpoint_key],
182            shape_and_slices=[""],
183            dtypes=[base_type],
184            name="%s_checkpoint_read" % (serialized_tensor.name,))
185        restore_tensors[serialized_tensor.name] = restore
186      return restore_tensors
187
188  @property
189  def checkpoint(self):
190    return self._checkpoint
191
192  @property
193  def checkpointable(self):
194    return self._checkpoint.object_by_proto_id[self._proto_id]
195
196  @property
197  def object_proto(self):
198    return self._checkpoint.object_graph_proto.nodes[self._proto_id]
199
200  @property
201  def restore_uid(self):
202    return self._checkpoint.restore_uid
203
204  def __repr__(self):
205    return repr(self.object_proto)
206
207
208_DeferredSlotVariableRestoration = collections.namedtuple(
209    "_DeferredSlotVariableRestoration",
210    [
211        "original_variable",
212        "slot_variable_id",
213        "slot_name",
214    ]
215)
216
217_SlotVariableRestoration = collections.namedtuple(
218    "_SlotVariableRestoration",
219    [
220        # The checkpoint proto id of the optimizer object.
221        "optimizer_id",
222        # The checkpoint proto id of the slot variable.
223        "slot_variable_id",
224        "slot_name",
225    ])
226
227
228class _Checkpoint(object):
229  """Holds the status of an object-based checkpoint load."""
230
231  def __init__(self, object_graph_proto, save_path):
232    """Specify the checkpoint being loaded.
233
234    Args:
235      object_graph_proto: The CheckpointableObjectGraph protocol buffer
236        associated with this checkpoint.
237      save_path: The path to the checkpoint, as returned by
238        `tf.train.latest_checkpoint`.
239    """
240    self.object_graph_proto = object_graph_proto
241    self.restore_uid = ops.uid()
242    # Dictionary mapping from an id in the protocol buffer flat array to
243    # Checkpointable Python objects. This mapping may be deferred if a
244    # checkpoint is restored before all dependencies have been tracked. Uses
245    # weak references so that partial restorations don't create reference cycles
246    # (as objects with deferred dependencies will generally have references to
247    # this object).
248    self.object_by_proto_id = weakref.WeakValueDictionary()
249    self.save_path = save_path
250    reader = pywrap_tensorflow.NewCheckpointReader(save_path)
251    self.dtype_map = reader.get_variable_to_dtype_map()
252    # When graph building, contains a list of ops to run to restore objects from
253    # this checkpoint.
254    self.restore_ops = []
255    # A mapping from optimizer proto ids to lists of slot variables to be
256    # restored when the optimizer is tracked. Only includes slot variables whose
257    # regular variables have already been created, and only for optimizer
258    # objects which have not yet been created/tracked.
259    self.deferred_slot_restorations = {}
260    # A mapping from variable proto ids to lists of slot variables to be
261    # restored when the variable is created/tracked. These get shifted over to
262    # deferred_slot_restorations if the optimizer hasn't been created when that
263    # happens.
264    self.slot_restorations = {}
265    for node_index, node in enumerate(self.object_graph_proto.nodes):
266      for slot_reference in node.slot_variables:
267        # `node` refers to an `Optimizer`, since only these have slot variables.
268        self.slot_restorations.setdefault(
269            slot_reference.original_variable_node_id, []).append(
270                _SlotVariableRestoration(
271                    optimizer_id=node_index,
272                    slot_variable_id=slot_reference.slot_variable_node_id,
273                    slot_name=slot_reference.slot_name))
274
275
276class CheckpointableBase(object):
277  """Base class for `Checkpointable` objects without automatic dependencies.
278
279  This class has no __setattr__ override for performance reasons. Dependencies
280  must be added explicitly. Unless attribute assignment is performance-critical,
281  use `Checkpointable` instead. Use `CheckpointableBase` for `isinstance`
282  checks.
283  """
284
285  def _maybe_initialize_checkpointable(self):
286    """Initialize dependency management.
287
288    Not __init__, since most objects will forget to call it.
289    """
290    if hasattr(self, "_checkpoint_dependencies"):
291      # __init__ already called. This check means that we don't need
292      # Checkpointable.__init__() in the constructor of every TensorFlow object.
293      return
294    # A list of _CheckpointableReference objects.
295    self._checkpoint_dependencies = []
296    # Maps names -> Checkpointable objects
297    self._dependency_names = {}
298    # Restorations for other Checkpointable objects on which this object may
299    # eventually depend.
300    self._deferred_dependencies = {}  # local name -> _CheckpointPosition list
301    # The UID of the highest assignment to this object. Used to ensure that the
302    # last requested assignment determines the final value of an object.
303    if hasattr(self, "_update_uid"):
304      raise AssertionError(
305          "Internal error: the object had an update UID set before its "
306          "initialization code was run.")
307    self._update_uid = -1
308
309  def _add_variable_with_custom_getter(
310      self, name, shape=None, dtype=dtypes.float32,
311      initializer=None, getter=None, **kwargs_for_getter):
312    """Restore-on-create for a variable be saved with this `Checkpointable`.
313
314    If the user has requested that this object or another `Checkpointable` which
315    depends on this object be restored from a checkpoint (deferred loading
316    before variable object creation), `initializer` may be ignored and the value
317    from the checkpoint used instead.
318
319    Args:
320      name: A name for the variable. Must be unique within this object.
321      shape: The shape of the variable.
322      dtype: The data type of the variable.
323
324      initializer: The initializer to use. Ignored if there is a deferred
325        restoration left over from a call to
326        `_restore_from_checkpoint_position`.
327
328      getter: The getter to wrap which actually fetches the variable.
329      **kwargs_for_getter: Passed to the getter.
330
331    Returns:
332      The new variable object.
333
334    Raises:
335      ValueError: If the variable name is not unique.
336    """
337    self._maybe_initialize_checkpointable()
338    if name in self._dependency_names:
339      raise ValueError(
340          ("A variable named '%s' already exists in this Checkpointable, but "
341           "Checkpointable._add_variable called to create another with "
342           "that name. Variable names must be unique within a Checkpointable "
343           "object.") % (name,))
344    if context.in_eager_mode():
345      # If this is a variable with a single Tensor stored in the checkpoint, we
346      # can set that value as an initializer rather than initializing and then
347      # assigning (when executing eagerly). This call returns None if there is
348      # nothing to restore.
349      checkpoint_initializer = self._preload_simple_restoration(
350          name=name, shape=shape)
351    else:
352      checkpoint_initializer = None
353    if (checkpoint_initializer is not None
354        and not (
355            isinstance(initializer, CheckpointInitialValue)
356            and initializer.restore_uid > checkpoint_initializer.restore_uid)):
357      # If multiple Checkpointable objects are "creating" the same variable via
358      # the magic of custom getters, the one with the highest restore UID (the
359      # one called last) has to make the final initializer. If another custom
360      # getter interrupts this process by overwriting the initializer, then
361      # we'll catch that when we call _track_checkpointable. So this is "best
362      # effort" to set the initializer with the highest restore UID.
363      initializer = checkpoint_initializer
364      shape = None
365
366    new_variable = getter(
367        name=name, shape=shape, dtype=dtype, initializer=initializer,
368        **kwargs_for_getter)
369
370    # If we set an initializer and the variable processed it, tracking will not
371    # assign again. It will add this variable to our dependencies, and if there
372    # is a non-trivial restoration queued, it will handle that. This also
373    # handles slot variables.
374    return self._track_checkpointable(new_variable, name=name)
375
376  def _preload_simple_restoration(self, name, shape):
377    """Return a dependency's value for restore-on-create.
378
379    Note the restoration is not deleted; if for some reason preload is called
380    and then not assigned to the variable (for example because a custom getter
381    overrides the initializer), the assignment will still happen once the
382    variable is tracked (determined based on checkpoint.restore_uid).
383
384    Args:
385      name: The object-local name of the dependency holding the variable's
386        value.
387      shape: The shape of the variable being loaded into.
388    Returns:
389      An callable for use as a variable's initializer/initial_value, or None if
390      one should not be set (either because there was no variable with this name
391      in the checkpoint or because it needs more complex deserialization). Any
392      non-trivial deserialization will happen when the variable object is
393      tracked.
394    """
395    deferred_dependencies_list = self._deferred_dependencies.get(name, ())
396    if not deferred_dependencies_list:
397      # Nothing to do; we don't have a restore for this dependency queued up.
398      return
399    for checkpoint_position in deferred_dependencies_list:
400      if not checkpoint_position.is_simple_variable():
401        # If _any_ pending restoration is too complicated to fit in an
402        # initializer (because it has dependencies, or because there are
403        # multiple Tensors to restore), bail and let the general tracking code
404        # handle it.
405        return None
406    checkpoint_position = max(
407        deferred_dependencies_list,
408        key=lambda restore: restore.checkpoint.restore_uid)
409    return CheckpointInitialValue(
410        checkpoint_position=checkpoint_position, shape=shape)
411
412  def _track_checkpointable(self, checkpointable, name, overwrite=False):
413    """Declare a dependency on another `Checkpointable` object.
414
415    Indicates that checkpoints for this object should include variables from
416    `checkpointable`.
417
418    Variables in a checkpoint are mapped to `Checkpointable`s based on names if
419    provided when the checkpoint was written, but otherwise use the order those
420    `Checkpointable`s were declared as dependencies.
421
422    To avoid breaking existing checkpoints when modifying a class, neither
423    variable names nor dependency names (the names passed to
424    `track_checkpointable`) may change.
425
426    Args:
427      checkpointable: A `Checkpointable` which this object depends on.
428      name: A local name for `checkpointable`, used for loading checkpoints into
429        the correct objects.
430      overwrite: Boolean, whether silently replacing dependencies is OK. Used
431        for __setattr__, where throwing an error on attribute reassignment would
432        be inappropriate.
433
434    Returns:
435      `checkpointable`, for convenience when declaring a dependency and
436      assigning to a member variable in one statement.
437
438    Raises:
439      TypeError: If `checkpointable` does not inherit from `Checkpointable`.
440      ValueError: If another object is already tracked by this name.
441    """
442    self._maybe_initialize_checkpointable()
443    if not isinstance(checkpointable, CheckpointableBase):
444      raise TypeError(
445          ("Checkpointable._track_checkpointable() passed type %s, not a "
446           "Checkpointable.") % (type(checkpointable),))
447    new_reference = _CheckpointableReference(name=name, ref=checkpointable)
448    if (name in self._dependency_names
449        and self._dependency_names[name] is not checkpointable):
450      if not overwrite:
451        raise ValueError(
452            ("Called Checkpointable._track_checkpointable() with name='%s', "
453             "but a Checkpointable with this name is already declared as a "
454             "dependency. Names must be unique (or overwrite=True).") % (name,))
455      # This is a weird thing to do, but we're not going to stop people from
456      # using __setattr__.
457      for index, (old_name, _) in enumerate(self._checkpoint_dependencies):
458        if name == old_name:
459          self._checkpoint_dependencies[index] = new_reference
460    else:
461      self._checkpoint_dependencies.append(new_reference)
462
463    self._dependency_names[name] = checkpointable
464    deferred_dependency_list = self._deferred_dependencies.pop(name, None)
465    if deferred_dependency_list is not None:
466      for checkpoint_position in deferred_dependency_list:
467        checkpoint_position.restore(checkpointable=checkpointable)
468    return checkpointable
469
470  def _restore_from_checkpoint_position(self, checkpoint_position):
471    """Restore this object and its dependencies (may be deferred)."""
472    # Attempt a breadth-first traversal, since presumably the user has more
473    # control over shorter paths. If we don't have all of the dependencies at
474    # this point, the end result is not breadth-first (since other deferred
475    # traversals will happen later).
476    visit_queue = collections.deque([checkpoint_position])
477    restore_ops = []
478    while visit_queue:
479      current_position = visit_queue.popleft()
480      restore_ops.extend(nest.flatten(
481          current_position.checkpointable  # pylint: disable=protected-access
482          ._single_restoration_from_checkpoint_position(
483              checkpoint_position=current_position,
484              visit_queue=visit_queue)))
485    return restore_ops
486
487  def _single_restoration_from_checkpoint_position(
488      self, checkpoint_position, visit_queue):
489    """Restore this object, and either queue its dependencies or defer them."""
490    self._maybe_initialize_checkpointable()
491    checkpoint = checkpoint_position.checkpoint
492    # If the UID of this restore is lower than our current update UID, we don't
493    # need to actually restore the object. However, we should pass the
494    # restoration on to our dependencies.
495    if checkpoint.restore_uid > self._update_uid:
496      restore_op = self._scatter_tensors_from_checkpoint(
497          checkpoint_position.restore_ops())
498      self._update_uid = checkpoint.restore_uid
499    else:
500      restore_op = ()
501    for child in checkpoint_position.object_proto.children:
502      child_position = _CheckpointPosition(
503          checkpoint=checkpoint,
504          proto_id=child.node_id)
505      local_object = self._dependency_names.get(child.local_name, None)
506      if local_object is None:
507        # We don't yet have a dependency registered with this name. Save it
508        # in case we do.
509        self._deferred_dependencies.setdefault(child.local_name, []).append(
510            child_position)
511      else:
512        if child_position.bind_object(checkpointable=local_object):
513          # This object's correspondence is new, so dependencies need to be
514          # visited. Delay doing it so that we get a breadth-first dependency
515          # resolution order (shallowest paths first). The caller is responsible
516          # for emptying visit_queue.
517          visit_queue.append(child_position)
518    return restore_op
519
520  def _scatter_tensors_from_checkpoint(self, attributes):
521    """Restores this object from a checkpoint.
522
523    Args:
524      attributes: A dictionary of Tensors, with key corresponding to those
525        returned from _gather_tensors_for_checkpoint.
526    Returns:
527      A restore op to run (if graph building).
528    """
529    if attributes:
530      raise AssertionError(
531          ("A Checkpointable object which was not expecting any data received "
532           "some from a checkpoint. (Got %s)") % (attributes,))
533    return ()  # No restore ops
534
535  def _gather_tensors_for_checkpoint(self):
536    """Returns a dictionary of Tensors to save with this object."""
537    return {}
538
539
540class Checkpointable(CheckpointableBase):
541  """Manages dependencies on other objects.
542
543  `Checkpointable` objects may have dependencies: other `Checkpointable` objects
544  which should be saved if the object declaring the dependency is saved. A
545  correctly saveable program has a dependency graph such that if changing a
546  global variable affects an object (e.g. changes the behavior of any of its
547  methods) then there is a chain of dependencies from the influenced object to
548  the variable.
549
550  Dependency edges have names, and are created implicitly when a
551  `Checkpointable` object is assigned to an attribute of another
552  `Checkpointable` object. For example:
553
554  ```
555  obj = Checkpointable()
556  obj.v = ResourceVariable(0.)
557  ```
558
559  The `Checkpointable` object `obj` now has a dependency named "v" on a
560  variable.
561
562  `Checkpointable` objects may specify `Tensor`s to be saved and restored
563  directly (e.g. a `Variable` indicating how to save itself) rather than through
564  dependencies on other objects. See
565  `Checkpointable._scatter_tensors_from_checkpoint` and
566  `Checkpointable._gather_tensors_for_checkpoint` for details.
567  """
568
569  def __setattr__(self, name, value):
570    """Support self.foo = checkpointable syntax."""
571    # Perform the attribute assignment, and potentially call other __setattr__
572    # overrides such as that for tf.keras.Model.
573    super(Checkpointable, self).__setattr__(name, value)
574    if isinstance(value, CheckpointableBase):
575      self._track_checkpointable(
576          value, name=name,
577          # Allow the user to switch the Checkpointable which is tracked by this
578          # name, since assigning a new variable to an attribute has
579          # historically been fine (e.g. Adam did this).
580          # TODO(allenl): Should this be a warning once Checkpointable save/load
581          # is usable?
582          overwrite=True)
583