1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/io_ops.cc.
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/save_restore_tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35
36 namespace tensorflow {
37
38 namespace {
39
40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
ValidateInputs(bool is_save_op,OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices)41 void ValidateInputs(bool is_save_op, OpKernelContext* context,
42 const Tensor& prefix, const Tensor& tensor_names,
43 const Tensor& shape_and_slices) {
44 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
45 const int num_tensors = static_cast<int>(tensor_names.NumElements());
46 OP_REQUIRES(
47 context, prefix.NumElements() == 1,
48 errors::InvalidArgument("Input prefix should have a single element, got ",
49 prefix.NumElements(), " instead."));
50 OP_REQUIRES(context,
51 TensorShapeUtils::IsVector(tensor_names.shape()) &&
52 TensorShapeUtils::IsVector(shape_and_slices.shape()),
53 errors::InvalidArgument(
54 "Input tensor_names and shape_and_slices "
55 "should be an 1-D tensors, got ",
56 tensor_names.shape().DebugString(), " and ",
57 shape_and_slices.shape().DebugString(), " instead."));
58 OP_REQUIRES(context,
59 tensor_names.NumElements() == shape_and_slices.NumElements(),
60 errors::InvalidArgument("tensor_names and shape_and_slices "
61 "have different number of elements: ",
62 tensor_names.NumElements(), " vs. ",
63 shape_and_slices.NumElements()));
64 OP_REQUIRES(context,
65 FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
66 std::numeric_limits<int>::max()),
67 errors::InvalidArgument("Too many inputs to the op"));
68 OP_REQUIRES(
69 context, shape_and_slices.NumElements() == num_tensors,
70 errors::InvalidArgument("Expected ", num_tensors,
71 " elements in shapes_and_slices, but got ",
72 context->input(2).NumElements()));
73 if (is_save_op) {
74 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
75 errors::InvalidArgument(
76 "Got ", num_tensors, " tensor names but ",
77 context->num_inputs() - kFixedInputs, " tensors."));
78 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
79 errors::InvalidArgument(
80 "Expected a total of ", num_tensors + kFixedInputs,
81 " inputs as input #1 (which is a string "
82 "tensor of saved names) contains ",
83 num_tensors, " names, but received ", context->num_inputs(),
84 " inputs"));
85 }
86 }
87
88 } // namespace
89
90 // Saves a list of named tensors using the tensor bundle library.
91 class SaveV2 : public OpKernel {
92 public:
SaveV2(OpKernelConstruction * context)93 explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
94
Compute(OpKernelContext * context)95 void Compute(OpKernelContext* context) override {
96 const Tensor& prefix = context->input(0);
97 const Tensor& tensor_names = context->input(1);
98 const Tensor& shape_and_slices = context->input(2);
99 ValidateInputs(true /* is save op */, context, prefix, tensor_names,
100 shape_and_slices);
101 if (!context->status().ok()) return;
102
103 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
104 const int num_tensors = static_cast<int>(tensor_names.NumElements());
105 const string& prefix_string = prefix.scalar<tstring>()();
106 const auto& tensor_names_flat = tensor_names.flat<tstring>();
107 const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
108
109 BundleWriter writer(Env::Default(), prefix_string);
110 OP_REQUIRES_OK(context, writer.status());
111 VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
112
113 for (int i = 0; i < num_tensors; ++i) {
114 const string& tensor_name = tensor_names_flat(i);
115 const Tensor& tensor = context->input(i + kFixedInputs);
116 VLOG(2) << "Starting save of " << tensor_name;
117
118 if (!shape_and_slices_flat(i).empty()) {
119 const string& shape_spec = shape_and_slices_flat(i);
120 TensorShape shape;
121 TensorSlice slice(tensor.dims());
122 TensorShape slice_shape;
123
124 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
125 shape_spec, &shape, &slice, &slice_shape));
126 OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
127 errors::InvalidArgument("Slice in shape_and_slice "
128 "specification does not match the "
129 "shape of the tensor to save: ",
130 shape_spec, ", tensor: ",
131 tensor.shape().DebugString()));
132
133 OP_REQUIRES_OK(context,
134 writer.AddSlice(tensor_name, shape, slice, tensor));
135 } else {
136 OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
137 }
138
139 if (VLOG_IS_ON(5)) {
140 if (tensor.dtype() == DT_FLOAT) {
141 const float* t_data = tensor.flat<float>().data();
142 float min = std::numeric_limits<float>::infinity();
143 float max = -std::numeric_limits<float>::infinity();
144 double avg = 0.0;
145 for (int i = 0; i < tensor.NumElements(); ++i) {
146 if (t_data[i] < min) min = t_data[i];
147 if (t_data[i] > max) max = t_data[i];
148 avg += t_data[i];
149 }
150 VLOG(5) << " min " << min << " max " << max << " avg "
151 << avg / tensor.NumElements() << " total elts "
152 << tensor.NumElements();
153 }
154 }
155
156 VLOG(2) << "Done save of " << tensor_name;
157 }
158 OP_REQUIRES_OK(context, writer.Finish());
159 VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string;
160 }
161 };
162 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
163
164 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
165 class RestoreV2 : public OpKernel {
166 public:
RestoreV2(OpKernelConstruction * context)167 explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
168 OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
169 }
170
Compute(OpKernelContext * context)171 void Compute(OpKernelContext* context) override {
172 const Tensor& prefix = context->input(0);
173 const Tensor& tensor_names = context->input(1);
174 const Tensor& shape_and_slices = context->input(2);
175 OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
176 errors::InvalidArgument("Got ", tensor_names.NumElements(),
177 " tensor names, but ", dtypes_.size(),
178 " expected dtypes."));
179 ValidateInputs(false /* not save op */, context, prefix, tensor_names,
180 shape_and_slices);
181 if (!context->status().ok()) return;
182
183 const string& prefix_string = prefix.scalar<tstring>()();
184
185 // Intention: we plan to use the RestoreV2 op as a backward-compatible
186 // reader as we upgrade to the V2 format. This allows transparent upgrade.
187 // We here attempt to read a V1 checkpoint, if "prefix_string" does not
188 // refer to a V2 checkpoint.
189 Env* env = Env::Default();
190 std::vector<string> paths;
191 if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
192 paths.empty()) {
193 // Cannot find V2's metadata file, so "prefix_string" does not point to a
194 // V2 checkpoint. Invokes the V1 read path instead.
195 for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
196 RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
197 /* preferred_shard */ -1, /* restore_slice */ true,
198 /* restore_index */ i);
199 if (!context->status().ok()) {
200 return;
201 }
202 }
203 return;
204 }
205 // If found, invokes the V2 reader.
206 OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
207 shape_and_slices, dtypes_));
208 }
209
210 private:
211 // Expected dtypes of the to-restore tensors.
212 std::vector<DataType> dtypes_;
213 };
214 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
215
216 // The final step in saving sharded V2 checkpoints: merges metadata files.
217 class MergeV2Checkpoints : public OpKernel {
218 public:
MergeV2Checkpoints(OpKernelConstruction * context)219 explicit MergeV2Checkpoints(OpKernelConstruction* context)
220 : OpKernel(context) {
221 OP_REQUIRES_OK(context,
222 context->GetAttr("delete_old_dirs", &delete_old_dirs_));
223 }
224
Compute(OpKernelContext * context)225 void Compute(OpKernelContext* context) override {
226 const Tensor& checkpoint_prefixes = context->input(0);
227 const Tensor& destination_prefix = context->input(1);
228 OP_REQUIRES(context,
229 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
230 errors::InvalidArgument(
231 "Input checkpoint_prefixes should be an 1-D tensor, got ",
232 checkpoint_prefixes.shape().DebugString(), " instead."));
233 OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
234 errors::InvalidArgument(
235 "Input destination_prefix should be a scalar tensor, got ",
236 destination_prefix.shape().DebugString(), " instead."));
237
238 const gtl::ArraySlice<tstring> input_prefixes =
239 gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>());
240 Env* env = Env::Default();
241 const string& merged_prefix = destination_prefix.scalar<tstring>()();
242 OP_REQUIRES_OK(
243 context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
244
245 if (delete_old_dirs_) {
246 const string merged_dir(io::Dirname(merged_prefix));
247 for (const string& input_prefix : input_prefixes) {
248 const string dirname(io::Dirname(input_prefix));
249 if (dirname == merged_dir) continue;
250 Status status = env->DeleteDir(dirname);
251 // For sharded save, only the first delete will go through and all
252 // others will hit NotFound. Use vlog to be less verbose.
253 if (!status.ok()) VLOG(1) << status;
254 }
255 }
256 }
257
258 private:
259 // On merge, whether or not to delete the input (temporary) directories.
260 bool delete_old_dirs_;
261 };
262 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
263 MergeV2Checkpoints);
264
265 } // namespace tensorflow
266