• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utility for creating multiple dependencies with synchronized save/restore."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.training import saver as saver_lib
24from tensorflow.python.training.tracking import base as trackable
25
26
27class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
28  """Wraps save and restore callbacks as a `SaveableObject`."""
29
30  def __init__(self, name, dtype, save_callback, restore_callback):
31    self._restore_callback = restore_callback
32    spec = saver_lib.BaseSaverBuilder.SaveSpec(
33        tensor=save_callback,
34        slice_spec="",
35        name=name,
36        dtype=dtype)
37    super(_CallbackSaveable, self).__init__(
38        save_callback, [spec], name)
39
40  def restore(self, restored_tensors, restored_shapes):
41    """Restore the same value into both variables."""
42    tensor, = restored_tensors
43    return self._restore_callback(tensor)
44
45
46class _SplitDependency(trackable.Trackable):
47  """Looks like a regular variable while synchronizing save/restores."""
48
49  def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
50               fill_save_buffer_fn, consume_restore_buffer_fn):
51    self._save_buffer = save_buffer
52    self._restore_buffer = restore_buffer
53    self._name = name
54    self._dtype = dtype
55    self._num_components = num_components
56    self._fill_save_buffer_fn = fill_save_buffer_fn
57    self._consume_restore_buffer_fn = consume_restore_buffer_fn
58
59  def _save(self):
60    """Pull from the shared buffer, populating it if necessary."""
61    if self._name not in self._save_buffer:
62      if self._save_buffer:
63        raise AssertionError(
64            ("Split dependency %s (%s) unsynchronized. Split dependencies must "
65             "be saved together.") % (self._name, self))
66      self._fill_save_buffer_fn(self._save_buffer)
67    return self._save_buffer.pop(self._name)
68
69  def _restore(self, tensor):
70    """Push into the shared buffer, flushing it if necessary."""
71    if self._name in self._restore_buffer:
72      raise AssertionError(
73          ("Split dependency %s (%s) unsynchronized. Split dependencies must "
74           "be restored together.") % (self._name, self))
75    self._restore_buffer[self._name] = tensor
76    if len(self._restore_buffer) == self._num_components:
77      op = self._consume_restore_buffer_fn(self._restore_buffer)
78      self._restore_buffer.clear()
79      return op
80    else:
81      return control_flow_ops.no_op()
82
83  def _gather_saveables_for_checkpoint(self):
84    """Looks to Trackable like a regular variable."""
85    return {
86        trackable.VARIABLE_VALUE_KEY:
87        functools.partial(_CallbackSaveable,
88                          dtype=self._dtype,
89                          save_callback=self._save,
90                          restore_callback=self._restore)
91    }
92
93
94def split_dependency(component_names, component_dtypes,
95                     fill_save_buffer_fn, consume_restore_buffer_fn):
96  """Creates multiple dependencies with a synchronized save/restore.
97
98  Useful when a single op produces `Tensor`s which should each be saved under
99  different objects, or when `Tensor`s saved with many different objects need to
100  be restored together as inputs to a single op (i.e. an object which uses a
101  single fused op may be swapped out for a subgraph of objects, and these two
102  programs are checkpoint compatible).
103
104  Args:
105    component_names: A sequence of names for the split
106      dependencies. `fill_save_buffer_fn` must add these keys to the dictionary
107      it is passed, and `consume_restore_buffer_fn` will receive a dictionary
108      with these keys.
109    component_dtypes: Data types for the `Tensor`s being saved and restored, a
110      sequence corresponding to `component_names`.
111    fill_save_buffer_fn: A function which takes an empty dictionary as an
112      argument and adds `Tensor`s with `component_names` as keys. These
113      `Tensor`s will be saved as if they were individual variables.
114    consume_restore_buffer_fn: A function which takes a dictionary with
115      `component_names` as keys mapping to restored individual `Tensor`s and
116      returns a restore op (or if executing eagerly, runs the restoration and
117      may return `None`).
118
119  Returns:
120    A dictionary mapping from names to Trackable objects. If one is
121    reachable from an object as a dependency, the others should be too; adding
122    dependencies on some but not all of the objects will result in errors.
123  """
124  save_buffer = {}
125  restore_buffer = {}
126  split_dependencies = {}
127  for name, dtype in zip(component_names, component_dtypes):
128    split_dependencies[name] = _SplitDependency(
129        save_buffer=save_buffer,
130        restore_buffer=restore_buffer,
131        name=name,
132        dtype=dtype,
133        num_components=len(component_names),
134        fill_save_buffer_fn=fill_save_buffer_fn,
135        consume_restore_buffer_fn=consume_restore_buffer_fn)
136  return split_dependencies
137