1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/save_restore_tensor.h"
17 #include <numeric>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34 #include "tensorflow/core/util/tensor_slice_reader.h"
35 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
36 #include "tensorflow/core/util/tensor_slice_writer.h"
37
38 namespace tensorflow {
39
SaveTensors(OpKernelContext * context,checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)40 void SaveTensors(
41 OpKernelContext* context,
42 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
43 bool save_slices) {
44 const Tensor& filename_t = context->input(0);
45 {
46 const int64_t size = filename_t.NumElements();
47 OP_REQUIRES(
48 context, size == 1,
49 errors::InvalidArgument(
50 "Input 0 (filename) must be a string scalar; got a tensor of ",
51 size, "elements"));
52 }
53
54 // Path, names, and slices if save_slices is true.
55 const int kFixedInputs = save_slices ? 3 : 2;
56 const Tensor& tensor_names_t = context->input(1);
57 OP_REQUIRES(context,
58 FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
59 std::numeric_limits<int>::max()),
60 errors::InvalidArgument("Too many inputs to SaveTensors"));
61 const int N = static_cast<int>(tensor_names_t.NumElements());
62 const tstring* tensor_shapes_and_slices_ptr = nullptr;
63 if (save_slices) {
64 const Tensor& tensor_shapes_and_slices_t = context->input(2);
65 OP_REQUIRES(
66 context,
67 tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
68 errors::InvalidArgument("Expected ", N,
69 " elements for the tensor "
70 "shapes and slices but got ",
71 tensor_shapes_and_slices_t.NumElements()));
72 tensor_shapes_and_slices_ptr =
73 tensor_shapes_and_slices_t.flat<tstring>().data();
74 }
75 OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
76 errors::InvalidArgument("Expected totally ", N + kFixedInputs,
77 " inputs as input #1 (which is a string "
78 "tensor of saved names) contains ",
79 N, " names, but received ",
80 context->num_inputs(), " inputs"));
81
82 VLOG(1) << "About to save tensors to file " << filename_t.flat<tstring>()(0)
83 << "...";
84 checkpoint::TensorSliceWriter writer(filename_t.flat<tstring>()(0),
85 std::move(builder_func));
86
87 Status s;
88 auto tensor_names_flat = tensor_names_t.flat<tstring>();
89
90 // Process tensors in sorted name order. This allows us to avoid seeking
91 // during restoration in the common case where we are restoring a full
92 // checkpoint.
93 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
94 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
95 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
96 [&tensor_names_flat](size_t a, size_t b) {
97 return tensor_names_flat(a) < tensor_names_flat(b);
98 });
99
100 for (const size_t i : sorted_name_idx) {
101 const string& name = tensor_names_flat(i);
102 const Tensor& input = context->input(i + kFixedInputs);
103 TensorShape shape(input.shape());
104 TensorSlice slice(input.dims());
105 if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
106 const tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
107 TensorShape slice_shape;
108 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
109 shape_spec, &shape, &slice, &slice_shape));
110 OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
111 errors::InvalidArgument(
112 "Slice in shape_and_slice "
113 "specification does not match the "
114 "shape of the tensor to save: ",
115 shape_spec, ", tensor: ", input.shape().DebugString()));
116 }
117
118 #define WRITER_ADD(T) \
119 case DataTypeToEnum<T>::value: \
120 s = writer.Add(name, shape, slice, input.flat<T>().data()); \
121 break;
122
123 switch (input.dtype()) {
124 TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
125 default:
126 context->SetStatus(errors::Unimplemented("Saving data type ",
127 DataTypeString(input.dtype()),
128 " not yet supported"));
129 return;
130 }
131 #undef WRITER_ADD
132 if (!s.ok()) {
133 context->SetStatus(s);
134 return;
135 }
136 }
137
138 s = writer.Finish();
139 if (!s.ok()) {
140 context->SetStatus(s);
141 }
142 }
143
RestoreTensor(OpKernelContext * context,checkpoint::TensorSliceReader::OpenTableFunction open_func,int preferred_shard,bool restore_slice,int restore_index)144 void RestoreTensor(OpKernelContext* context,
145 checkpoint::TensorSliceReader::OpenTableFunction open_func,
146 int preferred_shard, bool restore_slice, int restore_index) {
147 const Tensor& file_pattern_t = context->input(0);
148 {
149 const int64_t size = file_pattern_t.NumElements();
150 OP_REQUIRES(
151 context, size == 1,
152 errors::InvalidArgument(
153 "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
154 size, " elements"));
155 }
156 const string& file_pattern = file_pattern_t.flat<tstring>()(0);
157
158 const Tensor& tensor_name_t = context->input(1);
159 {
160 const int64_t size = tensor_name_t.NumElements();
161 OP_REQUIRES(context, size > restore_index,
162 errors::InvalidArgument(
163 "Input 1 (file_pattern) must be a have at least ",
164 restore_index + 1, " elements"));
165 }
166 const string& tensor_name = tensor_name_t.flat<tstring>()(restore_index);
167
168 // If we cannot find a cached reader we will allocate our own.
169 std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
170
171 const checkpoint::TensorSliceReader* reader = nullptr;
172
173 if (context->slice_reader_cache()) {
174 reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
175 preferred_shard);
176 }
177 if (!reader) {
178 allocated_reader.reset(new checkpoint::TensorSliceReader(
179 file_pattern, open_func, preferred_shard));
180 reader = allocated_reader.get();
181 }
182 OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
183
184 // Get the shape and type from the save file.
185 DataType type;
186 TensorShape saved_shape;
187 OP_REQUIRES(
188 context, reader->HasTensor(tensor_name, &saved_shape, &type),
189 errors::NotFound("Tensor name \"", tensor_name,
190 "\" not found in checkpoint files ", file_pattern));
191 OP_REQUIRES(
192 context, type == context->expected_output_dtype(restore_index),
193 errors::InvalidArgument("Expected to restore a tensor of type ",
194 DataTypeString(context->expected_output_dtype(0)),
195 ", got a tensor of type ", DataTypeString(type),
196 " instead: tensor_name = ", tensor_name));
197
198 // Shape of the output and slice to load.
199 TensorShape output_shape(saved_shape);
200 TensorSlice slice_to_load(saved_shape.dims());
201 if (restore_slice) {
202 const tstring& shape_spec =
203 context->input(2).flat<tstring>()(restore_index);
204 if (!shape_spec.empty()) {
205 TensorShape parsed_shape;
206 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
207 shape_spec, &parsed_shape, &slice_to_load,
208 &output_shape));
209 OP_REQUIRES(
210 context, parsed_shape.IsSameSize(saved_shape),
211 errors::InvalidArgument(
212 "Shape in shape_and_slice spec does not match the shape in the "
213 "save file: ",
214 parsed_shape.DebugString(),
215 ", save file shape: ", saved_shape.DebugString()));
216 }
217 }
218
219 Tensor* t = nullptr;
220 OP_REQUIRES_OK(context,
221 context->allocate_output(restore_index, output_shape, &t));
222
223 if (output_shape.num_elements() == 0) return;
224
225 #define READER_COPY(T) \
226 case DataTypeToEnum<T>::value: \
227 OP_REQUIRES(context, \
228 reader->CopySliceData(tensor_name, slice_to_load, \
229 t->flat<T>().data()), \
230 errors::InvalidArgument("Error copying slice data")); \
231 break;
232
233 switch (type) {
234 TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
235 default:
236 context->SetStatus(errors::Unimplemented(
237 "Restoring data type ", DataTypeString(type), " not yet supported"));
238 }
239 #undef READER_COPY
240 }
241
242 namespace {
243
244 // Tensors larger than this threshold will be restored from a thread-pool.
245 const int64_t kLargeShapeThreshold = 16 << 20; // 16M
246
247 // A restore operation for a single tensor. Small tensors may be restored
248 // directly from the op thread to improve read locality. Large tensors can be
249 // restored from a thread pool: this requires creating a separate BundleReader
250 // for each restore.
251 struct RestoreOp {
252 RestoreOp& operator=(const RestoreOp&) = delete;
253
should_run_in_pooltensorflow::__anon7987729c0211::RestoreOp254 bool should_run_in_pool(BundleReader* reader) const {
255 TensorShape restored_full_shape;
256
257 // Ignore status here; we'll catch the error later.
258 if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
259 return false;
260 }
261
262 return restored_full_shape.num_elements() > kLargeShapeThreshold;
263 }
264
265 // Run this restore operation using a new BundleReader.
run_with_new_readertensorflow::__anon7987729c0211::RestoreOp266 void run_with_new_reader() {
267 BundleReader reader(Env::Default(), reader_prefix);
268 if (!reader.status().ok()) {
269 status = reader.status();
270 return;
271 }
272
273 status = run(&reader);
274 }
275
runtensorflow::__anon7987729c0211::RestoreOp276 Status run(BundleReader* reader) {
277 TensorShape restored_full_shape;
278 TF_RETURN_IF_ERROR(
279 reader->LookupTensorShape(tensor_name, &restored_full_shape));
280
281 VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
282 << restored_full_shape.num_elements();
283 Tensor* restored_tensor;
284 if (shape_and_slice.empty()) {
285 // Lookup the full tensor.
286 TF_RETURN_IF_ERROR(
287 context->allocate_output(idx, restored_full_shape, &restored_tensor));
288 TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
289 } else {
290 // Lookup the slice.
291 TensorShape parsed_full_shape;
292 TensorSlice parsed_slice;
293 TensorShape parsed_slice_shape;
294
295 TF_RETURN_IF_ERROR(
296 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
297 &parsed_slice, &parsed_slice_shape));
298
299 if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
300 return errors::InvalidArgument(
301 "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
302 parsed_full_shape.DebugString(),
303 " does not match the shape stored in checkpoint: ",
304 restored_full_shape.DebugString());
305 }
306 TF_RETURN_IF_ERROR(
307 context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
308 TF_RETURN_IF_ERROR(
309 reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
310 }
311 if (VLOG_IS_ON(5)) {
312 if (restored_tensor->dtype() == DT_FLOAT) {
313 const float* t_data = restored_tensor->flat<float>().data();
314 float min = std::numeric_limits<float>::infinity();
315 float max = -std::numeric_limits<float>::infinity();
316 double avg = 0.0;
317 for (int i = 0; i < restored_tensor->NumElements(); ++i) {
318 if (t_data[i] < min) min = t_data[i];
319 if (t_data[i] > max) max = t_data[i];
320 avg += t_data[i];
321 }
322 VLOG(5) << " min " << min << " max " << max << " avg "
323 << avg / restored_tensor->NumElements() << " total elts "
324 << restored_tensor->NumElements();
325 }
326 }
327 VLOG(1) << "Done restoring tensor " << idx << " : " << tensor_name << " : "
328 << restored_full_shape.num_elements();
329 return Status::OK();
330 }
331
332 OpKernelContext* context;
333 size_t idx;
334 string tensor_name;
335 string shape_and_slice;
336 string reader_prefix;
337
338 ::tensorflow::Status status;
339 };
340
341 } // namespace
342
RestoreTensorsV2(OpKernelContext * context,const Tensor & prefix,const Tensor & tensor_names,const Tensor & shape_and_slices,gtl::ArraySlice<DataType> dtypes)343 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
344 const Tensor& tensor_names,
345 const Tensor& shape_and_slices,
346 gtl::ArraySlice<DataType> dtypes) {
347 const string& prefix_string = prefix.scalar<tstring>()();
348
349 const auto& tensor_names_flat = tensor_names.flat<tstring>();
350 const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();
351
352 // Sort lookup keys to improve locality when reading multiple tensors.
353 std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
354 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
355 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
356 [&tensor_names_flat](size_t a, size_t b) {
357 return tensor_names_flat(a) < tensor_names_flat(b);
358 });
359
360 std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
361 std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
362
363 BundleReader default_reader(Env::Default(), prefix_string);
364 TF_RETURN_IF_ERROR(default_reader.status());
365
366 std::vector<string> mismatched_errors;
367 for (const size_t i : sorted_name_idx) {
368 TensorShape restored_full_shape;
369 DataType original_dtype;
370 const string& tensor_name = tensor_names_flat(i);
371 TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
372 tensor_name, &original_dtype, &restored_full_shape));
373 if (dtypes[i] != original_dtype) {
374 string error_msg = strings::StrCat(
375 "tensor_name = ", tensor_name, "; expected dtype ",
376 DataTypeString(dtypes[i]), " does not equal original dtype ",
377 DataTypeString(original_dtype));
378 mismatched_errors.emplace_back(error_msg);
379 }
380 }
381 if (!mismatched_errors.empty()) {
382 const string error_msg = absl::StrJoin(mismatched_errors, "\n");
383 return errors::InvalidArgument(error_msg);
384 }
385
386 for (auto i : sorted_name_idx) {
387 const string& tensor_name = tensor_names_flat(i);
388 const string& shape_and_slice = shape_and_slices_flat(i);
389 auto op =
390 new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
391 if (op->should_run_in_pool(&default_reader)) {
392 pool_restore_ops.emplace_back(op);
393 } else {
394 direct_restore_ops.emplace_back(op);
395 }
396 }
397
398 {
399 // Schedule any threaded operations first, skipping thread pool creation if
400 // we don't have any expensive operations.
401 std::unique_ptr<thread::ThreadPool> reader_pool;
402 if (!pool_restore_ops.empty()) {
403 reader_pool.reset(
404 new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
405 for (auto& op : pool_restore_ops) {
406 reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
407 }
408 }
409
410 // Read small tensors from the op thread
411 for (auto& op : direct_restore_ops) {
412 TF_RETURN_IF_ERROR(op->run(&default_reader));
413 }
414 }
415
416 // Check status of pool ops; this must come after the pool shuts down.
417 for (auto& op : pool_restore_ops) {
418 TF_RETURN_IF_ERROR(op->status);
419 }
420
421 for (auto i : sorted_name_idx) {
422 const string& tensor_name = tensor_names_flat(i);
423 if (dtypes[i] != context->mutable_output(i)->dtype()) {
424 return errors::InvalidArgument(
425 "tensor_name = ", tensor_name, "; expected dtype ",
426 DataTypeString(dtypes[i]), " does not equal restored dtype ",
427 DataTypeString(context->mutable_output(i)->dtype()));
428 }
429 }
430
431 return Status::OK();
432 }
433
434 } // namespace tensorflow
435