• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Saves and restore variables inside traced @tf.functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import saver_pb2
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_io_ops
30from tensorflow.python.ops import io_ops
31from tensorflow.python.ops import string_ops
32from tensorflow.python.training.saving import checkpoint_options
33from tensorflow.python.training.saving import saveable_hook
34from tensorflow.python.training.saving import saveable_object
35from tensorflow.python.training.saving import saveable_object_util
36from tensorflow.python.util import nest
37
38
39class _SingleDeviceSaver(object):
40  """Saves and restores checkpoints from the current device."""
41
42  __slots__ = ["_saveable_objects"]
43
44  def __init__(self, saveable_objects):
45    """Specify a list of `SaveableObject`s to save and restore.
46
47    Args:
48      saveable_objects: A list of `SaveableObject`s.
49    """
50    saveable_objects = list(saveable_objects)
51    for saveable in saveable_objects:
52      if not isinstance(saveable, saveable_object.SaveableObject):
53        raise ValueError(
54            "Expected a list of SaveableObjects, got %s." % (saveable,))
55    self._saveable_objects = saveable_objects
56
57  def save(self, file_prefix, options=None):
58    """Save the saveable objects to a checkpoint with `file_prefix`.
59
60    Args:
61      file_prefix: A string or scalar string Tensor containing the prefix to
62        save under.
63      options: Optional `CheckpointOptions` object.
64    Returns:
65      An `Operation`, or None when executing eagerly.
66    """
67    options = options or checkpoint_options.CheckpointOptions()
68    tensor_names = []
69    tensors = []
70    tensor_slices = []
71    for saveable in self._saveable_objects:
72      for spec in saveable.specs:
73        tensor = spec.tensor
74        # A tensor value of `None` indicates that this SaveableObject gets
75        # recorded in the object graph, but that no value is saved in the
76        # checkpoint.
77        if tensor is not None:
78          tensor_names.append(spec.name)
79          tensors.append(tensor)
80          tensor_slices.append(spec.slice_spec)
81    save_device = options.experimental_io_device or "cpu:0"
82    with ops.device(save_device):
83      return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
84
85  def restore(self, file_prefix, options=None):
86    """Restore the saveable objects from a checkpoint with `file_prefix`.
87
88    Args:
89      file_prefix: A string or scalar string Tensor containing the prefix for
90        files to read from.
91      options: Optional `CheckpointOptions` object.
92
93    Returns:
94      A dictionary mapping from SaveableObject names to restore operations.
95    """
96    options = options or checkpoint_options.CheckpointOptions()
97    restore_specs = []
98    tensor_structure = []
99    for saveable in self._saveable_objects:
100      saveable_tensor_structure = []
101      tensor_structure.append(saveable_tensor_structure)
102      for spec in saveable.specs:
103        saveable_tensor_structure.append(spec.name)
104        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
105    tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
106    restore_device = options.experimental_io_device or "cpu:0"
107    with ops.device(restore_device):
108      restored_tensors = io_ops.restore_v2(
109          file_prefix, tensor_names, tensor_slices, tensor_dtypes)
110    structured_restored_tensors = nest.pack_sequence_as(
111        tensor_structure, restored_tensors)
112    restore_ops = {}
113    for saveable, restored_tensors in zip(self._saveable_objects,
114                                          structured_restored_tensors):
115      restore_ops[saveable.name] = saveable.restore(
116          restored_tensors, restored_shapes=None)
117    return restore_ops
118
119
120def sharded_filename(filename_tensor, shard, num_shards):
121  """Append sharding information to a filename.
122
123  Args:
124    filename_tensor: A string tensor.
125    shard: Integer.  The shard for the filename.
126    num_shards: An int Tensor for the number of shards.
127
128  Returns:
129    A string tensor.
130  """
131  return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
132
133
134class MultiDeviceSaver(object):
135  """Saves checkpoints directly from multiple devices.
136
137  Note that this is a low-level utility which stores Tensors in the keys
138  specified by `SaveableObject`s. Higher-level utilities for object-based
139  checkpointing are built on top of it.
140  """
141
142  def __init__(self, saveable_objects):
143    """Specify a list of `SaveableObject`s to save and restore.
144
145    Args:
146      saveable_objects: A list of `SaveableObject`s.
147        Objects extending `SaveableObject` will be saved and restored, and
148        objects extending `SaveableHook` will be called into at save and
149        restore time.
150    """
151    self._before_save_callbacks = []
152    self._after_restore_callbacks = []
153
154    saveable_objects = list(saveable_objects)
155    saveables_by_device = {}
156    for saveable in saveable_objects:
157      is_saveable = isinstance(saveable, saveable_object.SaveableObject)
158      is_hook = isinstance(saveable, saveable_hook.SaveableHook)
159
160      if not is_saveable and not is_hook:
161        raise ValueError(
162            "Expected a dictionary of SaveableObjects, got {}."
163            .format(saveable))
164
165      if is_hook:
166        self._before_save_callbacks.append(saveable.before_save)
167        self._after_restore_callbacks.append(saveable.after_restore)
168
169      if is_saveable:
170        host_device = saveable_object_util.set_cpu0(saveable.device)
171        saveables_by_device.setdefault(host_device, []).append(saveable)
172
173    self._single_device_savers = {
174        device: _SingleDeviceSaver(saveables)
175        for device, saveables in saveables_by_device.items()}
176
177  def to_proto(self):
178    """Serializes to a SaverDef referencing the current graph."""
179    filename_tensor = array_ops.placeholder(
180        shape=[], dtype=dtypes.string, name="saver_filename")
181    save_tensor = self._traced_save(filename_tensor)
182    restore_op = self._traced_restore(filename_tensor).op
183    return saver_pb2.SaverDef(
184        filename_tensor_name=filename_tensor.name,
185        save_tensor_name=save_tensor.name,
186        restore_op_name=restore_op.name,
187        version=saver_pb2.SaverDef.V2)
188
189  @def_function.function(
190      input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
191      autograph=False)
192  def _traced_save(self, file_prefix):
193    save_op = self.save(file_prefix)
194    with ops.device("cpu:0"):
195      with ops.control_dependencies([save_op]):
196        return array_ops.identity(file_prefix)
197
198  @def_function.function(
199      input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
200      autograph=False)
201  def _traced_restore(self, file_prefix):
202    restore_ops = self.restore(file_prefix)
203    with ops.device("cpu:0"):
204      with ops.control_dependencies(restore_ops.values()):
205        return array_ops.identity(file_prefix)
206
207  def save(self, file_prefix, options=None):
208    """Save the saveable objects to a checkpoint with `file_prefix`.
209
210    Args:
211      file_prefix: A string or scalar string Tensor containing the prefix to
212        save under.
213      options: Optional `CheckpointOptions` object.
214    Returns:
215      An `Operation`, or None when executing eagerly.
216    """
217    options = options or checkpoint_options.CheckpointOptions()
218    for callback in self._before_save_callbacks:
219      callback()
220
221    # IMPLEMENTATION DETAILS: most clients should skip.
222    #
223    # Suffix for any well-formed "checkpoint_prefix", when sharded.
224    # Transformations:
225    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
226    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
227    #
228    # Example:
229    #   During runtime, a temporary directory is first created, which contains
230    #   files
231    #
232    #     <train dir>/myckpt_temp/
233    #        part-?????-of-?????{.index, .data-00000-of-00001}
234    #
235    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
236    #
237    #     <train dir>/
238    #        myckpt{.index, .data-?????-of-?????}
239    #
240    #   Filesystems with eventual consistency (such as S3), don't need a
241    #   temporary location. Using a temporary directory in those cases might
242    #   cause situations where files are not available during copy.
243    #
244    # Users only need to interact with the user-specified prefix, which is
245    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
246    # prefix directly, instead of any physical pathname.  (On failure and
247    # subsequent restore, an outdated and orphaned temporary directory can be
248    # safely removed.)
249    with ops.device("CPU"):
250      sharded_suffix = array_ops.where(
251          string_ops.regex_full_match(file_prefix, "^s3://.*"),
252          constant_op.constant(".part"),
253          constant_op.constant("_temp/part"))
254      tmp_checkpoint_prefix = string_ops.string_join(
255          [file_prefix, sharded_suffix])
256
257    def save_fn():
258      num_shards = len(self._single_device_savers)
259      sharded_saves = []
260      sharded_prefixes = []
261      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
262      last_device = None
263      for shard, (device, saver) in enumerate(
264          sorted(self._single_device_savers.items())):
265        last_device = device
266        with ops.device(saveable_object_util.set_cpu0(device)):
267          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
268                                          num_shards_tensor)
269        sharded_prefixes.append(shard_prefix)
270        with ops.device(device):
271          # _SingleDeviceSaver will use the CPU device when necessary, but
272          # initial read operations should be placed on the SaveableObject's
273          # device.
274          sharded_saves.append(saver.save(shard_prefix, options))
275
276      with ops.control_dependencies(sharded_saves):
277        # Merge on the io_device if specified, otherwise co-locates the merge op
278        # with the last device used.
279        merge_device = (
280            options.experimental_io_device or
281            saveable_object_util.set_cpu0(last_device))
282        with ops.device(merge_device):
283          # V2 format write path consists of a metadata merge step.  Once
284          # merged, attempts to delete the temporary directory,
285          # "<user-fed prefix>_temp".
286          return gen_io_ops.merge_v2_checkpoints(
287              sharded_prefixes, file_prefix, delete_old_dirs=True)
288
289    # Since this will causes a function re-trace on each save, limit this to the
290    # cases where it is needed: eager and when there are multiple tasks/single
291    # device savers. Note that the retrace is needed to ensure we pickup the
292    # latest values of options like experimental_io_device.
293    if context.executing_eagerly() and len(self._single_device_savers) > 1:
294      # Explicitly place the identity op on the first device.
295      @def_function.function(jit_compile=False)
296      def tf_function_save():
297        save_fn()
298      tf_function_save()
299    else:
300      return save_fn()
301
302  def restore(self, file_prefix, options=None):
303    """Restore the saveable objects from a checkpoint with `file_prefix`.
304
305    Args:
306      file_prefix: A string or scalar string Tensor containing the prefix for
307        files to read from.
308      options: Optional `CheckpointOptions` object.
309
310    Returns:
311      A dictionary mapping from SaveableObject names to restore operations.
312    """
313    options = options or checkpoint_options.CheckpointOptions()
314
315    def restore_fn():
316      restore_ops = {}
317      # Sort by device name to avoid propagating non-deterministic dictionary
318      # ordering in some Python versions.
319      for device, saver in sorted(self._single_device_savers.items()):
320        with ops.device(device):
321          restore_ops.update(saver.restore(file_prefix, options))
322
323      return restore_ops
324
325    # Since this will causes a function re-trace on each save, limit this to the
326    # cases where it is needed: eager and when there are multiple tasks/single
327    # device savers. Note that the retrace is needed to ensure we pickup the
328    # latest values of options like experimental_io_device.
329    if context.executing_eagerly() and len(self._single_device_savers) > 1:
330      first_device, _ = list(self._single_device_savers.items())[0]
331      @def_function.function(jit_compile=False)
332      def tf_function_restore():
333        restore_ops = restore_fn()
334        restore_tensors = {}
335        # tf.functions must return tensors, thus we use control dependencies so
336        # that we can return a tensor which depends on the given op.
337        with ops.device(saveable_object_util.set_cpu0(first_device)):
338          for name, op in restore_ops.items():
339            with ops.control_dependencies([op]):
340              restore_tensors[name] = array_ops.identity(file_prefix)
341        return restore_tensors
342
343      restore_ops = tf_function_restore()
344    else:
345      restore_ops = restore_fn()
346
347    for callback in self._after_restore_callbacks:
348      callback()
349
350    return restore_ops
351