• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 #include <numeric>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 #include "tensorflow/core/util/tensor_slice_writer.h"
37 
38 namespace tensorflow {
39 
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)40 void SaveTensors(
41     OpKernelContext* context,
42     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
43     bool save_slices) {
44   const Tensor& filename_t = context->input(0);
45   {
46     const int64_t size = filename_t.NumElements();
47     OP_REQUIRES(
48         context, size == 1,
49         errors::InvalidArgument(
50             "Input 0 (filename) must be a string scalar; got a tensor of ",
51             size, "elements"));
52   }
53 
54   // Path, names, and slices if save_slices is true.
55   const int kFixedInputs = save_slices ? 3 : 2;
56   const Tensor& tensor_names_t = context->input(1);
57   OP_REQUIRES(context,
58               FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
59                               std::numeric_limits<int>::max()),
60               errors::InvalidArgument("Too many inputs to SaveTensors"));
61   const int N = static_cast<int>(tensor_names_t.NumElements());
62   const tstring* tensor_shapes_and_slices_ptr = nullptr;
63   if (save_slices) {
64     const Tensor& tensor_shapes_and_slices_t = context->input(2);
65     OP_REQUIRES(
66         context,
67         tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
68         errors::InvalidArgument("Expected ", N,
69                                 " elements for the tensor "
70                                 "shapes and slices but got ",
71                                 tensor_shapes_and_slices_t.NumElements()));
72     tensor_shapes_and_slices_ptr =
73         tensor_shapes_and_slices_t.flat<tstring>().data();
74   }
75   OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
76               errors::InvalidArgument("Expected totally ", N + kFixedInputs,
77                                       " inputs as input #1 (which is a string "
78                                       "tensor of saved names) contains ",
79                                       N, " names, but received ",
80                                       context->num_inputs(), " inputs"));
81 
82   VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
83           << "...";
84   checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
85                                        std::move(builder_func));
86 
87   Status s;
88   auto tensor_names_flat = tensor_names_t.flat<tstring>();
89 
90   // Process tensors in sorted name order.  This allows us to avoid seeking
91   // during restoration in the common case where we are restoring a full
92   // checkpoint.
93   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
94   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
95   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
96             [&tensor_names_flat](size_t a, size_t b) {
97               return tensor_names_flat(a) < tensor_names_flat(b);
98             });
99 
100   for (const size_t i : sorted_name_idx) {
101     const string& name = tensor_names_flat(i);
102     const Tensor& input = context->input(i + kFixedInputs);
103     TensorShape shape(input.shape());
104     TensorSlice slice(input.dims());
105     if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
106       const tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
107       TensorShape slice_shape;
108       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
109                                   shape_spec, &shape, &slice, &slice_shape));
110       OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
111                   errors::InvalidArgument(
112                       "Slice in shape_and_slice "
113                       "specification does not match the "
114                       "shape of the tensor to  save: ",
115                       shape_spec, ", tensor: ", input.shape().DebugString()));
116     }
117 
118 #define WRITER_ADD(T)                                           \
119   case DataTypeToEnum<T>::value:                                \
120     s = writer.Add(name, shape, slice, input.flat<T>().data()); \
121     break;
122 
123     switch (input.dtype()) {
124       TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
125       default:
126         context->SetStatus(errors::Unimplemented("Saving data type ",
127                                                  DataTypeString(input.dtype()),
128                                                  " not yet supported"));
129         return;
130     }
131 #undef WRITER_ADD
132     if (!s.ok()) {
133       context->SetStatus(s);
134       return;
135     }
136   }
137 
138   s = writer.Finish();
139   if (!s.ok()) {
140     context->SetStatus(s);
141   }
142 }
143 
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)144 void RestoreTensor(OpKernelContext* context,
145                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
146                    int preferred_shard, bool restore_slice, int restore_index) {
147   const Tensor& file_pattern_t = context->input(0);
148   {
149     const int64_t size = file_pattern_t.NumElements();
150     OP_REQUIRES(
151         context, size == 1,
152         errors::InvalidArgument(
153             "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
154             size, " elements"));
155   }
156   const string& file_pattern = file_pattern_t.flat<tstring>()(0);
157 
158   const Tensor& tensor_name_t = context->input(1);
159   {
160     const int64_t size = tensor_name_t.NumElements();
161     OP_REQUIRES(context, size > restore_index,
162                 errors::InvalidArgument(
163                     "Input 1 (file_pattern) must be a have at least ",
164                     restore_index + 1, " elements"));
165   }
166   const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
167 
168   // If we cannot find a cached reader we will allocate our own.
169   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
170 
171   const checkpoint::TensorSliceReader* reader = nullptr;
172 
173   if (context->slice_reader_cache()) {
174     reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
175                                                       preferred_shard);
176   }
177   if (!reader) {
178     allocated_reader.reset(new checkpoint::TensorSliceReader(
179         file_pattern, open_func, preferred_shard));
180     reader = allocated_reader.get();
181   }
182   OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
183 
184   // Get the shape and type from the save file.
185   DataType type;
186   TensorShape saved_shape;
187   OP_REQUIRES(
188       context, reader->HasTensor(tensor_name, &saved_shape, &type),
189       errors::NotFound("Tensor name \"", tensor_name,
190                        "\" not found in checkpoint files ", file_pattern));
191   OP_REQUIRES(
192       context, type == context->expected_output_dtype(restore_index),
193       errors::InvalidArgument("Expected to restore a tensor of type ",
194                               DataTypeString(context->expected_output_dtype(0)),
195                               ", got a tensor of type ", DataTypeString(type),
196                               " instead: tensor_name = ", tensor_name));
197 
198   // Shape of the output and slice to load.
199   TensorShape output_shape(saved_shape);
200   TensorSlice slice_to_load(saved_shape.dims());
201   if (restore_slice) {
202     const tstring& shape_spec =
203         context->input(2).flat<tstring>()(restore_index);
204     if (!shape_spec.empty()) {
205       TensorShape parsed_shape;
206       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
207                                   shape_spec, &parsed_shape, &slice_to_load,
208                                   &output_shape));
209       OP_REQUIRES(
210           context, parsed_shape.IsSameSize(saved_shape),
211           errors::InvalidArgument(
212               "Shape in shape_and_slice spec does not match the shape in the "
213               "save file: ",
214               parsed_shape.DebugString(),
215               ", save file shape: ", saved_shape.DebugString()));
216     }
217   }
218 
219   Tensor* t = nullptr;
220   OP_REQUIRES_OK(context,
221                  context->allocate_output(restore_index, output_shape, &t));
222 
223   if (output_shape.num_elements() == 0) return;
224 
225 #define READER_COPY(T)                                                \
226   case DataTypeToEnum<T>::value:                                      \
227     OP_REQUIRES(context,                                              \
228                 reader->CopySliceData(tensor_name, slice_to_load,     \
229                                       t->flat<T>().data()),           \
230                 errors::InvalidArgument("Error copying slice data")); \
231     break;
232 
233   switch (type) {
234     TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
235     default:
236       context->SetStatus(errors::Unimplemented(
237           "Restoring data type ", DataTypeString(type), " not yet supported"));
238   }
239 #undef READER_COPY
240 }
241 
242 namespace {
243 
244 // Tensors larger than this threshold will be restored from a thread-pool.
245 const int64_t kLargeShapeThreshold = 16 << 20;  // 16M
246 
247 // A restore operation for a single tensor.  Small tensors may be restored
248 // directly from the op thread to improve read locality.  Large tensors can be
249 // restored from a thread pool: this requires creating a separate BundleReader
250 // for each restore.
251 struct RestoreOp {
252   RestoreOp& operator=(const RestoreOp&) = delete;
253 
should_run_in_pooltensorflow::__anon7987729c0211::RestoreOp254   bool should_run_in_pool(BundleReader* reader) const {
255     TensorShape restored_full_shape;
256 
257     // Ignore status here; we'll catch the error later.
258     if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
259       return false;
260     }
261 
262     return restored_full_shape.num_elements() > kLargeShapeThreshold;
263   }
264 
265   // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anon7987729c0211::RestoreOp266   void run_with_new_reader() {
267     BundleReader reader(Env::Default(), reader_prefix);
268     if (!reader.status().ok()) {
269       status = reader.status();
270       return;
271     }
272 
273     status = run(&reader);
274   }
275 
runtensorflow::__anon7987729c0211::RestoreOp276   Status run(BundleReader* reader) {
277     TensorShape restored_full_shape;
278     TF_RETURN_IF_ERROR(
279         reader->LookupTensorShape(tensor_name, &restored_full_shape));
280 
281     VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
282             << restored_full_shape.num_elements();
283     Tensor* restored_tensor;
284     if (shape_and_slice.empty()) {
285       // Lookup the full tensor.
286       TF_RETURN_IF_ERROR(
287           context->allocate_output(idx, restored_full_shape, &restored_tensor));
288       TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
289     } else {
290       // Lookup the slice.
291       TensorShape parsed_full_shape;
292       TensorSlice parsed_slice;
293       TensorShape parsed_slice_shape;
294 
295       TF_RETURN_IF_ERROR(
296           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
297                                          &parsed_slice, &parsed_slice_shape));
298 
299       if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
300         return errors::InvalidArgument(
301             "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
302             parsed_full_shape.DebugString(),
303             " does not match the shape stored in checkpoint: ",
304             restored_full_shape.DebugString());
305       }
306       TF_RETURN_IF_ERROR(
307           context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
308       TF_RETURN_IF_ERROR(
309           reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
310     }
311     if (VLOG_IS_ON(5)) {
312       if (restored_tensor->dtype() == DT_FLOAT) {
313         const float* t_data = restored_tensor->flat<float>().data();
314         float min = std::numeric_limits<float>::infinity();
315         float max = -std::numeric_limits<float>::infinity();
316         double avg = 0.0;
317         for (int i = 0; i < restored_tensor->NumElements(); ++i) {
318           if (t_data[i] < min) min = t_data[i];
319           if (t_data[i] > max) max = t_data[i];
320           avg += t_data[i];
321         }
322         VLOG(5) << " min " << min << " max " << max << " avg "
323                 << avg / restored_tensor->NumElements() << " total elts "
324                 << restored_tensor->NumElements();
325       }
326     }
327     VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : "
328             << restored_full_shape.num_elements();
329     return Status::OK();
330   }
331 
332   OpKernelContext* context;
333   size_t idx;
334   string tensor_name;
335   string shape_and_slice;
336   string reader_prefix;
337 
338   ::tensorflow::Status status;
339 };
340 
341 }  // namespace
342 
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)343 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
344                         const Tensor& tensor_names,
345                         const Tensor& shape_and_slices,
346                         gtl::ArraySlice<DataType> dtypes) {
347   const string& prefix_string = prefix.scalar<tstring>()();
348 
349   const auto& tensor_names_flat = tensor_names.flat<tstring>();
350   const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
351 
352   // Sort lookup keys to improve locality when reading multiple tensors.
353   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
354   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
355   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
356             [&tensor_names_flat](size_t a, size_t b) {
357               return tensor_names_flat(a) < tensor_names_flat(b);
358             });
359 
360   std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
361   std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
362 
363   BundleReader default_reader(Env::Default(), prefix_string);
364   TF_RETURN_IF_ERROR(default_reader.status());
365 
366   std::vector<string> mismatched_errors;
367   for (const size_t i : sorted_name_idx) {
368     TensorShape restored_full_shape;
369     DataType original_dtype;
370     const string& tensor_name = tensor_names_flat(i);
371     TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
372         tensor_name, &original_dtype, &restored_full_shape));
373     if (dtypes[i] != original_dtype) {
374       string error_msg = strings::StrCat(
375           "tensor_name = ", tensor_name, "; expected dtype ",
376           DataTypeString(dtypes[i]), " does not equal original dtype ",
377           DataTypeString(original_dtype));
378       mismatched_errors.emplace_back(error_msg);
379     }
380   }
381   if (!mismatched_errors.empty()) {
382     const string error_msg = absl::StrJoin(mismatched_errors, "\n");
383     return errors::InvalidArgument(error_msg);
384   }
385 
386   for (auto i : sorted_name_idx) {
387     const string& tensor_name = tensor_names_flat(i);
388     const string& shape_and_slice = shape_and_slices_flat(i);
389     auto op =
390         new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
391     if (op->should_run_in_pool(&default_reader)) {
392       pool_restore_ops.emplace_back(op);
393     } else {
394       direct_restore_ops.emplace_back(op);
395     }
396   }
397 
398   {
399     // Schedule any threaded operations first, skipping thread pool creation if
400     // we don't have any expensive operations.
401     std::unique_ptr<thread::ThreadPool> reader_pool;
402     if (!pool_restore_ops.empty()) {
403       reader_pool.reset(
404           new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
405       for (auto& op : pool_restore_ops) {
406         reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
407       }
408     }
409 
410     // Read small tensors from the op thread
411     for (auto& op : direct_restore_ops) {
412       TF_RETURN_IF_ERROR(op->run(&default_reader));
413     }
414   }
415 
416   // Check status of pool ops; this must come after the pool shuts down.
417   for (auto& op : pool_restore_ops) {
418     TF_RETURN_IF_ERROR(op->status);
419   }
420 
421   for (auto i : sorted_name_idx) {
422     const string& tensor_name = tensor_names_flat(i);
423     if (dtypes[i] != context->mutable_output(i)->dtype()) {
424       return errors::InvalidArgument(
425           "tensor_name = ", tensor_name, "; expected dtype ",
426           DataTypeString(dtypes[i]), " does not equal restored dtype ",
427           DataTypeString(context->mutable_output(i)->dtype()));
428     }
429   }
430 
431   return Status::OK();
432 }
433 
434 }  // namespace tensorflow
435