1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Saves and restore variables inside traced @tf.functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import saver_pb2 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gen_io_ops 30from tensorflow.python.ops import io_ops 31from tensorflow.python.ops import string_ops 32from tensorflow.python.training.saving import checkpoint_options 33from tensorflow.python.training.saving import saveable_hook 34from tensorflow.python.training.saving import saveable_object 35from tensorflow.python.training.saving import saveable_object_util 36from tensorflow.python.util import nest 37 38 39class _SingleDeviceSaver(object): 40 """Saves and restores checkpoints from the current device.""" 41 42 __slots__ = ["_saveable_objects"] 43 44 def __init__(self, saveable_objects): 45 """Specify a list of `SaveableObject`s to save and restore. 46 47 Args: 48 saveable_objects: A list of `SaveableObject`s. 49 """ 50 saveable_objects = list(saveable_objects) 51 for saveable in saveable_objects: 52 if not isinstance(saveable, saveable_object.SaveableObject): 53 raise ValueError( 54 "Expected a list of SaveableObjects, got %s." % (saveable,)) 55 self._saveable_objects = saveable_objects 56 57 def save(self, file_prefix, options=None): 58 """Save the saveable objects to a checkpoint with `file_prefix`. 59 60 Args: 61 file_prefix: A string or scalar string Tensor containing the prefix to 62 save under. 63 options: Optional `CheckpointOptions` object. 64 Returns: 65 An `Operation`, or None when executing eagerly. 66 """ 67 options = options or checkpoint_options.CheckpointOptions() 68 tensor_names = [] 69 tensors = [] 70 tensor_slices = [] 71 for saveable in self._saveable_objects: 72 for spec in saveable.specs: 73 tensor = spec.tensor 74 # A tensor value of `None` indicates that this SaveableObject gets 75 # recorded in the object graph, but that no value is saved in the 76 # checkpoint. 77 if tensor is not None: 78 tensor_names.append(spec.name) 79 tensors.append(tensor) 80 tensor_slices.append(spec.slice_spec) 81 save_device = options.experimental_io_device or "cpu:0" 82 with ops.device(save_device): 83 return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors) 84 85 def restore(self, file_prefix, options=None): 86 """Restore the saveable objects from a checkpoint with `file_prefix`. 87 88 Args: 89 file_prefix: A string or scalar string Tensor containing the prefix for 90 files to read from. 91 options: Optional `CheckpointOptions` object. 92 93 Returns: 94 A dictionary mapping from SaveableObject names to restore operations. 95 """ 96 options = options or checkpoint_options.CheckpointOptions() 97 restore_specs = [] 98 tensor_structure = [] 99 for saveable in self._saveable_objects: 100 saveable_tensor_structure = [] 101 tensor_structure.append(saveable_tensor_structure) 102 for spec in saveable.specs: 103 saveable_tensor_structure.append(spec.name) 104 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 105 tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) 106 restore_device = options.experimental_io_device or "cpu:0" 107 with ops.device(restore_device): 108 restored_tensors = io_ops.restore_v2( 109 file_prefix, tensor_names, tensor_slices, tensor_dtypes) 110 structured_restored_tensors = nest.pack_sequence_as( 111 tensor_structure, restored_tensors) 112 restore_ops = {} 113 for saveable, restored_tensors in zip(self._saveable_objects, 114 structured_restored_tensors): 115 restore_ops[saveable.name] = saveable.restore( 116 restored_tensors, restored_shapes=None) 117 return restore_ops 118 119 120def sharded_filename(filename_tensor, shard, num_shards): 121 """Append sharding information to a filename. 122 123 Args: 124 filename_tensor: A string tensor. 125 shard: Integer. The shard for the filename. 126 num_shards: An int Tensor for the number of shards. 127 128 Returns: 129 A string tensor. 130 """ 131 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 132 133 134class MultiDeviceSaver(object): 135 """Saves checkpoints directly from multiple devices. 136 137 Note that this is a low-level utility which stores Tensors in the keys 138 specified by `SaveableObject`s. Higher-level utilities for object-based 139 checkpointing are built on top of it. 140 """ 141 142 def __init__(self, saveable_objects): 143 """Specify a list of `SaveableObject`s to save and restore. 144 145 Args: 146 saveable_objects: A list of `SaveableObject`s. 147 Objects extending `SaveableObject` will be saved and restored, and 148 objects extending `SaveableHook` will be called into at save and 149 restore time. 150 """ 151 self._before_save_callbacks = [] 152 self._after_restore_callbacks = [] 153 154 saveable_objects = list(saveable_objects) 155 saveables_by_device = {} 156 for saveable in saveable_objects: 157 is_saveable = isinstance(saveable, saveable_object.SaveableObject) 158 is_hook = isinstance(saveable, saveable_hook.SaveableHook) 159 160 if not is_saveable and not is_hook: 161 raise ValueError( 162 "Expected a dictionary of SaveableObjects, got {}." 163 .format(saveable)) 164 165 if is_hook: 166 self._before_save_callbacks.append(saveable.before_save) 167 self._after_restore_callbacks.append(saveable.after_restore) 168 169 if is_saveable: 170 host_device = saveable_object_util.set_cpu0(saveable.device) 171 saveables_by_device.setdefault(host_device, []).append(saveable) 172 173 self._single_device_savers = { 174 device: _SingleDeviceSaver(saveables) 175 for device, saveables in saveables_by_device.items()} 176 177 def to_proto(self): 178 """Serializes to a SaverDef referencing the current graph.""" 179 filename_tensor = array_ops.placeholder( 180 shape=[], dtype=dtypes.string, name="saver_filename") 181 save_tensor = self._traced_save(filename_tensor) 182 restore_op = self._traced_restore(filename_tensor).op 183 return saver_pb2.SaverDef( 184 filename_tensor_name=filename_tensor.name, 185 save_tensor_name=save_tensor.name, 186 restore_op_name=restore_op.name, 187 version=saver_pb2.SaverDef.V2) 188 189 @def_function.function( 190 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), 191 autograph=False) 192 def _traced_save(self, file_prefix): 193 save_op = self.save(file_prefix) 194 with ops.device("cpu:0"): 195 with ops.control_dependencies([save_op]): 196 return array_ops.identity(file_prefix) 197 198 @def_function.function( 199 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), 200 autograph=False) 201 def _traced_restore(self, file_prefix): 202 restore_ops = self.restore(file_prefix) 203 with ops.device("cpu:0"): 204 with ops.control_dependencies(restore_ops.values()): 205 return array_ops.identity(file_prefix) 206 207 def save(self, file_prefix, options=None): 208 """Save the saveable objects to a checkpoint with `file_prefix`. 209 210 Args: 211 file_prefix: A string or scalar string Tensor containing the prefix to 212 save under. 213 options: Optional `CheckpointOptions` object. 214 Returns: 215 An `Operation`, or None when executing eagerly. 216 """ 217 options = options or checkpoint_options.CheckpointOptions() 218 for callback in self._before_save_callbacks: 219 callback() 220 221 # IMPLEMENTATION DETAILS: most clients should skip. 222 # 223 # Suffix for any well-formed "checkpoint_prefix", when sharded. 224 # Transformations: 225 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 226 # * checkpoint_prefix gets fed <save_path><sharded_suffix>. 227 # 228 # Example: 229 # During runtime, a temporary directory is first created, which contains 230 # files 231 # 232 # <train dir>/myckpt_temp/ 233 # part-?????-of-?????{.index, .data-00000-of-00001} 234 # 235 # Before .save() finishes, they will be (hopefully, atomically) renamed to 236 # 237 # <train dir>/ 238 # myckpt{.index, .data-?????-of-?????} 239 # 240 # Filesystems with eventual consistency (such as S3), don't need a 241 # temporary location. Using a temporary directory in those cases might 242 # cause situations where files are not available during copy. 243 # 244 # Users only need to interact with the user-specified prefix, which is 245 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 246 # prefix directly, instead of any physical pathname. (On failure and 247 # subsequent restore, an outdated and orphaned temporary directory can be 248 # safely removed.) 249 with ops.device("CPU"): 250 sharded_suffix = array_ops.where( 251 string_ops.regex_full_match(file_prefix, "^s3://.*"), 252 constant_op.constant(".part"), 253 constant_op.constant("_temp/part")) 254 tmp_checkpoint_prefix = string_ops.string_join( 255 [file_prefix, sharded_suffix]) 256 257 def save_fn(): 258 num_shards = len(self._single_device_savers) 259 sharded_saves = [] 260 sharded_prefixes = [] 261 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 262 last_device = None 263 for shard, (device, saver) in enumerate( 264 sorted(self._single_device_savers.items())): 265 last_device = device 266 with ops.device(saveable_object_util.set_cpu0(device)): 267 shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, 268 num_shards_tensor) 269 sharded_prefixes.append(shard_prefix) 270 with ops.device(device): 271 # _SingleDeviceSaver will use the CPU device when necessary, but 272 # initial read operations should be placed on the SaveableObject's 273 # device. 274 sharded_saves.append(saver.save(shard_prefix, options)) 275 276 with ops.control_dependencies(sharded_saves): 277 # Merge on the io_device if specified, otherwise co-locates the merge op 278 # with the last device used. 279 merge_device = ( 280 options.experimental_io_device or 281 saveable_object_util.set_cpu0(last_device)) 282 with ops.device(merge_device): 283 # V2 format write path consists of a metadata merge step. Once 284 # merged, attempts to delete the temporary directory, 285 # "<user-fed prefix>_temp". 286 return gen_io_ops.merge_v2_checkpoints( 287 sharded_prefixes, file_prefix, delete_old_dirs=True) 288 289 # Since this will causes a function re-trace on each save, limit this to the 290 # cases where it is needed: eager and when there are multiple tasks/single 291 # device savers. Note that the retrace is needed to ensure we pickup the 292 # latest values of options like experimental_io_device. 293 if context.executing_eagerly() and len(self._single_device_savers) > 1: 294 # Explicitly place the identity op on the first device. 295 @def_function.function(jit_compile=False) 296 def tf_function_save(): 297 save_fn() 298 tf_function_save() 299 else: 300 return save_fn() 301 302 def restore(self, file_prefix, options=None): 303 """Restore the saveable objects from a checkpoint with `file_prefix`. 304 305 Args: 306 file_prefix: A string or scalar string Tensor containing the prefix for 307 files to read from. 308 options: Optional `CheckpointOptions` object. 309 310 Returns: 311 A dictionary mapping from SaveableObject names to restore operations. 312 """ 313 options = options or checkpoint_options.CheckpointOptions() 314 315 def restore_fn(): 316 restore_ops = {} 317 # Sort by device name to avoid propagating non-deterministic dictionary 318 # ordering in some Python versions. 319 for device, saver in sorted(self._single_device_savers.items()): 320 with ops.device(device): 321 restore_ops.update(saver.restore(file_prefix, options)) 322 323 return restore_ops 324 325 # Since this will causes a function re-trace on each save, limit this to the 326 # cases where it is needed: eager and when there are multiple tasks/single 327 # device savers. Note that the retrace is needed to ensure we pickup the 328 # latest values of options like experimental_io_device. 329 if context.executing_eagerly() and len(self._single_device_savers) > 1: 330 first_device, _ = list(self._single_device_savers.items())[0] 331 @def_function.function(jit_compile=False) 332 def tf_function_restore(): 333 restore_ops = restore_fn() 334 restore_tensors = {} 335 # tf.functions must return tensors, thus we use control dependencies so 336 # that we can return a tensor which depends on the given op. 337 with ops.device(saveable_object_util.set_cpu0(first_device)): 338 for name, op in restore_ops.items(): 339 with ops.control_dependencies([op]): 340 restore_tensors[name] = array_ops.identity(file_prefix) 341 return restore_tensors 342 343 restore_ops = tf_function_restore() 344 else: 345 restore_ops = restore_fn() 346 347 for callback in self._after_restore_callbacks: 348 callback() 349 350 return restore_ops 351