1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/io_ops.cc.
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/save_restore_tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35
36 namespace tensorflow {
37
38 namespace {
39
40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
ValidateInputs(bool is_save_op,OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices)41 void ValidateInputs(bool is_save_op, OpKernelContext* context,
42 const Tensor& prefix, const Tensor& tensor_names,
43 const Tensor& shape_and_slices) {
44 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
45 const int num_tensors = static_cast<int>(tensor_names.NumElements());
46 OP_REQUIRES(
47 context, prefix.NumElements() == 1,
48 errors::InvalidArgument("Input prefix should have a single element, got ",
49 prefix.NumElements(), " instead."));
50 OP_REQUIRES(context,
51 TensorShapeUtils::IsVector(tensor_names.shape()) &&
52 TensorShapeUtils::IsVector(shape_and_slices.shape()),
53 errors::InvalidArgument(
54 "Input tensor_names and shape_and_slices "
55 "should be an 1-D tensors, got ",
56 tensor_names.shape().DebugString(), " and ",
57 shape_and_slices.shape().DebugString(), " instead."));
58 OP_REQUIRES(context,
59 tensor_names.NumElements() == shape_and_slices.NumElements(),
60 errors::InvalidArgument("tensor_names and shape_and_slices "
61 "have different number of elements: ",
62 tensor_names.NumElements(), " vs. ",
63 shape_and_slices.NumElements()));
64 OP_REQUIRES(context,
65 FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
66 std::numeric_limits<int>::max()),
67 errors::InvalidArgument("Too many inputs to the op"));
68 OP_REQUIRES(
69 context, shape_and_slices.NumElements() == num_tensors,
70 errors::InvalidArgument("Expected ", num_tensors,
71 " elements in shapes_and_slices, but got ",
72 context->input(2).NumElements()));
73 if (is_save_op) {
74 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
75 errors::InvalidArgument(
76 "Got ", num_tensors, " tensor names but ",
77 context->num_inputs() - kFixedInputs, " tensors."));
78 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
79 errors::InvalidArgument(
80 "Expected a total of ", num_tensors + kFixedInputs,
81 " inputs as input #1 (which is a string "
82 "tensor of saved names) contains ",
83 num_tensors, " names, but received ", context->num_inputs(),
84 " inputs"));
85 }
86 }
87
88 } // namespace
89
90 // Saves a list of named tensors using the tensor bundle library.
91 class SaveV2 : public OpKernel {
92 public:
SaveV2(OpKernelConstruction * context)93 explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
94
Compute(OpKernelContext * context)95 void Compute(OpKernelContext* context) override {
96 const Tensor& prefix = context->input(0);
97 const Tensor& tensor_names = context->input(1);
98 const Tensor& shape_and_slices = context->input(2);
99 ValidateInputs(true /* is save op */, context, prefix, tensor_names,
100 shape_and_slices);
101
102 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
103 const int num_tensors = static_cast<int>(tensor_names.NumElements());
104 const string& prefix_string = prefix.scalar<tstring>()();
105 const auto& tensor_names_flat = tensor_names.flat<tstring>();
106 const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
107
108 BundleWriter writer(Env::Default(), prefix_string);
109 OP_REQUIRES_OK(context, writer.status());
110 VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
111
112 for (int i = 0; i < num_tensors; ++i) {
113 const string& tensor_name = tensor_names_flat(i);
114 const Tensor& tensor = context->input(i + kFixedInputs);
115 VLOG(2) << "Starting save of " << tensor_name;
116
117 if (!shape_and_slices_flat(i).empty()) {
118 const string& shape_spec = shape_and_slices_flat(i);
119 TensorShape shape;
120 TensorSlice slice(tensor.dims());
121 TensorShape slice_shape;
122
123 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
124 shape_spec, &shape, &slice, &slice_shape));
125 OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
126 errors::InvalidArgument("Slice in shape_and_slice "
127 "specification does not match the "
128 "shape of the tensor to save: ",
129 shape_spec, ", tensor: ",
130 tensor.shape().DebugString()));
131
132 OP_REQUIRES_OK(context,
133 writer.AddSlice(tensor_name, shape, slice, tensor));
134 } else {
135 OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
136 }
137
138 if (VLOG_IS_ON(5)) {
139 if (tensor.dtype() == DT_FLOAT) {
140 const float* t_data = tensor.flat<float>().data();
141 float min = std::numeric_limits<float>::infinity();
142 float max = -std::numeric_limits<float>::infinity();
143 double avg = 0.0;
144 for (int i = 0; i < tensor.NumElements(); ++i) {
145 if (t_data[i] < min) min = t_data[i];
146 if (t_data[i] > max) max = t_data[i];
147 avg += t_data[i];
148 }
149 VLOG(5) << " min " << min << " max " << max << " avg "
150 << avg / tensor.NumElements() << " total elts "
151 << tensor.NumElements();
152 }
153 }
154
155 VLOG(2) << "Done save of " << tensor_name;
156 }
157 OP_REQUIRES_OK(context, writer.Finish());
158 VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string;
159 }
160 };
161 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
162
163 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
164 class RestoreV2 : public OpKernel {
165 public:
RestoreV2(OpKernelConstruction * context)166 explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
167 OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
168 }
169
Compute(OpKernelContext * context)170 void Compute(OpKernelContext* context) override {
171 const Tensor& prefix = context->input(0);
172 const Tensor& tensor_names = context->input(1);
173 const Tensor& shape_and_slices = context->input(2);
174 OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
175 errors::InvalidArgument("Got ", tensor_names.NumElements(),
176 " tensor names, but ", dtypes_.size(),
177 " expected dtypes."));
178 ValidateInputs(false /* not save op */, context, prefix, tensor_names,
179 shape_and_slices);
180
181 const string& prefix_string = prefix.scalar<tstring>()();
182
183 // Intention: we plan to use the RestoreV2 op as a backward-compatible
184 // reader as we upgrade to the V2 format. This allows transparent upgrade.
185 // We here attempt to read a V1 checkpoint, if "prefix_string" does not
186 // refer to a V2 checkpoint.
187 Env* env = Env::Default();
188 std::vector<string> paths;
189 if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
190 paths.empty()) {
191 // Cannot find V2's metadata file, so "prefix_string" does not point to a
192 // V2 checkpoint. Invokes the V1 read path instead.
193 for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
194 RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
195 /* preferred_shard */ -1, /* restore_slice */ true,
196 /* restore_index */ i);
197 if (!context->status().ok()) {
198 return;
199 }
200 }
201 return;
202 }
203 // If found, invokes the V2 reader.
204 OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
205 shape_and_slices, dtypes_));
206 }
207
208 private:
209 // Expected dtypes of the to-restore tensors.
210 std::vector<DataType> dtypes_;
211 };
212 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
213
214 // The final step in saving sharded V2 checkpoints: merges metadata files.
215 class MergeV2Checkpoints : public OpKernel {
216 public:
MergeV2Checkpoints(OpKernelConstruction * context)217 explicit MergeV2Checkpoints(OpKernelConstruction* context)
218 : OpKernel(context) {
219 OP_REQUIRES_OK(context,
220 context->GetAttr("delete_old_dirs", &delete_old_dirs_));
221 }
222
Compute(OpKernelContext * context)223 void Compute(OpKernelContext* context) override {
224 const Tensor& checkpoint_prefixes = context->input(0);
225 const Tensor& destination_prefix = context->input(1);
226 OP_REQUIRES(context,
227 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
228 errors::InvalidArgument(
229 "Input checkpoint_prefixes should be an 1-D tensor, got ",
230 checkpoint_prefixes.shape().DebugString(), " instead."));
231 OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
232 errors::InvalidArgument(
233 "Input destination_prefix should be a scalar tensor, got ",
234 destination_prefix.shape().DebugString(), " instead."));
235
236 const gtl::ArraySlice<tstring> input_prefixes =
237 gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>());
238 Env* env = Env::Default();
239 const string& merged_prefix = destination_prefix.scalar<tstring>()();
240 OP_REQUIRES_OK(
241 context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
242
243 if (delete_old_dirs_) {
244 const string merged_dir(io::Dirname(merged_prefix));
245 for (const string& input_prefix : input_prefixes) {
246 const string dirname(io::Dirname(input_prefix));
247 if (dirname == merged_dir) continue;
248 Status status = env->DeleteDir(dirname);
249 // For sharded save, only the first delete will go through and all
250 // others will hit NotFound. Use vlog to be less verbose.
251 if (!status.ok()) VLOG(1) << status;
252 }
253 }
254 }
255
256 private:
257 // On merge, whether or not to delete the input (temporary) directories.
258 bool delete_old_dirs_;
259 };
260 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
261 MergeV2Checkpoints);
262
263 } // namespace tensorflow
264