• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16 
17 #include <vector>
18 
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb_text.h"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/lib/core/blocking_counter.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/core/lib/monitoring/counter.h"
31 #include "tensorflow/core/platform/byte_order.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/util/presized_cuckoo_map.h"
35 #include "tensorflow/core/util/sparse/sparse_tensor.h"
36 
37 namespace tensorflow {
38 namespace example {
39 
40 namespace {
41 
42 template <typename T>
43 using SmallVector = gtl::InlinedVector<T, 4>;
44 
45 template <typename A>
EnableAliasing(A * a)46 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
47   a->EnableAliasing(true);
48 }
49 
50 template <typename A>
EnableAliasing(A && a)51 void EnableAliasing(A&& a) {}
52 
PeekTag(protobuf::io::CodedInputStream * stream)53 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
54   DCHECK(stream != nullptr);
55   const void* ptr;
56   int size;
57   if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
58   return *static_cast<const uint8*>(ptr);
59 }
60 
kVarintTag(uint32 tag)61 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)62 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)63 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
64 
65 namespace parsed {
66 
67 // ParseDataType has to be called first, then appropriate ParseZzzzList.
68 class Feature {
69  public:
Feature()70   Feature() {}
Feature(StringPiece serialized)71   explicit Feature(StringPiece serialized) : serialized_(serialized) {}
72 
ParseDataType(DataType * dtype)73   Status ParseDataType(DataType* dtype) {
74     DCHECK(dtype != nullptr);
75     if (serialized_.empty()) {
76       *dtype = DT_INVALID;
77       return Status::OK();
78     }
79     uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
80     serialized_.remove_prefix(1);
81     switch (oneof_tag) {
82       case kDelimitedTag(1):
83         *dtype = DT_STRING;
84         break;
85       case kDelimitedTag(2):
86         *dtype = DT_FLOAT;
87         break;
88       case kDelimitedTag(3):
89         *dtype = DT_INT64;
90         break;
91       default:
92         // Initialize variable to avoid compiler warning
93         *dtype = DT_INVALID;
94         return errors::InvalidArgument("Unsupported datatype.");
95     }
96     return Status::OK();
97   }
98 
GetNumElementsInBytesList(int * num_elements)99   bool GetNumElementsInBytesList(int* num_elements) {
100     protobuf::io::CodedInputStream stream(
101         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
102     EnableAliasing(&stream);
103     uint32 length = 0;
104     if (!stream.ReadVarint32(&length)) return false;
105     auto limit = stream.PushLimit(length);
106     *num_elements = 0;
107     while (!stream.ExpectAtEnd()) {
108       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
109       uint32 bytes_length = 0;
110       if (!stream.ReadVarint32(&bytes_length)) return false;
111       if (!stream.Skip(bytes_length)) return false;
112       ++*num_elements;
113     }
114     stream.PopLimit(limit);
115     return true;
116   }
117 
118   template <typename Result>
ParseBytesList(Result * bytes_list)119   bool ParseBytesList(Result* bytes_list) {
120     DCHECK(bytes_list != nullptr);
121 
122     protobuf::io::CodedInputStream stream(
123         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
124 
125     EnableAliasing(&stream);
126 
127     uint32 length;
128     if (!stream.ReadVarint32(&length)) return false;
129     auto limit = stream.PushLimit(length);
130 
131     while (!stream.ExpectAtEnd()) {
132       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
133       // parse string
134       uint32 bytes_length;
135       if (!stream.ReadVarint32(&bytes_length)) return false;
136       string bytes;
137       if (!stream.ReadString(&bytes, bytes_length)) return false;
138       bytes_list->push_back(std::move(bytes));
139     }
140     stream.PopLimit(limit);
141     return true;
142   }
143 
144   template <typename Result>
ParseFloatList(Result * float_list)145   bool ParseFloatList(Result* float_list) {
146     DCHECK(float_list != nullptr);
147     protobuf::io::CodedInputStream stream(
148         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
149     EnableAliasing(&stream);
150     uint32 length;
151     if (!stream.ReadVarint32(&length)) return false;
152     auto limit = stream.PushLimit(length);
153 
154     if (!stream.ExpectAtEnd()) {
155       uint8 peek_tag = PeekTag(&stream);
156       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
157         return false;
158       }
159 
160       if (peek_tag == kDelimitedTag(1)) {                       // packed
161         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
162         uint32 packed_length;
163         if (!stream.ReadVarint32(&packed_length)) return false;
164         auto packed_limit = stream.PushLimit(packed_length);
165 
166         // If the result data type is float and we are on a little endian
167         // machine then we can simply memcpy the data from the proto into the
168         // result vector.
169         constexpr int32 kNumFloatBytes = 4;
170         if (port::kLittleEndian &&
171             sizeof(typename Result::value_type) == kNumFloatBytes) {
172           // Store the initial size to know the offset we have to start writing
173           // data from before resizing the output "vector".
174           const size_t initial_size = float_list->size();
175           float_list->resize(initial_size + packed_length / kNumFloatBytes);
176           // Calculate the length of the buffer available what can be less than
177           // what we requested in resize in case of a LimitedArraySlice.
178           const uint32 bytes_to_copy =
179               std::min(static_cast<uint32>((float_list->size() - initial_size) *
180                                            kNumFloatBytes),
181                        packed_length);
182           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
183             return false;
184         } else {
185           while (!stream.ExpectAtEnd()) {
186             uint32 buffer32;
187             if (!stream.ReadLittleEndian32(&buffer32)) return false;
188             float_list->push_back(absl::bit_cast<float>(buffer32));
189           }
190         }
191 
192         stream.PopLimit(packed_limit);
193       } else {  // non-packed
194         while (!stream.ExpectAtEnd()) {
195           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
196           uint32 buffer32;
197           if (!stream.ReadLittleEndian32(&buffer32)) return false;
198           float_list->push_back(absl::bit_cast<float>(buffer32));
199         }
200       }
201     }
202 
203     stream.PopLimit(limit);
204     return true;
205   }
206 
207   template <typename Result>
ParseInt64List(Result * int64_list)208   bool ParseInt64List(Result* int64_list) {
209     DCHECK(int64_list != nullptr);
210     protobuf::io::CodedInputStream stream(
211         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212     EnableAliasing(&stream);
213     uint32 length;
214     if (!stream.ReadVarint32(&length)) return false;
215     auto limit = stream.PushLimit(length);
216 
217     if (!stream.ExpectAtEnd()) {
218       uint8 peek_tag = PeekTag(&stream);
219       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
220         return false;
221       }
222       if (peek_tag == kDelimitedTag(1)) {                       // packed
223         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
224         uint32 packed_length;
225         if (!stream.ReadVarint32(&packed_length)) return false;
226         auto packed_limit = stream.PushLimit(packed_length);
227 
228         while (!stream.ExpectAtEnd()) {
229           protobuf_uint64 n;  // There is no API for int64
230           if (!stream.ReadVarint64(&n)) return false;
231           int64_list->push_back(static_cast<int64>(n));
232         }
233 
234         stream.PopLimit(packed_limit);
235       } else {  // non-packed
236         while (!stream.ExpectAtEnd()) {
237           if (!stream.ExpectTag(kVarintTag(1))) return false;
238           protobuf_uint64 n;  // There is no API for int64
239           if (!stream.ReadVarint64(&n)) return false;
240           int64_list->push_back(static_cast<int64>(n));
241         }
242       }
243     }
244     stream.PopLimit(limit);
245     return true;
246   }
247 
GetSerialized() const248   StringPiece GetSerialized() const { return serialized_; }
249 
250  private:
251   // TODO(lew): Pair of uint8* would be more natural.
252   StringPiece serialized_;
253 };
254 
255 using FeatureMapEntry = std::pair<StringPiece, Feature>;
256 using Example = std::vector<FeatureMapEntry>;
257 
258 }  // namespace parsed
259 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)260 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
261   uint32 data;
262   protobuf_uint64 dummy;
263   switch (stream->ReadTag() & 0x7) {
264     case 0:  // varint
265       if (!stream->ReadVarint32(&data)) return false;
266       return true;
267     case 1:  // fixed64
268       if (!stream->ReadLittleEndian64(&dummy)) return false;
269       return true;
270     case 2:  // length delimited
271       if (!stream->ReadVarint32(&data)) return false;
272       stream->Skip(data);
273       return true;
274     case 3:          // group begin
275       return false;  // groups not supported.
276     case 4:          // group end
277       return false;  // groups not supported.
278     case 5:          // fixed32
279       if (!stream->ReadLittleEndian32(&data)) return false;
280       return true;
281   }
282   return false;  // unrecognized tag type
283 }
284 
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)285 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
286   DCHECK(stream != nullptr);
287   DCHECK(result != nullptr);
288   uint32 length;
289   if (!stream->ReadVarint32(&length)) return false;
290   if (length == 0) {
291     *result = StringPiece(nullptr, 0);
292     return true;
293   }
294   const void* stream_alias;
295   int stream_size;
296   if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
297     return false;
298   }
299   if (static_cast<uint32>(stream_size) < length) return false;
300   *result = StringPiece(static_cast<const char*>(stream_alias), length);
301   stream->Skip(length);
302   return true;
303 }
304 
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)305 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
306                           parsed::FeatureMapEntry* feature_map_entry) {
307   DCHECK(stream != nullptr);
308   DCHECK(feature_map_entry != nullptr);
309   uint32 length;
310   if (!stream->ReadVarint32(&length)) return false;
311   auto limit = stream->PushLimit(length);
312   if (!stream->ExpectTag(kDelimitedTag(1))) return false;
313   if (!ParseString(stream, &feature_map_entry->first)) return false;
314   if (!stream->ExpectTag(kDelimitedTag(2))) return false;
315   StringPiece feature_string_piece;
316   if (!ParseString(stream, &feature_string_piece)) return false;
317   feature_map_entry->second = parsed::Feature(feature_string_piece);
318   if (!stream->ExpectAtEnd()) return false;
319   stream->PopLimit(limit);
320   return true;
321 }
322 
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)323 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
324                    parsed::Example* example) {
325   DCHECK(stream != nullptr);
326   DCHECK(example != nullptr);
327   uint32 length;
328   if (!stream->ReadVarint32(&length)) return false;
329   auto limit = stream->PushLimit(length);
330   while (!stream->ExpectAtEnd()) {
331     parsed::FeatureMapEntry feature_map_entry;
332     if (!stream->ExpectTag(kDelimitedTag(1))) return false;
333     if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
334     example->push_back(std::move(feature_map_entry));
335   }
336   stream->PopLimit(limit);
337   return true;
338 }
339 
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)340 bool ParseExample(protobuf::io::CodedInputStream* stream,
341                   parsed::Example* example) {
342   DCHECK(stream != nullptr);
343   DCHECK(example != nullptr);
344   // Loop over the input stream which may contain multiple serialized Example
345   // protos merged together as strings. This behavior is consistent with Proto's
346   // ParseFromString when string representations are concatenated.
347   while (!stream->ExpectAtEnd()) {
348     if (!stream->ExpectTag(kDelimitedTag(1))) {
349       if (!SkipExtraneousTag(stream)) return false;
350     } else {
351       if (!ParseFeatures(stream, example)) return false;
352     }
353   }
354   return true;
355 }
356 
ParseExample(StringPiece serialized,parsed::Example * example)357 bool ParseExample(StringPiece serialized, parsed::Example* example) {
358   DCHECK(example != nullptr);
359   protobuf::io::CodedInputStream stream(
360       reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
361   EnableAliasing(&stream);
362   return ParseExample(&stream, example);
363 }
364 
365 }  // namespace
366 
TestFastParse(const string & serialized,Example * example)367 bool TestFastParse(const string& serialized, Example* example) {
368   DCHECK(example != nullptr);
369   parsed::Example parsed_example;
370   if (!ParseExample(serialized, &parsed_example)) return false;
371   auto& features = *example->mutable_features();
372   size_t parsed_example_size = parsed_example.size();
373   for (size_t i = 0; i < parsed_example_size; ++i) {
374     // This is a logic that standard protobuf parsing is implementing.
375     // I.e. last entry in the map overwrites all the previous ones.
376     parsed::FeatureMapEntry& name_and_feature =
377         parsed_example[parsed_example_size - i - 1];
378     string name(name_and_feature.first);
379     if ((*features.mutable_feature()).count(name) > 0) continue;
380 
381     auto& value = (*features.mutable_feature())[name];
382     DataType dtype;
383     if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
384     switch (dtype) {
385       case DT_INVALID:
386         break;
387       case DT_STRING: {
388         SmallVector<string> list;
389         if (!name_and_feature.second.ParseBytesList(&list)) return false;
390         auto* result_list = value.mutable_bytes_list();
391         for (auto& bytes : list) {
392           auto* new_value = result_list->add_value();
393           new_value->swap(bytes);
394         }
395         break;
396       }
397       case DT_FLOAT: {
398         SmallVector<float> list;
399         if (!name_and_feature.second.ParseFloatList(&list)) return false;
400         auto* result_list = value.mutable_float_list();
401         for (float f : list) {
402           result_list->add_value(f);
403         }
404         break;
405       }
406       case DT_INT64: {
407         SmallVector<int64> list;
408         if (!name_and_feature.second.ParseInt64List(&list)) return false;
409         auto* result_list = value.mutable_int64_list();
410         for (int64 i : list) {
411           result_list->add_value(i);
412         }
413         break;
414       }
415       default:
416         LOG(FATAL) << "Should not happen.";
417     }
418   }
419   return true;
420 }
421 
422 // -----------------------------------------------------------------------------
423 
424 namespace {
425 
426 using Config = FastParseExampleConfig;
427 
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)428 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
429                  thread::ThreadPool* thread_pool) {
430   if (n == 0) return;
431   if (thread_pool == nullptr) {
432     for (size_t i = 0; i < n; ++i) {
433       f(i);
434     }
435   } else {
436     BlockingCounter counter(n - 1);
437     for (size_t i = 1; i < n; ++i) {
438       thread_pool->Schedule([i, &f, &counter] {
439         f(i);
440         counter.DecrementCount();
441       });
442     }
443     f(0);
444     counter.Wait();
445   }
446 }
447 
448 enum class Type { Sparse, Dense };
449 
450 struct SparseBuffer {
451   // Features are in one of the 3 vectors below depending on config's dtype.
452   // Other 2 vectors remain empty.
453   SmallVector<string> bytes_list;
454   SmallVector<float> float_list;
455   SmallVector<int64> int64_list;
456 
457   // Features of example i are elements with indices
458   // from example_end_indices[i-1] to example_end_indices[i]-1 on the
459   // appropriate xxxxx_list
460   std::vector<size_t> example_end_indices;
461 };
462 
463 struct SeededHasher {
operator ()tensorflow::example::__anoncf92a8560211::SeededHasher464   uint64 operator()(StringPiece s) const {
465     return Hash64(s.data(), s.size(), seed);
466   }
467   uint64 seed{0xDECAFCAFFE};
468 };
469 
470 template <typename T>
471 class LimitedArraySlice {
472  public:
473   using value_type = T;
474 
LimitedArraySlice(T * begin,size_t num_elements)475   LimitedArraySlice(T* begin, size_t num_elements)
476       : current_(begin), begin_(begin), end_(begin + num_elements) {}
477 
478   // May return negative if there were push_back calls after slice was filled.
EndDistance() const479   int64 EndDistance() const { return end_ - current_; }
480 
481   // Attempts to push value to the back of this. If the slice has
482   // already been filled, this method has no effect on the underlying data, but
483   // it changes the number returned by EndDistance into negative values.
push_back(T && value)484   void push_back(T&& value) {
485     if (EndDistance() > 0) *current_ = std::move(value);
486     ++current_;
487   }
488 
489   // Returns the number of elements in the slice.
size() const490   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
491 
492   // Attempts to resize the vector to the given size. It does so by advancing
493   // the pointer to the current element, possibly beyond the end of the slice.
494   // As a consequence, calling `size()` after `resize(x)` was called might
495   // return a value less than `x`.
resize(size_t size)496   void resize(size_t size) { current_ = begin_ + size; }
497 
498   // Returns the pointer to the underlying data buffer.
data()499   T* data() { return begin_; }
500 
501  private:
502   T* current_;
503   T* begin_;
504   T* end_;
505 };
506 
LogDenseFeatureDataLoss(StringPiece feature_name)507 void LogDenseFeatureDataLoss(StringPiece feature_name) {
508   LOG(WARNING) << "Data loss! Feature '" << feature_name
509                << "' is present in multiple concatenated "
510                   "tf.Examples. Ignoring all but last one.";
511   static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
512       "/tensorflow/core/util/example_proto_fast_parsing/"
513       "duplicated_dense_feature",
514       "Dense feature appears twice in a tf.Example");
515   duplicated_dense_feature->GetCell()->IncrementBy(1);
516 }
517 
LogSparseFeatureDataLoss(StringPiece feature_name)518 void LogSparseFeatureDataLoss(StringPiece feature_name) {
519   LOG(WARNING) << "Data loss! Feature '" << feature_name
520                << "' is present in multiple concatenated "
521                   "tf.Examples. Ignoring all but last one.";
522   static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
523       "/tensorflow/core/util/example_proto_fast_parsing/"
524       "duplicated_sparse_feature",
525       "Sparse feature appears twice in a tf.Example");
526   duplicated_sparse_feature->GetCell()->IncrementBy(1);
527 }
528 
FastParseSerializedExample(const string & serialized_example,const string & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,PerExampleFeatureStats * output_stats)529 Status FastParseSerializedExample(
530     const string& serialized_example, const string& example_name,
531     const size_t example_index, const Config& config,
532     const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
533     SeededHasher hasher, std::vector<Tensor>* output_dense,
534     std::vector<SparseBuffer>* output_varlen_dense,
535     std::vector<SparseBuffer>* output_sparse,
536     PerExampleFeatureStats* output_stats) {
537   DCHECK(output_dense != nullptr);
538   DCHECK(output_sparse != nullptr);
539   parsed::Example parsed_example;
540   if (!ParseExample(serialized_example, &parsed_example)) {
541     return errors::InvalidArgument("Could not parse example input, value: '",
542                                    serialized_example, "'");
543   }
544   std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1);
545   std::vector<int64> dense_feature_last_example(config.dense.size(), -1);
546 
547   // Handle features present in the example.
548   const size_t parsed_example_size = parsed_example.size();
549 
550   if (output_stats) {
551     // TODO(b/111553342): This may over-count the number of features if there
552     // are duplicate keys in the feature map. Consider deduplicating the keys
553     // before computing the count.
554     output_stats->features_count = parsed_example_size;
555   }
556 
557   for (size_t i = 0; i < parsed_example_size; ++i) {
558     // This is a logic that standard protobuf parsing is implementing.
559     // I.e. last entry in the map overwrites all the previous ones.
560     parsed::FeatureMapEntry& name_and_feature =
561         parsed_example[parsed_example_size - i - 1];
562 
563     const StringPiece feature_name = name_and_feature.first;
564     parsed::Feature& feature = name_and_feature.second;
565 
566     std::pair<size_t, Type> d_and_type;
567     uint64 h = hasher(feature_name);
568     if (!config_index.Find(h, &d_and_type)) continue;
569 
570     size_t d = d_and_type.first;
571     bool is_dense = d_and_type.second == Type::Dense;
572 
573     {
574       // Testing for PresizedCuckooMap collision.
575       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
576       const string& config_feature_name = is_dense
577                                               ? config.dense[d].feature_name
578                                               : config.sparse[d].feature_name;
579       if (feature_name != config_feature_name) continue;
580     }
581 
582     auto example_error = [&](StringPiece suffix) {
583       return errors::InvalidArgument("Name: ", example_name,
584                                      ", Key: ", feature_name,
585                                      ", Index: ", example_index, ".  ", suffix);
586     };
587 
588     auto parse_error = [&] {
589       return example_error("Can't parse serialized Example.");
590     };
591 
592     DataType example_dtype;
593     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
594 
595     if (is_dense) {
596       if (example_dtype == DT_INVALID) continue;
597 
598       // If feature was already visited, skip.
599       // Compare comment at the beginning of the loop.
600       if (dense_feature_last_example[d] == example_index) {
601         LogDenseFeatureDataLoss(feature_name);
602         continue;
603       }
604       dense_feature_last_example[d] = example_index;
605 
606       if (example_dtype != config.dense[d].dtype) {
607         return example_error(strings::StrCat(
608             "Data types don't match. Data type: ",
609             DataTypeString(example_dtype),
610             " but expected type: ", DataTypeString(config.dense[d].dtype)));
611       }
612       if (!config.dense[d].variable_length) {
613         Tensor& out = (*output_dense)[d];
614 
615         const std::size_t num_elements = config.dense[d].elements_per_stride;
616         if (output_stats) {
617           // TODO(b/111553342): If desirable, we could add support for counting
618           // elements in the features that aren't parsed, but this could add
619           // considerable runtime cost.
620           output_stats->feature_values_count += num_elements;
621         }
622 
623         const std::size_t offset = example_index * num_elements;
624 
625         auto shape_error = [&](size_t size, StringPiece type_str) {
626           return example_error(strings::StrCat(
627               "Number of ", type_str,
628               " values != expected.  "
629               "Values size: ",
630               size,
631               " but output shape: ", config.dense[d].shape.DebugString()));
632         };
633 
634         switch (config.dense[d].dtype) {
635           case DT_INT64: {
636             auto out_p = out.flat<int64>().data() + offset;
637             LimitedArraySlice<int64> slice(out_p, num_elements);
638             if (!feature.ParseInt64List(&slice)) return parse_error();
639             if (slice.EndDistance() != 0) {
640               return shape_error(num_elements - slice.EndDistance(), "int64");
641             }
642             break;
643           }
644           case DT_FLOAT: {
645             auto out_p = out.flat<float>().data() + offset;
646             LimitedArraySlice<float> slice(out_p, num_elements);
647             if (!feature.ParseFloatList(&slice)) return parse_error();
648             if (slice.EndDistance() != 0) {
649               return shape_error(num_elements - slice.EndDistance(), "float");
650             }
651             break;
652           }
653           case DT_STRING: {
654             auto out_p = out.flat<string>().data() + offset;
655             LimitedArraySlice<string> slice(out_p, num_elements);
656             if (!feature.ParseBytesList(&slice)) return parse_error();
657             if (slice.EndDistance() != 0) {
658               return shape_error(num_elements - slice.EndDistance(), "bytes");
659             }
660             break;
661           }
662           default:
663             LOG(FATAL) << "Should not happen.";
664         }
665       } else {  // if variable length
666         SparseBuffer& out = (*output_varlen_dense)[d];
667 
668         const std::size_t num_elements = config.dense[d].elements_per_stride;
669 
670         if (example_dtype != DT_INVALID &&
671             example_dtype != config.dense[d].dtype) {
672           return example_error(strings::StrCat(
673               "Data types don't match. ",
674               "Expected type: ", DataTypeString(config.dense[d].dtype)));
675         }
676 
677         auto shape_error = [&](size_t size, StringPiece type_str) {
678           return example_error(strings::StrCat(
679               "Number of ", type_str,
680               " values is not a multiple of stride length. Saw ", size,
681               " values but output shape is: ",
682               config.dense[d].shape.DebugString()));
683         };
684 
685         switch (config.dense[d].dtype) {
686           case DT_INT64: {
687             if (example_dtype != DT_INVALID) {
688               if (!feature.ParseInt64List(&out.int64_list)) {
689                 return parse_error();
690               }
691               if (out.int64_list.size() % num_elements != 0) {
692                 return shape_error(out.int64_list.size(), "int64");
693               }
694             }
695             out.example_end_indices.push_back(out.int64_list.size());
696             break;
697           }
698           case DT_FLOAT: {
699             if (example_dtype != DT_INVALID) {
700               if (!feature.ParseFloatList(&out.float_list)) {
701                 return parse_error();
702               }
703               if (out.float_list.size() % num_elements != 0) {
704                 return shape_error(out.float_list.size(), "float");
705               }
706             }
707             out.example_end_indices.push_back(out.float_list.size());
708             break;
709           }
710           case DT_STRING: {
711             if (example_dtype != DT_INVALID) {
712               if (!feature.ParseBytesList(&out.bytes_list)) {
713                 return parse_error();
714               }
715               if (out.bytes_list.size() % num_elements != 0) {
716                 return shape_error(out.bytes_list.size(), "bytes");
717               }
718             }
719             out.example_end_indices.push_back(out.bytes_list.size());
720             break;
721           }
722           default:
723             LOG(FATAL) << "Should not happen.";
724         }
725 
726         if (output_stats) {
727           // Use `out.example_end_indices` to determine the feature-value count
728           // for this feature, because the preceding switch statement pushes
729           // the length of the appropriate feature list to that vector.
730           // TODO(b/111553342): If desirable, we could add support for counting
731           // elements in the features that aren't parsed, but this could add
732           // considerable runtime cost.
733           const size_t out_examples_count = out.example_end_indices.size();
734           if (out_examples_count == 1) {
735             output_stats->feature_values_count += out.example_end_indices[0];
736           } else {
737             output_stats->feature_values_count +=
738                 out.example_end_indices[out_examples_count - 1] -
739                 out.example_end_indices[out_examples_count - 2];
740           }
741         }
742       }
743     } else {
744       // If feature was already visited, skip.
745       // Compare comment at the beginning of the loop.
746       if (sparse_feature_last_example[d] == example_index) {
747         LogSparseFeatureDataLoss(feature_name);
748         continue;
749       }
750       sparse_feature_last_example[d] = example_index;
751 
752       // Handle sparse features.
753       SparseBuffer& out = (*output_sparse)[d];
754       if (example_dtype != DT_INVALID &&
755           example_dtype != config.sparse[d].dtype) {
756         return example_error(strings::StrCat(
757             "Data types don't match. ",
758             "Expected type: ", DataTypeString(config.sparse[d].dtype),
759             ", Actual type: ", DataTypeString(example_dtype)));
760       }
761 
762       switch (config.sparse[d].dtype) {
763         case DT_INT64: {
764           if (example_dtype != DT_INVALID) {
765             if (!feature.ParseInt64List(&out.int64_list)) {
766               return parse_error();
767             }
768           }
769           out.example_end_indices.push_back(out.int64_list.size());
770           break;
771         }
772         case DT_FLOAT: {
773           if (example_dtype != DT_INVALID) {
774             if (!feature.ParseFloatList(&out.float_list)) {
775               return parse_error();
776             }
777           }
778           out.example_end_indices.push_back(out.float_list.size());
779           break;
780         }
781         case DT_STRING: {
782           if (example_dtype != DT_INVALID) {
783             if (!feature.ParseBytesList(&out.bytes_list)) {
784               return parse_error();
785             }
786           }
787           out.example_end_indices.push_back(out.bytes_list.size());
788           break;
789         }
790         default:
791           LOG(FATAL) << "Should not happen.";
792       }
793 
794       if (output_stats) {
795         // Use `out.example_end_indices` to determine the feature-value count
796         // for this feature, because the preceding switch statement pushes
797         // the length of the appropriate feature list to that vector.
798         // TODO(b/111553342): If desirable, we could add support for counting
799         // elements in the features that aren't parsed, but this could add
800         // considerable runtime cost.
801         const size_t out_examples_count = out.example_end_indices.size();
802         if (out_examples_count == 1) {
803           output_stats->feature_values_count += out.example_end_indices[0];
804         } else {
805           output_stats->feature_values_count +=
806               out.example_end_indices[out_examples_count - 1] -
807               out.example_end_indices[out_examples_count - 2];
808         }
809       }
810     }
811   }
812 
813   // Handle missing dense features for fixed strides.
814   for (size_t d = 0; d < config.dense.size(); ++d) {
815     if (config.dense[d].variable_length) continue;
816     if (dense_feature_last_example[d] == example_index) continue;
817     if (config.dense[d].default_value.NumElements() == 0) {
818       return errors::InvalidArgument(
819           "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
820           " (data type: ", DataTypeString(config.dense[d].dtype), ")",
821           " is required but could not be found.");
822     }
823     const Tensor& in = config.dense[d].default_value;
824     Tensor& out = (*output_dense)[d];
825     const std::size_t num_elements = in.shape().num_elements();
826     const std::size_t offset = example_index * num_elements;
827 
828     switch (config.dense[d].dtype) {
829       case DT_INT64: {
830         std::copy_n(in.flat<int64>().data(), num_elements,
831                     out.flat<int64>().data() + offset);
832         break;
833       }
834       case DT_FLOAT: {
835         std::copy_n(in.flat<float>().data(), num_elements,
836                     out.flat<float>().data() + offset);
837         break;
838       }
839       case DT_STRING: {
840         std::copy_n(in.flat<string>().data(), num_elements,
841                     out.flat<string>().data() + offset);
842         break;
843       }
844       default:
845         LOG(FATAL) << "Should not happen.";
846     }
847   }
848 
849   // Handle missing varlen dense features.
850   for (size_t d = 0; d < config.dense.size(); ++d) {
851     if (!config.dense[d].variable_length) continue;
852     if (dense_feature_last_example[d] == example_index) continue;
853     SparseBuffer& out = (*output_varlen_dense)[d];
854     size_t prev_example_end_index =
855         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
856     out.example_end_indices.push_back(prev_example_end_index);
857   }
858 
859   // Handle missing sparse features.
860   for (size_t d = 0; d < config.sparse.size(); ++d) {
861     if (sparse_feature_last_example[d] == example_index) continue;
862     SparseBuffer& out = (*output_sparse)[d];
863     size_t prev_example_end_index =
864         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
865     out.example_end_indices.push_back(prev_example_end_index);
866   }
867 
868   return Status::OK();
869 }
870 
CheckConfigDataType(DataType dtype)871 Status CheckConfigDataType(DataType dtype) {
872   switch (dtype) {
873     case DT_INT64:
874     case DT_FLOAT:
875     case DT_STRING:
876       return Status::OK();
877     default:
878       return errors::InvalidArgument("Invalid config dtype: ",
879                                      DataTypeString(dtype));
880   }
881 }
882 
883 template <typename T>
884 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
885 
886 template <>
GetListFromBuffer(const SparseBuffer & buffer)887 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
888   return buffer.int64_list;
889 }
890 template <>
GetListFromBuffer(const SparseBuffer & buffer)891 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
892   return buffer.float_list;
893 }
894 template <>
GetListFromBuffer(const SparseBuffer & buffer)895 const SmallVector<string>& GetListFromBuffer<string>(
896     const SparseBuffer& buffer) {
897   return buffer.bytes_list;
898 }
899 
900 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)901 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
902   std::copy(b, e, t);
903 }
904 template <>
CopyOrMoveBlock(const string * b,const string * e,string * t)905 void CopyOrMoveBlock(const string* b, const string* e, string* t) {
906   std::move(b, e, t);
907 }
908 
909 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)910 void FillAndCopyVarLen(
911     const int d, const size_t num_elements,
912     const size_t num_elements_per_minibatch, const Config& config,
913     const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
914     Tensor* values) {
915   const Tensor& default_value = config.dense[d].default_value;
916 
917   // Copy-fill the tensors (creating the zero/fill-padding)
918   std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
919             default_value.flat<T>()(0));
920 
921   // Data is [batch_size, max_num_elements, data_stride_size]
922   //   and num_elements_per_minibatch = max_num_elements * data_stride_size
923   auto data = values->flat<T>().data();
924 
925   // Iterate over minibatch elements
926   for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
927     const SparseBuffer& buffer = varlen_dense_buffers[i][d];
928     // Number of examples being stored in this buffer
929     const auto& end_indices = buffer.example_end_indices;
930     const size_t examples_in_buffer = end_indices.size();
931     // const size_t stride_size = config.dense[d].elements_per_stride;
932 
933     const auto& list = GetListFromBuffer<T>(buffer);
934     auto list_ptr = list.begin();
935 
936     size_t elements_tally = 0;
937     // Iterate through all the examples stored in this buffer.
938     for (size_t j = 0; j < examples_in_buffer; ++j) {
939       // Number of elements stored for this example.
940       const size_t num_elems = end_indices[j] - elements_tally;
941       CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
942       // Move forward this many elements in the varlen buffer.
943       list_ptr += num_elems;
944       // Move forward to the next minibatch entry in the values output.
945       data += num_elements_per_minibatch;
946       elements_tally = end_indices[j];
947     }
948     DCHECK(elements_tally == list.size());
949   }
950 }
951 
952 }  // namespace
953 
FastParseExample(const Config & config,gtl::ArraySlice<string> serialized,gtl::ArraySlice<string> example_names,thread::ThreadPool * thread_pool,Result * result)954 Status FastParseExample(const Config& config,
955                         gtl::ArraySlice<string> serialized,
956                         gtl::ArraySlice<string> example_names,
957                         thread::ThreadPool* thread_pool, Result* result) {
958   DCHECK(result != nullptr);
959   // Check config so we can safely CHECK(false) in switches on config.*.dtype
960   for (auto& c : config.sparse) {
961     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
962   }
963   for (auto& c : config.dense) {
964     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
965   }
966 
967   if (config.collect_feature_stats) {
968     result->feature_stats.resize(serialized.size());
969   }
970 
971   size_t config_size = config.dense.size() + config.sparse.size();
972   SeededHasher hasher;
973   // Build config index.
974   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
975   bool ok = true;
976   for (size_t i = 0; i < 1000; ++i) {
977     for (size_t d = 0; d < config.dense.size(); ++d) {
978       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
979                                       {d, Type::Dense});
980     }
981     for (size_t d = 0; d < config.sparse.size(); ++d) {
982       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
983                                       {d, Type::Sparse});
984     }
985     if (ok) break;
986     LOG(WARNING) << "Collision found. This should happen only if you have "
987                     "around 2^32 entries in your config.";
988     hasher.seed++;
989     config_index.Clear(config_size);
990   }
991   if (!ok) {
992     return errors::Internal(
993         "Could not avoid collision. This should not happen.");
994   }
995 
996   // Allocate dense output for fixed length dense values
997   // (variable-length dense and sparse have to be buffered).
998   std::vector<Tensor> fixed_dense_values(config.dense.size());
999   for (size_t d = 0; d < config.dense.size(); ++d) {
1000     if (config.dense[d].variable_length) continue;
1001     TensorShape out_shape;
1002     out_shape.AddDim(serialized.size());
1003     for (const int64 dim : config.dense[d].shape.dim_sizes()) {
1004       out_shape.AddDim(dim);
1005     }
1006     fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1007   }
1008 
1009   // This parameter affects performance in a big and data-dependent way.
1010   const size_t kMiniBatchSizeBytes = 50000;
1011 
1012   // Calculate number of minibatches.
1013   // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1014   // Apply 'special logic' below for small and big regimes.
1015   const size_t num_minibatches = [&] {
1016     size_t result = 0;
1017     size_t minibatch_bytes = 0;
1018     for (size_t i = 0; i < serialized.size(); i++) {
1019       if (minibatch_bytes == 0) {  // start minibatch
1020         result++;
1021       }
1022       minibatch_bytes += serialized[i].size() + 1;
1023       if (minibatch_bytes > kMiniBatchSizeBytes) {
1024         minibatch_bytes = 0;
1025       }
1026     }
1027     // 'special logic'
1028     const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1029     const size_t max_minibatches = 64;
1030     return std::max<size_t>(min_minibatches,
1031                             std::min<size_t>(max_minibatches, result));
1032   }();
1033 
1034   auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1035     return (serialized.size() * minibatch) / num_minibatches;
1036   };
1037 
1038   // TODO(lew): A big performance low-hanging fruit here is to improve
1039   //   num_minibatches calculation to take into account actual amount of work
1040   //   needed, as the size in bytes is not perfect. Linear combination of
1041   //   size in bytes and average number of features per example is promising.
1042   //   Even better: measure time instead of estimating, but this is too costly
1043   //   in small batches.
1044   //   Maybe accept outside parameter #num_minibatches?
1045 
1046   // Do minibatches in parallel.
1047   std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1048   std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1049   std::vector<Status> status_of_minibatch(num_minibatches);
1050   auto ProcessMiniBatch = [&](size_t minibatch) {
1051     sparse_buffers[minibatch].resize(config.sparse.size());
1052     varlen_dense_buffers[minibatch].resize(config.dense.size());
1053     size_t start = first_example_of_minibatch(minibatch);
1054     size_t end = first_example_of_minibatch(minibatch + 1);
1055     for (size_t e = start; e < end; ++e) {
1056       PerExampleFeatureStats* stats = nullptr;
1057       if (config.collect_feature_stats) {
1058         stats = &result->feature_stats[e];
1059       }
1060       status_of_minibatch[minibatch] = FastParseSerializedExample(
1061           serialized[e],
1062           (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1063           config_index, hasher, &fixed_dense_values,
1064           &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch], stats);
1065       if (!status_of_minibatch[minibatch].ok()) break;
1066     }
1067   };
1068 
1069   ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1070 
1071   for (Status& status : status_of_minibatch) {
1072     TF_RETURN_IF_ERROR(status);
1073   }
1074 
1075   for (size_t d = 0; d < config.dense.size(); ++d) {
1076     result->dense_values.push_back(std::move(fixed_dense_values[d]));
1077   }
1078 
1079   // Merge SparseBuffers from all minibatches for every config.sparse.
1080   auto MergeSparseMinibatches = [&](size_t d) {
1081     // Loop over minibatches
1082     size_t total_num_features = 0;
1083     size_t max_num_features = 0;
1084     for (auto& sparse_values_tmp : sparse_buffers) {
1085       const std::vector<size_t>& end_indices =
1086           sparse_values_tmp[d].example_end_indices;
1087       total_num_features += end_indices.back();
1088       max_num_features = std::max(max_num_features, end_indices[0]);
1089       for (size_t i = 1; i < end_indices.size(); ++i) {
1090         size_t example_size = end_indices[i] - end_indices[i - 1];
1091         max_num_features = std::max(max_num_features, example_size);
1092       }
1093     }
1094 
1095     TensorShape indices_shape;
1096     indices_shape.AddDim(total_num_features);
1097     indices_shape.AddDim(2);
1098     result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1099     Tensor* indices = &result->sparse_indices.back();
1100 
1101     TensorShape values_shape;
1102     values_shape.AddDim(total_num_features);
1103     result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1104     Tensor* values = &result->sparse_values.back();
1105 
1106     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1107     auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
1108     shapes_shape_t(0) = serialized.size();
1109     shapes_shape_t(1) = max_num_features;
1110 
1111     size_t offset = 0;
1112     for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1113       const SparseBuffer& buffer = sparse_buffers[i][d];
1114 
1115       // Update indices.
1116       int64* ix_p = &indices->matrix<int64>()(offset, 0);
1117       size_t delta = 0;
1118       size_t example_index = first_example_of_minibatch(i);
1119       for (size_t example_end_index : buffer.example_end_indices) {
1120         size_t feature_index = 0;
1121         for (; delta < example_end_index; ++delta) {
1122           // Column 0: example index
1123           *ix_p = example_index;
1124           // Column 1: the feature index buffer example
1125           *(ix_p + 1) = feature_index;
1126           ix_p += 2;
1127           ++feature_index;
1128         }
1129         ++example_index;
1130       }
1131 
1132       // Copy values over.
1133       switch (config.sparse[d].dtype) {
1134         case DT_INT64: {
1135           std::copy(buffer.int64_list.begin(), buffer.int64_list.end(),
1136                     values->flat<int64>().data() + offset);
1137           break;
1138         }
1139         case DT_FLOAT: {
1140           std::copy(buffer.float_list.begin(), buffer.float_list.end(),
1141                     values->flat<float>().data() + offset);
1142           break;
1143         }
1144         case DT_STRING: {
1145           std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(),
1146                     values->flat<string>().data() + offset);
1147           break;
1148         }
1149         default:
1150           LOG(FATAL) << "Should not happen.";
1151       }
1152 
1153       offset += delta;
1154     }
1155   };
1156 
1157   // Merge SparseBuffers from all minibatches for every config.dense having
1158   // variable_length.
1159   auto MergeDenseVarLenMinibatches = [&](size_t d) {
1160     if (!config.dense[d].variable_length) return;
1161 
1162     // Loop over minibatches
1163     size_t max_num_features = 0;
1164     for (auto& dense_values_tmp : varlen_dense_buffers) {
1165       std::vector<size_t>& end_indices =
1166           dense_values_tmp[d].example_end_indices;
1167       max_num_features = std::max(max_num_features, end_indices[0]);
1168       for (size_t i = 1; i < end_indices.size(); ++i) {
1169         size_t example_size = end_indices[i] - end_indices[i - 1];
1170         max_num_features = std::max(max_num_features, example_size);
1171       }
1172     }
1173 
1174     const size_t stride_size = config.dense[d].elements_per_stride;
1175     const size_t max_num_elements = max_num_features / stride_size;
1176     TensorShape values_shape;
1177     DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1178     const size_t batch_size = serialized.size();
1179     values_shape.AddDim(batch_size);
1180     values_shape.AddDim(max_num_elements);
1181     for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1182       values_shape.AddDim(config.dense[d].shape.dim_size(i));
1183     }
1184     Tensor values(config.dense[d].dtype, values_shape);
1185     result->dense_values[d] = values;
1186     const size_t num_elements = values.NumElements();
1187 
1188     // Nothing to write, exit early.
1189     if (num_elements == 0) return;
1190 
1191     const size_t num_elements_per_minibatch = num_elements / batch_size;
1192 
1193     switch (config.dense[d].dtype) {
1194       case DT_INT64: {
1195         FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
1196                                  config, varlen_dense_buffers, &values);
1197         break;
1198       }
1199       case DT_FLOAT: {
1200         FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1201                                  config, varlen_dense_buffers, &values);
1202         break;
1203       }
1204       case DT_STRING: {
1205         FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch,
1206                                   config, varlen_dense_buffers, &values);
1207         break;
1208       }
1209       default:
1210         LOG(FATAL) << "Should not happen.";
1211     }
1212   };
1213 
1214   for (size_t d = 0; d < config.dense.size(); ++d) {
1215     MergeDenseVarLenMinibatches(d);
1216   }
1217 
1218   for (size_t d = 0; d < config.sparse.size(); ++d) {
1219     MergeSparseMinibatches(d);
1220   }
1221 
1222   return Status::OK();
1223 }
1224 
FastParseSingleExample(const Config & config,const string & serialized,Result * result)1225 Status FastParseSingleExample(const Config& config, const string& serialized,
1226                               Result* result) {
1227   DCHECK(result != nullptr);
1228   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1229   for (auto& c : config.sparse) {
1230     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1231   }
1232   for (auto& c : config.dense) {
1233     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1234   }
1235 
1236   PerExampleFeatureStats* stats = nullptr;
1237   if (config.collect_feature_stats) {
1238     result->feature_stats.emplace_back();
1239     stats = &result->feature_stats.back();
1240   }
1241 
1242   // TODO(mrry): Cache the construction of this map at Op construction time.
1243   size_t config_size = config.dense.size() + config.sparse.size();
1244   SeededHasher hasher;
1245   // Build config index.
1246   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1247   bool ok = true;
1248   for (size_t i = 0; i < 1000; ++i) {
1249     for (size_t d = 0; d < config.dense.size(); ++d) {
1250       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1251                                       {d, Type::Dense});
1252     }
1253     for (size_t d = 0; d < config.sparse.size(); ++d) {
1254       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1255                                       {d, Type::Sparse});
1256     }
1257     if (ok) break;
1258     LOG(WARNING) << "Collision found. This should happen only if you have "
1259                     "around 2^32 entries in your config.";
1260     hasher.seed++;
1261     config_index.Clear(config_size);
1262   }
1263   if (!ok) {
1264     return errors::Internal(
1265         "Could not avoid collision. This should not happen.");
1266   }
1267 
1268   // Allocate dense output tensors.
1269   for (size_t d = 0; d < config.dense.size(); ++d) {
1270     if (!config.dense[d].variable_length) {
1271       TensorShape values_shape;
1272       if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1273         return errors::Internal(
1274             "Fixed-length shape was not a statically defined shape.");
1275       }
1276       result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1277     } else {
1278       // Variable-length tensor will be allocated later.
1279       result->dense_values.emplace_back();
1280     }
1281   }
1282 
1283   // Allocate sparse output tensors.
1284   for (size_t d = 0; d < config.sparse.size(); ++d) {
1285     // The dense_shape is always a vector of length 1.
1286     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1287     // Variable-length tensors will be allocated later.
1288     result->sparse_indices.emplace_back();
1289     result->sparse_values.emplace_back();
1290   }
1291 
1292   parsed::Example parsed_example;
1293   if (!ParseExample(serialized, &parsed_example)) {
1294     return errors::InvalidArgument("Could not parse example input, value: '",
1295                                    serialized, "'");
1296   }
1297   std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1298   std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1299 
1300   if (stats) {
1301     // TODO(b/111553342): This may over-count the number of features if there
1302     // are duplicate keys in the feature map. Consider deduplicating the keys
1303     // before computing the count.
1304     stats->features_count = parsed_example.size();
1305   }
1306 
1307   // Handle features present in the example.
1308   const size_t parsed_example_size = parsed_example.size();
1309   for (size_t i = 0; i < parsed_example_size; ++i) {
1310     // This is a logic that standard protobuf parsing is implementing.
1311     // I.e. last entry in the map overwrites all the previous ones.
1312     parsed::FeatureMapEntry& name_and_feature =
1313         parsed_example[parsed_example_size - i - 1];
1314 
1315     const StringPiece feature_name = name_and_feature.first;
1316     parsed::Feature& feature = name_and_feature.second;
1317 
1318     std::pair<size_t, Type> d_and_type;
1319     uint64 h = hasher(feature_name);
1320     if (!config_index.Find(h, &d_and_type)) continue;
1321 
1322     size_t d = d_and_type.first;
1323     bool is_dense = d_and_type.second == Type::Dense;
1324 
1325     {
1326       // Testing for PresizedCuckooMap collision.
1327       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1328       const string& config_feature_name = is_dense
1329                                               ? config.dense[d].feature_name
1330                                               : config.sparse[d].feature_name;
1331       if (feature_name != config_feature_name) continue;
1332     }
1333 
1334     auto example_error = [feature_name](StringPiece suffix) {
1335       return errors::InvalidArgument("Key: ", feature_name, ".  ", suffix);
1336     };
1337 
1338     auto parse_error = [feature_name] {
1339       return errors::InvalidArgument("Key: ", feature_name,
1340                                      ".  Can't parse serialized Example.");
1341     };
1342 
1343     DataType example_dtype;
1344     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1345     if (example_dtype == DT_INVALID) continue;
1346 
1347     if (is_dense && !config.dense[d].variable_length) {
1348       // If feature was already visited, skip.
1349       // Compare comment at the beginning of the loop.
1350       if (dense_feature_already_seen[d]) {
1351         LogDenseFeatureDataLoss(feature_name);
1352         continue;
1353       }
1354       dense_feature_already_seen[d] = true;
1355 
1356       if (example_dtype != config.dense[d].dtype) {
1357         return example_error(strings::StrCat(
1358             "Data types don't match. Data type: ",
1359             DataTypeString(example_dtype),
1360             " but expected type: ", DataTypeString(config.dense[d].dtype)));
1361       }
1362 
1363       Tensor* out = &result->dense_values[d];
1364       const std::size_t num_elements = config.dense[d].elements_per_stride;
1365       if (stats) {
1366         // TODO(b/111553342): If desirable, we could add support for counting
1367         // elements in the features that aren't parsed, but this could add
1368         // considerable runtime cost.
1369         stats->feature_values_count += num_elements;
1370       }
1371       switch (example_dtype) {
1372         case DT_INT64: {
1373           auto out_p = out->flat<int64>().data();
1374           LimitedArraySlice<int64> slice(out_p, num_elements);
1375           if (!feature.ParseInt64List(&slice)) return parse_error();
1376           if (slice.EndDistance() != 0) {
1377             return parse_error();
1378           }
1379           break;
1380         }
1381         case DT_FLOAT: {
1382           auto out_p = out->flat<float>().data();
1383           LimitedArraySlice<float> slice(out_p, num_elements);
1384           if (!feature.ParseFloatList(&slice)) return parse_error();
1385           if (slice.EndDistance() != 0) {
1386             return parse_error();
1387           }
1388           break;
1389         }
1390         case DT_STRING: {
1391           auto out_p = out->flat<string>().data();
1392           LimitedArraySlice<string> slice(out_p, num_elements);
1393           if (!feature.ParseBytesList(&slice)) return parse_error();
1394           if (slice.EndDistance() != 0) {
1395             return parse_error();
1396           }
1397           break;
1398         }
1399         default:
1400           LOG(FATAL) << "Should not happen.";
1401       }
1402 
1403     } else {  // if variable length
1404       SparseBuffer out_temp;
1405       const size_t num_elements_divisor =
1406           is_dense ? config.dense[d].elements_per_stride : 1;
1407       size_t num_elements;
1408 
1409       if (is_dense) {
1410         // If feature was already visited, skip.
1411         // Compare comment at the beginning of the loop.
1412         if (dense_feature_already_seen[d]) {
1413           LogDenseFeatureDataLoss(feature_name);
1414           continue;
1415         }
1416         dense_feature_already_seen[d] = true;
1417         if (example_dtype != config.dense[d].dtype) {
1418           return example_error(strings::StrCat(
1419               "Data types don't match. Data type: ",
1420               DataTypeString(example_dtype),
1421               " but expected type: ", DataTypeString(config.dense[d].dtype)));
1422         }
1423       } else {
1424         // If feature was already visited, skip.
1425         // Compare comment at the beginning of the loop.
1426         if (sparse_feature_already_seen[d]) {
1427           LogSparseFeatureDataLoss(feature_name);
1428           continue;
1429         }
1430         sparse_feature_already_seen[d] = true;
1431 
1432         // Handle sparse features.
1433         if (example_dtype != DT_INVALID &&
1434             example_dtype != config.sparse[d].dtype) {
1435           return example_error(strings::StrCat(
1436               "Data types don't match. ",
1437               "Expected type: ", DataTypeString(config.sparse[d].dtype),
1438               ", Actual type: ", DataTypeString(example_dtype)));
1439         }
1440       }
1441 
1442       switch (example_dtype) {
1443         case DT_INT64: {
1444           // TODO(mrry): Use the fact that the `int64_list` is packed to read
1445           // out the length and pre-allocate the output tensor.
1446           if (!feature.ParseInt64List(&out_temp.int64_list))
1447             return parse_error();
1448           num_elements = out_temp.int64_list.size();
1449           break;
1450         }
1451         case DT_FLOAT: {
1452           // TODO(mrry): Use the fact that the `float_list` is packed to read
1453           // out the length and pre-allocate the output tensor.
1454           if (!feature.ParseFloatList(&out_temp.float_list))
1455             return parse_error();
1456           num_elements = out_temp.float_list.size();
1457           break;
1458         }
1459         case DT_STRING: {
1460           int actual_num_elements = 0;
1461           if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1462             return parse_error();
1463           }
1464           out_temp.bytes_list.reserve(actual_num_elements);
1465           if (!feature.ParseBytesList(&out_temp.bytes_list))
1466             return parse_error();
1467           num_elements = out_temp.bytes_list.size();
1468           break;
1469         }
1470         default:
1471           LOG(FATAL) << "Should not happen. " << DataTypeString(example_dtype);
1472       }
1473 
1474       if (num_elements % num_elements_divisor != 0) {
1475         return parse_error();
1476       }
1477 
1478       if (stats) {
1479         stats->feature_values_count += num_elements;
1480       }
1481 
1482       Tensor* out;
1483       if (is_dense) {
1484         TensorShape values_shape;
1485         values_shape.AddDim(num_elements / num_elements_divisor);
1486         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1487           values_shape.AddDim(config.dense[d].shape.dim_size(i));
1488         }
1489 
1490         out = &result->dense_values[d];
1491         *out = Tensor(config.dense[d].dtype, values_shape);
1492 
1493       } else {
1494         Tensor* out_indices = &result->sparse_indices[d];
1495         Tensor* out_dense_shape = &result->sparse_shapes[d];
1496         out = &result->sparse_values[d];
1497 
1498         // TODO(mrry): Investigate the possibility of not materializing
1499         // the indices (and perhaps dense_shape) until they are needed.
1500         *out_indices = Tensor(
1501             DT_INT64, TensorShape({static_cast<int64>(num_elements), 1}));
1502         auto indices_flat = out_indices->flat<int64>();
1503         for (size_t i = 0; i < num_elements; ++i) {
1504           indices_flat(i) = static_cast<int64>(i);
1505         }
1506 
1507         *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1508         auto shapes_shape_t = out_dense_shape->vec<int64>();
1509         shapes_shape_t(0) = num_elements;
1510 
1511         *out = Tensor(config.sparse[d].dtype,
1512                       TensorShape({static_cast<int64>(num_elements)}));
1513       }
1514 
1515       switch (example_dtype) {
1516         case DT_INT64: {
1517           CopyOrMoveBlock(out_temp.int64_list.begin(),
1518                           out_temp.int64_list.end(), out->flat<int64>().data());
1519           break;
1520         }
1521         case DT_FLOAT: {
1522           CopyOrMoveBlock(out_temp.float_list.begin(),
1523                           out_temp.float_list.end(), out->flat<float>().data());
1524           break;
1525         }
1526         case DT_STRING: {
1527           CopyOrMoveBlock(out_temp.bytes_list.begin(),
1528                           out_temp.bytes_list.end(),
1529                           out->flat<string>().data());
1530           break;
1531         }
1532         default:
1533           LOG(FATAL) << "Should not happen.";
1534       }
1535     }
1536   }
1537 
1538   // Handle missing dense features.
1539   for (size_t d = 0; d < config.dense.size(); ++d) {
1540     if (!dense_feature_already_seen[d]) {
1541       if (!config.dense[d].variable_length) {
1542         // Handle missing fixed-length dense feature.
1543         if (config.dense[d].default_value.NumElements() == 0) {
1544           return errors::InvalidArgument(
1545               "Feature: ", config.dense[d].feature_name,
1546               " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1547               " is required but could not be found.");
1548         }
1549         result->dense_values[d] = config.dense[d].default_value;
1550       } else {
1551         // Handle missing varlen dense feature.
1552         TensorShape empty_shape;
1553         empty_shape.AddDim(0);
1554         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1555           empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1556         }
1557         result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1558       }
1559     }
1560   }
1561 
1562   // Handle missing sparse features.
1563   for (size_t d = 0; d < config.sparse.size(); ++d) {
1564     if (!sparse_feature_already_seen[d]) {
1565       result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1566       result->sparse_values[d] =
1567           Tensor(config.sparse[d].dtype, TensorShape({0}));
1568       result->sparse_shapes[d].vec<int64>()(0) = 0;
1569     }
1570   }
1571 
1572   return Status::OK();
1573 }
1574 
1575 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1576 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,string * out)1577 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1578                              string* out) {
1579   int num_elements = 0;
1580   uint32 length;
1581   if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1582     return -1;
1583   }
1584   if (length > 0) {
1585     auto limit = stream->PushLimit(length);
1586     while (!stream->ExpectAtEnd()) {
1587       uint32 bytes_length;
1588       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1589           !stream->ReadVarint32(&bytes_length) ||
1590           (out != nullptr && !stream->ReadString(out++, bytes_length))) {
1591         return -1;
1592       }
1593       if (out == nullptr) {
1594         stream->Skip(bytes_length);
1595       }
1596       num_elements++;
1597     }
1598     stream->PopLimit(limit);
1599   }
1600   return num_elements;
1601 }
1602 
PadFloatFeature(int num_to_pad,float * out)1603 inline void PadFloatFeature(int num_to_pad, float* out) {
1604   for (int i = 0; i < num_to_pad; i++) {
1605     *out++ = 0.0;
1606   }
1607 }
1608 
PadInt64Feature(int num_to_pad,int64 * out)1609 inline void PadInt64Feature(int num_to_pad, int64* out) {
1610   for (int i = 0; i < num_to_pad; i++) {
1611     *out++ = 0;
1612   }
1613 }
1614 
1615 // Return the number of float elements parsed, or -1 on error. If out is null,
1616 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1617 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1618                              float* out) {
1619   int num_elements = 0;
1620   uint32 length;
1621   if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1622     return -1;
1623   }
1624   if (length > 0) {
1625     auto limit = stream->PushLimit(length);
1626     uint8 peek_tag = PeekTag(stream);
1627     if (peek_tag == kDelimitedTag(1)) {  // packed
1628       uint32 packed_length;
1629       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1630           !stream->ReadVarint32(&packed_length)) {
1631         return -1;
1632       }
1633       auto packed_limit = stream->PushLimit(packed_length);
1634       while (!stream->ExpectAtEnd()) {
1635         uint32 buffer32;
1636         if (!stream->ReadLittleEndian32(&buffer32)) {
1637           return -1;
1638         }
1639         if (out != nullptr) {
1640           *out++ = absl::bit_cast<float>(buffer32);
1641         }
1642         num_elements++;
1643       }
1644       stream->PopLimit(packed_limit);
1645     } else if (peek_tag == kFixed32Tag(1)) {
1646       while (!stream->ExpectAtEnd()) {
1647         uint32 buffer32;
1648         if (!stream->ExpectTag(kFixed32Tag(1)) ||
1649             !stream->ReadLittleEndian32(&buffer32)) {
1650           return -1;
1651         }
1652         if (out != nullptr) {
1653           *out++ = absl::bit_cast<float>(buffer32);
1654         }
1655         num_elements++;
1656       }
1657     } else {
1658       // Unknown tag.
1659       return -1;
1660     }
1661     stream->PopLimit(limit);
1662   }
1663   return num_elements;
1664 }
1665 
1666 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1667 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)1668 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1669                              int64* out) {
1670   int num_elements = 0;
1671   uint32 length;
1672   if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1673     return -1;
1674   }
1675   if (length > 0) {
1676     auto limit = stream->PushLimit(length);
1677     uint8 peek_tag = PeekTag(stream);
1678     if (peek_tag == kDelimitedTag(1)) {  // packed
1679       uint32 packed_length;
1680       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1681           !stream->ReadVarint32(&packed_length)) {
1682         return -1;
1683       }
1684       auto packed_limit = stream->PushLimit(packed_length);
1685       while (!stream->ExpectAtEnd()) {
1686         protobuf_uint64 n;  // There is no API for int64
1687         if (!stream->ReadVarint64(&n)) {
1688           return -1;
1689         }
1690         if (out != nullptr) {
1691           *out++ = n;
1692         }
1693         num_elements++;
1694       }
1695       stream->PopLimit(packed_limit);
1696     } else if (peek_tag == kVarintTag(1)) {
1697       while (!stream->ExpectAtEnd()) {
1698         protobuf_uint64 n;  // There is no API for int64
1699         if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1700           return -1;
1701         }
1702         if (out != nullptr) {
1703           *out++ = n;
1704         }
1705         num_elements++;
1706       }
1707     } else {
1708       // Unknown tag.
1709       return -1;
1710     }
1711     stream->PopLimit(limit);
1712   }
1713   return num_elements;
1714 }
1715 
ParseDataType(protobuf::io::CodedInputStream * stream)1716 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
1717   uint8 peek_tag = PeekTag(stream);
1718   switch (peek_tag) {
1719     case kDelimitedTag(1):
1720       return DT_STRING;
1721     case kDelimitedTag(2):
1722       return DT_FLOAT;
1723     case kDelimitedTag(3):
1724       return DT_INT64;
1725     default:
1726       return DT_INVALID;
1727   }
1728 }
1729 
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)1730 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
1731                              DataType dtype) {
1732   switch (dtype) {
1733     case DT_STRING:
1734       if (!stream->ExpectTag(kDelimitedTag(1))) {
1735         return false;
1736       }
1737       break;
1738     case DT_FLOAT:
1739       if (!stream->ExpectTag(kDelimitedTag(2))) {
1740         return false;
1741       }
1742       break;
1743     case DT_INT64:
1744       if (!stream->ExpectTag(kDelimitedTag(3))) {
1745         return false;
1746       }
1747       break;
1748     default:
1749       return false;
1750   }
1751   uint32 length;
1752   return stream->ReadVarint32(&length) && length == 0;
1753 }
1754 
1755 // TODO(sundberg): Use the threadpool to parallelize example parsing.
1756 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & feature_list_config,gtl::ArraySlice<string> serialized,gtl::ArraySlice<string> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * feature_list_result,std::vector<Tensor> * dense_feature_lengths)1757 Status FastParseSequenceExample(
1758     const FastParseExampleConfig& context_config,
1759     const FastParseExampleConfig& feature_list_config,
1760     gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
1761     thread::ThreadPool* thread_pool, Result* context_result,
1762     Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
1763   int num_examples = serialized.size();
1764   DCHECK(context_result != nullptr);
1765   DCHECK(feature_list_result != nullptr);
1766   DCHECK(dense_feature_lengths != nullptr);
1767   size_t num_context_features =
1768       context_config.sparse.size() + context_config.dense.size();
1769   absl::flat_hash_map<StringPiece, bool> context_is_sparse;
1770   context_is_sparse.reserve(num_context_features);
1771   absl::flat_hash_map<StringPiece, std::pair<DataType, size_t>>
1772       context_feature_type_and_lengths;
1773   context_feature_type_and_lengths.reserve(num_context_features);
1774   if (!example_names.empty() && example_names.size() != num_examples) {
1775     return errors::InvalidArgument(
1776         "example_names must be empty or have the correct number of elements");
1777   }
1778   for (auto& c : context_config.sparse) {
1779     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1780     context_feature_type_and_lengths[c.feature_name] =
1781         std::make_pair(c.dtype, 0);
1782     context_is_sparse[c.feature_name] = true;
1783   }
1784   for (auto& c : context_config.dense) {
1785     if (context_is_sparse[c.feature_name]) {
1786       return errors::InvalidArgument("Context feature " + c.feature_name +
1787                                      " cannot be both dense and sparse");
1788     }
1789     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1790     context_feature_type_and_lengths[c.feature_name] =
1791         std::make_pair(c.dtype, c.default_value.NumElements());
1792     if (c.default_value.NumElements() > 0) {
1793       if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
1794         return errors::InvalidArgument("Default value for context feature ",
1795                                        c.feature_name,
1796                                        " has an incorrect shape: saw ",
1797                                        c.default_value.shape().DebugString(),
1798                                        " but expected ", c.shape.DebugString());
1799       }
1800     }
1801   }
1802   size_t num_sequence_features =
1803       feature_list_config.sparse.size() + feature_list_config.dense.size();
1804   absl::flat_hash_map<StringPiece, bool> sequence_is_sparse;
1805   sequence_is_sparse.reserve(num_sequence_features);
1806   absl::flat_hash_map<StringPiece, std::pair<DataType, size_t>>
1807       sequence_feature_type_and_lengths;
1808   sequence_feature_type_and_lengths.reserve(num_sequence_features);
1809   for (auto& c : feature_list_config.sparse) {
1810     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1811     sequence_feature_type_and_lengths[c.feature_name] =
1812         std::make_pair(c.dtype, 0);
1813     sequence_is_sparse[c.feature_name] = true;
1814   }
1815   for (auto& c : feature_list_config.dense) {
1816     if (sequence_is_sparse[c.feature_name]) {
1817       return errors::InvalidArgument("Sequence feature " + c.feature_name +
1818                                      " cannot be both dense and sparse");
1819     }
1820     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1821     sequence_feature_type_and_lengths[c.feature_name] =
1822         std::make_pair(c.dtype, 0);
1823   }
1824 
1825   std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
1826       all_context_features(num_examples);
1827   std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
1828       all_sequence_features(num_examples);
1829   const string kUnknown = "<unknown>";
1830   for (int d = 0; d < num_examples; d++) {
1831     const string& example = serialized[d];
1832     const string& example_name =
1833         example_names.empty() ? kUnknown : example_names[d];
1834     auto* context_features = &all_context_features[d];
1835     auto* sequence_features = &all_sequence_features[d];
1836 
1837     protobuf::io::CodedInputStream stream(
1838         reinterpret_cast<const uint8*>(example.data()), example.size());
1839     // Not clear what this does. Why not stream.EnableAliasing()?
1840     EnableAliasing(&stream);
1841 
1842     // Extract pointers to all features within this serialized example.
1843     while (!stream.ExpectAtEnd()) {
1844       absl::flat_hash_map<StringPiece, StringPiece>* features = nullptr;
1845       const absl::flat_hash_map<StringPiece, std::pair<DataType, size_t>>*
1846           config = nullptr;
1847       if (stream.ExpectTag(kDelimitedTag(1))) {
1848         // Context
1849         features = context_features;
1850         config = &context_feature_type_and_lengths;
1851       } else if (stream.ExpectTag(kDelimitedTag(2))) {
1852         // Sequence
1853         features = sequence_features;
1854         config = &sequence_feature_type_and_lengths;
1855       } else if (!SkipExtraneousTag(&stream)) {
1856         return errors::InvalidArgument(
1857             "Invalid protocol message input, example id: ", example_name);
1858       }
1859       if (features != nullptr) {
1860         uint32 length;
1861         if (!stream.ReadVarint32(&length)) {
1862           return errors::InvalidArgument(
1863               "Invalid protocol message input, example id: ", example_name);
1864         }
1865         auto limit = stream.PushLimit(length);
1866         while (!stream.ExpectAtEnd()) {
1867           StringPiece key, value;
1868           uint32 length;
1869           if (!stream.ExpectTag(kDelimitedTag(1)) ||
1870               !stream.ReadVarint32(&length)) {
1871             return errors::InvalidArgument(
1872                 "Invalid protocol message input, example id: ", example_name);
1873           }
1874           auto limit = stream.PushLimit(length);
1875           if (!stream.ExpectTag(kDelimitedTag(1)) ||
1876               !ParseString(&stream, &key) ||
1877               !stream.ExpectTag(kDelimitedTag(2)) ||
1878               !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
1879             return errors::InvalidArgument(
1880                 "Invalid protocol message input, example id: ", example_name);
1881           }
1882           stream.PopLimit(limit);
1883           // Only save if this feature was requested.
1884           if (config->count(key) > 0) {
1885             (*features)[key] = value;
1886           }
1887         }
1888         stream.PopLimit(limit);
1889       }
1890     }
1891 
1892     for (const auto& c : *context_features) {
1893       size_t num_elements = 0;
1894       if (!c.second.empty()) {
1895         protobuf::io::CodedInputStream stream(
1896             reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
1897         EnableAliasing(&stream);
1898         DataType dtype = context_feature_type_and_lengths[c.first].first;
1899         int64 num;
1900         switch (dtype) {
1901           case DT_STRING:
1902             num = ParseBytesFeature(&stream, nullptr);
1903             break;
1904           case DT_FLOAT:
1905             num = ParseFloatFeature(&stream, nullptr);
1906             break;
1907           case DT_INT64:
1908             num = ParseInt64Feature(&stream, nullptr);
1909             break;
1910           default:
1911             num = -1;
1912             break;
1913         }
1914         if (num == -1) {
1915           return errors::InvalidArgument("Error in context feature ", c.first,
1916                                          " in example ", example_name);
1917         }
1918         num_elements += num;
1919       }
1920       if (context_is_sparse[c.first]) {
1921         context_feature_type_and_lengths[c.first].second += num_elements;
1922       } else {
1923         size_t current_max = context_feature_type_and_lengths[c.first].second;
1924         context_feature_type_and_lengths[c.first].second =
1925             std::max(current_max, num_elements);
1926       }
1927     }
1928     for (const auto& c : *sequence_features) {
1929       size_t num_elements = 0;
1930       if (!c.second.empty()) {
1931         protobuf::io::CodedInputStream stream(
1932             reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
1933         EnableAliasing(&stream);
1934         DataType dtype = sequence_feature_type_and_lengths[c.first].first;
1935         while (!stream.ExpectAtEnd()) {
1936           uint32 feature_length;
1937           if (!stream.ExpectTag(kDelimitedTag(1)) ||
1938               !stream.ReadVarint32(&feature_length)) {
1939             return errors::InvalidArgument("Error in sequence feature ",
1940                                            c.first, " in example ",
1941                                            example_name);
1942           }
1943           if (feature_length > 2) {
1944             auto limit = stream.PushLimit(feature_length);
1945             int64 num;
1946             switch (dtype) {
1947               case DT_STRING:
1948                 num = ParseBytesFeature(&stream, nullptr);
1949                 break;
1950               case DT_FLOAT:
1951                 num = ParseFloatFeature(&stream, nullptr);
1952                 break;
1953               case DT_INT64:
1954                 num = ParseInt64Feature(&stream, nullptr);
1955                 break;
1956               default:
1957                 num = -1;
1958                 break;
1959             }
1960             if (num == -1) {
1961               return errors::InvalidArgument("Error in sequence feature ",
1962                                              c.first, " in example ",
1963                                              example_name);
1964             }
1965             num_elements += num;
1966             stream.PopLimit(limit);
1967           } else if (feature_length == 2) {
1968             if (!SkipEmptyFeature(&stream, dtype)) {
1969               return errors::InvalidArgument("Error in sequence feature ",
1970                                              c.first, " in example ",
1971                                              example_name);
1972             }
1973           } else if (feature_length != 0) {
1974             return errors::InvalidArgument("Error in sequence feature ",
1975                                            c.first, " in example ",
1976                                            example_name);
1977           }
1978         }
1979       }
1980       if (sequence_is_sparse[c.first]) {
1981         sequence_feature_type_and_lengths[c.first].second += num_elements;
1982       } else {
1983         size_t current_max = sequence_feature_type_and_lengths[c.first].second;
1984         sequence_feature_type_and_lengths[c.first].second =
1985             std::max(current_max, num_elements);
1986       }
1987     }
1988   }
1989 
1990   // Allocate memory.
1991   context_result->sparse_values.resize(context_config.sparse.size());
1992   context_result->sparse_indices.resize(context_config.sparse.size());
1993   context_result->sparse_shapes.resize(context_config.sparse.size());
1994   context_result->dense_values.resize(context_config.dense.size());
1995   feature_list_result->sparse_values.resize(feature_list_config.sparse.size());
1996   feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
1997   feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
1998   feature_list_result->dense_values.resize(feature_list_config.dense.size());
1999   dense_feature_lengths->resize(feature_list_config.dense.size());
2000 
2001   int t = 0;
2002   for (const auto& c : context_config.dense) {
2003     TensorShape dense_shape, example_shape;
2004     DataType dtype = c.dtype;
2005     const size_t expected_max_elements =
2006         context_feature_type_and_lengths[c.feature_name].second;
2007     if (!c.shape.AsTensorShape(&example_shape) ||
2008         expected_max_elements != example_shape.num_elements()) {
2009       return errors::InvalidArgument(
2010           "Inconsistent number of elements for feature ", c.feature_name, ": ",
2011           expected_max_elements, " vs ", dense_shape.num_elements());
2012     }
2013     dense_shape.AddDim(num_examples);
2014     for (const int dim : c.shape.dim_sizes()) {
2015       dense_shape.AddDim(dim);
2016     }
2017     context_result->dense_values[t] = Tensor(dtype, dense_shape);
2018 
2019     // TODO(sundberg): Refactor to reduce code duplication, and add bounds
2020     // checking for the outputs.
2021     string* out_bytes = nullptr;
2022     float* out_float = nullptr;
2023     int64* out_int64 = nullptr;
2024     switch (dtype) {
2025       case DT_STRING:
2026         out_bytes = context_result->dense_values[t].flat<string>().data();
2027         break;
2028       case DT_FLOAT:
2029         out_float = context_result->dense_values[t].flat<float>().data();
2030         break;
2031       case DT_INT64:
2032         out_int64 = context_result->dense_values[t].flat<int64>().data();
2033         break;
2034       default:
2035         return errors::InvalidArgument("Unexpected dtype ", dtype,
2036                                        " in feature ", c.feature_name);
2037     }
2038     t++;
2039 
2040     // Fill in the values.
2041     for (int e = 0; e < num_examples; e++) {
2042       size_t num_elements = 0;
2043       const auto feature_iter = all_context_features[e].find(c.feature_name);
2044       const string& example_name =
2045           example_names.empty() ? kUnknown : example_names[e];
2046       if (feature_iter == all_context_features[e].end()) {
2047         // Copy the default value, if present. If not, return an error.
2048         if (c.default_value.NumElements() == 0) {
2049           return errors::InvalidArgument(
2050               "Feature: ", c.feature_name,
2051               " (data type: ", DataTypeString(c.dtype), ")",
2052               " is required but could not be found.");
2053         }
2054         const string* in_bytes = nullptr;
2055         const float* in_float = nullptr;
2056         const int64* in_int64 = nullptr;
2057         size_t num = 0;
2058         switch (dtype) {
2059           case DT_STRING:
2060             in_bytes = c.default_value.flat<string>().data();
2061             num = c.default_value.NumElements();
2062             for (int p = 0; p < num; p++) {
2063               *out_bytes++ = *in_bytes++;
2064             }
2065             break;
2066           case DT_FLOAT:
2067             in_float = c.default_value.flat<float>().data();
2068             num = c.default_value.NumElements();
2069             for (int p = 0; p < num; p++) {
2070               *out_float++ = *in_float++;
2071             }
2072             break;
2073           case DT_INT64:
2074             in_int64 = c.default_value.flat<int64>().data();
2075             num = c.default_value.NumElements();
2076             for (int p = 0; p < num; p++) {
2077               *out_int64++ = *in_int64++;
2078             }
2079             break;
2080           default:
2081             return errors::InvalidArgument("Unexpected dtype ", dtype,
2082                                            " in example ", example_name);
2083         }
2084         num_elements += num;
2085       } else if (!feature_iter->second.empty()) {
2086         const auto& feature = feature_iter->second;
2087         protobuf::io::CodedInputStream stream(
2088             reinterpret_cast<const uint8*>(feature.data()), feature.size());
2089         EnableAliasing(&stream);
2090         size_t num_added;
2091         switch (dtype) {
2092           case DT_STRING:
2093             num_added = ParseBytesFeature(&stream, out_bytes);
2094             out_bytes += num_added;
2095             break;
2096           case DT_FLOAT:
2097             num_added = ParseFloatFeature(&stream, out_float);
2098             out_float += num_added;
2099             break;
2100           case DT_INT64:
2101             num_added = ParseInt64Feature(&stream, out_int64);
2102             out_int64 += num_added;
2103             break;
2104           default:
2105             return errors::InvalidArgument("Unexpected dtype ", dtype,
2106                                            " in example ", example_name);
2107         }
2108         num_elements += num_added;
2109       }
2110       if (num_elements != expected_max_elements) {
2111         return errors::InvalidArgument(
2112             "Unexpected number of elements in example ", example_name);
2113       }
2114     }
2115   }
2116   t = 0;
2117   for (const auto& c : context_config.sparse) {
2118     TensorShape indices_shape, values_shape;
2119     DataType dtype = c.dtype;
2120     size_t expected_num_elements =
2121         context_feature_type_and_lengths[c.feature_name].second;
2122     indices_shape.AddDim(expected_num_elements);
2123     indices_shape.AddDim(2);
2124     values_shape.AddDim(expected_num_elements);
2125     context_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
2126     context_result->sparse_values[t] = Tensor(dtype, values_shape);
2127     context_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({2}));
2128     // TODO(sundberg): Refactor to reduce code duplication, and add bounds
2129     // checking for the outputs.
2130     string* out_bytes = nullptr;
2131     float* out_float = nullptr;
2132     int64* out_int64 = nullptr;
2133     switch (dtype) {
2134       case DT_STRING:
2135         out_bytes = context_result->sparse_values[t].flat<string>().data();
2136         break;
2137       case DT_FLOAT:
2138         out_float = context_result->sparse_values[t].flat<float>().data();
2139         break;
2140       case DT_INT64:
2141         out_int64 = context_result->sparse_values[t].flat<int64>().data();
2142         break;
2143       default:
2144         return errors::InvalidArgument("Unexpected dtype ", dtype,
2145                                        " in feature ", c.feature_name);
2146     }
2147     int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
2148     auto out_shape = context_result->sparse_shapes[t].vec<int64>();
2149     t++;
2150 
2151     // Fill in the values.
2152     size_t num_elements = 0;
2153     size_t max_num_cols = 0;
2154     for (int e = 0; e < num_examples; e++) {
2155       const auto& feature = all_context_features[e][c.feature_name];
2156       const string& example_name =
2157           example_names.empty() ? kUnknown : example_names[e];
2158       if (!feature.empty()) {
2159         protobuf::io::CodedInputStream stream(
2160             reinterpret_cast<const uint8*>(feature.data()), feature.size());
2161         EnableAliasing(&stream);
2162         size_t num_added;
2163         switch (dtype) {
2164           case DT_STRING:
2165             num_added = ParseBytesFeature(&stream, out_bytes);
2166             out_bytes += num_added;
2167             break;
2168           case DT_FLOAT:
2169             num_added = ParseFloatFeature(&stream, out_float);
2170             out_float += num_added;
2171             break;
2172           case DT_INT64:
2173             num_added = ParseInt64Feature(&stream, out_int64);
2174             out_int64 += num_added;
2175             break;
2176           default:
2177             return errors::InvalidArgument("Unexpected dtype ", dtype,
2178                                            " in example ", example_name);
2179         }
2180         num_elements += num_added;
2181         max_num_cols = std::max(max_num_cols, num_added);
2182         for (int i = 0; i < num_added; i++) {
2183           *out_indices++ = e;
2184           *out_indices++ = i;
2185         }
2186       }
2187     }
2188     if (num_elements != expected_num_elements) {
2189       return errors::InvalidArgument(
2190           "Unexpected total number of elements in feature ", c.feature_name);
2191     }
2192     out_shape(0) = num_examples;
2193     out_shape(1) = max_num_cols;
2194   }
2195   t = 0;
2196   TensorShape dense_length_shape({num_examples});
2197   for (const auto& c : feature_list_config.dense) {
2198     TensorShape dense_shape, row_shape;
2199     DataType dtype = c.dtype;
2200     const size_t expected_max_elements =
2201         sequence_feature_type_and_lengths[c.feature_name].second;
2202     if (!c.shape.AsTensorShape(&row_shape) ||
2203         expected_max_elements !=
2204             (expected_max_elements / row_shape.num_elements()) *
2205                 row_shape.num_elements()) {
2206       return errors::InvalidArgument("Unexpected shape error in feature ",
2207                                      c.feature_name);
2208     }
2209     int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
2210     dense_shape.AddDim(num_examples);
2211     dense_shape.AddDim(expected_max_rows);
2212     for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
2213       dense_shape.AddDim(dim);
2214     }
2215     feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
2216     (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape);
2217     int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
2218 
2219     string* out_bytes = nullptr;
2220     float* out_float = nullptr;
2221     int64* out_int64 = nullptr;
2222     switch (dtype) {
2223       case DT_STRING:
2224         out_bytes = feature_list_result->dense_values[t].flat<string>().data();
2225         break;
2226       case DT_FLOAT:
2227         out_float = feature_list_result->dense_values[t].flat<float>().data();
2228         break;
2229       case DT_INT64:
2230         out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
2231         break;
2232       default:
2233         return errors::InvalidArgument("Unexpected dtype ", dtype,
2234                                        " in feature ", c.feature_name);
2235     }
2236     t++;
2237 
2238     // Fill in the values.
2239     for (int e = 0; e < num_examples; e++) {
2240       size_t num_elements = 0, num_rows = 0;
2241       const auto feature_iter = all_sequence_features[e].find(c.feature_name);
2242       const string& example_name =
2243           example_names.empty() ? kUnknown : example_names[e];
2244       if (feature_iter == all_sequence_features[e].end()) {
2245         // Return an error if this feature was not allowed to be missing.
2246         // Otherwise, we'll pad as needed below.
2247         if (!c.variable_length) {
2248           return errors::InvalidArgument("Missing feature ", c.feature_name,
2249                                          " in example ", example_name);
2250         }
2251       } else if (!feature_iter->second.empty()) {
2252         const auto& feature = feature_iter->second;
2253         protobuf::io::CodedInputStream stream(
2254             reinterpret_cast<const uint8*>(feature.data()), feature.size());
2255         EnableAliasing(&stream);
2256         while (!stream.ExpectAtEnd()) {
2257           uint32 feature_length;
2258           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2259               !stream.ReadVarint32(&feature_length)) {
2260             return errors::InvalidArgument("Error in sequence feature ",
2261                                            c.feature_name, " in example ",
2262                                            example_name);
2263           }
2264           auto limit = stream.PushLimit(feature_length);
2265           size_t num_added;
2266           switch (dtype) {
2267             case DT_STRING:
2268               num_added = ParseBytesFeature(&stream, out_bytes);
2269               out_bytes += num_added;
2270               break;
2271             case DT_FLOAT:
2272               num_added = ParseFloatFeature(&stream, out_float);
2273               out_float += num_added;
2274               break;
2275             case DT_INT64:
2276               num_added = ParseInt64Feature(&stream, out_int64);
2277               out_int64 += num_added;
2278               break;
2279             default:
2280               return errors::InvalidArgument("Unexpected dtype ", dtype,
2281                                              " in example ", example_name);
2282           }
2283           num_elements += num_added;
2284           num_rows++;
2285           if (num_added != row_shape.num_elements()) {
2286             return errors::InvalidArgument(
2287                 "Unexpected number of elements in feature ", c.feature_name,
2288                 ", example ", example_name);
2289           }
2290           stream.PopLimit(limit);
2291         }
2292       }
2293       *out_lengths++ = num_rows;
2294       // Pad as necessary.
2295       int num_to_pad = expected_max_elements - num_elements;
2296       switch (dtype) {
2297         case DT_STRING:
2298           out_bytes += num_to_pad;
2299           break;
2300         case DT_FLOAT:
2301           PadFloatFeature(num_to_pad, out_float);
2302           out_float += num_to_pad;
2303           break;
2304         case DT_INT64:
2305           PadInt64Feature(num_to_pad, out_int64);
2306           out_int64 += num_to_pad;
2307           break;
2308         default:
2309           return errors::InvalidArgument("Unexpected dtype ", dtype,
2310                                          " in example ", example_name);
2311       }
2312     }
2313   }
2314   t = 0;
2315   for (const auto& c : feature_list_config.sparse) {
2316     TensorShape indices_shape, values_shape;
2317     DataType dtype = c.dtype;
2318     size_t expected_num_elements =
2319         sequence_feature_type_and_lengths[c.feature_name].second;
2320     indices_shape.AddDim(expected_num_elements);
2321     indices_shape.AddDim(3);
2322     values_shape.AddDim(expected_num_elements);
2323     feature_list_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
2324     feature_list_result->sparse_values[t] = Tensor(dtype, values_shape);
2325     feature_list_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({3}));
2326 
2327     string* out_bytes = nullptr;
2328     float* out_float = nullptr;
2329     int64* out_int64 = nullptr;
2330     switch (dtype) {
2331       case DT_STRING:
2332         out_bytes = feature_list_result->sparse_values[t].flat<string>().data();
2333         break;
2334       case DT_FLOAT:
2335         out_float = feature_list_result->sparse_values[t].flat<float>().data();
2336         break;
2337       case DT_INT64:
2338         out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
2339         break;
2340       default:
2341         return errors::InvalidArgument("Unexpected dtype ", dtype,
2342                                        " in feature ", c.feature_name);
2343     }
2344     int64* out_indices =
2345         feature_list_result->sparse_indices[t].flat<int64>().data();
2346     auto out_shape = feature_list_result->sparse_shapes[t].vec<int64>();
2347     t++;
2348 
2349     // Fill in the values.
2350     size_t num_elements = 0;
2351     size_t max_num_rows = 0;
2352     size_t max_num_cols = 0;
2353     for (int e = 0; e < num_examples; e++) {
2354       const auto& feature = all_sequence_features[e][c.feature_name];
2355       const string& example_name =
2356           example_names.empty() ? kUnknown : example_names[e];
2357       if (!feature.empty()) {
2358         protobuf::io::CodedInputStream stream(
2359             reinterpret_cast<const uint8*>(feature.data()), feature.size());
2360         EnableAliasing(&stream);
2361         size_t num_rows = 0;
2362         while (!stream.ExpectAtEnd()) {
2363           uint32 feature_length;
2364           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2365               !stream.ReadVarint32(&feature_length)) {
2366             return errors::InvalidArgument("Error in sequence feature ",
2367                                            c.feature_name, " in example ",
2368                                            example_name);
2369           }
2370           if (feature_length > 2) {
2371             auto limit = stream.PushLimit(feature_length);
2372             size_t num_added;
2373             switch (dtype) {
2374               case DT_STRING:
2375                 num_added = ParseBytesFeature(&stream, out_bytes);
2376                 out_bytes += num_added;
2377                 break;
2378               case DT_FLOAT:
2379                 num_added = ParseFloatFeature(&stream, out_float);
2380                 out_float += num_added;
2381                 break;
2382               case DT_INT64:
2383                 num_added = ParseInt64Feature(&stream, out_int64);
2384                 out_int64 += num_added;
2385                 break;
2386               default:
2387                 return errors::InvalidArgument("Unexpected dtype ", dtype,
2388                                                " in example ", example_name);
2389             }
2390             num_elements += num_added;
2391             max_num_cols = std::max(max_num_cols, num_added);
2392             for (int i = 0; i < num_added; i++) {
2393               *out_indices++ = e;
2394               *out_indices++ = num_rows;
2395               *out_indices++ = i;
2396             }
2397             stream.PopLimit(limit);
2398           } else if (feature_length == 2) {
2399             if (!SkipEmptyFeature(&stream, dtype)) {
2400               return errors::InvalidArgument("Error in sequence feature ",
2401                                              c.feature_name, " in example ",
2402                                              example_name);
2403             }
2404           } else if (feature_length != 0) {
2405             return errors::InvalidArgument("Error in sequence feature ",
2406                                            c.feature_name, " in example ",
2407                                            example_name);
2408           }
2409           num_rows++;
2410         }
2411         max_num_rows = std::max(max_num_rows, num_rows);
2412       }
2413     }
2414     if (num_elements != expected_num_elements) {
2415       return errors::InvalidArgument(
2416           "Unexpected number of elements in feature ", c.feature_name);
2417     }
2418     out_shape(0) = num_examples;
2419     out_shape(1) = max_num_rows;
2420     out_shape(2) = max_num_cols;
2421   }
2422 
2423   return Status::OK();
2424 }
2425 
2426 }  // namespace example
2427 }  // namespace tensorflow
2428