• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 
19 #include "tensorflow/cc/ops/const_op.h"
20 #include "tensorflow/cc/ops/io_ops.h"
21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/fake_input.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/graph/graph_def_builder.h"
30 #include "tensorflow/core/kernels/ops_testutil.h"
31 #include "tensorflow/core/kernels/ops_util.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/protobuf/config.pb.h"
39 #include "tensorflow/core/public/session_options.h"
40 #include "tensorflow/core/util/tensor_slice_reader.h"
41 
42 namespace tensorflow {
43 namespace {
44 
45 class SaveOpTest : public OpsTestBase {
46  protected:
MakeOp()47   void MakeOp() {
48     TF_ASSERT_OK(
49         NodeDefBuilder("myop", "Save")
50             .Input(FakeInput())
51             .Input(FakeInput())
52             .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8,
53                               DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_INT64,
54                               DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_HALF}))
55             .Finalize(node_def()));
56     TF_ASSERT_OK(InitOp());
57   }
58 };
59 
TEST_F(SaveOpTest,Simple)60 TEST_F(SaveOpTest, Simple) {
61   const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
62   const string tensornames[] = {
63       "tensor_bool",       "tensor_int",    "tensor_float",  "tensor_double",
64       "tensor_qint8",      "tensor_qint32", "tensor_uint8",  "tensor_int8",
65       "tensor_int16",      "tensor_int64",  "tensor_string", "tensor_complex64",
66       "tensor_complex128", "tensor_half"};
67 
68   MakeOp();
69   // Add a file name
70   AddInput<tstring>(TensorShape({}),
71                     [&filename](int x) -> tstring { return filename; });
72 
73   // Add the tensor names
74   AddInput<tstring>(TensorShape({14}), [&tensornames](int x) -> tstring {
75     return tensornames[x];
76   });
77 
78   // Add a 1-d bool tensor
79   AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; });
80 
81   // Add a 1-d integer tensor
82   AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
83 
84   // Add a 2-d float tensor
85   AddInput<float>(TensorShape({2, 4}),
86                   [](int x) -> float { return static_cast<float>(x) / 10; });
87 
88   // Add a 2-d double tensor
89   AddInput<double>(TensorShape({2, 4}),
90                    [](int x) -> double { return static_cast<double>(x) / 20; });
91 
92   // Add a 2-d qint8 tensor
93   AddInput<qint8>(TensorShape({3, 2}),
94                   [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
95 
96   // Add a 2-d qint32 tensor
97   AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
98     return *reinterpret_cast<qint32*>(&x) * qint8(2);
99   });
100 
101   // Add a 1-d uint8 tensor
102   AddInput<uint8>(TensorShape({11}), [](int x) -> uint8 { return x + 1; });
103 
104   // Add a 1-d int8 tensor
105   AddInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; });
106 
107   // Add a 1-d int16 tensor
108   AddInput<int16>(TensorShape({7}), [](int x) -> int16 { return x - 8; });
109 
110   // Add a 1-d int64 tensor
111   AddInput<int64>(TensorShape({9}), [](int x) -> int64 { return x - 9; });
112 
113   // Add a 1-d string tensor
114   AddInput<tstring>(TensorShape({2}),
115                     [](int x) -> tstring { return x ? "yes" : "no"; });
116 
117   // Add a 2-d complex64 tensor
118   AddInput<complex64>(TensorShape({2, 3}), [](int x) -> complex64 {
119     return complex64(100 + x, 200 + x);
120   });
121 
122   // Add a 2-d complex128 tensor
123   AddInput<complex128>(TensorShape({2, 3}), [](int x) -> complex128 {
124     return complex128(100 + x, 200 + x);
125   });
126 
127   // Add a 2-d half tensor
128   AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half {
129     return static_cast<Eigen::half>(x) / Eigen::half(2);
130   });
131   TF_ASSERT_OK(RunOpKernel());
132 
133   // Check that the checkpoint file is properly written
134   checkpoint::TensorSliceReader reader(filename,
135                                        checkpoint::OpenTableTensorSliceReader);
136   TF_EXPECT_OK(reader.status());
137 
138   // We expect to find all saved tensors
139   {
140     // The 1-d bool tensor
141     TensorShape shape;
142     DataType type;
143     EXPECT_TRUE(reader.HasTensor("tensor_bool", &shape, &type));
144     TensorShape expected({2});
145     EXPECT_TRUE(shape.IsSameSize(expected));
146     EXPECT_EQ(DT_BOOL, type);
147 
148     // We expect the tensor value to be correct.
149     TensorSlice s = TensorSlice::ParseOrDie("-");
150     bool data[2];
151     std::fill_n(data, 2, false);
152     EXPECT_TRUE(reader.CopySliceData("tensor_bool", s, data));
153     for (int i = 0; i < 2; ++i) {
154       EXPECT_EQ((i != 0), data[i]);
155     }
156   }
157 
158   {
159     // The 1-d integer tensor
160     TensorShape shape;
161     DataType type;
162     EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
163     TensorShape expected({10});
164     EXPECT_TRUE(shape.IsSameSize(expected));
165     EXPECT_EQ(DT_INT32, type);
166 
167     // We expect the tensor value to be correct.
168     TensorSlice s = TensorSlice::ParseOrDie("-");
169     int data[10];
170     std::fill_n(data, 10, 0);
171     EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
172     for (int i = 0; i < 10; ++i) {
173       EXPECT_EQ(i + 1, data[i]);
174     }
175   }
176 
177   {
178     // The 2-d float tensor
179     TensorShape shape;
180     DataType type;
181     EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
182     TensorShape expected({2, 4});
183     EXPECT_TRUE(shape.IsSameSize(expected));
184     EXPECT_EQ(DT_FLOAT, type);
185 
186     // We expect the tensor value to be correct.
187     TensorSlice s = TensorSlice::ParseOrDie("-:-");
188     float data[8];
189     std::fill_n(data, 8, 0);
190     EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data));
191     for (int i = 0; i < 8; ++i) {
192       EXPECT_EQ(static_cast<float>(i) / 10, data[i]);
193     }
194   }
195 
196   {
197     // The 2-d double tensor
198     TensorShape shape;
199     DataType type;
200     EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
201     TensorShape expected({2, 4});
202     EXPECT_TRUE(shape.IsSameSize(expected));
203     EXPECT_EQ(DT_DOUBLE, type);
204 
205     // We expect the tensor value to be correct.
206     TensorSlice s = TensorSlice::ParseOrDie("-:-");
207     double data[8];
208     std::fill_n(data, 8, 0);
209     EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data));
210     for (int i = 0; i < 8; ++i) {
211       EXPECT_EQ(static_cast<double>(i) / 20, data[i]);
212     }
213   }
214 
215   {
216     // The 2-d qint8 tensor
217     TensorShape shape;
218     DataType type;
219     EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
220     TensorShape expected({3, 2});
221     EXPECT_TRUE(shape.IsSameSize(expected));
222     EXPECT_EQ(DT_QINT8, type);
223 
224     // We expect the tensor value to be correct.
225     TensorSlice s = TensorSlice::ParseOrDie("-:-");
226     qint8 data[6];
227     EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
228     for (int i = 0; i < 6; ++i) {
229       EXPECT_EQ(*reinterpret_cast<qint8*>(&i), data[i]);
230     }
231   }
232 
233   {
234     // The 2-d qint32 tensor
235     TensorShape shape;
236     DataType type;
237     EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
238     TensorShape expected({2, 3});
239     EXPECT_TRUE(shape.IsSameSize(expected));
240     EXPECT_EQ(DT_QINT32, type);
241 
242     // We expect the tensor value to be correct.
243     TensorSlice s = TensorSlice::ParseOrDie("-:-");
244     qint32 data[6];
245     EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data));
246     for (int i = 0; i < 6; ++i) {
247       EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), data[i]);
248     }
249   }
250 
251   {
252     // The 1-d uint8 tensor
253     TensorShape shape;
254     DataType type;
255     EXPECT_TRUE(reader.HasTensor("tensor_uint8", &shape, &type));
256     TensorShape expected({11});
257     EXPECT_TRUE(shape.IsSameSize(expected));
258     EXPECT_EQ(DT_UINT8, type);
259 
260     // We expect the tensor value to be correct.
261     TensorSlice s = TensorSlice::ParseOrDie("-");
262     uint8 data[11];
263     EXPECT_TRUE(reader.CopySliceData("tensor_uint8", s, data));
264     for (int i = 0; i < 11; ++i) {
265       EXPECT_EQ(i + 1, data[i]);
266     }
267   }
268 
269   {
270     // The 1-d int8 tensor
271     TensorShape shape;
272     DataType type;
273     EXPECT_TRUE(reader.HasTensor("tensor_int8", &shape, &type));
274     TensorShape expected({7});
275     EXPECT_TRUE(shape.IsSameSize(expected));
276     EXPECT_EQ(DT_INT8, type);
277 
278     // We expect the tensor value to be correct.
279     TensorSlice s = TensorSlice::ParseOrDie("-");
280     int8 data[7];
281     EXPECT_TRUE(reader.CopySliceData("tensor_int8", s, data));
282     for (int i = 0; i < 7; ++i) {
283       EXPECT_EQ(i - 7, data[i]);
284     }
285   }
286 
287   {
288     // The 1-d int16 tensor
289     TensorShape shape;
290     DataType type;
291     EXPECT_TRUE(reader.HasTensor("tensor_int16", &shape, &type));
292     TensorShape expected({7});
293     EXPECT_TRUE(shape.IsSameSize(expected));
294     EXPECT_EQ(DT_INT16, type);
295 
296     // We expect the tensor value to be correct.
297     TensorSlice s = TensorSlice::ParseOrDie("-");
298     int16 data[7];
299     EXPECT_TRUE(reader.CopySliceData("tensor_int16", s, data));
300     for (int i = 0; i < 7; ++i) {
301       EXPECT_EQ(i - 8, data[i]);
302     }
303   }
304 
305   {
306     // The 1-d int64 tensor
307     TensorShape shape;
308     DataType type;
309     EXPECT_TRUE(reader.HasTensor("tensor_int64", &shape, &type));
310     TensorShape expected({9});
311     EXPECT_TRUE(shape.IsSameSize(expected));
312     EXPECT_EQ(DT_INT64, type);
313 
314     // We expect the tensor value to be correct.
315     TensorSlice s = TensorSlice::ParseOrDie("-");
316     int64 data[9];
317     EXPECT_TRUE(reader.CopySliceData("tensor_int64", s, data));
318     for (int i = 0; i < 9; ++i) {
319       EXPECT_EQ(i - 9, data[i]);
320     }
321   }
322 
323   {
324     // The 1-d string tensor
325     TensorShape shape;
326     DataType type;
327     EXPECT_TRUE(reader.HasTensor("tensor_string", &shape, &type));
328     TensorShape expected({2});
329     EXPECT_TRUE(shape.IsSameSize(expected));
330     EXPECT_EQ(DT_STRING, type);
331 
332     // We expect the tensor value to be correct.
333     TensorSlice s = TensorSlice::ParseOrDie("-");
334     tstring data[2];
335     EXPECT_TRUE(reader.CopySliceData("tensor_string", s, data));
336     EXPECT_EQ("no", data[0]);
337     EXPECT_EQ("yes", data[1]);
338   }
339 
340   {
341     // The 2-d complex64 tensor
342     TensorShape shape;
343     DataType type;
344     EXPECT_TRUE(reader.HasTensor("tensor_complex64", &shape, &type));
345     TensorShape expected({2, 3});
346     EXPECT_TRUE(shape.IsSameSize(expected));
347     EXPECT_EQ(DT_COMPLEX64, type);
348 
349     // We expect the tensor value to be correct.
350     TensorSlice s = TensorSlice::ParseOrDie("-:-");
351     complex64 data[6];
352     EXPECT_TRUE(reader.CopySliceData("tensor_complex64", s, data));
353     for (int i = 0; i < 6; ++i) {
354       EXPECT_EQ(100 + i, data[i].real());
355       EXPECT_EQ(200 + i, data[i].imag());
356     }
357   }
358 
359   {
360     // The 2-d complex128 tensor
361     TensorShape shape;
362     DataType type;
363     EXPECT_TRUE(reader.HasTensor("tensor_complex128", &shape, &type));
364     TensorShape expected({2, 3});
365     EXPECT_TRUE(shape.IsSameSize(expected));
366     EXPECT_EQ(DT_COMPLEX128, type);
367 
368     // We expect the tensor value to be correct.
369     TensorSlice s = TensorSlice::ParseOrDie("-:-");
370     complex128 data[6];
371     EXPECT_TRUE(reader.CopySliceData("tensor_complex128", s, data));
372     for (int i = 0; i < 6; ++i) {
373       EXPECT_EQ(100 + i, data[i].real());
374       EXPECT_EQ(200 + i, data[i].imag());
375     }
376   }
377   {
378     // The 2-d half tensor
379     TensorShape shape;
380     DataType type;
381     EXPECT_TRUE(reader.HasTensor("tensor_half", &shape, &type));
382     TensorShape expected({2, 4});
383     EXPECT_TRUE(shape.IsSameSize(expected));
384     EXPECT_EQ(DT_HALF, type);
385 
386     // We expect the tensor value to be correct.
387     TensorSlice s = TensorSlice::ParseOrDie("-:-");
388     Eigen::half data[8];
389     std::fill_n(data, 8, Eigen::half(0));
390     EXPECT_TRUE(reader.CopySliceData("tensor_half", s, data));
391     for (int i = 0; i < 8; ++i) {
392       EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), data[i]);
393     }
394   }
395 }
396 
397 class SaveSlicesOpTest : public OpsTestBase {
398  protected:
MakeOp()399   void MakeOp() {
400     TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
401                      .Input(FakeInput())
402                      .Input(FakeInput())
403                      .Input(FakeInput())
404                      .Input(FakeInput(
405                          {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
406                      .Finalize(node_def()));
407     TF_ASSERT_OK(InitOp());
408   }
409 };
410 
411 // Here we save only slices.  We restore them in a larger tensor and we check
412 // that the right slice is restored.  It is quite tricky to check that the
413 // right slices are actually restored so instead we just check that
414 // CopySliceData() return true/false depending on the slice we ask for.
TEST_F(SaveSlicesOpTest,Slices)415 TEST_F(SaveSlicesOpTest, Slices) {
416   const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices");
417   const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
418                                 "tensor_qint8", "tensor_qint32"};
419   // Specifies that the data we save are slices of larger tensors.
420   // See core/framework/tensor_slice.h for the slice syntax.
421   const string tensorshapes[] = {
422       "10 -",         // Full contents of a 10 element vector.
423       "2 4 -:0,2",    // A 2x2 slice of a 2x4 tensor.
424       "2 4 0,1:2,2",  // A 1x2 slice of a 2x4 tensor.
425       "3 2 -:-",      // Full contents of a 3x2 tensor.
426       "2 3 1,1:2,1"   // Another 1x1 slice of a2x3 tensor.
427   };
428 
429   MakeOp();
430   // Add a file name
431   AddInput<tstring>(TensorShape({}),
432                     [&filename](int x) -> tstring { return filename; });
433 
434   // Add the tensor names
435   AddInput<tstring>(TensorShape({5}), [&tensornames](int x) -> tstring {
436     return tensornames[x];
437   });
438 
439   // Add the tensor shapes and slices
440   AddInput<tstring>(TensorShape({5}), [&tensorshapes](int x) -> tstring {
441     return tensorshapes[x];
442   });
443 
444   // Add a 1-d integer tensor
445   AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
446 
447   // Add a 2-d float tensor
448   AddInput<float>(TensorShape({2, 2}),
449                   [](int x) -> float { return static_cast<float>(x) / 10; });
450 
451   // Add a 2-d double tensor
452   AddInput<double>(TensorShape({1, 2}),
453                    [](int x) -> double { return static_cast<double>(x) / 20; });
454 
455   // Add a 2-d qint8 tensor
456   AddInput<qint8>(TensorShape({3, 2}),
457                   [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
458 
459   // Add a 2-d qint32 tensor
460   AddInput<qint32>(TensorShape({1, 1}), [](int x) -> qint32 {
461     return *reinterpret_cast<qint32*>(&x) * qint8(2);
462   });
463 
464   TF_ASSERT_OK(RunOpKernel());
465 
466   // Check that the checkpoint file is properly written
467   checkpoint::TensorSliceReader reader(filename,
468                                        checkpoint::OpenTableTensorSliceReader);
469   TF_EXPECT_OK(reader.status());
470 
471   // We expect to find all saved tensors
472   {
473     // The 1-d integer tensor
474     TensorShape shape;
475     DataType type;
476     EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
477     TensorShape expected({10});
478     EXPECT_TRUE(shape.IsSameSize(expected));
479     EXPECT_EQ(DT_INT32, type);
480 
481     // We saved the full tensor so we should be able to read it all.
482     TensorSlice s = TensorSlice::ParseOrDie("-");
483     int data[10];
484     EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
485   }
486 
487   {
488     // The 2-d float tensor
489     TensorShape shape;
490     DataType type;
491     EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
492     TensorShape expected({2, 4});
493     EXPECT_TRUE(shape.IsSameSize(expected));
494     EXPECT_EQ(DT_FLOAT, type);
495 
496     // We saved the slice "-:0,2" so we should not be able to read the full
497     // tensor.
498     TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
499     TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2");
500     float data[8];
501     EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data));
502     EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data));
503   }
504 
505   {
506     // The 2-d double tensor
507     TensorShape shape;
508     DataType type;
509     EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
510     TensorShape expected({2, 4});
511     EXPECT_TRUE(shape.IsSameSize(expected));
512     EXPECT_EQ(DT_DOUBLE, type);
513 
514     // We saved the slice "0,1:2,2" so we should not be able to read the full
515     // tensor.
516     TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
517     TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2");
518     double data[8];
519     EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data));
520     EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data));
521   }
522 
523   {
524     // The 2-d qint8 tensor
525     TensorShape shape;
526     DataType type;
527     EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
528     TensorShape expected({3, 2});
529     EXPECT_TRUE(shape.IsSameSize(expected));
530     EXPECT_EQ(DT_QINT8, type);
531 
532     // We saved the full slice.
533     TensorSlice s = TensorSlice::ParseOrDie("-:-");
534     qint8 data[6];
535     EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
536   }
537 
538   {
539     // The 2-d qint32 tensor
540     TensorShape shape;
541     DataType type;
542     EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
543     TensorShape expected({2, 3});
544     EXPECT_TRUE(shape.IsSameSize(expected));
545     EXPECT_EQ(DT_QINT32, type);
546 
547     // We expect the tensor value to be correct.
548     TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1");
549     TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
550     TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1");
551     qint32 data[6];
552     EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data));
553     EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data));
554   }
555 }
556 
557 class SaveOpSlices2Test : public OpsTestBase {
558  protected:
MakeOp()559   void MakeOp() {
560     TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
561                      .Input(FakeInput())
562                      .Input(FakeInput())
563                      .Input(FakeInput())
564                      .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT}))
565                      .Finalize(node_def()));
566     TF_ASSERT_OK(InitOp());
567   }
568 };
569 
TEST_F(SaveOpSlices2Test,TwoSlices)570 TEST_F(SaveOpSlices2Test, TwoSlices) {
571   const string filename = io::JoinPath(testing::TmpDir(), "three_slices");
572   // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16,
573   // and one slice of the "small" tensor.
574   const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"};
575   const string tensorshapes[] = {
576       // Slice specifications for the 2 slices of "four_by_sixteen"
577       "4 16 0,2:-",  // 1st slice covers indices 0 and 1 in the first dim.
578       "4 16 2,2:-",  // 2nd slice covers indices 2 and 3 in the first dim.
579       ""             // We save the full "small" tensors.
580   };
581 
582   MakeOp();
583   // Add a file name
584   AddInput<tstring>(TensorShape({}),
585                     [&filename](int x) -> tstring { return filename; });
586 
587   // Add the tensor names
588   AddInput<tstring>(TensorShape({3}), [&tensornames](int x) -> tstring {
589     return tensornames[x];
590   });
591 
592   // Add the tensor shapes and slices
593   AddInput<tstring>(TensorShape({3}), [&tensorshapes](int x) -> tstring {
594     return tensorshapes[x];
595   });
596 
597   // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16.
598   AddInput<int32>(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; });
599 
600   // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16.
601   AddInput<int32>(TensorShape({2, 16}),
602                   [](int x) -> int32 { return 10 * (x + 1); });
603 
604   // Add a float tensor for "small"
605   AddInput<float>(TensorShape({2, 4}),
606                   [](int x) -> float { return static_cast<float>(x) / 10; });
607 
608   TF_ASSERT_OK(RunOpKernel());
609 
610   // Check that the checkpoint file is properly written
611   checkpoint::TensorSliceReader reader(filename,
612                                        checkpoint::OpenTableTensorSliceReader);
613   TF_EXPECT_OK(reader.status());
614 
615   {
616     // Reload the two slices of "four_by_sixteen" into that tensor.
617     Tensor reloaded(DT_INT32, {4, 16});
618 
619     // We expect to find all slices
620     TensorShape shape;
621     DataType type;
622     EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type));
623     EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
624     EXPECT_EQ(type, reloaded.dtype());
625 
626     // Reload the whole tensor.
627     EXPECT_TRUE(reader.CopySliceData("four_by_sixteen",
628                                      TensorSlice(reloaded.dims()),
629                                      reloaded.flat<int>().data()));
630 
631     {
632       auto slice = reloaded.Slice(0, 2).flat<int>();
633       for (int i = 0; i < slice.size(); ++i) {
634         EXPECT_EQ(i + 1, slice(i));
635       }
636     }
637     {
638       auto slice = reloaded.Slice(2, 4).flat<int>();
639       for (int i = 0; i < slice.size(); ++i) {
640         EXPECT_EQ(10 * (i + 1), slice(i));
641       }
642     }
643   }
644 
645   {
646     // Reload the small float tensor.
647     Tensor reloaded(DT_FLOAT, {2, 4});
648 
649     TensorShape shape;
650     DataType type;
651     EXPECT_TRUE(reader.HasTensor("small", &shape, &type));
652     EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
653     EXPECT_EQ(DT_FLOAT, reloaded.dtype());
654 
655     EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()),
656                                      reloaded.flat<float>().data()));
657 
658     for (int64_t i = 0; i < reloaded.NumElements(); ++i) {
659       EXPECT_EQ(static_cast<float>(i) / 10, reloaded.flat<float>().data()[i]);
660     }
661   }
662 }
663 
664 // Benchmark-related code below.
665 
BM_LargeTensorWrite(::testing::benchmark::State & state)666 void BM_LargeTensorWrite(::testing::benchmark::State& state) {
667   const int num_elements = state.range(0);
668 
669   // 4 * num_elements bytes total , since sizeof(float) == 4.
670   Tensor tensor(DT_FLOAT, TensorShape({num_elements}));
671   tensor.flat<float>().setZero();
672 
673   // Builds the graph.
674   const tstring temp_filename =
675       io::JoinPath(testing::TmpDir(), "benchmark_checkpoint");
676   auto root = Scope::NewRootScope().ExitOnError();
677   const tstring tensor_name = "my_tensor";
678   ops::Save(root, temp_filename, {tensor_name}, {{tensor}});
679 
680   // Disables optimizations.
681   SessionOptions session_options;
682   session_options.config.mutable_graph_options()
683       ->mutable_optimizer_options()
684       ->set_opt_level(tensorflow::OptimizerOptions::L0);
685 
686   TF_CHECK_OK(root.status());
687   Graph* g = new Graph(OpRegistry::Global());
688   TF_CHECK_OK(root.ToGraph(g));
689   VLOG(1) << "Save op's output path: " << temp_filename;
690   VLOG(1) << "# nodes in Graph: " << g->num_nodes();
691 
692   test::Benchmark("cpu", g, &session_options, nullptr, nullptr, "",
693                   /*old_benchmark_api*/ false)
694       .Run(state);
695 }
696 BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */);
697 
698 }  // namespace
699 }  // namespace tensorflow
700