• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 
18 #include <memory>
19 #include <numeric>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 #include "tensorflow/core/lib/gtl/array_slice.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/lib/strings/stringprintf.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
37 #include "tensorflow/core/util/tensor_slice_reader.h"
38 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
39 #include "tensorflow/core/util/tensor_slice_writer.h"
40 
41 namespace tensorflow {
42 
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)43 void SaveTensors(
44     OpKernelContext* context,
45     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
46     bool save_slices) {
47   const Tensor& filename_t = context->input(0);
48   {
49     const int64_t size = filename_t.NumElements();
50     OP_REQUIRES(
51         context, size == 1,
52         errors::InvalidArgument(
53             "Input 0 (filename) must be a string scalar; got a tensor of ",
54             size, "elements"));
55   }
56 
57   // Path, names, and slices if save_slices is true.
58   const int kFixedInputs = save_slices ? 3 : 2;
59   const Tensor& tensor_names_t = context->input(1);
60   OP_REQUIRES(context,
61               FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
62                               std::numeric_limits<int>::max()),
63               errors::InvalidArgument("Too many inputs to SaveTensors"));
64   const int N = static_cast<int>(tensor_names_t.NumElements());
65   const tstring* tensor_shapes_and_slices_ptr = nullptr;
66   if (save_slices) {
67     const Tensor& tensor_shapes_and_slices_t = context->input(2);
68     OP_REQUIRES(
69         context,
70         tensor_shapes_and_slices_t.NumElements() == static_cast<int64_t>(N),
71         errors::InvalidArgument("Expected ", N,
72                                 " elements for the tensor "
73                                 "shapes and slices but got ",
74                                 tensor_shapes_and_slices_t.NumElements()));
75     tensor_shapes_and_slices_ptr =
76         tensor_shapes_and_slices_t.flat<tstring>().data();
77   }
78   OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
79               errors::InvalidArgument("Expected totally ", N + kFixedInputs,
80                                       " inputs as input #1 (which is a string "
81                                       "tensor of saved names) contains ",
82                                       N, " names, but received ",
83                                       context->num_inputs(), " inputs"));
84 
85   VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
86           << "...";
87   checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
88                                        std::move(builder_func));
89 
90   Status s;
91   auto tensor_names_flat = tensor_names_t.flat<tstring>();
92 
93   // Process tensors in sorted name order.  This allows us to avoid seeking
94   // during restoration in the common case where we are restoring a full
95   // checkpoint.
96   // RestoreTensorsV2 was changed to sort by file offset, so this sorting isn't
97   // strictly necessary anymore. However, restores with TF version <= 2.7 will
98   // still benefit.
99   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
100   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
101   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
102             [&tensor_names_flat](size_t a, size_t b) {
103               return tensor_names_flat(a) < tensor_names_flat(b);
104             });
105 
106   for (const size_t i : sorted_name_idx) {
107     const string& name = tensor_names_flat(i);
108     const Tensor& input = context->input(i + kFixedInputs);
109     TensorShape shape(input.shape());
110     TensorSlice slice(input.dims());
111     if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
112       const tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
113       TensorShape slice_shape;
114       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
115                                   shape_spec, &shape, &slice, &slice_shape));
116       OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
117                   errors::InvalidArgument(
118                       "Slice in shape_and_slice "
119                       "specification does not match the "
120                       "shape of the tensor to  save: ",
121                       shape_spec, ", tensor: ", input.shape().DebugString()));
122     }
123 
124 #define WRITER_ADD(T)                                           \
125   case DataTypeToEnum<T>::value:                                \
126     s = writer.Add(name, shape, slice, input.flat<T>().data()); \
127     break;
128 
129     switch (input.dtype()) {
130       TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
131       default:
132         context->SetStatus(errors::Unimplemented("Saving data type ",
133                                                  DataTypeString(input.dtype()),
134                                                  " not yet supported"));
135         return;
136     }
137 #undef WRITER_ADD
138     if (!s.ok()) {
139       context->SetStatus(s);
140       return;
141     }
142   }
143 
144   s = writer.Finish();
145   if (!s.ok()) {
146     context->SetStatus(s);
147   }
148 }
149 
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)150 void RestoreTensor(OpKernelContext* context,
151                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
152                    int preferred_shard, bool restore_slice, int restore_index) {
153   const Tensor& file_pattern_t = context->input(0);
154   {
155     const int64_t size = file_pattern_t.NumElements();
156     OP_REQUIRES(
157         context, size == 1,
158         errors::InvalidArgument(
159             "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
160             size, " elements"));
161   }
162   const string& file_pattern = file_pattern_t.flat<tstring>()(0);
163 
164   const Tensor& tensor_name_t = context->input(1);
165   {
166     const int64_t size = tensor_name_t.NumElements();
167     OP_REQUIRES(context, size > restore_index,
168                 errors::InvalidArgument(
169                     "Input 1 (file_pattern) must be a have at least ",
170                     restore_index + 1, " elements"));
171   }
172   const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
173 
174   // If we cannot find a cached reader we will allocate our own.
175   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
176 
177   const checkpoint::TensorSliceReader* reader = nullptr;
178 
179   if (context->slice_reader_cache()) {
180     reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
181                                                       preferred_shard);
182   }
183   if (!reader) {
184     allocated_reader.reset(new checkpoint::TensorSliceReader(
185         file_pattern, open_func, preferred_shard));
186     reader = allocated_reader.get();
187   }
188   OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
189 
190   // Get the shape and type from the save file.
191   DataType type;
192   TensorShape saved_shape;
193   OP_REQUIRES(
194       context, reader->HasTensor(tensor_name, &saved_shape, &type),
195       errors::NotFound("Tensor name \"", tensor_name,
196                        "\" not found in checkpoint files ", file_pattern));
197   OP_REQUIRES(
198       context, type == context->expected_output_dtype(restore_index),
199       errors::InvalidArgument("Expected to restore a tensor of type ",
200                               DataTypeString(context->expected_output_dtype(0)),
201                               ", got a tensor of type ", DataTypeString(type),
202                               " instead: tensor_name = ", tensor_name));
203 
204   // Shape of the output and slice to load.
205   TensorShape output_shape(saved_shape);
206   TensorSlice slice_to_load(saved_shape.dims());
207   if (restore_slice) {
208     const tstring& shape_spec =
209         context->input(2).flat<tstring>()(restore_index);
210     if (!shape_spec.empty()) {
211       TensorShape parsed_shape;
212       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
213                                   shape_spec, &parsed_shape, &slice_to_load,
214                                   &output_shape));
215       OP_REQUIRES(
216           context, parsed_shape.IsSameSize(saved_shape),
217           errors::InvalidArgument(
218               "Shape in shape_and_slice spec does not match the shape in the "
219               "save file: ",
220               parsed_shape.DebugString(),
221               ", save file shape: ", saved_shape.DebugString()));
222     }
223   }
224 
225   Tensor* t = nullptr;
226   OP_REQUIRES_OK(context,
227                  context->allocate_output(restore_index, output_shape, &t));
228 
229   if (output_shape.num_elements() == 0) return;
230 
231 #define READER_COPY(T)                                                \
232   case DataTypeToEnum<T>::value:                                      \
233     OP_REQUIRES(context,                                              \
234                 reader->CopySliceData(tensor_name, slice_to_load,     \
235                                       t->flat<T>().data()),           \
236                 errors::InvalidArgument("Error copying slice data")); \
237     break;
238 
239   switch (type) {
240     TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
241     default:
242       context->SetStatus(errors::Unimplemented(
243           "Restoring data type ", DataTypeString(type), " not yet supported"));
244   }
245 #undef READER_COPY
246 }
247 
248 namespace {
249 
250 // Tensors larger than this threshold will be restored from a thread-pool.
251 const int64_t kLargeShapeThreshold = 16 << 20;  // 16M
252 
253 // A restore operation for a single tensor.  Small tensors may be restored
254 // directly from the op thread to improve read locality.  Large tensors can be
255 // restored from a thread pool: this requires creating a separate BundleReader
256 // for each restore.
257 struct RestoreOp {
RestoreOptensorflow::__anon3b4d6c350211::RestoreOp258   RestoreOp(OpKernelContext* context, int idx, const string& tensor_name,
259             const string& shape_and_slice, const string& reader_prefix,
260             DataType dtype)
261       : context(context),
262         idx(idx),
263         tensor_name(tensor_name),
264         shape_and_slice(shape_and_slice),
265         reader_prefix(reader_prefix),
266         dtype(dtype) {}
267 
268   // Move-only. It does not make sense to "run()" a copied RestoreOp.
269   RestoreOp(const RestoreOp&) = delete;
270   RestoreOp& operator=(const RestoreOp&) = delete;
271   RestoreOp(RestoreOp&&) = default;
272   RestoreOp& operator=(RestoreOp&&) = default;
273 
should_run_in_pooltensorflow::__anon3b4d6c350211::RestoreOp274   bool should_run_in_pool(BundleReader* reader) const {
275     TensorShape restored_full_shape;
276 
277     // Ignore status here; we'll catch the error later.
278     if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
279       return false;
280     }
281 
282     return restored_full_shape.num_elements() > kLargeShapeThreshold;
283   }
284 
285   // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anon3b4d6c350211::RestoreOp286   void run_with_new_reader() {
287     BundleReader reader(Env::Default(), reader_prefix);
288     if (!reader.status().ok()) {
289       status = reader.status();
290       return;
291     }
292 
293     status = run(&reader);
294   }
295 
runtensorflow::__anon3b4d6c350211::RestoreOp296   Status run(BundleReader* reader) {
297     TensorShape restored_full_shape;
298     TF_RETURN_IF_ERROR(
299         reader->LookupTensorShape(tensor_name, &restored_full_shape));
300 
301     VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
302             << restored_full_shape.num_elements();
303     Tensor* restored_tensor;
304     if (shape_and_slice.empty()) {
305       // Lookup the full tensor.
306       TF_RETURN_IF_ERROR(
307           context->allocate_output(idx, restored_full_shape, &restored_tensor));
308       TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
309     } else {
310       // Lookup the slice.
311       TensorShape parsed_full_shape;
312       TensorSlice parsed_slice;
313       TensorShape parsed_slice_shape;
314 
315       TF_RETURN_IF_ERROR(
316           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
317                                          &parsed_slice, &parsed_slice_shape));
318 
319       if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
320         return errors::InvalidArgument(
321             "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
322             parsed_full_shape.DebugString(),
323             " does not match the shape stored in checkpoint: ",
324             restored_full_shape.DebugString());
325       }
326       TF_RETURN_IF_ERROR(
327           context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
328       TF_RETURN_IF_ERROR(
329           reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
330     }
331     if (VLOG_IS_ON(5)) {
332       if (restored_tensor->dtype() == DT_FLOAT) {
333         const float* t_data = restored_tensor->flat<float>().data();
334         float min = std::numeric_limits<float>::infinity();
335         float max = -std::numeric_limits<float>::infinity();
336         double avg = 0.0;
337         for (int i = 0; i < restored_tensor->NumElements(); ++i) {
338           if (t_data[i] < min) min = t_data[i];
339           if (t_data[i] > max) max = t_data[i];
340           avg += t_data[i];
341         }
342         VLOG(5) << " min " << min << " max " << max << " avg "
343                 << avg / restored_tensor->NumElements() << " total elts "
344                 << restored_tensor->NumElements();
345       }
346     }
347     VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : "
348             << restored_full_shape.num_elements();
349     return OkStatus();
350   }
351 
352   OpKernelContext* context;
353   int idx;
354   string tensor_name;
355   string shape_and_slice;
356   string reader_prefix;
357   DataType dtype;
358 
359   ::tensorflow::Status status;
360 };
361 
362 }  // namespace
363 
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)364 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
365                         const Tensor& tensor_names,
366                         const Tensor& shape_and_slices,
367                         gtl::ArraySlice<DataType> dtypes) {
368   const string& prefix_string = prefix.scalar<tstring>()();
369 
370   const auto& tensor_names_flat = tensor_names.flat<tstring>();
371   const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
372 
373   std::vector<RestoreOp> restore_ops;
374   restore_ops.reserve(tensor_names_flat.size());
375   for (int i = 0; i < tensor_names_flat.size(); ++i) {
376     restore_ops.push_back({context, i, tensor_names_flat(i),
377                            shape_and_slices_flat(i), prefix_string, dtypes[i]});
378   }
379 
380   BundleReader default_reader(Env::Default(), prefix_string);
381   TF_RETURN_IF_ERROR(default_reader.status());
382 
383   TF_RETURN_IF_ERROR(default_reader.SortForSequentialAccess<RestoreOp>(
384       restore_ops, [](const RestoreOp& op) { return op.tensor_name; }));
385 
386   std::vector<string> mismatched_errors;
387   for (const RestoreOp& restore_op : restore_ops) {
388     TensorShape restored_full_shape;
389     DataType original_dtype;
390     TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
391         restore_op.tensor_name, &original_dtype, &restored_full_shape));
392     if (restore_op.dtype != original_dtype) {
393       string error_msg = strings::StrCat(
394           "tensor_name = ", restore_op.tensor_name, "; expected dtype ",
395           DataTypeString(restore_op.dtype), " does not equal original dtype ",
396           DataTypeString(original_dtype));
397       mismatched_errors.emplace_back(error_msg);
398     }
399   }
400   if (!mismatched_errors.empty()) {
401     const string error_msg = absl::StrJoin(mismatched_errors, "\n");
402     return errors::InvalidArgument(error_msg);
403   }
404 
405   std::vector<RestoreOp*> pool_restore_ops;
406   std::vector<RestoreOp*> direct_restore_ops;
407   for (RestoreOp& restore_op : restore_ops) {
408     if (restore_op.should_run_in_pool(&default_reader)) {
409       pool_restore_ops.push_back(&restore_op);
410     } else {
411       direct_restore_ops.push_back(&restore_op);
412     }
413   }
414 
415   {
416     // Schedule any threaded operations first, skipping thread pool creation if
417     // we don't have any expensive operations.
418     std::unique_ptr<thread::ThreadPool> reader_pool;
419     if (!pool_restore_ops.empty()) {
420       reader_pool.reset(
421           new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
422       for (auto* op : pool_restore_ops) {
423         reader_pool->Schedule([op]() { op->run_with_new_reader(); });
424       }
425     }
426 
427     // Read small tensors from the op thread
428     for (auto* op : direct_restore_ops) {
429       TF_RETURN_IF_ERROR(op->run(&default_reader));
430     }
431   }
432 
433   // Check status of pool ops; this must come after the pool shuts down.
434   for (auto* op : pool_restore_ops) {
435     TF_RETURN_IF_ERROR(op->status);
436   }
437 
438   for (const RestoreOp& restore_op : restore_ops) {
439     if (restore_op.dtype != context->mutable_output(restore_op.idx)->dtype()) {
440       return errors::InvalidArgument(
441           "tensor_name = ", restore_op.tensor_name, "; expected dtype ",
442           DataTypeString(restore_op.dtype), " does not equal restored dtype ",
443           DataTypeString(context->mutable_output(restore_op.idx)->dtype()));
444     }
445   }
446 
447   return OkStatus();
448 }
449 
450 }  // namespace tensorflow
451