• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 #include <numeric>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 #include "tensorflow/core/util/tensor_slice_writer.h"
37 
38 namespace tensorflow {
39 
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)40 void SaveTensors(
41     OpKernelContext* context,
42     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
43     bool save_slices) {
44   const Tensor& filename_t = context->input(0);
45   {
46     const int64 size = filename_t.NumElements();
47     OP_REQUIRES(
48         context, size == 1,
49         errors::InvalidArgument(
50             "Input 0 (filename) must be a string scalar; got a tensor of ",
51             size, "elements"));
52   }
53 
54   // Path, names, and slices if save_slices is true.
55   const int kFixedInputs = save_slices ? 3 : 2;
56   const Tensor& tensor_names_t = context->input(1);
57   OP_REQUIRES(context,
58               FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
59                               std::numeric_limits<int>::max()),
60               errors::InvalidArgument("Too many inputs to SaveTensors"));
61   const int N = static_cast<int>(tensor_names_t.NumElements());
62   const string* tensor_shapes_and_slices_ptr = nullptr;
63   if (save_slices) {
64     const Tensor& tensor_shapes_and_slices_t = context->input(2);
65     OP_REQUIRES(
66         context,
67         tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
68         errors::InvalidArgument("Expected ", N,
69                                 " elements for the tensor "
70                                 "shapes and slices but got ",
71                                 tensor_shapes_and_slices_t.NumElements()));
72     tensor_shapes_and_slices_ptr =
73         tensor_shapes_and_slices_t.flat<string>().data();
74   }
75   OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
76               errors::InvalidArgument("Expected totally ", N + kFixedInputs,
77                                       " inputs as input #1 (which is a string "
78                                       "tensor of saved names) contains ",
79                                       N, " names, but received ",
80                                       context->num_inputs(), " inputs"));
81 
82   VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0)
83           << "...";
84   checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0),
85                                        std::move(builder_func));
86 
87   Status s;
88   auto tensor_names_flat = tensor_names_t.flat<string>();
89 
90   // Process tensors in sorted name order.  This allows us to avoid seeking
91   // during restoration in the common case where we are restoring a full
92   // checkpoint.
93   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
94   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
95   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
96             [&tensor_names_flat](size_t a, size_t b) {
97               return tensor_names_flat(a) < tensor_names_flat(b);
98             });
99 
100   for (const size_t i : sorted_name_idx) {
101     const string& name = tensor_names_flat(i);
102     const Tensor& input = context->input(i + kFixedInputs);
103     TensorShape shape(input.shape());
104     TensorSlice slice(input.dims());
105     if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
106       const string& shape_spec = tensor_shapes_and_slices_ptr[i];
107       TensorShape slice_shape;
108       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
109                                   shape_spec, &shape, &slice, &slice_shape));
110       OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
111                   errors::InvalidArgument(
112                       "Slice in shape_and_slice "
113                       "specification does not match the "
114                       "shape of the tensor to  save: ",
115                       shape_spec, ", tensor: ", input.shape().DebugString()));
116     }
117 
118 #define WRITER_ADD(T)                                           \
119   case DataTypeToEnum<T>::value:                                \
120     s = writer.Add(name, shape, slice, input.flat<T>().data()); \
121     break;
122 
123     switch (input.dtype()) {
124       TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
125       default:
126         context->SetStatus(errors::Unimplemented("Saving data type ",
127                                                  DataTypeString(input.dtype()),
128                                                  " not yet supported"));
129         return;
130     }
131 #undef WRITER_ADD
132     if (!s.ok()) {
133       context->SetStatus(s);
134       return;
135     }
136   }
137 
138   s = writer.Finish();
139   if (!s.ok()) {
140     context->SetStatus(s);
141   }
142 }
143 
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)144 void RestoreTensor(OpKernelContext* context,
145                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
146                    int preferred_shard, bool restore_slice, int restore_index) {
147   const Tensor& file_pattern_t = context->input(0);
148   {
149     const int64 size = file_pattern_t.NumElements();
150     OP_REQUIRES(
151         context, size == 1,
152         errors::InvalidArgument(
153             "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
154             size, "elements"));
155   }
156   const string& file_pattern = file_pattern_t.flat<string>()(0);
157 
158   const Tensor& tensor_name_t = context->input(1);
159   const string& tensor_name = tensor_name_t.flat<string>()(restore_index);
160 
161   // If we cannot find a cached reader we will allocate our own.
162   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
163 
164   const checkpoint::TensorSliceReader* reader = nullptr;
165 
166   if (context->slice_reader_cache()) {
167     reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
168                                                       preferred_shard);
169   }
170   if (!reader) {
171     allocated_reader.reset(new checkpoint::TensorSliceReader(
172         file_pattern, open_func, preferred_shard));
173     reader = allocated_reader.get();
174   }
175   OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
176 
177   // Get the shape and type from the save file.
178   DataType type;
179   TensorShape saved_shape;
180   OP_REQUIRES(
181       context, reader->HasTensor(tensor_name, &saved_shape, &type),
182       errors::NotFound("Tensor name \"", tensor_name,
183                        "\" not found in checkpoint files ", file_pattern));
184   OP_REQUIRES(
185       context, type == context->expected_output_dtype(restore_index),
186       errors::InvalidArgument("Expected to restore a tensor of type ",
187                               DataTypeString(context->expected_output_dtype(0)),
188                               ", got a tensor of type ", DataTypeString(type),
189                               " instead: tensor_name = ", tensor_name));
190 
191   // Shape of the output and slice to load.
192   TensorShape output_shape(saved_shape);
193   TensorSlice slice_to_load(saved_shape.dims());
194   if (restore_slice) {
195     const string& shape_spec = context->input(2).flat<string>()(restore_index);
196     if (!shape_spec.empty()) {
197       TensorShape parsed_shape;
198       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
199                                   shape_spec, &parsed_shape, &slice_to_load,
200                                   &output_shape));
201       OP_REQUIRES(
202           context, parsed_shape.IsSameSize(saved_shape),
203           errors::InvalidArgument(
204               "Shape in shape_and_slice spec does not match the shape in the "
205               "save file: ",
206               parsed_shape.DebugString(),
207               ", save file shape: ", saved_shape.DebugString()));
208     }
209   }
210 
211   Tensor* t = nullptr;
212   OP_REQUIRES_OK(context,
213                  context->allocate_output(restore_index, output_shape, &t));
214 
215   if (output_shape.num_elements() == 0) return;
216 
217 #define READER_COPY(T)                                                \
218   case DataTypeToEnum<T>::value:                                      \
219     OP_REQUIRES(context,                                              \
220                 reader->CopySliceData(tensor_name, slice_to_load,     \
221                                       t->flat<T>().data()),           \
222                 errors::InvalidArgument("Error copying slice data")); \
223     break;
224 
225   switch (type) {
226     TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
227     default:
228       context->SetStatus(errors::Unimplemented(
229           "Restoring data type ", DataTypeString(type), " not yet supported"));
230   }
231 #undef READER_COPY
232 }
233 
234 namespace {
235 
236 // Tensors larger than this threshold will be restored from a thread-pool.
237 const int64 kLargeShapeThreshold = 16 << 20;  // 16M
238 
239 // A restore operation for a single tensor.  Small tensors may be restored
240 // directly from the op thread to improve read locality.  Large tensors can be
241 // restored from a thread pool: this requires creating a separate BundleReader
242 // for each restore.
243 struct RestoreOp {
244   RestoreOp& operator=(const RestoreOp&) = delete;
245 
should_run_in_pooltensorflow::__anone6f391f90211::RestoreOp246   bool should_run_in_pool(BundleReader* reader) const {
247     TensorShape restored_full_shape;
248 
249     // Ignore status here; we'll catch the error later.
250     if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
251       return false;
252     }
253 
254     return restored_full_shape.num_elements() > kLargeShapeThreshold;
255   }
256 
257   // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anone6f391f90211::RestoreOp258   void run_with_new_reader() {
259     BundleReader reader(Env::Default(), reader_prefix);
260     if (!reader.status().ok()) {
261       status = reader.status();
262       return;
263     }
264 
265     status = run(&reader);
266   }
267 
runtensorflow::__anone6f391f90211::RestoreOp268   Status run(BundleReader* reader) {
269     TensorShape restored_full_shape;
270     TF_RETURN_IF_ERROR(
271         reader->LookupTensorShape(tensor_name, &restored_full_shape));
272 
273     VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
274             << restored_full_shape.num_elements();
275     Tensor* restored_tensor;
276     if (shape_and_slice.empty()) {
277       // Lookup the full tensor.
278       TF_RETURN_IF_ERROR(
279           context->allocate_output(idx, restored_full_shape, &restored_tensor));
280       TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
281     } else {
282       // Lookup the slice.
283       TensorShape parsed_full_shape;
284       TensorSlice parsed_slice;
285       TensorShape parsed_slice_shape;
286 
287       TF_RETURN_IF_ERROR(
288           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
289                                          &parsed_slice, &parsed_slice_shape));
290 
291       if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
292         return errors::InvalidArgument(
293             "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
294             parsed_full_shape.DebugString(),
295             " does not match the shape stored in checkpoint: ",
296             restored_full_shape.DebugString());
297       }
298       TF_RETURN_IF_ERROR(
299           context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
300       TF_RETURN_IF_ERROR(
301           reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
302     }
303     return Status::OK();
304   }
305 
306   OpKernelContext* context;
307   size_t idx;
308   string tensor_name;
309   string shape_and_slice;
310   string reader_prefix;
311 
312   ::tensorflow::Status status;
313 };
314 
315 }  // namespace
316 
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)317 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
318                         const Tensor& tensor_names,
319                         const Tensor& shape_and_slices,
320                         gtl::ArraySlice<DataType> dtypes) {
321   const string& prefix_string = prefix.scalar<string>()();
322 
323   const auto& tensor_names_flat = tensor_names.flat<string>();
324   const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
325 
326   // Sort lookup keys to improve locality when reading multiple tensors.
327   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
328   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
329   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
330             [&tensor_names_flat](size_t a, size_t b) {
331               return tensor_names_flat(a) < tensor_names_flat(b);
332             });
333 
334   std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
335   std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
336 
337   BundleReader default_reader(Env::Default(), prefix_string);
338   TF_RETURN_IF_ERROR(default_reader.status());
339 
340   std::vector<string> mismatched_errors;
341   for (const size_t i : sorted_name_idx) {
342     TensorShape restored_full_shape;
343     DataType original_dtype;
344     const string& tensor_name = tensor_names_flat(i);
345     TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
346         tensor_name, &original_dtype, &restored_full_shape));
347     if (dtypes[i] != original_dtype) {
348       string error_msg = strings::StrCat(
349           "tensor_name = ", tensor_name, "; expected dtype ",
350           DataTypeString(dtypes[i]), " does not equal original dtype ",
351           DataTypeString(original_dtype));
352       mismatched_errors.emplace_back(error_msg);
353     }
354   }
355   if (!mismatched_errors.empty()) {
356     const string error_msg = str_util::Join(mismatched_errors, "\n");
357     return errors::InvalidArgument(error_msg);
358   }
359 
360   for (auto i : sorted_name_idx) {
361     const string& tensor_name = tensor_names_flat(i);
362     const string& shape_and_slice = shape_and_slices_flat(i);
363     auto op =
364         new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
365     if (op->should_run_in_pool(&default_reader)) {
366       pool_restore_ops.emplace_back(op);
367     } else {
368       direct_restore_ops.emplace_back(op);
369     }
370   }
371 
372   {
373     // Schedule any threaded operations first, skipping thread pool creation if
374     // we don't have any expensive operations.
375     std::unique_ptr<thread::ThreadPool> reader_pool;
376     if (!pool_restore_ops.empty()) {
377       reader_pool.reset(
378           new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
379       for (auto& op : pool_restore_ops) {
380         reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
381       }
382     }
383 
384     // Read small tensors from the op thread
385     for (auto& op : direct_restore_ops) {
386       TF_RETURN_IF_ERROR(op->run(&default_reader));
387     }
388   }
389 
390   // Check status of pool ops; this must come after the pool shuts down.
391   for (auto& op : pool_restore_ops) {
392     TF_RETURN_IF_ERROR(op->status);
393   }
394 
395   for (auto i : sorted_name_idx) {
396     const string& tensor_name = tensor_names_flat(i);
397     if (dtypes[i] != context->mutable_output(i)->dtype()) {
398       return errors::InvalidArgument(
399           "tensor_name = ", tensor_name, "; expected dtype ",
400           DataTypeString(dtypes[i]), " does not equal restored dtype ",
401           DataTypeString(context->mutable_output(i)->dtype()));
402     }
403   }
404 
405   return Status::OK();
406 }
407 
408 }  // namespace tensorflow
409