1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 #include <numeric>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 #include "tensorflow/core/util/tensor_slice_writer.h"
37
38 namespace tensorflow {
39
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)40 void SaveTensors(
41 OpKernelContext* context,
42 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
43 bool save_slices) {
44 const Tensor& filename_t = context->input(0);
45 {
46 const int64 size = filename_t.NumElements();
47 OP_REQUIRES(
48 context, size == 1,
49 errors::InvalidArgument(
50 "Input 0 (filename) must be a string scalar; got a tensor of ",
51 size, "elements"));
52 }
53
54 // Path, names, and slices if save_slices is true.
55 const int kFixedInputs = save_slices ? 3 : 2;
56 const Tensor& tensor_names_t = context->input(1);
57 OP_REQUIRES(context,
58 FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
59 std::numeric_limits<int>::max()),
60 errors::InvalidArgument("Too many inputs to SaveTensors"));
61 const int N = static_cast<int>(tensor_names_t.NumElements());
62 const tstring* tensor_shapes_and_slices_ptr = nullptr;
63 if (save_slices) {
64 const Tensor& tensor_shapes_and_slices_t = context->input(2);
65 OP_REQUIRES(
66 context,
67 tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
68 errors::InvalidArgument("Expected ", N,
69 " elements for the tensor "
70 "shapes and slices but got ",
71 tensor_shapes_and_slices_t.NumElements()));
72 tensor_shapes_and_slices_ptr =
73 tensor_shapes_and_slices_t.flat<tstring>().data();
74 }
75 OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
76 errors::InvalidArgument("Expected totally ", N + kFixedInputs,
77 " inputs as input #1 (which is a string "
78 "tensor of saved names) contains ",
79 N, " names, but received ",
80 context->num_inputs(), " inputs"));
81
82 VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
83 << "...";
84 checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
85 std::move(builder_func));
86
87 Status s;
88 auto tensor_names_flat = tensor_names_t.flat<tstring>();
89
90 // Process tensors in sorted name order. This allows us to avoid seeking
91 // during restoration in the common case where we are restoring a full
92 // checkpoint.
93 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
94 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
95 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
96 [&tensor_names_flat](size_t a, size_t b) {
97 return tensor_names_flat(a) < tensor_names_flat(b);
98 });
99
100 for (const size_t i : sorted_name_idx) {
101 const string& name = tensor_names_flat(i);
102 const Tensor& input = context->input(i + kFixedInputs);
103 TensorShape shape(input.shape());
104 TensorSlice slice(input.dims());
105 if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
106 const tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
107 TensorShape slice_shape;
108 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
109 shape_spec, &shape, &slice, &slice_shape));
110 OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
111 errors::InvalidArgument(
112 "Slice in shape_and_slice "
113 "specification does not match the "
114 "shape of the tensor to save: ",
115 shape_spec, ", tensor: ", input.shape().DebugString()));
116 }
117
118 #define WRITER_ADD(T) \
119 case DataTypeToEnum<T>::value: \
120 s = writer.Add(name, shape, slice, input.flat<T>().data()); \
121 break;
122
123 switch (input.dtype()) {
124 TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
125 default:
126 context->SetStatus(errors::Unimplemented("Saving data type ",
127 DataTypeString(input.dtype()),
128 " not yet supported"));
129 return;
130 }
131 #undef WRITER_ADD
132 if (!s.ok()) {
133 context->SetStatus(s);
134 return;
135 }
136 }
137
138 s = writer.Finish();
139 if (!s.ok()) {
140 context->SetStatus(s);
141 }
142 }
143
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)144 void RestoreTensor(OpKernelContext* context,
145 checkpoint::TensorSliceReader::OpenTableFunction open_func,
146 int preferred_shard, bool restore_slice, int restore_index) {
147 const Tensor& file_pattern_t = context->input(0);
148 {
149 const int64 size = file_pattern_t.NumElements();
150 OP_REQUIRES(
151 context, size == 1,
152 errors::InvalidArgument(
153 "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
154 size, "elements"));
155 }
156 const string& file_pattern = file_pattern_t.flat<tstring>()(0);
157
158 const Tensor& tensor_name_t = context->input(1);
159 const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
160
161 // If we cannot find a cached reader we will allocate our own.
162 std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
163
164 const checkpoint::TensorSliceReader* reader = nullptr;
165
166 if (context->slice_reader_cache()) {
167 reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
168 preferred_shard);
169 }
170 if (!reader) {
171 allocated_reader.reset(new checkpoint::TensorSliceReader(
172 file_pattern, open_func, preferred_shard));
173 reader = allocated_reader.get();
174 }
175 OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
176
177 // Get the shape and type from the save file.
178 DataType type;
179 TensorShape saved_shape;
180 OP_REQUIRES(
181 context, reader->HasTensor(tensor_name, &saved_shape, &type),
182 errors::NotFound("Tensor name \"", tensor_name,
183 "\" not found in checkpoint files ", file_pattern));
184 OP_REQUIRES(
185 context, type == context->expected_output_dtype(restore_index),
186 errors::InvalidArgument("Expected to restore a tensor of type ",
187 DataTypeString(context->expected_output_dtype(0)),
188 ", got a tensor of type ", DataTypeString(type),
189 " instead: tensor_name = ", tensor_name));
190
191 // Shape of the output and slice to load.
192 TensorShape output_shape(saved_shape);
193 TensorSlice slice_to_load(saved_shape.dims());
194 if (restore_slice) {
195 const tstring& shape_spec =
196 context->input(2).flat<tstring>()(restore_index);
197 if (!shape_spec.empty()) {
198 TensorShape parsed_shape;
199 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
200 shape_spec, &parsed_shape, &slice_to_load,
201 &output_shape));
202 OP_REQUIRES(
203 context, parsed_shape.IsSameSize(saved_shape),
204 errors::InvalidArgument(
205 "Shape in shape_and_slice spec does not match the shape in the "
206 "save file: ",
207 parsed_shape.DebugString(),
208 ", save file shape: ", saved_shape.DebugString()));
209 }
210 }
211
212 Tensor* t = nullptr;
213 OP_REQUIRES_OK(context,
214 context->allocate_output(restore_index, output_shape, &t));
215
216 if (output_shape.num_elements() == 0) return;
217
218 #define READER_COPY(T) \
219 case DataTypeToEnum<T>::value: \
220 OP_REQUIRES(context, \
221 reader->CopySliceData(tensor_name, slice_to_load, \
222 t->flat<T>().data()), \
223 errors::InvalidArgument("Error copying slice data")); \
224 break;
225
226 switch (type) {
227 TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
228 default:
229 context->SetStatus(errors::Unimplemented(
230 "Restoring data type ", DataTypeString(type), " not yet supported"));
231 }
232 #undef READER_COPY
233 }
234
235 namespace {
236
237 // Tensors larger than this threshold will be restored from a thread-pool.
238 const int64 kLargeShapeThreshold = 16 << 20; // 16M
239
240 // A restore operation for a single tensor. Small tensors may be restored
241 // directly from the op thread to improve read locality. Large tensors can be
242 // restored from a thread pool: this requires creating a separate BundleReader
243 // for each restore.
244 struct RestoreOp {
245 RestoreOp& operator=(const RestoreOp&) = delete;
246
should_run_in_pooltensorflow::__anon11207f170211::RestoreOp247 bool should_run_in_pool(BundleReader* reader) const {
248 TensorShape restored_full_shape;
249
250 // Ignore status here; we'll catch the error later.
251 if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
252 return false;
253 }
254
255 return restored_full_shape.num_elements() > kLargeShapeThreshold;
256 }
257
258 // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anon11207f170211::RestoreOp259 void run_with_new_reader() {
260 BundleReader reader(Env::Default(), reader_prefix);
261 if (!reader.status().ok()) {
262 status = reader.status();
263 return;
264 }
265
266 status = run(&reader);
267 }
268
runtensorflow::__anon11207f170211::RestoreOp269 Status run(BundleReader* reader) {
270 TensorShape restored_full_shape;
271 TF_RETURN_IF_ERROR(
272 reader->LookupTensorShape(tensor_name, &restored_full_shape));
273
274 VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
275 << restored_full_shape.num_elements();
276 Tensor* restored_tensor;
277 if (shape_and_slice.empty()) {
278 // Lookup the full tensor.
279 TF_RETURN_IF_ERROR(
280 context->allocate_output(idx, restored_full_shape, &restored_tensor));
281 TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
282 } else {
283 // Lookup the slice.
284 TensorShape parsed_full_shape;
285 TensorSlice parsed_slice;
286 TensorShape parsed_slice_shape;
287
288 TF_RETURN_IF_ERROR(
289 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
290 &parsed_slice, &parsed_slice_shape));
291
292 if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
293 return errors::InvalidArgument(
294 "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
295 parsed_full_shape.DebugString(),
296 " does not match the shape stored in checkpoint: ",
297 restored_full_shape.DebugString());
298 }
299 TF_RETURN_IF_ERROR(
300 context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
301 TF_RETURN_IF_ERROR(
302 reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
303 }
304 if (VLOG_IS_ON(5)) {
305 if (restored_tensor->dtype() == DT_FLOAT) {
306 const float* t_data = restored_tensor->flat<float>().data();
307 float min = std::numeric_limits<float>::infinity();
308 float max = -std::numeric_limits<float>::infinity();
309 double avg = 0.0;
310 for (int i = 0; i < restored_tensor->NumElements(); ++i) {
311 if (t_data[i] < min) min = t_data[i];
312 if (t_data[i] > max) max = t_data[i];
313 avg += t_data[i];
314 }
315 VLOG(5) << " min " << min << " max " << max << " avg "
316 << avg / restored_tensor->NumElements() << " total elts "
317 << restored_tensor->NumElements();
318 }
319 }
320 VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : "
321 << restored_full_shape.num_elements();
322 return Status::OK();
323 }
324
325 OpKernelContext* context;
326 size_t idx;
327 string tensor_name;
328 string shape_and_slice;
329 string reader_prefix;
330
331 ::tensorflow::Status status;
332 };
333
334 } // namespace
335
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)336 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
337 const Tensor& tensor_names,
338 const Tensor& shape_and_slices,
339 gtl::ArraySlice<DataType> dtypes) {
340 const string& prefix_string = prefix.scalar<tstring>()();
341
342 const auto& tensor_names_flat = tensor_names.flat<tstring>();
343 const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
344
345 // Sort lookup keys to improve locality when reading multiple tensors.
346 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
347 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
348 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
349 [&tensor_names_flat](size_t a, size_t b) {
350 return tensor_names_flat(a) < tensor_names_flat(b);
351 });
352
353 std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
354 std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
355
356 BundleReader default_reader(Env::Default(), prefix_string);
357 TF_RETURN_IF_ERROR(default_reader.status());
358
359 std::vector<string> mismatched_errors;
360 for (const size_t i : sorted_name_idx) {
361 TensorShape restored_full_shape;
362 DataType original_dtype;
363 const string& tensor_name = tensor_names_flat(i);
364 TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
365 tensor_name, &original_dtype, &restored_full_shape));
366 if (dtypes[i] != original_dtype) {
367 string error_msg = strings::StrCat(
368 "tensor_name = ", tensor_name, "; expected dtype ",
369 DataTypeString(dtypes[i]), " does not equal original dtype ",
370 DataTypeString(original_dtype));
371 mismatched_errors.emplace_back(error_msg);
372 }
373 }
374 if (!mismatched_errors.empty()) {
375 const string error_msg = absl::StrJoin(mismatched_errors, "\n");
376 return errors::InvalidArgument(error_msg);
377 }
378
379 for (auto i : sorted_name_idx) {
380 const string& tensor_name = tensor_names_flat(i);
381 const string& shape_and_slice = shape_and_slices_flat(i);
382 auto op =
383 new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
384 if (op->should_run_in_pool(&default_reader)) {
385 pool_restore_ops.emplace_back(op);
386 } else {
387 direct_restore_ops.emplace_back(op);
388 }
389 }
390
391 {
392 // Schedule any threaded operations first, skipping thread pool creation if
393 // we don't have any expensive operations.
394 std::unique_ptr<thread::ThreadPool> reader_pool;
395 if (!pool_restore_ops.empty()) {
396 reader_pool.reset(
397 new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
398 for (auto& op : pool_restore_ops) {
399 reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
400 }
401 }
402
403 // Read small tensors from the op thread
404 for (auto& op : direct_restore_ops) {
405 TF_RETURN_IF_ERROR(op->run(&default_reader));
406 }
407 }
408
409 // Check status of pool ops; this must come after the pool shuts down.
410 for (auto& op : pool_restore_ops) {
411 TF_RETURN_IF_ERROR(op->status);
412 }
413
414 for (auto i : sorted_name_idx) {
415 const string& tensor_name = tensor_names_flat(i);
416 if (dtypes[i] != context->mutable_output(i)->dtype()) {
417 return errors::InvalidArgument(
418 "tensor_name = ", tensor_name, "; expected dtype ",
419 DataTypeString(dtypes[i]), " does not equal restored dtype ",
420 DataTypeString(context->mutable_output(i)->dtype()));
421 }
422 }
423
424 return Status::OK();
425 }
426
427 } // namespace tensorflow
428