• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/snapshot_utils.h"
17 
18 #include "tensorflow/core/framework/tensor.pb.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/compression.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 
26 namespace tensorflow {
27 namespace data {
28 namespace snapshot_util {
29 namespace {
30 
GenerateTensorVector(tensorflow::DataTypeVector & dtypes,std::vector<Tensor> & tensors)31 void GenerateTensorVector(tensorflow::DataTypeVector& dtypes,
32                           std::vector<Tensor>& tensors) {
33   std::string tensor_data(1024, 'a');
34   for (int i = 0; i < 10; ++i) {
35     Tensor t(tensor_data.data());
36     dtypes.push_back(t.dtype());
37     tensors.push_back(t);
38   }
39 }
40 
SnapshotRoundTrip(std::string compression_type,int version)41 void SnapshotRoundTrip(std::string compression_type, int version) {
42   // Generate ground-truth tensors for writing and reading.
43   std::vector<Tensor> tensors;
44   tensorflow::DataTypeVector dtypes;
45   GenerateTensorVector(dtypes, tensors);
46 
47   std::string filename;
48   EXPECT_TRUE(Env::Default()->LocalTempFilename(&filename));
49 
50   std::unique_ptr<Writer> writer;
51   TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
52                               compression_type, version, dtypes, &writer));
53 
54   for (int i = 0; i < 100; ++i) {
55     TF_ASSERT_OK(writer->WriteTensors(tensors));
56   }
57   TF_ASSERT_OK(writer->Close());
58 
59   std::unique_ptr<Reader> reader;
60   TF_ASSERT_OK(Reader::Create(Env::Default(), filename, compression_type,
61                               version, dtypes, &reader));
62 
63   for (int i = 0; i < 100; ++i) {
64     std::vector<Tensor> read_tensors;
65     TF_ASSERT_OK(reader->ReadTensors(&read_tensors));
66     EXPECT_EQ(tensors.size(), read_tensors.size());
67     for (int j = 0; j < read_tensors.size(); ++j) {
68       TensorProto proto;
69       TensorProto read_proto;
70 
71       tensors[j].AsProtoTensorContent(&proto);
72       read_tensors[j].AsProtoTensorContent(&read_proto);
73 
74       std::string proto_serialized, read_proto_serialized;
75       proto.AppendToString(&proto_serialized);
76       read_proto.AppendToString(&read_proto_serialized);
77       EXPECT_EQ(proto_serialized, read_proto_serialized);
78     }
79   }
80 
81   TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
82 }
83 
TEST(SnapshotUtilTest,CombinationRoundTripTest)84 TEST(SnapshotUtilTest, CombinationRoundTripTest) {
85   SnapshotRoundTrip(io::compression::kNone, 1);
86   SnapshotRoundTrip(io::compression::kGzip, 1);
87   SnapshotRoundTrip(io::compression::kSnappy, 1);
88 
89   SnapshotRoundTrip(io::compression::kNone, 2);
90   SnapshotRoundTrip(io::compression::kGzip, 2);
91   SnapshotRoundTrip(io::compression::kSnappy, 2);
92 }
93 
SnapshotReaderBenchmarkLoop(::testing::benchmark::State & state,std::string compression_type,int version)94 void SnapshotReaderBenchmarkLoop(::testing::benchmark::State& state,
95                                  std::string compression_type, int version) {
96   tensorflow::DataTypeVector dtypes;
97   std::vector<Tensor> tensors;
98   GenerateTensorVector(dtypes, tensors);
99 
100   std::string filename;
101   EXPECT_TRUE(Env::Default()->LocalTempFilename(&filename));
102 
103   std::unique_ptr<Writer> writer;
104   TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
105                               compression_type, version, dtypes, &writer));
106 
107   for (auto s : state) {
108     writer->WriteTensors(tensors).IgnoreError();
109   }
110   TF_ASSERT_OK(writer->Close());
111 
112   std::unique_ptr<Reader> reader;
113   TF_ASSERT_OK(Reader::Create(Env::Default(), filename, compression_type,
114                               version, dtypes, &reader));
115 
116   for (auto s : state) {
117     std::vector<Tensor> read_tensors;
118     reader->ReadTensors(&read_tensors).IgnoreError();
119   }
120 
121   TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
122 }
123 
SnapshotCustomReaderNoneBenchmark(::testing::benchmark::State & state)124 void SnapshotCustomReaderNoneBenchmark(::testing::benchmark::State& state) {
125   SnapshotReaderBenchmarkLoop(state, io::compression::kNone, 1);
126 }
127 
SnapshotCustomReaderGzipBenchmark(::testing::benchmark::State & state)128 void SnapshotCustomReaderGzipBenchmark(::testing::benchmark::State& state) {
129   SnapshotReaderBenchmarkLoop(state, io::compression::kGzip, 1);
130 }
131 
SnapshotCustomReaderSnappyBenchmark(::testing::benchmark::State & state)132 void SnapshotCustomReaderSnappyBenchmark(::testing::benchmark::State& state) {
133   SnapshotReaderBenchmarkLoop(state, io::compression::kSnappy, 1);
134 }
135 
SnapshotTFRecordReaderNoneBenchmark(::testing::benchmark::State & state)136 void SnapshotTFRecordReaderNoneBenchmark(::testing::benchmark::State& state) {
137   SnapshotReaderBenchmarkLoop(state, io::compression::kNone, 2);
138 }
139 
SnapshotTFRecordReaderGzipBenchmark(::testing::benchmark::State & state)140 void SnapshotTFRecordReaderGzipBenchmark(::testing::benchmark::State& state) {
141   SnapshotReaderBenchmarkLoop(state, io::compression::kGzip, 2);
142 }
143 
144 BENCHMARK(SnapshotCustomReaderNoneBenchmark);
145 BENCHMARK(SnapshotCustomReaderGzipBenchmark);
146 BENCHMARK(SnapshotCustomReaderSnappyBenchmark);
147 BENCHMARK(SnapshotTFRecordReaderNoneBenchmark);
148 BENCHMARK(SnapshotTFRecordReaderGzipBenchmark);
149 
SnapshotWriterBenchmarkLoop(::testing::benchmark::State & state,std::string compression_type,int version)150 void SnapshotWriterBenchmarkLoop(::testing::benchmark::State& state,
151                                  std::string compression_type, int version) {
152   tensorflow::DataTypeVector dtypes;
153   std::vector<Tensor> tensors;
154   GenerateTensorVector(dtypes, tensors);
155 
156   std::string filename;
157   EXPECT_TRUE(Env::Default()->LocalTempFilename(&filename));
158 
159   std::unique_ptr<Writer> writer;
160   TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
161                               compression_type, version, dtypes, &writer));
162 
163   for (auto s : state) {
164     writer->WriteTensors(tensors).IgnoreError();
165   }
166   writer->Close().IgnoreError();
167 
168   TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
169 }
170 
SnapshotCustomWriterNoneBenchmark(::testing::benchmark::State & state)171 void SnapshotCustomWriterNoneBenchmark(::testing::benchmark::State& state) {
172   SnapshotWriterBenchmarkLoop(state, io::compression::kNone, 1);
173 }
174 
SnapshotCustomWriterGzipBenchmark(::testing::benchmark::State & state)175 void SnapshotCustomWriterGzipBenchmark(::testing::benchmark::State& state) {
176   SnapshotWriterBenchmarkLoop(state, io::compression::kGzip, 1);
177 }
178 
SnapshotCustomWriterSnappyBenchmark(::testing::benchmark::State & state)179 void SnapshotCustomWriterSnappyBenchmark(::testing::benchmark::State& state) {
180   SnapshotWriterBenchmarkLoop(state, io::compression::kSnappy, 1);
181 }
182 
SnapshotTFRecordWriterNoneBenchmark(::testing::benchmark::State & state)183 void SnapshotTFRecordWriterNoneBenchmark(::testing::benchmark::State& state) {
184   SnapshotWriterBenchmarkLoop(state, io::compression::kNone, 2);
185 }
186 
SnapshotTFRecordWriterGzipBenchmark(::testing::benchmark::State & state)187 void SnapshotTFRecordWriterGzipBenchmark(::testing::benchmark::State& state) {
188   SnapshotWriterBenchmarkLoop(state, io::compression::kGzip, 2);
189 }
190 
SnapshotTFRecordWriterSnappyBenchmark(::testing::benchmark::State & state)191 void SnapshotTFRecordWriterSnappyBenchmark(::testing::benchmark::State& state) {
192   SnapshotWriterBenchmarkLoop(state, io::compression::kSnappy, 2);
193 }
194 
195 BENCHMARK(SnapshotCustomWriterNoneBenchmark);
196 BENCHMARK(SnapshotCustomWriterGzipBenchmark);
197 BENCHMARK(SnapshotCustomWriterSnappyBenchmark);
198 BENCHMARK(SnapshotTFRecordWriterNoneBenchmark);
199 BENCHMARK(SnapshotTFRecordWriterGzipBenchmark);
200 BENCHMARK(SnapshotTFRecordWriterSnappyBenchmark);
201 
202 }  // namespace
203 }  // namespace snapshot_util
204 }  // namespace data
205 }  // namespace tensorflow
206