• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/io_ops.cc.
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/kernels/checkpoint_callback_manager.h"
28 #include "tensorflow/core/kernels/save_restore_tensor.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/io/path.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/saved_tensor_slice_util.h"
35 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
36 #include "tensorflow/core/util/tensor_slice_reader.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 
42 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
ValidateInputs(bool is_save_op,OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices)43 void ValidateInputs(bool is_save_op, OpKernelContext* context,
44                     const Tensor& prefix, const Tensor& tensor_names,
45                     const Tensor& shape_and_slices) {
46   const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
47   const int num_tensors = static_cast<int>(tensor_names.NumElements());
48   OP_REQUIRES(
49       context, prefix.NumElements() == 1,
50       errors::InvalidArgument("Input prefix should have a single element, got ",
51                               prefix.NumElements(), " instead."));
52   OP_REQUIRES(context,
53               TensorShapeUtils::IsVector(tensor_names.shape()) &&
54                   TensorShapeUtils::IsVector(shape_and_slices.shape()),
55               errors::InvalidArgument(
56                   "Input tensor_names and shape_and_slices "
57                   "should be an 1-D tensors, got ",
58                   tensor_names.shape().DebugString(), " and ",
59                   shape_and_slices.shape().DebugString(), " instead."));
60   OP_REQUIRES(context,
61               tensor_names.NumElements() == shape_and_slices.NumElements(),
62               errors::InvalidArgument("tensor_names and shape_and_slices "
63                                       "have different number of elements: ",
64                                       tensor_names.NumElements(), " vs. ",
65                                       shape_and_slices.NumElements()));
66   OP_REQUIRES(context,
67               FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
68                               std::numeric_limits<int>::max()),
69               errors::InvalidArgument("Too many inputs to the op"));
70   OP_REQUIRES(
71       context, shape_and_slices.NumElements() == num_tensors,
72       errors::InvalidArgument("Expected ", num_tensors,
73                               " elements in shapes_and_slices, but got ",
74                               context->input(2).NumElements()));
75   if (is_save_op) {
76     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
77                 errors::InvalidArgument(
78                     "Got ", num_tensors, " tensor names but ",
79                     context->num_inputs() - kFixedInputs, " tensors."));
80     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
81                 errors::InvalidArgument(
82                     "Expected a total of ", num_tensors + kFixedInputs,
83                     " inputs as input #1 (which is a string "
84                     "tensor of saved names) contains ",
85                     num_tensors, " names, but received ", context->num_inputs(),
86                     " inputs"));
87   }
88 }
89 
90 }  // namespace
91 
92 // Saves a list of named tensors using the tensor bundle library.
93 class SaveV2 : public OpKernel {
94  public:
SaveV2(OpKernelConstruction * context)95   explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
96 
Compute(OpKernelContext * context)97   void Compute(OpKernelContext* context) override {
98     const Tensor& prefix = context->input(0);
99     const Tensor& tensor_names = context->input(1);
100     const Tensor& shape_and_slices = context->input(2);
101     ValidateInputs(true /* is save op */, context, prefix, tensor_names,
102                    shape_and_slices);
103     if (!context->status().ok()) return;
104 
105     const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
106     const int num_tensors = static_cast<int>(tensor_names.NumElements());
107     const string& prefix_string = prefix.scalar<tstring>()();
108     const auto& tensor_names_flat = tensor_names.flat<tstring>();
109     const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
110 
111     BundleWriter writer(Env::Default(), prefix_string);
112     OP_REQUIRES_OK(context, writer.status());
113     VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
114 
115     for (int i = 0; i < num_tensors; ++i) {
116       const string& tensor_name = tensor_names_flat(i);
117       const Tensor& tensor = context->input(i + kFixedInputs);
118       VLOG(2) << "Starting save of " << tensor_name;
119 
120       if (!shape_and_slices_flat(i).empty()) {
121         const string& shape_spec = shape_and_slices_flat(i);
122         TensorShape shape;
123         TensorSlice slice(tensor.dims());
124         TensorShape slice_shape;
125 
126         OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
127                                     shape_spec, &shape, &slice, &slice_shape));
128         OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
129                     errors::InvalidArgument("Slice in shape_and_slice "
130                                             "specification does not match the "
131                                             "shape of the tensor to  save: ",
132                                             shape_spec, ", tensor: ",
133                                             tensor.shape().DebugString()));
134 
135         OP_REQUIRES_OK(context,
136                        writer.AddSlice(tensor_name, shape, slice, tensor));
137       } else {
138         OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
139       }
140 
141       if (VLOG_IS_ON(5)) {
142         if (tensor.dtype() == DT_FLOAT) {
143           const float* t_data = tensor.flat<float>().data();
144           float min = std::numeric_limits<float>::infinity();
145           float max = -std::numeric_limits<float>::infinity();
146           double avg = 0.0;
147           for (int i = 0; i < tensor.NumElements(); ++i) {
148             if (t_data[i] < min) min = t_data[i];
149             if (t_data[i] > max) max = t_data[i];
150             avg += t_data[i];
151           }
152           VLOG(5) << " min " << min << " max " << max << " avg "
153                   << avg / tensor.NumElements() << " total elts "
154                   << tensor.NumElements();
155         }
156       }
157 
158       VLOG(2) << "Done save of " << tensor_name;
159     }
160     OP_REQUIRES_OK(context, writer.Finish());
161     VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string;
162 
163     ResourceMgr* resource_manager = context->resource_manager();
164     if (resource_manager != nullptr) {
165       checkpoint::CheckpointCallbackManager* checkpoint_callback_manager;
166       OP_REQUIRES_OK(
167           context,
168           resource_manager
169               ->LookupOrCreate<checkpoint::CheckpointCallbackManager>(
170                   resource_manager->default_container(),
171                   std::string(
172                       checkpoint::kCheckpointCallbackManagerResourceName),
173                   &checkpoint_callback_manager,
174                   [](checkpoint::CheckpointCallbackManager** out) {
175                     *out = new checkpoint::CheckpointCallbackManager();
176                     return OkStatus();
177                   }));
178       checkpoint_callback_manager->Save(prefix_string);
179       checkpoint_callback_manager->Unref();
180     }
181   }
182 };
183 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
184 
185 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
186 class RestoreV2 : public OpKernel {
187  public:
RestoreV2(OpKernelConstruction * context)188   explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
189     OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
190   }
191 
Compute(OpKernelContext * context)192   void Compute(OpKernelContext* context) override {
193     const Tensor& prefix = context->input(0);
194     const Tensor& tensor_names = context->input(1);
195     const Tensor& shape_and_slices = context->input(2);
196     OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
197                 errors::InvalidArgument("Got ", tensor_names.NumElements(),
198                                         " tensor names, but ", dtypes_.size(),
199                                         " expected dtypes."));
200     ValidateInputs(false /* not save op */, context, prefix, tensor_names,
201                    shape_and_slices);
202     if (!context->status().ok()) return;
203 
204     const string& prefix_string = prefix.scalar<tstring>()();
205 
206     // Intention: we plan to use the RestoreV2 op as a backward-compatible
207     // reader as we upgrade to the V2 format.  This allows transparent upgrade.
208     // We here attempt to read a V1 checkpoint, if "prefix_string" does not
209     // refer to a V2 checkpoint.
210     Env* env = Env::Default();
211     std::vector<string> paths;
212     if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
213         paths.empty()) {
214       // Cannot find V2's metadata file, so "prefix_string" does not point to a
215       // V2 checkpoint.  Invokes the V1 read path instead.
216       for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
217         RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
218                       /* preferred_shard */ -1, /* restore_slice */ true,
219                       /* restore_index */ i);
220         if (!context->status().ok()) {
221           return;
222         }
223       }
224       return;
225     }
226     // If found, invokes the V2 reader.
227     OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
228                                              shape_and_slices, dtypes_));
229 
230     ResourceMgr* resource_manager = context->resource_manager();
231     if (resource_manager != nullptr) {
232       checkpoint::CheckpointCallbackManager* checkpoint_callback_manager;
233       OP_REQUIRES_OK(
234           context,
235           resource_manager
236               ->LookupOrCreate<checkpoint::CheckpointCallbackManager>(
237                   resource_manager->default_container(),
238                   std::string(
239                       checkpoint::kCheckpointCallbackManagerResourceName),
240                   &checkpoint_callback_manager,
241                   [](checkpoint::CheckpointCallbackManager** out) {
242                     *out = new checkpoint::CheckpointCallbackManager();
243                     return OkStatus();
244                   }));
245       checkpoint_callback_manager->Restore(prefix_string);
246       checkpoint_callback_manager->Unref();
247     }
248   }
249 
250  private:
251   // Expected dtypes of the to-restore tensors.
252   std::vector<DataType> dtypes_;
253 };
254 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
255 
256 // The final step in saving sharded V2 checkpoints: merges metadata files.
257 class MergeV2Checkpoints : public OpKernel {
258  public:
MergeV2Checkpoints(OpKernelConstruction * context)259   explicit MergeV2Checkpoints(OpKernelConstruction* context)
260       : OpKernel(context) {
261     OP_REQUIRES_OK(context,
262                    context->GetAttr("delete_old_dirs", &delete_old_dirs_));
263     OP_REQUIRES_OK(context, context->GetAttr("allow_missing_files",
264                                              &allow_missing_files_));
265   }
266 
Compute(OpKernelContext * context)267   void Compute(OpKernelContext* context) override {
268     const Tensor& checkpoint_prefixes = context->input(0);
269     const Tensor& destination_prefix = context->input(1);
270     OP_REQUIRES(context,
271                 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
272                 errors::InvalidArgument(
273                     "Input checkpoint_prefixes should be an 1-D tensor, got ",
274                     checkpoint_prefixes.shape().DebugString(), " instead."));
275     OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
276                 errors::InvalidArgument(
277                     "Input destination_prefix should be a scalar tensor, got ",
278                     destination_prefix.shape().DebugString(), " instead."));
279 
280     const gtl::ArraySlice<tstring> input_prefixes =
281         gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>());
282     Env* env = Env::Default();
283     const string& merged_prefix = destination_prefix.scalar<tstring>()();
284     OP_REQUIRES_OK(context,
285                    tensorflow::MergeBundles(env, input_prefixes, merged_prefix,
286                                             allow_missing_files_));
287 
288     if (delete_old_dirs_) {
289       const string merged_dir(io::Dirname(merged_prefix));
290       for (const string& input_prefix : input_prefixes) {
291         const string dirname(io::Dirname(input_prefix));
292         if (dirname == merged_dir) continue;
293         Status status = env->DeleteDir(dirname);
294         // For sharded save, only the first delete will go through and all
295         // others will hit NotFound.  Use vlog to be less verbose.
296         if (!status.ok()) VLOG(1) << status;
297       }
298     }
299   }
300 
301  private:
302   // On merge, whether or not to delete the input (temporary) directories.
303   bool delete_old_dirs_;
304 
305   // On merge, whether or not to relax condition that all input prefix filenames
306   // to exist.
307   bool allow_missing_files_;
308 };
309 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
310                         MergeV2Checkpoints);
311 
312 }  // namespace tensorflow
313