• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "tensorflow/core/util/example_proto_fast_parsing.h"
19 
20 #include "tensorflow/core/example/example.pb.h"
21 #include "tensorflow/core/example/feature.pb.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/simple_philox.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/util/example_proto_fast_parsing_test.pb.h"
28 
29 namespace tensorflow {
30 namespace example {
31 namespace {
32 
33 constexpr char kDenseInt64Key[] = "dense_int64";
34 constexpr char kDenseFloatKey[] = "dense_float";
35 constexpr char kDenseStringKey[] = "dense_string";
36 
37 constexpr char kSparseInt64Key[] = "sparse_int64";
38 constexpr char kSparseFloatKey[] = "sparse_float";
39 constexpr char kSparseStringKey[] = "sparse_string";
40 
SerializedToReadable(string serialized)41 string SerializedToReadable(string serialized) {
42   string result;
43   result += '"';
44   for (char c : serialized)
45     result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
46   result += '"';
47   return result;
48 }
49 
50 template <class T>
Serialize(const T & example)51 string Serialize(const T& example) {
52   string serialized;
53   example.SerializeToString(&serialized);
54   return serialized;
55 }
56 
57 // Tests that serialized gets parsed identically by TestFastParse(..)
58 // and the regular Example.ParseFromString(..).
TestCorrectness(const string & serialized)59 void TestCorrectness(const string& serialized) {
60   Example example;
61   Example fast_example;
62   EXPECT_TRUE(example.ParseFromString(serialized));
63   example.DiscardUnknownFields();
64   EXPECT_TRUE(TestFastParse(serialized, &fast_example));
65   EXPECT_EQ(example.DebugString(), fast_example.DebugString());
66   if (example.DebugString() != fast_example.DebugString()) {
67     LOG(ERROR) << "Bad serialized: " << SerializedToReadable(serialized);
68   }
69 }
70 
71 // Fast parsing does not differentiate between EmptyExample and EmptyFeatures
72 // TEST(FastParse, EmptyExample) {
73 //   Example example;
74 //   TestCorrectness(example);
75 // }
76 
TEST(FastParse,IgnoresPrecedingUnknownTopLevelFields)77 TEST(FastParse, IgnoresPrecedingUnknownTopLevelFields) {
78   ExampleWithExtras example;
79   (*example.mutable_features()->mutable_feature())["age"]
80       .mutable_int64_list()
81       ->add_value(13);
82   example.set_extra1("some_str");
83   example.set_extra2(123);
84   example.set_extra3(234);
85   example.set_extra4(345);
86   example.set_extra5(4.56);
87   example.add_extra6(5.67);
88   example.add_extra6(6.78);
89   (*example.mutable_extra7()->mutable_feature())["extra7"]
90       .mutable_int64_list()
91       ->add_value(1337);
92 
93   Example context;
94   (*context.mutable_features()->mutable_feature())["zipcode"]
95       .mutable_int64_list()
96       ->add_value(94043);
97 
98   TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
99 }
100 
TEST(FastParse,IgnoresTrailingUnknownTopLevelFields)101 TEST(FastParse, IgnoresTrailingUnknownTopLevelFields) {
102   Example example;
103   (*example.mutable_features()->mutable_feature())["age"]
104       .mutable_int64_list()
105       ->add_value(13);
106 
107   ExampleWithExtras context;
108   (*context.mutable_features()->mutable_feature())["zipcode"]
109       .mutable_int64_list()
110       ->add_value(94043);
111   context.set_extra1("some_str");
112   context.set_extra2(123);
113   context.set_extra3(234);
114   context.set_extra4(345);
115   context.set_extra5(4.56);
116   context.add_extra6(5.67);
117   context.add_extra6(6.78);
118   (*context.mutable_extra7()->mutable_feature())["extra7"]
119       .mutable_int64_list()
120       ->add_value(1337);
121 
122   TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
123 }
124 
TEST(FastParse,SingleInt64WithContext)125 TEST(FastParse, SingleInt64WithContext) {
126   Example example;
127   (*example.mutable_features()->mutable_feature())["age"]
128       .mutable_int64_list()
129       ->add_value(13);
130 
131   Example context;
132   (*context.mutable_features()->mutable_feature())["zipcode"]
133       .mutable_int64_list()
134       ->add_value(94043);
135 
136   TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
137 }
138 
TEST(FastParse,DenseInt64WithContext)139 TEST(FastParse, DenseInt64WithContext) {
140   Example example;
141   (*example.mutable_features()->mutable_feature())["age"]
142       .mutable_int64_list()
143       ->add_value(0);
144 
145   Example context;
146   (*context.mutable_features()->mutable_feature())["age"]
147       .mutable_int64_list()
148       ->add_value(15);
149 
150   string serialized = Serialize(example) + Serialize(context);
151 
152   {
153     Example deserialized;
154     EXPECT_TRUE(deserialized.ParseFromString(serialized));
155     EXPECT_EQ(deserialized.DebugString(), context.DebugString());
156     // Whoa! Last EQ is very surprising, but standard deserialization is what it
157     // is and Servo team requested to replicate this 'feature'.
158     // In future we should return error.
159   }
160   TestCorrectness(serialized);
161 }
162 
TEST(FastParse,NonPacked)163 TEST(FastParse, NonPacked) {
164   TestCorrectness(
165       "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d");
166 }
167 
TEST(FastParse,Packed)168 TEST(FastParse, Packed) {
169   TestCorrectness(
170       "\x0a\x0d\x0a\x0b\x0a\x03\x61\x67\x65\x12\x04\x1a\x02\x08\x0d");
171 }
172 
TEST(FastParse,EmptyFeatures)173 TEST(FastParse, EmptyFeatures) {
174   Example example;
175   example.mutable_features();
176   TestCorrectness(Serialize(example));
177 }
178 
TestCorrectnessJson(const string & json)179 void TestCorrectnessJson(const string& json) {
180   auto resolver = protobuf::util::NewTypeResolverForDescriptorPool(
181       "type.googleapis.com", protobuf::DescriptorPool::generated_pool());
182   string serialized;
183   auto s = protobuf::util::JsonToBinaryString(
184       resolver, "type.googleapis.com/tensorflow.Example", json, &serialized);
185   EXPECT_TRUE(s.ok()) << s;
186   delete resolver;
187   TestCorrectness(serialized);
188 }
189 
TEST(FastParse,JsonUnivalent)190 TEST(FastParse, JsonUnivalent) {
191   TestCorrectnessJson(
192       "{'features': {"
193       "  'feature': {'age': {'int64_list': {'value': [0]} }}, "
194       "  'feature': {'flo': {'float_list': {'value': [1.1]} }}, "
195       "  'feature': {'byt': {'bytes_list': {'value': ['WW8='] }}}"
196       "}}");
197 }
198 
TEST(FastParse,JsonMultivalent)199 TEST(FastParse, JsonMultivalent) {
200   TestCorrectnessJson(
201       "{'features': {"
202       "  'feature': {'age': {'int64_list': {'value': [0, 13, 23]} }}, "
203       "  'feature': {'flo': {'float_list': {'value': [1.1, 1.2, 1.3]} }}, "
204       "  'feature': {'byt': {'bytes_list': {'value': ['WW8=', 'WW8K'] }}}"
205       "}}");
206 }
207 
TEST(FastParse,SingleInt64)208 TEST(FastParse, SingleInt64) {
209   Example example;
210   (*example.mutable_features()->mutable_feature())["age"]
211       .mutable_int64_list()
212       ->add_value(13);
213   TestCorrectness(Serialize(example));
214 }
215 
ExampleWithSomeFeatures()216 static string ExampleWithSomeFeatures() {
217   Example example;
218 
219   (*example.mutable_features()->mutable_feature())[""];
220 
221   (*example.mutable_features()->mutable_feature())["empty_bytes_list"]
222       .mutable_bytes_list();
223   (*example.mutable_features()->mutable_feature())["empty_float_list"]
224       .mutable_float_list();
225   (*example.mutable_features()->mutable_feature())["empty_int64_list"]
226       .mutable_int64_list();
227 
228   BytesList* bytes_list =
229       (*example.mutable_features()->mutable_feature())["bytes_list"]
230           .mutable_bytes_list();
231   bytes_list->add_value("bytes1");
232   bytes_list->add_value("bytes2");
233 
234   FloatList* float_list =
235       (*example.mutable_features()->mutable_feature())["float_list"]
236           .mutable_float_list();
237   float_list->add_value(1.0);
238   float_list->add_value(2.0);
239 
240   Int64List* int64_list =
241       (*example.mutable_features()->mutable_feature())["int64_list"]
242           .mutable_int64_list();
243   int64_list->add_value(3);
244   int64_list->add_value(270);
245   int64_list->add_value(86942);
246 
247   return Serialize(example);
248 }
249 
TEST(FastParse,SomeFeatures)250 TEST(FastParse, SomeFeatures) { TestCorrectness(ExampleWithSomeFeatures()); }
251 
AddDenseFeature(const char * feature_name,DataType dtype,PartialTensorShape shape,bool variable_length,size_t elements_per_stride,FastParseExampleConfig * out_config)252 static void AddDenseFeature(const char* feature_name, DataType dtype,
253                             PartialTensorShape shape, bool variable_length,
254                             size_t elements_per_stride,
255                             FastParseExampleConfig* out_config) {
256   out_config->dense.emplace_back();
257   auto& new_feature = out_config->dense.back();
258   new_feature.feature_name = feature_name;
259   new_feature.dtype = dtype;
260   new_feature.shape = std::move(shape);
261   new_feature.default_value = Tensor(dtype, {});
262   new_feature.variable_length = variable_length;
263   new_feature.elements_per_stride = elements_per_stride;
264 }
265 
AddSparseFeature(const char * feature_name,DataType dtype,FastParseExampleConfig * out_config)266 static void AddSparseFeature(const char* feature_name, DataType dtype,
267                              FastParseExampleConfig* out_config) {
268   out_config->sparse.emplace_back();
269   auto& new_feature = out_config->sparse.back();
270   new_feature.feature_name = feature_name;
271   new_feature.dtype = dtype;
272 }
273 
TEST(FastParse,StatsCollection)274 TEST(FastParse, StatsCollection) {
275   const size_t kNumExamples = 13;
276   std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
277 
278   FastParseExampleConfig config_dense;
279   AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
280   AddDenseFeature("float_list", DT_FLOAT, {2}, false, 2, &config_dense);
281   AddDenseFeature("int64_list", DT_INT64, {3}, false, 3, &config_dense);
282   config_dense.collect_feature_stats = true;
283 
284   FastParseExampleConfig config_varlen;
285   AddDenseFeature("bytes_list", DT_STRING, {-1}, true, 1, &config_varlen);
286   AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_varlen);
287   AddDenseFeature("int64_list", DT_INT64, {-1}, true, 1, &config_varlen);
288   config_varlen.collect_feature_stats = true;
289 
290   FastParseExampleConfig config_sparse;
291   AddSparseFeature("bytes_list", DT_STRING, &config_sparse);
292   AddSparseFeature("float_list", DT_FLOAT, &config_sparse);
293   AddSparseFeature("int64_list", DT_INT64, &config_sparse);
294   config_sparse.collect_feature_stats = true;
295 
296   FastParseExampleConfig config_mixed;
297   AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_mixed);
298   AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_mixed);
299   AddSparseFeature("int64_list", DT_INT64, &config_mixed);
300   config_mixed.collect_feature_stats = true;
301 
302   for (const FastParseExampleConfig& config :
303        {config_dense, config_varlen, config_sparse, config_mixed}) {
304     {
305       Result result;
306       TF_CHECK_OK(FastParseExample(config, serialized, {}, nullptr, &result));
307       EXPECT_EQ(kNumExamples, result.feature_stats.size());
308       for (const PerExampleFeatureStats& stats : result.feature_stats) {
309         EXPECT_EQ(7, stats.features_count);
310         EXPECT_EQ(7, stats.feature_values_count);
311       }
312     }
313 
314     {
315       Result result;
316       TF_CHECK_OK(FastParseSingleExample(config, serialized[0], &result));
317       EXPECT_EQ(1, result.feature_stats.size());
318       EXPECT_EQ(7, result.feature_stats[0].features_count);
319       EXPECT_EQ(7, result.feature_stats[0].feature_values_count);
320     }
321   }
322 }
323 
RandStr(random::SimplePhilox * rng)324 string RandStr(random::SimplePhilox* rng) {
325   static const char key_char_lookup[] =
326       "0123456789{}~`!@#$%^&*()"
327       "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
328       "abcdefghijklmnopqrstuvwxyz";
329   auto len = 1 + rng->Rand32() % 200;
330   string str;
331   str.reserve(len);
332   while (len-- > 0) {
333     str.push_back(
334         key_char_lookup[rng->Rand32() % (sizeof(key_char_lookup) /
335                                          sizeof(key_char_lookup[0]))]);
336   }
337   return str;
338 }
339 
Fuzz(random::SimplePhilox * rng)340 void Fuzz(random::SimplePhilox* rng) {
341   // Generate keys.
342   auto num_keys = 1 + rng->Rand32() % 100;
343   std::unordered_set<string> unique_keys;
344   for (auto i = 0; i < num_keys; ++i) {
345     unique_keys.emplace(RandStr(rng));
346   }
347 
348   // Generate serialized example.
349   Example example;
350   string serialized_example;
351   auto num_concats = 1 + rng->Rand32() % 4;
352   std::vector<Feature::KindCase> feat_types(
353       {Feature::kBytesList, Feature::kFloatList, Feature::kInt64List});
354   std::vector<string> all_keys(unique_keys.begin(), unique_keys.end());
355   while (num_concats--) {
356     example.Clear();
357     auto num_active_keys = 1 + rng->Rand32() % all_keys.size();
358 
359     // Generate features.
360     for (auto i = 0; i < num_active_keys; ++i) {
361       auto fkey = all_keys[rng->Rand32() % all_keys.size()];
362       auto ftype_idx = rng->Rand32() % feat_types.size();
363       auto num_features = 1 + rng->Rand32() % 5;
364       switch (static_cast<Feature::KindCase>(feat_types[ftype_idx])) {
365         case Feature::kBytesList: {
366           BytesList* bytes_list =
367               (*example.mutable_features()->mutable_feature())[fkey]
368                   .mutable_bytes_list();
369           while (num_features--) {
370             bytes_list->add_value(RandStr(rng));
371           }
372           break;
373         }
374         case Feature::kFloatList: {
375           FloatList* float_list =
376               (*example.mutable_features()->mutable_feature())[fkey]
377                   .mutable_float_list();
378           while (num_features--) {
379             float_list->add_value(rng->RandFloat());
380           }
381           break;
382         }
383         case Feature::kInt64List: {
384           Int64List* int64_list =
385               (*example.mutable_features()->mutable_feature())[fkey]
386                   .mutable_int64_list();
387           while (num_features--) {
388             int64_list->add_value(rng->Rand64());
389           }
390           break;
391         }
392         default: {
393           LOG(QFATAL);
394           break;
395         }
396       }
397     }
398     serialized_example += example.SerializeAsString();
399   }
400 
401   // Test correctness.
402   TestCorrectness(serialized_example);
403 }
404 
TEST(FastParse,FuzzTest)405 TEST(FastParse, FuzzTest) {
406   const uint64 seed = 1337;
407   random::PhiloxRandom philox(seed);
408   random::SimplePhilox rng(&philox);
409   auto num_runs = 200;
410   while (num_runs--) {
411     LOG(INFO) << "runs left: " << num_runs;
412     Fuzz(&rng);
413   }
414 }
415 
TEST(TestFastParseExample,Empty)416 TEST(TestFastParseExample, Empty) {
417   Result result;
418   FastParseExampleConfig config;
419   config.sparse.push_back({"test", DT_STRING});
420   Status status =
421       FastParseExample(config, gtl::ArraySlice<tstring>(),
422                        gtl::ArraySlice<tstring>(), nullptr, &result);
423   EXPECT_TRUE(status.ok()) << status;
424 }
425 
426 }  // namespace
427 }  // namespace example
428 }  // namespace tensorflow
429