• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 #include <numeric>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 #include "tensorflow/core/util/tensor_slice_writer.h"
37 
38 namespace tensorflow {
39 
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)40 void SaveTensors(
41     OpKernelContext* context,
42     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
43     bool save_slices) {
44   const Tensor& filename_t = context->input(0);
45   {
46     const int64 size = filename_t.NumElements();
47     OP_REQUIRES(
48         context, size == 1,
49         errors::InvalidArgument(
50             "Input 0 (filename) must be a string scalar; got a tensor of ",
51             size, "elements"));
52   }
53 
54   // Path, names, and slices if save_slices is true.
55   const int kFixedInputs = save_slices ? 3 : 2;
56   const Tensor& tensor_names_t = context->input(1);
57   OP_REQUIRES(context,
58               FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
59                               std::numeric_limits<int>::max()),
60               errors::InvalidArgument("Too many inputs to SaveTensors"));
61   const int N = static_cast<int>(tensor_names_t.NumElements());
62   const tstring* tensor_shapes_and_slices_ptr = nullptr;
63   if (save_slices) {
64     const Tensor& tensor_shapes_and_slices_t = context->input(2);
65     OP_REQUIRES(
66         context,
67         tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
68         errors::InvalidArgument("Expected ", N,
69                                 " elements for the tensor "
70                                 "shapes and slices but got ",
71                                 tensor_shapes_and_slices_t.NumElements()));
72     tensor_shapes_and_slices_ptr =
73         tensor_shapes_and_slices_t.flat<tstring>().data();
74   }
75   OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
76               errors::InvalidArgument("Expected totally ", N + kFixedInputs,
77                                       " inputs as input #1 (which is a string "
78                                       "tensor of saved names) contains ",
79                                       N, " names, but received ",
80                                       context->num_inputs(), " inputs"));
81 
82   VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
83           << "...";
84   checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
85                                        std::move(builder_func));
86 
87   Status s;
88   auto tensor_names_flat = tensor_names_t.flat<tstring>();
89 
90   // Process tensors in sorted name order.  This allows us to avoid seeking
91   // during restoration in the common case where we are restoring a full
92   // checkpoint.
93   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
94   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
95   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
96             [&tensor_names_flat](size_t a, size_t b) {
97               return tensor_names_flat(a) < tensor_names_flat(b);
98             });
99 
100   for (const size_t i : sorted_name_idx) {
101     const string& name = tensor_names_flat(i);
102     const Tensor& input = context->input(i + kFixedInputs);
103     TensorShape shape(input.shape());
104     TensorSlice slice(input.dims());
105     if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
106       const tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
107       TensorShape slice_shape;
108       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
109                                   shape_spec, &shape, &slice, &slice_shape));
110       OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
111                   errors::InvalidArgument(
112                       "Slice in shape_and_slice "
113                       "specification does not match the "
114                       "shape of the tensor to  save: ",
115                       shape_spec, ", tensor: ", input.shape().DebugString()));
116     }
117 
118 #define WRITER_ADD(T)                                           \
119   case DataTypeToEnum<T>::value:                                \
120     s = writer.Add(name, shape, slice, input.flat<T>().data()); \
121     break;
122 
123     switch (input.dtype()) {
124       TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
125       default:
126         context->SetStatus(errors::Unimplemented("Saving data type ",
127                                                  DataTypeString(input.dtype()),
128                                                  " not yet supported"));
129         return;
130     }
131 #undef WRITER_ADD
132     if (!s.ok()) {
133       context->SetStatus(s);
134       return;
135     }
136   }
137 
138   s = writer.Finish();
139   if (!s.ok()) {
140     context->SetStatus(s);
141   }
142 }
143 
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)144 void RestoreTensor(OpKernelContext* context,
145                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
146                    int preferred_shard, bool restore_slice, int restore_index) {
147   const Tensor& file_pattern_t = context->input(0);
148   {
149     const int64 size = file_pattern_t.NumElements();
150     OP_REQUIRES(
151         context, size == 1,
152         errors::InvalidArgument(
153             "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
154             size, "elements"));
155   }
156   const string& file_pattern = file_pattern_t.flat<tstring>()(0);
157 
158   const Tensor& tensor_name_t = context->input(1);
159   const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
160 
161   // If we cannot find a cached reader we will allocate our own.
162   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
163 
164   const checkpoint::TensorSliceReader* reader = nullptr;
165 
166   if (context->slice_reader_cache()) {
167     reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
168                                                       preferred_shard);
169   }
170   if (!reader) {
171     allocated_reader.reset(new checkpoint::TensorSliceReader(
172         file_pattern, open_func, preferred_shard));
173     reader = allocated_reader.get();
174   }
175   OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
176 
177   // Get the shape and type from the save file.
178   DataType type;
179   TensorShape saved_shape;
180   OP_REQUIRES(
181       context, reader->HasTensor(tensor_name, &saved_shape, &type),
182       errors::NotFound("Tensor name \"", tensor_name,
183                        "\" not found in checkpoint files ", file_pattern));
184   OP_REQUIRES(
185       context, type == context->expected_output_dtype(restore_index),
186       errors::InvalidArgument("Expected to restore a tensor of type ",
187                               DataTypeString(context->expected_output_dtype(0)),
188                               ", got a tensor of type ", DataTypeString(type),
189                               " instead: tensor_name = ", tensor_name));
190 
191   // Shape of the output and slice to load.
192   TensorShape output_shape(saved_shape);
193   TensorSlice slice_to_load(saved_shape.dims());
194   if (restore_slice) {
195     const tstring& shape_spec =
196         context->input(2).flat<tstring>()(restore_index);
197     if (!shape_spec.empty()) {
198       TensorShape parsed_shape;
199       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
200                                   shape_spec, &parsed_shape, &slice_to_load,
201                                   &output_shape));
202       OP_REQUIRES(
203           context, parsed_shape.IsSameSize(saved_shape),
204           errors::InvalidArgument(
205               "Shape in shape_and_slice spec does not match the shape in the "
206               "save file: ",
207               parsed_shape.DebugString(),
208               ", save file shape: ", saved_shape.DebugString()));
209     }
210   }
211 
212   Tensor* t = nullptr;
213   OP_REQUIRES_OK(context,
214                  context->allocate_output(restore_index, output_shape, &t));
215 
216   if (output_shape.num_elements() == 0) return;
217 
218 #define READER_COPY(T)                                                \
219   case DataTypeToEnum<T>::value:                                      \
220     OP_REQUIRES(context,                                              \
221                 reader->CopySliceData(tensor_name, slice_to_load,     \
222                                       t->flat<T>().data()),           \
223                 errors::InvalidArgument("Error copying slice data")); \
224     break;
225 
226   switch (type) {
227     TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
228     default:
229       context->SetStatus(errors::Unimplemented(
230           "Restoring data type ", DataTypeString(type), " not yet supported"));
231   }
232 #undef READER_COPY
233 }
234 
235 namespace {
236 
237 // Tensors larger than this threshold will be restored from a thread-pool.
238 const int64 kLargeShapeThreshold = 16 << 20;  // 16M
239 
240 // A restore operation for a single tensor.  Small tensors may be restored
241 // directly from the op thread to improve read locality.  Large tensors can be
242 // restored from a thread pool: this requires creating a separate BundleReader
243 // for each restore.
244 struct RestoreOp {
245   RestoreOp& operator=(const RestoreOp&) = delete;
246 
should_run_in_pooltensorflow::__anon11207f170211::RestoreOp247   bool should_run_in_pool(BundleReader* reader) const {
248     TensorShape restored_full_shape;
249 
250     // Ignore status here; we'll catch the error later.
251     if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
252       return false;
253     }
254 
255     return restored_full_shape.num_elements() > kLargeShapeThreshold;
256   }
257 
258   // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anon11207f170211::RestoreOp259   void run_with_new_reader() {
260     BundleReader reader(Env::Default(), reader_prefix);
261     if (!reader.status().ok()) {
262       status = reader.status();
263       return;
264     }
265 
266     status = run(&reader);
267   }
268 
runtensorflow::__anon11207f170211::RestoreOp269   Status run(BundleReader* reader) {
270     TensorShape restored_full_shape;
271     TF_RETURN_IF_ERROR(
272         reader->LookupTensorShape(tensor_name, &restored_full_shape));
273 
274     VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
275             << restored_full_shape.num_elements();
276     Tensor* restored_tensor;
277     if (shape_and_slice.empty()) {
278       // Lookup the full tensor.
279       TF_RETURN_IF_ERROR(
280           context->allocate_output(idx, restored_full_shape, &restored_tensor));
281       TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
282     } else {
283       // Lookup the slice.
284       TensorShape parsed_full_shape;
285       TensorSlice parsed_slice;
286       TensorShape parsed_slice_shape;
287 
288       TF_RETURN_IF_ERROR(
289           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
290                                          &parsed_slice, &parsed_slice_shape));
291 
292       if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
293         return errors::InvalidArgument(
294             "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
295             parsed_full_shape.DebugString(),
296             " does not match the shape stored in checkpoint: ",
297             restored_full_shape.DebugString());
298       }
299       TF_RETURN_IF_ERROR(
300           context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
301       TF_RETURN_IF_ERROR(
302           reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
303     }
304     if (VLOG_IS_ON(5)) {
305       if (restored_tensor->dtype() == DT_FLOAT) {
306         const float* t_data = restored_tensor->flat<float>().data();
307         float min = std::numeric_limits<float>::infinity();
308         float max = -std::numeric_limits<float>::infinity();
309         double avg = 0.0;
310         for (int i = 0; i < restored_tensor->NumElements(); ++i) {
311           if (t_data[i] < min) min = t_data[i];
312           if (t_data[i] > max) max = t_data[i];
313           avg += t_data[i];
314         }
315         VLOG(5) << " min " << min << " max " << max << " avg "
316                 << avg / restored_tensor->NumElements() << " total elts "
317                 << restored_tensor->NumElements();
318       }
319     }
320     VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : "
321             << restored_full_shape.num_elements();
322     return Status::OK();
323   }
324 
325   OpKernelContext* context;
326   size_t idx;
327   string tensor_name;
328   string shape_and_slice;
329   string reader_prefix;
330 
331   ::tensorflow::Status status;
332 };
333 
334 }  // namespace
335 
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)336 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
337                         const Tensor& tensor_names,
338                         const Tensor& shape_and_slices,
339                         gtl::ArraySlice<DataType> dtypes) {
340   const string& prefix_string = prefix.scalar<tstring>()();
341 
342   const auto& tensor_names_flat = tensor_names.flat<tstring>();
343   const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
344 
345   // Sort lookup keys to improve locality when reading multiple tensors.
346   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
347   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
348   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
349             [&tensor_names_flat](size_t a, size_t b) {
350               return tensor_names_flat(a) < tensor_names_flat(b);
351             });
352 
353   std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
354   std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
355 
356   BundleReader default_reader(Env::Default(), prefix_string);
357   TF_RETURN_IF_ERROR(default_reader.status());
358 
359   std::vector<string> mismatched_errors;
360   for (const size_t i : sorted_name_idx) {
361     TensorShape restored_full_shape;
362     DataType original_dtype;
363     const string& tensor_name = tensor_names_flat(i);
364     TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
365         tensor_name, &original_dtype, &restored_full_shape));
366     if (dtypes[i] != original_dtype) {
367       string error_msg = strings::StrCat(
368           "tensor_name = ", tensor_name, "; expected dtype ",
369           DataTypeString(dtypes[i]), " does not equal original dtype ",
370           DataTypeString(original_dtype));
371       mismatched_errors.emplace_back(error_msg);
372     }
373   }
374   if (!mismatched_errors.empty()) {
375     const string error_msg = absl::StrJoin(mismatched_errors, "\n");
376     return errors::InvalidArgument(error_msg);
377   }
378 
379   for (auto i : sorted_name_idx) {
380     const string& tensor_name = tensor_names_flat(i);
381     const string& shape_and_slice = shape_and_slices_flat(i);
382     auto op =
383         new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
384     if (op->should_run_in_pool(&default_reader)) {
385       pool_restore_ops.emplace_back(op);
386     } else {
387       direct_restore_ops.emplace_back(op);
388     }
389   }
390 
391   {
392     // Schedule any threaded operations first, skipping thread pool creation if
393     // we don't have any expensive operations.
394     std::unique_ptr<thread::ThreadPool> reader_pool;
395     if (!pool_restore_ops.empty()) {
396       reader_pool.reset(
397           new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
398       for (auto& op : pool_restore_ops) {
399         reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
400       }
401     }
402 
403     // Read small tensors from the op thread
404     for (auto& op : direct_restore_ops) {
405       TF_RETURN_IF_ERROR(op->run(&default_reader));
406     }
407   }
408 
409   // Check status of pool ops; this must come after the pool shuts down.
410   for (auto& op : pool_restore_ops) {
411     TF_RETURN_IF_ERROR(op->status);
412   }
413 
414   for (auto i : sorted_name_idx) {
415     const string& tensor_name = tensor_names_flat(i);
416     if (dtypes[i] != context->mutable_output(i)->dtype()) {
417       return errors::InvalidArgument(
418           "tensor_name = ", tensor_name, "; expected dtype ",
419           DataTypeString(dtypes[i]), " does not equal restored dtype ",
420           DataTypeString(context->mutable_output(i)->dtype()));
421     }
422   }
423 
424   return Status::OK();
425 }
426 
427 }  // namespace tensorflow
428