1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/snapshot_utils.h"
17
18 #include "tensorflow/core/framework/tensor.pb.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/compression.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25
26 namespace tensorflow {
27 namespace data {
28 namespace snapshot_util {
29 namespace {
30
GenerateTensorVector(tensorflow::DataTypeVector & dtypes,std::vector<Tensor> & tensors)31 void GenerateTensorVector(tensorflow::DataTypeVector& dtypes,
32 std::vector<Tensor>& tensors) {
33 std::string tensor_data(1024, 'a');
34 for (int i = 0; i < 10; ++i) {
35 Tensor t(tensor_data.data());
36 dtypes.push_back(t.dtype());
37 tensors.push_back(t);
38 }
39 }
40
SnapshotRoundTrip(std::string compression_type,int version)41 void SnapshotRoundTrip(std::string compression_type, int version) {
42 // Generate ground-truth tensors for writing and reading.
43 std::vector<Tensor> tensors;
44 tensorflow::DataTypeVector dtypes;
45 GenerateTensorVector(dtypes, tensors);
46
47 std::string filename;
48 EXPECT_TRUE(Env::Default()->LocalTempFilename(&filename));
49
50 std::unique_ptr<Writer> writer;
51 TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
52 compression_type, version, dtypes, &writer));
53
54 for (int i = 0; i < 100; ++i) {
55 TF_ASSERT_OK(writer->WriteTensors(tensors));
56 }
57 TF_ASSERT_OK(writer->Close());
58
59 std::unique_ptr<Reader> reader;
60 TF_ASSERT_OK(Reader::Create(Env::Default(), filename, compression_type,
61 version, dtypes, &reader));
62
63 for (int i = 0; i < 100; ++i) {
64 std::vector<Tensor> read_tensors;
65 TF_ASSERT_OK(reader->ReadTensors(&read_tensors));
66 EXPECT_EQ(tensors.size(), read_tensors.size());
67 for (int j = 0; j < read_tensors.size(); ++j) {
68 TensorProto proto;
69 TensorProto read_proto;
70
71 tensors[j].AsProtoTensorContent(&proto);
72 read_tensors[j].AsProtoTensorContent(&read_proto);
73
74 std::string proto_serialized, read_proto_serialized;
75 proto.AppendToString(&proto_serialized);
76 read_proto.AppendToString(&read_proto_serialized);
77 EXPECT_EQ(proto_serialized, read_proto_serialized);
78 }
79 }
80
81 TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
82 }
83
TEST(SnapshotUtilTest,CombinationRoundTripTest)84 TEST(SnapshotUtilTest, CombinationRoundTripTest) {
85 SnapshotRoundTrip(io::compression::kNone, 1);
86 SnapshotRoundTrip(io::compression::kGzip, 1);
87 SnapshotRoundTrip(io::compression::kSnappy, 1);
88
89 SnapshotRoundTrip(io::compression::kNone, 2);
90 SnapshotRoundTrip(io::compression::kGzip, 2);
91 SnapshotRoundTrip(io::compression::kSnappy, 2);
92 }
93
SnapshotReaderBenchmarkLoop(::testing::benchmark::State & state,std::string compression_type,int version)94 void SnapshotReaderBenchmarkLoop(::testing::benchmark::State& state,
95 std::string compression_type, int version) {
96 tensorflow::DataTypeVector dtypes;
97 std::vector<Tensor> tensors;
98 GenerateTensorVector(dtypes, tensors);
99
100 std::string filename;
101 EXPECT_TRUE(Env::Default()->LocalTempFilename(&filename));
102
103 std::unique_ptr<Writer> writer;
104 TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
105 compression_type, version, dtypes, &writer));
106
107 for (auto s : state) {
108 writer->WriteTensors(tensors).IgnoreError();
109 }
110 TF_ASSERT_OK(writer->Close());
111
112 std::unique_ptr<Reader> reader;
113 TF_ASSERT_OK(Reader::Create(Env::Default(), filename, compression_type,
114 version, dtypes, &reader));
115
116 for (auto s : state) {
117 std::vector<Tensor> read_tensors;
118 reader->ReadTensors(&read_tensors).IgnoreError();
119 }
120
121 TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
122 }
123
SnapshotCustomReaderNoneBenchmark(::testing::benchmark::State & state)124 void SnapshotCustomReaderNoneBenchmark(::testing::benchmark::State& state) {
125 SnapshotReaderBenchmarkLoop(state, io::compression::kNone, 1);
126 }
127
SnapshotCustomReaderGzipBenchmark(::testing::benchmark::State & state)128 void SnapshotCustomReaderGzipBenchmark(::testing::benchmark::State& state) {
129 SnapshotReaderBenchmarkLoop(state, io::compression::kGzip, 1);
130 }
131
SnapshotCustomReaderSnappyBenchmark(::testing::benchmark::State & state)132 void SnapshotCustomReaderSnappyBenchmark(::testing::benchmark::State& state) {
133 SnapshotReaderBenchmarkLoop(state, io::compression::kSnappy, 1);
134 }
135
SnapshotTFRecordReaderNoneBenchmark(::testing::benchmark::State & state)136 void SnapshotTFRecordReaderNoneBenchmark(::testing::benchmark::State& state) {
137 SnapshotReaderBenchmarkLoop(state, io::compression::kNone, 2);
138 }
139
SnapshotTFRecordReaderGzipBenchmark(::testing::benchmark::State & state)140 void SnapshotTFRecordReaderGzipBenchmark(::testing::benchmark::State& state) {
141 SnapshotReaderBenchmarkLoop(state, io::compression::kGzip, 2);
142 }
143
144 BENCHMARK(SnapshotCustomReaderNoneBenchmark);
145 BENCHMARK(SnapshotCustomReaderGzipBenchmark);
146 BENCHMARK(SnapshotCustomReaderSnappyBenchmark);
147 BENCHMARK(SnapshotTFRecordReaderNoneBenchmark);
148 BENCHMARK(SnapshotTFRecordReaderGzipBenchmark);
149
SnapshotWriterBenchmarkLoop(::testing::benchmark::State & state,std::string compression_type,int version)150 void SnapshotWriterBenchmarkLoop(::testing::benchmark::State& state,
151 std::string compression_type, int version) {
152 tensorflow::DataTypeVector dtypes;
153 std::vector<Tensor> tensors;
154 GenerateTensorVector(dtypes, tensors);
155
156 std::string filename;
157 EXPECT_TRUE(Env::Default()->LocalTempFilename(&filename));
158
159 std::unique_ptr<Writer> writer;
160 TF_ASSERT_OK(Writer::Create(tensorflow::Env::Default(), filename,
161 compression_type, version, dtypes, &writer));
162
163 for (auto s : state) {
164 writer->WriteTensors(tensors).IgnoreError();
165 }
166 writer->Close().IgnoreError();
167
168 TF_ASSERT_OK(Env::Default()->DeleteFile(filename));
169 }
170
SnapshotCustomWriterNoneBenchmark(::testing::benchmark::State & state)171 void SnapshotCustomWriterNoneBenchmark(::testing::benchmark::State& state) {
172 SnapshotWriterBenchmarkLoop(state, io::compression::kNone, 1);
173 }
174
SnapshotCustomWriterGzipBenchmark(::testing::benchmark::State & state)175 void SnapshotCustomWriterGzipBenchmark(::testing::benchmark::State& state) {
176 SnapshotWriterBenchmarkLoop(state, io::compression::kGzip, 1);
177 }
178
SnapshotCustomWriterSnappyBenchmark(::testing::benchmark::State & state)179 void SnapshotCustomWriterSnappyBenchmark(::testing::benchmark::State& state) {
180 SnapshotWriterBenchmarkLoop(state, io::compression::kSnappy, 1);
181 }
182
SnapshotTFRecordWriterNoneBenchmark(::testing::benchmark::State & state)183 void SnapshotTFRecordWriterNoneBenchmark(::testing::benchmark::State& state) {
184 SnapshotWriterBenchmarkLoop(state, io::compression::kNone, 2);
185 }
186
SnapshotTFRecordWriterGzipBenchmark(::testing::benchmark::State & state)187 void SnapshotTFRecordWriterGzipBenchmark(::testing::benchmark::State& state) {
188 SnapshotWriterBenchmarkLoop(state, io::compression::kGzip, 2);
189 }
190
SnapshotTFRecordWriterSnappyBenchmark(::testing::benchmark::State & state)191 void SnapshotTFRecordWriterSnappyBenchmark(::testing::benchmark::State& state) {
192 SnapshotWriterBenchmarkLoop(state, io::compression::kSnappy, 2);
193 }
194
195 BENCHMARK(SnapshotCustomWriterNoneBenchmark);
196 BENCHMARK(SnapshotCustomWriterGzipBenchmark);
197 BENCHMARK(SnapshotCustomWriterSnappyBenchmark);
198 BENCHMARK(SnapshotTFRecordWriterNoneBenchmark);
199 BENCHMARK(SnapshotTFRecordWriterGzipBenchmark);
200 BENCHMARK(SnapshotTFRecordWriterSnappyBenchmark);
201
202 } // namespace
203 } // namespace snapshot_util
204 } // namespace data
205 } // namespace tensorflow
206