• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/io_ops.cc.
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/save_restore_tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 
36 namespace tensorflow {
37 
38 namespace {
39 
40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
ValidateInputs(bool is_save_op,OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices)41 void ValidateInputs(bool is_save_op, OpKernelContext* context,
42                     const Tensor& prefix, const Tensor& tensor_names,
43                     const Tensor& shape_and_slices) {
44   const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
45   const int num_tensors = static_cast<int>(tensor_names.NumElements());
46   OP_REQUIRES(
47       context, prefix.NumElements() == 1,
48       errors::InvalidArgument("Input prefix should have a single element, got ",
49                               prefix.NumElements(), " instead."));
50   OP_REQUIRES(context,
51               TensorShapeUtils::IsVector(tensor_names.shape()) &&
52                   TensorShapeUtils::IsVector(shape_and_slices.shape()),
53               errors::InvalidArgument(
54                   "Input tensor_names and shape_and_slices "
55                   "should be an 1-D tensors, got ",
56                   tensor_names.shape().DebugString(), " and ",
57                   shape_and_slices.shape().DebugString(), " instead."));
58   OP_REQUIRES(context,
59               tensor_names.NumElements() == shape_and_slices.NumElements(),
60               errors::InvalidArgument("tensor_names and shape_and_slices "
61                                       "have different number of elements: ",
62                                       tensor_names.NumElements(), " vs. ",
63                                       shape_and_slices.NumElements()));
64   OP_REQUIRES(context,
65               FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
66                               std::numeric_limits<int>::max()),
67               errors::InvalidArgument("Too many inputs to the op"));
68   OP_REQUIRES(
69       context, shape_and_slices.NumElements() == num_tensors,
70       errors::InvalidArgument("Expected ", num_tensors,
71                               " elements in shapes_and_slices, but got ",
72                               context->input(2).NumElements()));
73   if (is_save_op) {
74     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
75                 errors::InvalidArgument(
76                     "Got ", num_tensors, " tensor names but ",
77                     context->num_inputs() - kFixedInputs, " tensors."));
78     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
79                 errors::InvalidArgument(
80                     "Expected a total of ", num_tensors + kFixedInputs,
81                     " inputs as input #1 (which is a string "
82                     "tensor of saved names) contains ",
83                     num_tensors, " names, but received ", context->num_inputs(),
84                     " inputs"));
85   }
86 }
87 
88 }  // namespace
89 
90 // Saves a list of named tensors using the tensor bundle library.
91 class SaveV2 : public OpKernel {
92  public:
SaveV2(OpKernelConstruction * context)93   explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
94 
Compute(OpKernelContext * context)95   void Compute(OpKernelContext* context) override {
96     const Tensor& prefix = context->input(0);
97     const Tensor& tensor_names = context->input(1);
98     const Tensor& shape_and_slices = context->input(2);
99     ValidateInputs(true /* is save op */, context, prefix, tensor_names,
100                    shape_and_slices);
101 
102     const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
103     const int num_tensors = static_cast<int>(tensor_names.NumElements());
104     const string& prefix_string = prefix.scalar<tstring>()();
105     const auto& tensor_names_flat = tensor_names.flat<tstring>();
106     const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
107 
108     BundleWriter writer(Env::Default(), prefix_string);
109     OP_REQUIRES_OK(context, writer.status());
110     VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
111 
112     for (int i = 0; i < num_tensors; ++i) {
113       const string& tensor_name = tensor_names_flat(i);
114       const Tensor& tensor = context->input(i + kFixedInputs);
115       VLOG(2) << "Starting save of " << tensor_name;
116 
117       if (!shape_and_slices_flat(i).empty()) {
118         const string& shape_spec = shape_and_slices_flat(i);
119         TensorShape shape;
120         TensorSlice slice(tensor.dims());
121         TensorShape slice_shape;
122 
123         OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
124                                     shape_spec, &shape, &slice, &slice_shape));
125         OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
126                     errors::InvalidArgument("Slice in shape_and_slice "
127                                             "specification does not match the "
128                                             "shape of the tensor to  save: ",
129                                             shape_spec, ", tensor: ",
130                                             tensor.shape().DebugString()));
131 
132         OP_REQUIRES_OK(context,
133                        writer.AddSlice(tensor_name, shape, slice, tensor));
134       } else {
135         OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
136       }
137 
138       if (VLOG_IS_ON(5)) {
139         if (tensor.dtype() == DT_FLOAT) {
140           const float* t_data = tensor.flat<float>().data();
141           float min = std::numeric_limits<float>::infinity();
142           float max = -std::numeric_limits<float>::infinity();
143           double avg = 0.0;
144           for (int i = 0; i < tensor.NumElements(); ++i) {
145             if (t_data[i] < min) min = t_data[i];
146             if (t_data[i] > max) max = t_data[i];
147             avg += t_data[i];
148           }
149           VLOG(5) << " min " << min << " max " << max << " avg "
150                   << avg / tensor.NumElements() << " total elts "
151                   << tensor.NumElements();
152         }
153       }
154 
155       VLOG(2) << "Done save of " << tensor_name;
156     }
157     OP_REQUIRES_OK(context, writer.Finish());
158     VLOG(1) << "Done BundleWriter, prefix_string: " << prefix_string;
159   }
160 };
161 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
162 
163 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
164 class RestoreV2 : public OpKernel {
165  public:
RestoreV2(OpKernelConstruction * context)166   explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
167     OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
168   }
169 
Compute(OpKernelContext * context)170   void Compute(OpKernelContext* context) override {
171     const Tensor& prefix = context->input(0);
172     const Tensor& tensor_names = context->input(1);
173     const Tensor& shape_and_slices = context->input(2);
174     OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
175                 errors::InvalidArgument("Got ", tensor_names.NumElements(),
176                                         " tensor names, but ", dtypes_.size(),
177                                         " expected dtypes."));
178     ValidateInputs(false /* not save op */, context, prefix, tensor_names,
179                    shape_and_slices);
180 
181     const string& prefix_string = prefix.scalar<tstring>()();
182 
183     // Intention: we plan to use the RestoreV2 op as a backward-compatible
184     // reader as we upgrade to the V2 format.  This allows transparent upgrade.
185     // We here attempt to read a V1 checkpoint, if "prefix_string" does not
186     // refer to a V2 checkpoint.
187     Env* env = Env::Default();
188     std::vector<string> paths;
189     if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
190         paths.empty()) {
191       // Cannot find V2's metadata file, so "prefix_string" does not point to a
192       // V2 checkpoint.  Invokes the V1 read path instead.
193       for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
194         RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
195                       /* preferred_shard */ -1, /* restore_slice */ true,
196                       /* restore_index */ i);
197         if (!context->status().ok()) {
198           return;
199         }
200       }
201       return;
202     }
203     // If found, invokes the V2 reader.
204     OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
205                                              shape_and_slices, dtypes_));
206   }
207 
208  private:
209   // Expected dtypes of the to-restore tensors.
210   std::vector<DataType> dtypes_;
211 };
212 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
213 
214 // The final step in saving sharded V2 checkpoints: merges metadata files.
215 class MergeV2Checkpoints : public OpKernel {
216  public:
MergeV2Checkpoints(OpKernelConstruction * context)217   explicit MergeV2Checkpoints(OpKernelConstruction* context)
218       : OpKernel(context) {
219     OP_REQUIRES_OK(context,
220                    context->GetAttr("delete_old_dirs", &delete_old_dirs_));
221   }
222 
Compute(OpKernelContext * context)223   void Compute(OpKernelContext* context) override {
224     const Tensor& checkpoint_prefixes = context->input(0);
225     const Tensor& destination_prefix = context->input(1);
226     OP_REQUIRES(context,
227                 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
228                 errors::InvalidArgument(
229                     "Input checkpoint_prefixes should be an 1-D tensor, got ",
230                     checkpoint_prefixes.shape().DebugString(), " instead."));
231     OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
232                 errors::InvalidArgument(
233                     "Input destination_prefix should be a scalar tensor, got ",
234                     destination_prefix.shape().DebugString(), " instead."));
235 
236     const gtl::ArraySlice<tstring> input_prefixes =
237         gtl::ArraySlice<tstring>(checkpoint_prefixes.flat<tstring>());
238     Env* env = Env::Default();
239     const string& merged_prefix = destination_prefix.scalar<tstring>()();
240     OP_REQUIRES_OK(
241         context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
242 
243     if (delete_old_dirs_) {
244       const string merged_dir(io::Dirname(merged_prefix));
245       for (const string& input_prefix : input_prefixes) {
246         const string dirname(io::Dirname(input_prefix));
247         if (dirname == merged_dir) continue;
248         Status status = env->DeleteDir(dirname);
249         // For sharded save, only the first delete will go through and all
250         // others will hit NotFound.  Use vlog to be less verbose.
251         if (!status.ok()) VLOG(1) << status;
252       }
253     }
254   }
255 
256  private:
257   // On merge, whether or not to delete the input (temporary) directories.
258   bool delete_old_dirs_;
259 };
260 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
261                         MergeV2Checkpoints);
262 
263 }  // namespace tensorflow
264