• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/io_ops.cc.
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/save_restore_tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 
36 namespace tensorflow {
37 
38 namespace {
39 
40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
ValidateInputs(bool is_save_op,OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices)41 void ValidateInputs(bool is_save_op, OpKernelContext* context,
42                     const Tensor& prefix, const Tensor& tensor_names,
43                     const Tensor& shape_and_slices) {
44   const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
45   const int num_tensors = static_cast<int>(tensor_names.NumElements());
46   OP_REQUIRES(
47       context, prefix.NumElements() == 1,
48       errors::InvalidArgument("Input prefix should have a single element, got ",
49                               prefix.NumElements(), " instead."));
50   OP_REQUIRES(context,
51               TensorShapeUtils::IsVector(tensor_names.shape()) &&
52                   TensorShapeUtils::IsVector(shape_and_slices.shape()),
53               errors::InvalidArgument(
54                   "Input tensor_names and shape_and_slices "
55                   "should be an 1-D tensors, got ",
56                   tensor_names.shape().DebugString(), " and ",
57                   shape_and_slices.shape().DebugString(), " instead."));
58   OP_REQUIRES(context,
59               tensor_names.NumElements() == shape_and_slices.NumElements(),
60               errors::InvalidArgument("tensor_names and shape_and_slices "
61                                       "have different number of elements: ",
62                                       tensor_names.NumElements(), " vs. ",
63                                       shape_and_slices.NumElements()));
64   OP_REQUIRES(context,
65               FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
66                               std::numeric_limits<int>::max()),
67               errors::InvalidArgument("Too many inputs to the op"));
68   OP_REQUIRES(
69       context, shape_and_slices.NumElements() == num_tensors,
70       errors::InvalidArgument("Expected ", num_tensors,
71                               " elements in shapes_and_slices, but got ",
72                               context->input(2).NumElements()));
73   if (is_save_op) {
74     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
75                 errors::InvalidArgument(
76                     "Got ", num_tensors, " tensor names but ",
77                     context->num_inputs() - kFixedInputs, " tensors."));
78     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
79                 errors::InvalidArgument(
80                     "Expected a total of ", num_tensors + kFixedInputs,
81                     " inputs as input #1 (which is a string "
82                     "tensor of saved names) contains ",
83                     num_tensors, " names, but received ", context->num_inputs(),
84                     " inputs"));
85   }
86 }
87 
88 }  // namespace
89 
90 // Saves a list of named tensors using the tensor bundle library.
91 class SaveV2 : public OpKernel {
92  public:
SaveV2(OpKernelConstruction * context)93   explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
94 
Compute(OpKernelContext * context)95   void Compute(OpKernelContext* context) override {
96     const Tensor& prefix = context->input(0);
97     const Tensor& tensor_names = context->input(1);
98     const Tensor& shape_and_slices = context->input(2);
99     ValidateInputs(true /* is save op */, context, prefix, tensor_names,
100                    shape_and_slices);
101     if (!context->status().ok()) return;
102 
103     const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
104     const int num_tensors = static_cast<int>(tensor_names.NumElements());
105     const string& prefix_string = prefix.scalar<tstring>()();
106     const auto& tensor_names_flat = tensor_names.flat<tstring>();
107     const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
108 
109     BundleWriter writer(Env::Default(), prefix_string);
110     OP_REQUIRES_OK(context, writer.status());
111     VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
112 
113     for (int i = 0; i < num_tensors; ++i) {
114       const string& tensor_name = tensor_names_flat(i);
115       const Tensor& tensor = context->input(i + kFixedInputs);
116       VLOG(2) << "Starting save of " << tensor_name;
117 
118       if (!shape_and_slices_flat(i).empty()) {
119         const string& shape_spec = shape_and_slices_flat(i);
120         TensorShape shape;
121         TensorSlice slice(tensor.dims());
122         TensorShape slice_shape;
123 
124         OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
125                                     shape_spec, &shape, &slice, &slice_shape));
126         OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
127                     errors::InvalidArgument("Slice in shape_and_slice "
128                                             "specification does not match the "
129                                             "shape of the tensor to  save: ",
130                                             shape_spec, ", tensor: ",
131                                             tensor.shape().DebugString()));
132 
133         OP_REQUIRES_OK(context,
134                        writer.AddSlice(tensor_name, shape, slice, tensor));
135       } else {
136         OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
137       }
138 
139       if (VLOG_IS_ON(5)) {
140         if (tensor.dtype() == DT_FLOAT) {
141           const float* t_data = tensor.flat<float>().data();
142           float min = std::numeric_limits<float>::infinity();
143           float max = -std::numeric_limits<float>::infinity();
144           double avg = 0.0;
145           for (int i = 0; i < tensor.NumElements(); ++i) {
146             if (t_data[i] < min) min = t_data[i];
147             if (t_data[i] > max) max = t_data[i];
148             avg += t_data[i];
149           }
150           VLOG(5) << " min " << min << " max " << max << " avg "
151                   << avg / tensor.NumElements() << " total elts "
152                   << tensor.NumElements();
153         }
154       }
155 
156       VLOG(2) << "Done save of " << tensor_name;
157     }
158     OP_REQUIRES_OK(context, writer.Finish());
159     VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string;
160   }
161 };
162 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
163 
164 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
165 class RestoreV2 : public OpKernel {
166  public:
RestoreV2(OpKernelConstruction * context)167   explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
168     OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
169   }
170 
Compute(OpKernelContext * context)171   void Compute(OpKernelContext* context) override {
172     const Tensor& prefix = context->input(0);
173     const Tensor& tensor_names = context->input(1);
174     const Tensor& shape_and_slices = context->input(2);
175     OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
176                 errors::InvalidArgument("Got ", tensor_names.NumElements(),
177                                         " tensor names, but ", dtypes_.size(),
178                                         " expected dtypes."));
179     ValidateInputs(false /* not save op */, context, prefix, tensor_names,
180                    shape_and_slices);
181     if (!context->status().ok()) return;
182 
183     const string& prefix_string = prefix.scalar<tstring>()();
184 
185     // Intention: we plan to use the RestoreV2 op as a backward-compatible
186     // reader as we upgrade to the V2 format.  This allows transparent upgrade.
187     // We here attempt to read a V1 checkpoint, if "prefix_string" does not
188     // refer to a V2 checkpoint.
189     Env* env = Env::Default();
190     std::vector<string> paths;
191     if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
192         paths.empty()) {
193       // Cannot find V2's metadata file, so "prefix_string" does not point to a
194       // V2 checkpoint.  Invokes the V1 read path instead.
195       for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
196         RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
197                       /* preferred_shard */ -1, /* restore_slice */ true,
198                       /* restore_index */ i);
199         if (!context->status().ok()) {
200           return;
201         }
202       }
203       return;
204     }
205     // If found, invokes the V2 reader.
206     OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
207                                              shape_and_slices, dtypes_));
208   }
209 
210  private:
211   // Expected dtypes of the to-restore tensors.
212   std::vector<DataType> dtypes_;
213 };
214 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
215 
216 // The final step in saving sharded V2 checkpoints: merges metadata files.
217 class MergeV2Checkpoints : public OpKernel {
218  public:
MergeV2Checkpoints(OpKernelConstruction * context)219   explicit MergeV2Checkpoints(OpKernelConstruction* context)
220       : OpKernel(context) {
221     OP_REQUIRES_OK(context,
222                    context->GetAttr("delete_old_dirs", &delete_old_dirs_));
223   }
224 
Compute(OpKernelContext * context)225   void Compute(OpKernelContext* context) override {
226     const Tensor& checkpoint_prefixes = context->input(0);
227     const Tensor& destination_prefix = context->input(1);
228     OP_REQUIRES(context,
229                 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
230                 errors::InvalidArgument(
231                     "Input checkpoint_prefixes should be an 1-D tensor, got ",
232                     checkpoint_prefixes.shape().DebugString(), " instead."));
233     OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
234                 errors::InvalidArgument(
235                     "Input destination_prefix should be a scalar tensor, got ",
236                     destination_prefix.shape().DebugString(), " instead."));
237 
238     const gtl::ArraySlice<tstring> input_prefixes =
239         gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>());
240     Env* env = Env::Default();
241     const string& merged_prefix = destination_prefix.scalar<tstring>()();
242     OP_REQUIRES_OK(
243         context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
244 
245     if (delete_old_dirs_) {
246       const string merged_dir(io::Dirname(merged_prefix));
247       for (const string& input_prefix : input_prefixes) {
248         const string dirname(io::Dirname(input_prefix));
249         if (dirname == merged_dir) continue;
250         Status status = env->DeleteDir(dirname);
251         // For sharded save, only the first delete will go through and all
252         // others will hit NotFound.  Use vlog to be less verbose.
253         if (!status.ok()) VLOG(1) << status;
254       }
255     }
256   }
257 
258  private:
259   // On merge, whether or not to delete the input (temporary) directories.
260   bool delete_old_dirs_;
261 };
262 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
263                         MergeV2Checkpoints);
264 
265 }  // namespace tensorflow
266