1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Saves and restore variables inside traced @tf.functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import saver_pb2 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_spec 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import io_ops 28from tensorflow.python.training.saving import saveable_object 29from tensorflow.python.util import nest 30 31 32class Saver(object): 33 """A minimal utility class for saving and restoring checkpoints. 34 35 Note that this is a low-level utility which stores Tensors in the keys 36 specified by `SaveableObject`s. Higher-level utilities for object-based 37 checkpointing are built on top of it. 38 """ 39 40 def __init__(self, saveable_objects): 41 """Specify a list of `SaveableObject`s to save and restore. 42 43 Args: 44 saveable_objects: A list of `SaveableObject`s. 45 """ 46 saveable_objects = list(saveable_objects) 47 for saveable in saveable_objects: 48 if not isinstance(saveable, saveable_object.SaveableObject): 49 raise ValueError( 50 "Saver expected a list of SaveableObjects, got %s." % (saveable,)) 51 self._saveable_objects = saveable_objects 52 53 def to_proto(self): 54 """Serializes to a SaverDef referencing the current graph.""" 55 filename_tensor = array_ops.placeholder( 56 shape=[], dtype=dtypes.string, name="saver_filename") 57 # TODO(allenl): Add save and restore function names to the proto directly. 58 signature = (tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),) 59 # Autograph is off because of reference cycles which must be collected when 60 # a function is created and destroyed (as in tf.saved_model.save). It's also 61 # not necessary, so having it off may be slightly faster. 62 # 63 # TODO(b/121302372): We should be able to decorate save() and restore() 64 # unconditionally. 65 save_tensor = def_function.function( 66 self.save, input_signature=signature, autograph=False)(filename_tensor) 67 restore_op = def_function.function( 68 self.restore, input_signature=signature, autograph=False)( 69 filename_tensor).op 70 return saver_pb2.SaverDef( 71 filename_tensor_name=filename_tensor.name, 72 save_tensor_name=save_tensor.name, 73 restore_op_name=restore_op.name, 74 version=saver_pb2.SaverDef.V2) 75 76 def save(self, file_prefix): 77 """Save the saveable objects to a checkpoint with `file_prefix`. 78 79 Args: 80 file_prefix: A string or scalar string Tensor containing the prefix to 81 save under. 82 Returns: 83 A scalar string Tensor containing `file_prefix` with control dependencies 84 on the save ops. 85 """ 86 tensor_names = [] 87 tensors = [] 88 tensor_slices = [] 89 for saveable in self._saveable_objects: 90 for spec in saveable.specs: 91 tensor_names.append(spec.name) 92 tensors.append(spec.tensor) 93 tensor_slices.append(spec.slice_spec) 94 with ops.device("cpu:0"): 95 with ops.control_dependencies([io_ops.save_v2( 96 file_prefix, tensor_names, tensor_slices, tensors)]): 97 return array_ops.identity(file_prefix) 98 99 def restore(self, file_prefix): 100 """Restore the saveable objects from a checkpoint with `file_prefix`. 101 102 Args: 103 file_prefix: A string or scalar string Tensor containing the prefix for 104 files to read from. 105 106 Returns: 107 A scalar string Tensor containing `file_prefix` with control dependencies 108 on the restore ops. 109 """ 110 restore_ops = restore_from_saveable_objects( 111 file_prefix, self._saveable_objects) 112 with ops.device("cpu:0"): 113 with ops.control_dependencies(restore_ops): 114 return array_ops.identity(file_prefix) 115 116 117def restore_from_saveable_objects(file_prefix, saveable_objects): 118 """Reads from a checkpoint and returns restore ops for `saveable_objects`s.""" 119 restore_specs = [] 120 tensor_structure = [] 121 for saveable in saveable_objects: 122 saveable_tensor_structure = [] 123 tensor_structure.append(saveable_tensor_structure) 124 for spec in saveable.specs: 125 saveable_tensor_structure.append(spec.name) 126 restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) 127 tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) 128 with ops.device("cpu:0"): 129 restored_tensors = io_ops.restore_v2( 130 file_prefix, tensor_names, tensor_slices, tensor_dtypes) 131 structured_restored_tensors = nest.pack_sequence_as( 132 tensor_structure, restored_tensors) 133 restore_ops = [] 134 for saveable, restored_tensors in zip(saveable_objects, 135 structured_restored_tensors): 136 restore_ops.append(saveable.restore(restored_tensors, 137 restored_shapes=None)) 138 return restore_ops 139