1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <memory>
18
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/io_ops.h"
21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/fake_input.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/graph/graph_def_builder.h"
30 #include "tensorflow/core/kernels/ops_testutil.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/protobuf/config.pb.h"
39 #include "tensorflow/core/public/session_options.h"
40 #include "tensorflow/core/util/tensor_slice_reader.h"
41
42 namespace tensorflow {
43 namespace {
44
45 class SaveOpTest : public OpsTestBase {
46 protected:
MakeOp()47 void MakeOp() {
48 TF_ASSERT_OK(
49 NodeDefBuilder("myop", "Save")
50 .Input(FakeInput())
51 .Input(FakeInput())
52 .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
53 DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_INT64,
54 DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_HALF}))
55 .Finalize(node_def()));
56 TF_ASSERT_OK(InitOp());
57 }
58 };
59
TEST_F(SaveOpTest,Simple)60 TEST_F(SaveOpTest, Simple) {
61 const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
62 const string tensornames[] = {
63 "tensor_bool", "tensor_int", "tensor_float", "tensor_double",
64 "tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8",
65 "tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64",
66 "tensor_complex128", "tensor_half"};
67
68 MakeOp();
69 // Add a file name
70 AddInput<tstring>(TensorShape({}),
71 [&filename](int x) -> tstring { return filename; });
72
73 // Add the tensor names
74 AddInput<tstring>(TensorShape({14}), [&tensornames](int x) -> tstring {
75 return tensornames[x];
76 });
77
78 // Add a 1-d bool tensor
79 AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; });
80
81 // Add a 1-d integer tensor
82 AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
83
84 // Add a 2-d float tensor
85 AddInput<float>(TensorShape({2, 4}),
86 [](int x) -> float { return static_cast<float>(x) / 10; });
87
88 // Add a 2-d double tensor
89 AddInput<double>(TensorShape({2, 4}),
90 [](int x) -> double { return static_cast<double>(x) / 20; });
91
92 // Add a 2-d qint8 tensor
93 AddInput<qint8>(TensorShape({3, 2}),
94 [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
95
96 // Add a 2-d qint32 tensor
97 AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
98 return *reinterpret_cast<qint32*>(&x) * qint8(2);
99 });
100
101 // Add a 1-d uint8 tensor
102 AddInput<uint8>(TensorShape({11}), [](int x) -> uint8 { return x + 1; });
103
104 // Add a 1-d int8 tensor
105 AddInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; });
106
107 // Add a 1-d int16 tensor
108 AddInput<int16>(TensorShape({7}), [](int x) -> int16 { return x - 8; });
109
110 // Add a 1-d int64 tensor
111 AddInput<int64>(TensorShape({9}), [](int x) -> int64 { return x - 9; });
112
113 // Add a 1-d string tensor
114 AddInput<tstring>(TensorShape({2}),
115 [](int x) -> tstring { return x ? "yes" : "no"; });
116
117 // Add a 2-d complex64 tensor
118 AddInput<complex64>(TensorShape({2, 3}), [](int x) -> complex64 {
119 return complex64(100 + x, 200 + x);
120 });
121
122 // Add a 2-d complex128 tensor
123 AddInput<complex128>(TensorShape({2, 3}), [](int x) -> complex128 {
124 return complex128(100 + x, 200 + x);
125 });
126
127 // Add a 2-d half tensor
128 AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
129 return static_cast<Eigen::half>(x) / Eigen::half(2);
130 });
131 TF_ASSERT_OK(RunOpKernel());
132
133 // Check that the checkpoint file is properly written
134 checkpoint::TensorSliceReader reader(filename,
135 checkpoint::OpenTableTensorSliceReader);
136 TF_EXPECT_OK(reader.status());
137
138 // We expect to find all saved tensors
139 {
140 // The 1-d bool tensor
141 TensorShape shape;
142 DataType type;
143 EXPECT_TRUE(reader.HasTensor("tensor_bool", &shape, &type));
144 TensorShape expected({2});
145 EXPECT_TRUE(shape.IsSameSize(expected));
146 EXPECT_EQ(DT_BOOL, type);
147
148 // We expect the tensor value to be correct.
149 TensorSlice s = TensorSlice::ParseOrDie("-");
150 bool data[2];
151 std::fill_n(data, 2, false);
152 EXPECT_TRUE(reader.CopySliceData("tensor_bool", s, data));
153 for (int i = 0; i < 2; ++i) {
154 EXPECT_EQ((i != 0), data[i]);
155 }
156 }
157
158 {
159 // The 1-d integer tensor
160 TensorShape shape;
161 DataType type;
162 EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
163 TensorShape expected({10});
164 EXPECT_TRUE(shape.IsSameSize(expected));
165 EXPECT_EQ(DT_INT32, type);
166
167 // We expect the tensor value to be correct.
168 TensorSlice s = TensorSlice::ParseOrDie("-");
169 int data[10];
170 std::fill_n(data, 10, 0);
171 EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
172 for (int i = 0; i < 10; ++i) {
173 EXPECT_EQ(i + 1, data[i]);
174 }
175 }
176
177 {
178 // The 2-d float tensor
179 TensorShape shape;
180 DataType type;
181 EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
182 TensorShape expected({2, 4});
183 EXPECT_TRUE(shape.IsSameSize(expected));
184 EXPECT_EQ(DT_FLOAT, type);
185
186 // We expect the tensor value to be correct.
187 TensorSlice s = TensorSlice::ParseOrDie("-:-");
188 float data[8];
189 std::fill_n(data, 8, 0);
190 EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data));
191 for (int i = 0; i < 8; ++i) {
192 EXPECT_EQ(static_cast<float>(i) / 10, data[i]);
193 }
194 }
195
196 {
197 // The 2-d double tensor
198 TensorShape shape;
199 DataType type;
200 EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
201 TensorShape expected({2, 4});
202 EXPECT_TRUE(shape.IsSameSize(expected));
203 EXPECT_EQ(DT_DOUBLE, type);
204
205 // We expect the tensor value to be correct.
206 TensorSlice s = TensorSlice::ParseOrDie("-:-");
207 double data[8];
208 std::fill_n(data, 8, 0);
209 EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data));
210 for (int i = 0; i < 8; ++i) {
211 EXPECT_EQ(static_cast<double>(i) / 20, data[i]);
212 }
213 }
214
215 {
216 // The 2-d qint8 tensor
217 TensorShape shape;
218 DataType type;
219 EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
220 TensorShape expected({3, 2});
221 EXPECT_TRUE(shape.IsSameSize(expected));
222 EXPECT_EQ(DT_QINT8, type);
223
224 // We expect the tensor value to be correct.
225 TensorSlice s = TensorSlice::ParseOrDie("-:-");
226 qint8 data[6];
227 EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
228 for (int i = 0; i < 6; ++i) {
229 EXPECT_EQ(*reinterpret_cast<qint8*>(&i), data[i]);
230 }
231 }
232
233 {
234 // The 2-d qint32 tensor
235 TensorShape shape;
236 DataType type;
237 EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
238 TensorShape expected({2, 3});
239 EXPECT_TRUE(shape.IsSameSize(expected));
240 EXPECT_EQ(DT_QINT32, type);
241
242 // We expect the tensor value to be correct.
243 TensorSlice s = TensorSlice::ParseOrDie("-:-");
244 qint32 data[6];
245 EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data));
246 for (int i = 0; i < 6; ++i) {
247 EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), data[i]);
248 }
249 }
250
251 {
252 // The 1-d uint8 tensor
253 TensorShape shape;
254 DataType type;
255 EXPECT_TRUE(reader.HasTensor("tensor_uint8", &shape, &type));
256 TensorShape expected({11});
257 EXPECT_TRUE(shape.IsSameSize(expected));
258 EXPECT_EQ(DT_UINT8, type);
259
260 // We expect the tensor value to be correct.
261 TensorSlice s = TensorSlice::ParseOrDie("-");
262 uint8 data[11];
263 EXPECT_TRUE(reader.CopySliceData("tensor_uint8", s, data));
264 for (int i = 0; i < 11; ++i) {
265 EXPECT_EQ(i + 1, data[i]);
266 }
267 }
268
269 {
270 // The 1-d int8 tensor
271 TensorShape shape;
272 DataType type;
273 EXPECT_TRUE(reader.HasTensor("tensor_int8", &shape, &type));
274 TensorShape expected({7});
275 EXPECT_TRUE(shape.IsSameSize(expected));
276 EXPECT_EQ(DT_INT8, type);
277
278 // We expect the tensor value to be correct.
279 TensorSlice s = TensorSlice::ParseOrDie("-");
280 int8 data[7];
281 EXPECT_TRUE(reader.CopySliceData("tensor_int8", s, data));
282 for (int i = 0; i < 7; ++i) {
283 EXPECT_EQ(i - 7, data[i]);
284 }
285 }
286
287 {
288 // The 1-d int16 tensor
289 TensorShape shape;
290 DataType type;
291 EXPECT_TRUE(reader.HasTensor("tensor_int16", &shape, &type));
292 TensorShape expected({7});
293 EXPECT_TRUE(shape.IsSameSize(expected));
294 EXPECT_EQ(DT_INT16, type);
295
296 // We expect the tensor value to be correct.
297 TensorSlice s = TensorSlice::ParseOrDie("-");
298 int16 data[7];
299 EXPECT_TRUE(reader.CopySliceData("tensor_int16", s, data));
300 for (int i = 0; i < 7; ++i) {
301 EXPECT_EQ(i - 8, data[i]);
302 }
303 }
304
305 {
306 // The 1-d int64 tensor
307 TensorShape shape;
308 DataType type;
309 EXPECT_TRUE(reader.HasTensor("tensor_int64", &shape, &type));
310 TensorShape expected({9});
311 EXPECT_TRUE(shape.IsSameSize(expected));
312 EXPECT_EQ(DT_INT64, type);
313
314 // We expect the tensor value to be correct.
315 TensorSlice s = TensorSlice::ParseOrDie("-");
316 int64 data[9];
317 EXPECT_TRUE(reader.CopySliceData("tensor_int64", s, data));
318 for (int i = 0; i < 9; ++i) {
319 EXPECT_EQ(i - 9, data[i]);
320 }
321 }
322
323 {
324 // The 1-d string tensor
325 TensorShape shape;
326 DataType type;
327 EXPECT_TRUE(reader.HasTensor("tensor_string", &shape, &type));
328 TensorShape expected({2});
329 EXPECT_TRUE(shape.IsSameSize(expected));
330 EXPECT_EQ(DT_STRING, type);
331
332 // We expect the tensor value to be correct.
333 TensorSlice s = TensorSlice::ParseOrDie("-");
334 tstring data[2];
335 EXPECT_TRUE(reader.CopySliceData("tensor_string", s, data));
336 EXPECT_EQ("no", data[0]);
337 EXPECT_EQ("yes", data[1]);
338 }
339
340 {
341 // The 2-d complex64 tensor
342 TensorShape shape;
343 DataType type;
344 EXPECT_TRUE(reader.HasTensor("tensor_complex64", &shape, &type));
345 TensorShape expected({2, 3});
346 EXPECT_TRUE(shape.IsSameSize(expected));
347 EXPECT_EQ(DT_COMPLEX64, type);
348
349 // We expect the tensor value to be correct.
350 TensorSlice s = TensorSlice::ParseOrDie("-:-");
351 complex64 data[6];
352 EXPECT_TRUE(reader.CopySliceData("tensor_complex64", s, data));
353 for (int i = 0; i < 6; ++i) {
354 EXPECT_EQ(100 + i, data[i].real());
355 EXPECT_EQ(200 + i, data[i].imag());
356 }
357 }
358
359 {
360 // The 2-d complex128 tensor
361 TensorShape shape;
362 DataType type;
363 EXPECT_TRUE(reader.HasTensor("tensor_complex128", &shape, &type));
364 TensorShape expected({2, 3});
365 EXPECT_TRUE(shape.IsSameSize(expected));
366 EXPECT_EQ(DT_COMPLEX128, type);
367
368 // We expect the tensor value to be correct.
369 TensorSlice s = TensorSlice::ParseOrDie("-:-");
370 complex128 data[6];
371 EXPECT_TRUE(reader.CopySliceData("tensor_complex128", s, data));
372 for (int i = 0; i < 6; ++i) {
373 EXPECT_EQ(100 + i, data[i].real());
374 EXPECT_EQ(200 + i, data[i].imag());
375 }
376 }
377 {
378 // The 2-d half tensor
379 TensorShape shape;
380 DataType type;
381 EXPECT_TRUE(reader.HasTensor("tensor_half", &shape, &type));
382 TensorShape expected({2, 4});
383 EXPECT_TRUE(shape.IsSameSize(expected));
384 EXPECT_EQ(DT_HALF, type);
385
386 // We expect the tensor value to be correct.
387 TensorSlice s = TensorSlice::ParseOrDie("-:-");
388 Eigen::half data[8];
389 std::fill_n(data, 8, Eigen::half(0));
390 EXPECT_TRUE(reader.CopySliceData("tensor_half", s, data));
391 for (int i = 0; i < 8; ++i) {
392 EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), data[i]);
393 }
394 }
395 }
396
397 class SaveSlicesOpTest : public OpsTestBase {
398 protected:
MakeOp()399 void MakeOp() {
400 TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
401 .Input(FakeInput())
402 .Input(FakeInput())
403 .Input(FakeInput())
404 .Input(FakeInput(
405 {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
406 .Finalize(node_def()));
407 TF_ASSERT_OK(InitOp());
408 }
409 };
410
411 // Here we save only slices. We restore them in a larger tensor and we check
412 // that the right slice is restored. It is quite tricky to check that the
413 // right slices are actually restored so instead we just check that
414 // CopySliceData() return true/false depending on the slice we ask for.
TEST_F(SaveSlicesOpTest,Slices)415 TEST_F(SaveSlicesOpTest, Slices) {
416 const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices");
417 const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
418 "tensor_qint8", "tensor_qint32"};
419 // Specifies that the data we save are slices of larger tensors.
420 // See core/framework/tensor_slice.h for the slice syntax.
421 const string tensorshapes[] = {
422 "10 -", // Full contents of a 10 element vector.
423 "2 4 -:0,2", // A 2x2 slice of a 2x4 tensor.
424 "2 4 0,1:2,2", // A 1x2 slice of a 2x4 tensor.
425 "3 2 -:-", // Full contents of a 3x2 tensor.
426 "2 3 1,1:2,1" // Another 1x1 slice of a2x3 tensor.
427 };
428
429 MakeOp();
430 // Add a file name
431 AddInput<tstring>(TensorShape({}),
432 [&filename](int x) -> tstring { return filename; });
433
434 // Add the tensor names
435 AddInput<tstring>(TensorShape({5}), [&tensornames](int x) -> tstring {
436 return tensornames[x];
437 });
438
439 // Add the tensor shapes and slices
440 AddInput<tstring>(TensorShape({5}), [&tensorshapes](int x) -> tstring {
441 return tensorshapes[x];
442 });
443
444 // Add a 1-d integer tensor
445 AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
446
447 // Add a 2-d float tensor
448 AddInput<float>(TensorShape({2, 2}),
449 [](int x) -> float { return static_cast<float>(x) / 10; });
450
451 // Add a 2-d double tensor
452 AddInput<double>(TensorShape({1, 2}),
453 [](int x) -> double { return static_cast<double>(x) / 20; });
454
455 // Add a 2-d qint8 tensor
456 AddInput<qint8>(TensorShape({3, 2}),
457 [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
458
459 // Add a 2-d qint32 tensor
460 AddInput<qint32>(TensorShape({1, 1}), [](int x) -> qint32 {
461 return *reinterpret_cast<qint32*>(&x) * qint8(2);
462 });
463
464 TF_ASSERT_OK(RunOpKernel());
465
466 // Check that the checkpoint file is properly written
467 checkpoint::TensorSliceReader reader(filename,
468 checkpoint::OpenTableTensorSliceReader);
469 TF_EXPECT_OK(reader.status());
470
471 // We expect to find all saved tensors
472 {
473 // The 1-d integer tensor
474 TensorShape shape;
475 DataType type;
476 EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
477 TensorShape expected({10});
478 EXPECT_TRUE(shape.IsSameSize(expected));
479 EXPECT_EQ(DT_INT32, type);
480
481 // We saved the full tensor so we should be able to read it all.
482 TensorSlice s = TensorSlice::ParseOrDie("-");
483 int data[10];
484 EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
485 }
486
487 {
488 // The 2-d float tensor
489 TensorShape shape;
490 DataType type;
491 EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
492 TensorShape expected({2, 4});
493 EXPECT_TRUE(shape.IsSameSize(expected));
494 EXPECT_EQ(DT_FLOAT, type);
495
496 // We saved the slice "-:0,2" so we should not be able to read the full
497 // tensor.
498 TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
499 TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2");
500 float data[8];
501 EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data));
502 EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data));
503 }
504
505 {
506 // The 2-d double tensor
507 TensorShape shape;
508 DataType type;
509 EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
510 TensorShape expected({2, 4});
511 EXPECT_TRUE(shape.IsSameSize(expected));
512 EXPECT_EQ(DT_DOUBLE, type);
513
514 // We saved the slice "0,1:2,2" so we should not be able to read the full
515 // tensor.
516 TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
517 TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2");
518 double data[8];
519 EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data));
520 EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data));
521 }
522
523 {
524 // The 2-d qint8 tensor
525 TensorShape shape;
526 DataType type;
527 EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
528 TensorShape expected({3, 2});
529 EXPECT_TRUE(shape.IsSameSize(expected));
530 EXPECT_EQ(DT_QINT8, type);
531
532 // We saved the full slice.
533 TensorSlice s = TensorSlice::ParseOrDie("-:-");
534 qint8 data[6];
535 EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
536 }
537
538 {
539 // The 2-d qint32 tensor
540 TensorShape shape;
541 DataType type;
542 EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
543 TensorShape expected({2, 3});
544 EXPECT_TRUE(shape.IsSameSize(expected));
545 EXPECT_EQ(DT_QINT32, type);
546
547 // We expect the tensor value to be correct.
548 TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1");
549 TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
550 TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1");
551 qint32 data[6];
552 EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data));
553 EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data));
554 }
555 }
556
557 class SaveOpSlices2Test : public OpsTestBase {
558 protected:
MakeOp()559 void MakeOp() {
560 TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
561 .Input(FakeInput())
562 .Input(FakeInput())
563 .Input(FakeInput())
564 .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT}))
565 .Finalize(node_def()));
566 TF_ASSERT_OK(InitOp());
567 }
568 };
569
TEST_F(SaveOpSlices2Test,TwoSlices)570 TEST_F(SaveOpSlices2Test, TwoSlices) {
571 const string filename = io::JoinPath(testing::TmpDir(), "three_slices");
572 // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16,
573 // and one slice of the "small" tensor.
574 const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"};
575 const string tensorshapes[] = {
576 // Slice specifications for the 2 slices of "four_by_sixteen"
577 "4 16 0,2:-", // 1st slice covers indices 0 and 1 in the first dim.
578 "4 16 2,2:-", // 2nd slice covers indices 2 and 3 in the first dim.
579 "" // We save the full "small" tensors.
580 };
581
582 MakeOp();
583 // Add a file name
584 AddInput<tstring>(TensorShape({}),
585 [&filename](int x) -> tstring { return filename; });
586
587 // Add the tensor names
588 AddInput<tstring>(TensorShape({3}), [&tensornames](int x) -> tstring {
589 return tensornames[x];
590 });
591
592 // Add the tensor shapes and slices
593 AddInput<tstring>(TensorShape({3}), [&tensorshapes](int x) -> tstring {
594 return tensorshapes[x];
595 });
596
597 // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16.
598 AddInput<int32>(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; });
599
600 // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16.
601 AddInput<int32>(TensorShape({2, 16}),
602 [](int x) -> int32 { return 10 * (x + 1); });
603
604 // Add a float tensor for "small"
605 AddInput<float>(TensorShape({2, 4}),
606 [](int x) -> float { return static_cast<float>(x) / 10; });
607
608 TF_ASSERT_OK(RunOpKernel());
609
610 // Check that the checkpoint file is properly written
611 checkpoint::TensorSliceReader reader(filename,
612 checkpoint::OpenTableTensorSliceReader);
613 TF_EXPECT_OK(reader.status());
614
615 {
616 // Reload the two slices of "four_by_sixteen" into that tensor.
617 Tensor reloaded(DT_INT32, {4, 16});
618
619 // We expect to find all slices
620 TensorShape shape;
621 DataType type;
622 EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type));
623 EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
624 EXPECT_EQ(type, reloaded.dtype());
625
626 // Reload the whole tensor.
627 EXPECT_TRUE(reader.CopySliceData("four_by_sixteen",
628 TensorSlice(reloaded.dims()),
629 reloaded.flat<int>().data()));
630
631 {
632 auto slice = reloaded.Slice(0, 2).flat<int>();
633 for (int i = 0; i < slice.size(); ++i) {
634 EXPECT_EQ(i + 1, slice(i));
635 }
636 }
637 {
638 auto slice = reloaded.Slice(2, 4).flat<int>();
639 for (int i = 0; i < slice.size(); ++i) {
640 EXPECT_EQ(10 * (i + 1), slice(i));
641 }
642 }
643 }
644
645 {
646 // Reload the small float tensor.
647 Tensor reloaded(DT_FLOAT, {2, 4});
648
649 TensorShape shape;
650 DataType type;
651 EXPECT_TRUE(reader.HasTensor("small", &shape, &type));
652 EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
653 EXPECT_EQ(DT_FLOAT, reloaded.dtype());
654
655 EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()),
656 reloaded.flat<float>().data()));
657
658 for (int64_t i = 0; i < reloaded.NumElements(); ++i) {
659 EXPECT_EQ(static_cast<float>(i) / 10, reloaded.flat<float>().data()[i]);
660 }
661 }
662 }
663
664 // Benchmark-related code below.
665
BM_LargeTensorWrite(::testing::benchmark::State & state)666 void BM_LargeTensorWrite(::testing::benchmark::State& state) {
667 const int num_elements = state.range(0);
668
669 // 4 * num_elements bytes total , since sizeof(float) == 4.
670 Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
671 tensor.flat<float>().setZero();
672
673 // Builds the graph.
674 const tstring temp_filename =
675 io::JoinPath(testing::TmpDir(), "benchmark_checkpoint");
676 auto root = Scope::NewRootScope().ExitOnError();
677 const tstring tensor_name = "my_tensor";
678 ops::Save(root, temp_filename, {tensor_name}, {{tensor}});
679
680 // Disables optimizations.
681 SessionOptions session_options;
682 session_options.config.mutable_graph_options()
683 ->mutable_optimizer_options()
684 ->set_opt_level(tensorflow::OptimizerOptions::L0);
685
686 TF_CHECK_OK(root.status());
687 Graph* g = new Graph(OpRegistry::Global());
688 TF_CHECK_OK(root.ToGraph(g));
689 VLOG(1) << "Save op's output path: " << temp_filename;
690 VLOG(1) << "# nodes in Graph: " << g->num_nodes();
691
692 test::Benchmark("cpu", g, &session_options, nullptr, nullptr, "",
693 /*old_benchmark_api*/ false)
694 .Run(state);
695 }
696 BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
697
698 } // namespace
699 } // namespace tensorflow
700