• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16 
17 #include <vector>
18 
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/util/presized_cuckoo_map.h"
37 #include "tensorflow/core/util/sparse/sparse_tensor.h"
38 
39 namespace tensorflow {
40 namespace example {
41 
42 namespace {
43 
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46 
47 template <typename T>
48 class LimitedArraySlice {
49  public:
50   using value_type = T;
51 
LimitedArraySlice(T * begin,size_t num_elements)52   LimitedArraySlice(T* begin, size_t num_elements)
53       : current_(begin), begin_(begin), end_(begin + num_elements) {}
54 
55   // May return negative if there were push_back calls after slice was filled.
EndDistance() const56   int64 EndDistance() const { return end_ - current_; }
57 
58   // Attempts to push value to the back of this. If the slice has
59   // already been filled, this method has no effect on the underlying data, but
60   // it changes the number returned by EndDistance into negative values.
push_back(T && value)61   void push_back(T&& value) {
62     if (EndDistance() > 0) *current_ = std::move(value);
63     ++current_;
64   }
65 
66   // "Constructs" an element at the back of this by resizing the slice, and
67   // returns a mutable reference to the new last element.
68   // REQUIRES: EndDistance() > 0.
construct_at_end()69   T& construct_at_end() {
70     DCHECK_GT(EndDistance(), 0);
71     return *(current_++);
72   }
73 
74   // Returns a mutable reference to the last element in the slice.
75   // REQUIRES: size() > 0.
back()76   T& back() { return *(current_ - 1); }
77 
78   // Returns the number of elements in the slice.
size() const79   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80 
81   // Attempts to resize the vector to the given size. It does so by advancing
82   // the pointer to the current element, possibly beyond the end of the slice.
83   // As a consequence, calling `size()` after `resize(x)` was called might
84   // return a value less than `x`.
resize(size_t size)85   void resize(size_t size) { current_ = begin_ + size; }
86 
87   // Returns the pointer to the underlying data buffer.
data()88   T* data() { return begin_; }
89 
90  private:
91   T* current_;
92   T* begin_;
93   T* end_;
94 };
95 
96 template <typename A>
EnableAliasing(A * a)97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98   a->EnableAliasing(true);
99 }
100 
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103 
PeekTag(protobuf::io::CodedInputStream * stream)104 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
105   DCHECK(stream != nullptr);
106   const void* ptr;
107   int size;
108   if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
109   return *static_cast<const uint8*>(ptr);
110 }
111 
kVarintTag(uint32 tag)112 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)113 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)114 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
115 
116 namespace parsed {
117 
118 // ParseDataType has to be called first, then appropriate ParseZzzzList.
119 class Feature {
120  public:
Feature()121   Feature() {}
Feature(StringPiece serialized)122   explicit Feature(StringPiece serialized) : serialized_(serialized) {}
123 
ParseDataType(DataType * dtype)124   Status ParseDataType(DataType* dtype) {
125     DCHECK(dtype != nullptr);
126     if (serialized_.empty()) {
127       *dtype = DT_INVALID;
128       return Status::OK();
129     }
130     uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
131     serialized_.remove_prefix(1);
132     switch (oneof_tag) {
133       case kDelimitedTag(1):
134         *dtype = DT_STRING;
135         break;
136       case kDelimitedTag(2):
137         *dtype = DT_FLOAT;
138         break;
139       case kDelimitedTag(3):
140         *dtype = DT_INT64;
141         break;
142       default:
143         // Initialize variable to avoid compiler warning
144         *dtype = DT_INVALID;
145         return errors::InvalidArgument("Unsupported datatype.");
146     }
147     return Status::OK();
148   }
149 
GetNumElementsInBytesList(int * num_elements)150   bool GetNumElementsInBytesList(int* num_elements) {
151     protobuf::io::CodedInputStream stream(
152         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
153     EnableAliasing(&stream);
154     uint32 length = 0;
155     if (!stream.ReadVarint32(&length)) return false;
156     auto limit = stream.PushLimit(length);
157     *num_elements = 0;
158     while (!stream.ExpectAtEnd()) {
159       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
160       uint32 bytes_length = 0;
161       if (!stream.ReadVarint32(&bytes_length)) return false;
162       if (!stream.Skip(bytes_length)) return false;
163       ++*num_elements;
164     }
165     stream.PopLimit(limit);
166     return true;
167   }
168 
169   // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)170   tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
171     if (bytes_list->EndDistance() <= 0) {
172       return nullptr;
173     }
174     return &bytes_list->construct_at_end();
175   }
construct_at_end(SmallVector<tstring> * bytes_list)176   tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
177     return &bytes_list->emplace_back();
178   }
179 
180   template <typename Result>
ParseBytesList(Result * bytes_list)181   bool ParseBytesList(Result* bytes_list) {
182     DCHECK(bytes_list != nullptr);
183 
184     protobuf::io::CodedInputStream stream(
185         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
186 
187     EnableAliasing(&stream);
188 
189     uint32 length;
190     if (!stream.ReadVarint32(&length)) return false;
191     auto limit = stream.PushLimit(length);
192 
193     while (!stream.ExpectAtEnd()) {
194       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
195       // parse string
196       uint32 bytes_length;
197       if (!stream.ReadVarint32(&bytes_length)) return false;
198       tstring* bytes = construct_at_end(bytes_list);
199       if (bytes == nullptr) return false;
200       bytes->resize_uninitialized(bytes_length);
201       if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
202     }
203     stream.PopLimit(limit);
204     return true;
205   }
206 
207   template <typename Result>
ParseFloatList(Result * float_list)208   bool ParseFloatList(Result* float_list) {
209     DCHECK(float_list != nullptr);
210     protobuf::io::CodedInputStream stream(
211         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212     EnableAliasing(&stream);
213     uint32 length;
214     if (!stream.ReadVarint32(&length)) return false;
215     auto limit = stream.PushLimit(length);
216 
217     if (!stream.ExpectAtEnd()) {
218       uint8 peek_tag = PeekTag(&stream);
219       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
220         return false;
221       }
222 
223       constexpr int32_t kNumFloatBytes = 4;
224       if (peek_tag == kDelimitedTag(1)) {                       // packed
225         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
226         uint32 packed_length;
227         if (!stream.ReadVarint32(&packed_length)) return false;
228         auto packed_limit = stream.PushLimit(packed_length);
229 
230         // Store the initial size to know the offset we have to start writing
231         // data from before resizing the output "vector".
232         const size_t initial_size = float_list->size();
233         float_list->resize(initial_size + packed_length / kNumFloatBytes);
234 
235         // If the result data type is float and we are on a little endian
236         // machine then we can simply memcpy the data from the proto into the
237         // result vector.
238         if (port::kLittleEndian &&
239             sizeof(typename Result::value_type) == kNumFloatBytes) {
240           // Calculate the length of the buffer available what can be less than
241           // what we requested in resize in case of a LimitedArraySlice.
242           const uint32 bytes_to_copy =
243               std::min(static_cast<uint32>((float_list->size() - initial_size) *
244                                            kNumFloatBytes),
245                        packed_length);
246           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
247             return false;
248         } else {
249           int64_t index = initial_size;
250           while (!stream.ExpectAtEnd()) {
251             uint32 buffer32;
252             if (!stream.ReadLittleEndian32(&buffer32)) return false;
253             if (index < float_list->size()) {
254               float_list->data()[index] = absl::bit_cast<float>(buffer32);
255               ++index;
256             }
257           }
258         }
259 
260         stream.PopLimit(packed_limit);
261       } else {  // non-packed
262         const size_t initial_size = float_list->size();
263         // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
264         // the value.
265         const int64_t num_elements =
266             stream.BytesUntilLimit() / (1 + kNumFloatBytes);
267         float_list->resize(initial_size + num_elements);
268         int64_t index = initial_size;
269         while (!stream.ExpectAtEnd()) {
270           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
271           uint32 buffer32;
272           if (!stream.ReadLittleEndian32(&buffer32)) return false;
273           float_list->data()[index] = absl::bit_cast<float>(buffer32);
274           ++index;
275         }
276       }
277     }
278 
279     stream.PopLimit(limit);
280     return true;
281   }
282 
283   template <typename Result>
ParseInt64List(Result * int64_list)284   bool ParseInt64List(Result* int64_list) {
285     DCHECK(int64_list != nullptr);
286     protobuf::io::CodedInputStream stream(
287         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
288     EnableAliasing(&stream);
289     uint32 length;
290     if (!stream.ReadVarint32(&length)) return false;
291     auto limit = stream.PushLimit(length);
292 
293     if (!stream.ExpectAtEnd()) {
294       uint8 peek_tag = PeekTag(&stream);
295       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
296         return false;
297       }
298       if (peek_tag == kDelimitedTag(1)) {                       // packed
299         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
300         uint32 packed_length;
301         if (!stream.ReadVarint32(&packed_length)) return false;
302         auto packed_limit = stream.PushLimit(packed_length);
303 
304         while (!stream.ExpectAtEnd()) {
305           protobuf_uint64 n;  // There is no API for int64
306           if (!stream.ReadVarint64(&n)) return false;
307           int64_list->push_back(static_cast<int64>(n));
308         }
309 
310         stream.PopLimit(packed_limit);
311       } else {  // non-packed
312         while (!stream.ExpectAtEnd()) {
313           if (!stream.ExpectTag(kVarintTag(1))) return false;
314           protobuf_uint64 n;  // There is no API for int64
315           if (!stream.ReadVarint64(&n)) return false;
316           int64_list->push_back(static_cast<int64>(n));
317         }
318       }
319     }
320     stream.PopLimit(limit);
321     return true;
322   }
323 
GetSerialized() const324   StringPiece GetSerialized() const { return serialized_; }
325 
326  private:
327   // TODO(lew): Pair of uint8* would be more natural.
328   StringPiece serialized_;
329 };
330 
331 using FeatureMapEntry = std::pair<StringPiece, Feature>;
332 using Example = std::vector<FeatureMapEntry>;
333 
334 }  // namespace parsed
335 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)336 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
337   uint32 data;
338   protobuf_uint64 dummy;
339   switch (stream->ReadTag() & 0x7) {
340     case 0:  // varint
341       if (!stream->ReadVarint32(&data)) return false;
342       return true;
343     case 1:  // fixed64
344       if (!stream->ReadLittleEndian64(&dummy)) return false;
345       return true;
346     case 2:  // length delimited
347       if (!stream->ReadVarint32(&data)) return false;
348       stream->Skip(data);
349       return true;
350     case 3:          // group begin
351       return false;  // groups not supported.
352     case 4:          // group end
353       return false;  // groups not supported.
354     case 5:          // fixed32
355       if (!stream->ReadLittleEndian32(&data)) return false;
356       return true;
357   }
358   return false;  // unrecognized tag type
359 }
360 
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)361 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
362   DCHECK(stream != nullptr);
363   DCHECK(result != nullptr);
364   uint32 length;
365   if (!stream->ReadVarint32(&length)) return false;
366   if (length == 0) {
367     *result = StringPiece(nullptr, 0);
368     return true;
369   }
370   const void* stream_alias;
371   int stream_size;
372   if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
373     return false;
374   }
375   if (static_cast<uint32>(stream_size) < length) return false;
376   *result = StringPiece(static_cast<const char*>(stream_alias), length);
377   stream->Skip(length);
378   return true;
379 }
380 
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)381 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
382                           parsed::FeatureMapEntry* feature_map_entry) {
383   DCHECK(stream != nullptr);
384   DCHECK(feature_map_entry != nullptr);
385   uint32 length;
386   if (!stream->ReadVarint32(&length)) return false;
387   auto limit = stream->PushLimit(length);
388   if (!stream->ExpectTag(kDelimitedTag(1))) return false;
389   if (!ParseString(stream, &feature_map_entry->first)) return false;
390   if (!stream->ExpectTag(kDelimitedTag(2))) return false;
391   StringPiece feature_string_piece;
392   if (!ParseString(stream, &feature_string_piece)) return false;
393   feature_map_entry->second = parsed::Feature(feature_string_piece);
394   if (!stream->ExpectAtEnd()) return false;
395   stream->PopLimit(limit);
396   return true;
397 }
398 
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)399 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
400                    parsed::Example* example) {
401   DCHECK(stream != nullptr);
402   DCHECK(example != nullptr);
403   uint32 length;
404   if (!stream->ReadVarint32(&length)) return false;
405   auto limit = stream->PushLimit(length);
406   while (!stream->ExpectAtEnd()) {
407     parsed::FeatureMapEntry feature_map_entry;
408     if (!stream->ExpectTag(kDelimitedTag(1))) return false;
409     if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
410     example->push_back(std::move(feature_map_entry));
411   }
412   stream->PopLimit(limit);
413   return true;
414 }
415 
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)416 bool ParseExample(protobuf::io::CodedInputStream* stream,
417                   parsed::Example* example) {
418   DCHECK(stream != nullptr);
419   DCHECK(example != nullptr);
420   // Loop over the input stream which may contain multiple serialized Example
421   // protos merged together as strings. This behavior is consistent with Proto's
422   // ParseFromString when string representations are concatenated.
423   while (!stream->ExpectAtEnd()) {
424     if (!stream->ExpectTag(kDelimitedTag(1))) {
425       if (!SkipExtraneousTag(stream)) return false;
426     } else {
427       if (!ParseFeatures(stream, example)) return false;
428     }
429   }
430   return true;
431 }
432 
ParseExample(StringPiece serialized,parsed::Example * example)433 bool ParseExample(StringPiece serialized, parsed::Example* example) {
434   DCHECK(example != nullptr);
435   protobuf::io::CodedInputStream stream(
436       reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
437   EnableAliasing(&stream);
438   return ParseExample(&stream, example);
439 }
440 
441 }  // namespace
442 
TestFastParse(const string & serialized,Example * example)443 bool TestFastParse(const string& serialized, Example* example) {
444   DCHECK(example != nullptr);
445   parsed::Example parsed_example;
446   if (!ParseExample(serialized, &parsed_example)) return false;
447   auto& features = *example->mutable_features();
448   size_t parsed_example_size = parsed_example.size();
449   for (size_t i = 0; i < parsed_example_size; ++i) {
450     // This is a logic that standard protobuf parsing is implementing.
451     // I.e. last entry in the map overwrites all the previous ones.
452     parsed::FeatureMapEntry& name_and_feature =
453         parsed_example[parsed_example_size - i - 1];
454     string name(name_and_feature.first);
455     if ((*features.mutable_feature()).count(name) > 0) continue;
456 
457     auto& value = (*features.mutable_feature())[name];
458     DataType dtype;
459     if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
460     switch (dtype) {
461       case DT_INVALID:
462         break;
463       case DT_STRING: {
464         SmallVector<tstring> list;
465         if (!name_and_feature.second.ParseBytesList(&list)) return false;
466         auto* result_list = value.mutable_bytes_list();
467         for (auto& bytes : list) {
468           result_list->add_value(bytes.data(), bytes.size());
469         }
470         break;
471       }
472       case DT_FLOAT: {
473         SmallVector<float> list;
474         if (!name_and_feature.second.ParseFloatList(&list)) return false;
475         auto* result_list = value.mutable_float_list();
476         for (float f : list) {
477           result_list->add_value(f);
478         }
479         break;
480       }
481       case DT_INT64: {
482         SmallVector<int64> list;
483         if (!name_and_feature.second.ParseInt64List(&list)) return false;
484         auto* result_list = value.mutable_int64_list();
485         for (int64_t i : list) {
486           result_list->add_value(i);
487         }
488         break;
489       }
490       default:
491         LOG(FATAL) << "Should not happen.";
492     }
493   }
494   return true;
495 }
496 
497 // -----------------------------------------------------------------------------
498 
499 namespace {
500 
501 using Config = FastParseExampleConfig;
502 
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)503 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
504                  thread::ThreadPool* thread_pool) {
505   if (n == 0) return;
506   if (thread_pool == nullptr) {
507     for (size_t i = 0; i < n; ++i) {
508       f(i);
509     }
510   } else {
511     BlockingCounter counter(n - 1);
512     for (size_t i = 1; i < n; ++i) {
513       thread_pool->Schedule([i, &f, &counter] {
514         f(i);
515         counter.DecrementCount();
516       });
517     }
518     f(0);
519     counter.Wait();
520   }
521 }
522 
523 // Enumeration for distinguishing feature types.
524 // Note: FastParseSequenceExample constructs a map that includes Type values,
525 // and relies on the fact that they are default-initialized to Dense.
526 enum class Type { Dense, Sparse, Ragged };
527 
528 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
529 struct SparseBuffer {
530   // Features are in one of the 3 vectors below depending on config's dtype.
531   // Other 2 vectors remain empty.
532   SmallVector<tstring> bytes_list;
533   SmallVector<float> float_list;
534   SmallVector<int64> int64_list;
535 
536   // Features of example i are elements with indices
537   // from example_end_indices[i-1] to example_end_indices[i]-1 on the
538   // appropriate xxxxx_list
539   std::vector<size_t> example_end_indices;
540 };
541 
542 struct SeededHasher {
operator ()tensorflow::example::__anon549c22790211::SeededHasher543   uint64 operator()(StringPiece s) const {
544     return Hash64(s.data(), s.size(), seed);
545   }
546   uint64 seed{0xDECAFCAFFE};
547 };
548 
LogDenseFeatureDataLoss(StringPiece feature_name)549 void LogDenseFeatureDataLoss(StringPiece feature_name) {
550   LOG(WARNING) << "Data loss! Feature '" << feature_name
551                << "' is present in multiple concatenated "
552                   "tf.Examples. Ignoring all but last one.";
553   static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
554       "/tensorflow/core/util/example_proto_fast_parsing/"
555       "duplicated_dense_feature",
556       "Dense feature appears twice in a tf.Example");
557   duplicated_dense_feature->GetCell()->IncrementBy(1);
558 }
559 
LogSparseFeatureDataLoss(StringPiece feature_name)560 void LogSparseFeatureDataLoss(StringPiece feature_name) {
561   LOG(WARNING) << "Data loss! Feature '" << feature_name
562                << "' is present in multiple concatenated "
563                   "tf.Examples. Ignoring all but last one.";
564   static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
565       "/tensorflow/core/util/example_proto_fast_parsing/"
566       "duplicated_sparse_feature",
567       "Sparse feature appears twice in a tf.Example");
568   duplicated_sparse_feature->GetCell()->IncrementBy(1);
569 }
570 
FastParseSerializedExample(const tstring & serialized_example,const tstring & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::vector<SparseBuffer> * output_ragged,PerExampleFeatureStats * output_stats)571 Status FastParseSerializedExample(
572     const tstring& serialized_example, const tstring& example_name,
573     const size_t example_index, const Config& config,
574     const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
575     SeededHasher hasher, std::vector<Tensor>* output_dense,
576     std::vector<SparseBuffer>* output_varlen_dense,
577     std::vector<SparseBuffer>* output_sparse,
578     std::vector<SparseBuffer>* output_ragged,
579     PerExampleFeatureStats* output_stats) {
580   DCHECK(output_dense != nullptr);
581   DCHECK(output_sparse != nullptr);
582   DCHECK(output_ragged != nullptr);
583   parsed::Example parsed_example;
584   if (!ParseExample(serialized_example, &parsed_example)) {
585     return errors::InvalidArgument("Could not parse example input, value: '",
586                                    serialized_example, "'");
587   }
588   std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1);
589   std::vector<int64> dense_feature_last_example(config.dense.size(), -1);
590   std::vector<int64> ragged_feature_last_example(config.ragged.size(), -1);
591 
592   // Handle features present in the example.
593   const size_t parsed_example_size = parsed_example.size();
594 
595   if (output_stats) {
596     // TODO(b/111553342): This may over-count the number of features if there
597     // are duplicate keys in the feature map. Consider deduplicating the keys
598     // before computing the count.
599     output_stats->features_count = parsed_example_size;
600   }
601 
602   for (size_t i = 0; i < parsed_example_size; ++i) {
603     // This is a logic that standard protobuf parsing is implementing.
604     // I.e. last entry in the map overwrites all the previous ones.
605     parsed::FeatureMapEntry& name_and_feature =
606         parsed_example[parsed_example_size - i - 1];
607 
608     const StringPiece feature_name = name_and_feature.first;
609     parsed::Feature& feature = name_and_feature.second;
610 
611     std::pair<size_t, Type> d_and_type;
612     uint64 h = hasher(feature_name);
613     if (!config_index.Find(h, &d_and_type)) continue;
614 
615     size_t d = d_and_type.first;
616     bool is_dense = d_and_type.second == Type::Dense;
617     bool is_ragged = d_and_type.second == Type::Ragged;
618 
619     {
620       // Testing for PresizedCuckooMap collision.
621       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
622       const tstring& config_feature_name =
623           is_dense ? config.dense[d].feature_name
624                    : (is_ragged ? config.ragged[d].feature_name
625                                 : config.sparse[d].feature_name);
626       if (feature_name != config_feature_name) continue;
627     }
628 
629     auto example_error = [&](StringPiece suffix) {
630       return errors::InvalidArgument("Name: ", example_name,
631                                      ", Key: ", feature_name,
632                                      ", Index: ", example_index, ".  ", suffix);
633     };
634 
635     auto parse_error = [&] {
636       return example_error("Can't parse serialized Example.");
637     };
638 
639     DataType example_dtype;
640     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
641 
642     if (is_dense) {
643       if (example_dtype == DT_INVALID) continue;
644 
645       // If feature was already visited, skip.
646       // Compare comment at the beginning of the loop.
647       if (dense_feature_last_example[d] == example_index) {
648         LogDenseFeatureDataLoss(feature_name);
649         continue;
650       }
651       dense_feature_last_example[d] = example_index;
652 
653       if (example_dtype != config.dense[d].dtype) {
654         return example_error(strings::StrCat(
655             "Data types don't match. Data type: ",
656             DataTypeString(example_dtype),
657             " but expected type: ", DataTypeString(config.dense[d].dtype)));
658       }
659       if (!config.dense[d].variable_length) {
660         Tensor& out = (*output_dense)[d];
661 
662         const std::size_t num_elements = config.dense[d].elements_per_stride;
663         if (output_stats) {
664           // TODO(b/111553342): If desirable, we could add support for counting
665           // elements in the features that aren't parsed, but this could add
666           // considerable runtime cost.
667           output_stats->feature_values_count += num_elements;
668         }
669 
670         const std::size_t offset = example_index * num_elements;
671 
672         auto shape_error = [&](size_t size, StringPiece type_str) {
673           return example_error(strings::StrCat(
674               "Number of ", type_str,
675               " values != expected.  "
676               "Values size: ",
677               size,
678               " but output shape: ", config.dense[d].shape.DebugString()));
679         };
680 
681         switch (config.dense[d].dtype) {
682           case DT_INT64: {
683             auto out_p = out.flat<int64>().data() + offset;
684             LimitedArraySlice<int64> slice(out_p, num_elements);
685             if (!feature.ParseInt64List(&slice)) return parse_error();
686             if (slice.EndDistance() != 0) {
687               return shape_error(num_elements - slice.EndDistance(), "int64");
688             }
689             break;
690           }
691           case DT_FLOAT: {
692             auto out_p = out.flat<float>().data() + offset;
693             LimitedArraySlice<float> slice(out_p, num_elements);
694             if (!feature.ParseFloatList(&slice)) return parse_error();
695             if (slice.EndDistance() != 0) {
696               return shape_error(num_elements - slice.EndDistance(), "float");
697             }
698             break;
699           }
700           case DT_STRING: {
701             auto out_p = out.flat<tstring>().data() + offset;
702             LimitedArraySlice<tstring> slice(out_p, num_elements);
703             if (!feature.ParseBytesList(&slice)) return parse_error();
704             if (slice.EndDistance() != 0) {
705               return shape_error(num_elements - slice.EndDistance(), "bytes");
706             }
707             break;
708           }
709           default:
710             LOG(FATAL) << "Should not happen.";
711         }
712       } else {  // if variable length
713         SparseBuffer& out = (*output_varlen_dense)[d];
714 
715         const std::size_t num_elements = config.dense[d].elements_per_stride;
716 
717         if (example_dtype != DT_INVALID &&
718             example_dtype != config.dense[d].dtype) {
719           return example_error(strings::StrCat(
720               "Data types don't match. ",
721               "Expected type: ", DataTypeString(config.dense[d].dtype)));
722         }
723 
724         auto shape_error = [&](size_t size, StringPiece type_str) {
725           return example_error(strings::StrCat(
726               "Number of ", type_str,
727               " values is not a multiple of stride length. Saw ", size,
728               " values but output shape is: ",
729               config.dense[d].shape.DebugString()));
730         };
731 
732         switch (config.dense[d].dtype) {
733           case DT_INT64: {
734             if (example_dtype != DT_INVALID) {
735               if (!feature.ParseInt64List(&out.int64_list)) {
736                 return parse_error();
737               }
738               if (out.int64_list.size() % num_elements != 0) {
739                 return shape_error(out.int64_list.size(), "int64");
740               }
741             }
742             out.example_end_indices.push_back(out.int64_list.size());
743             break;
744           }
745           case DT_FLOAT: {
746             if (example_dtype != DT_INVALID) {
747               if (!feature.ParseFloatList(&out.float_list)) {
748                 return parse_error();
749               }
750               if (out.float_list.size() % num_elements != 0) {
751                 return shape_error(out.float_list.size(), "float");
752               }
753             }
754             out.example_end_indices.push_back(out.float_list.size());
755             break;
756           }
757           case DT_STRING: {
758             if (example_dtype != DT_INVALID) {
759               if (!feature.ParseBytesList(&out.bytes_list)) {
760                 return parse_error();
761               }
762               if (out.bytes_list.size() % num_elements != 0) {
763                 return shape_error(out.bytes_list.size(), "bytes");
764               }
765             }
766             out.example_end_indices.push_back(out.bytes_list.size());
767             break;
768           }
769           default:
770             LOG(FATAL) << "Should not happen.";
771         }
772 
773         if (output_stats) {
774           // Use `out.example_end_indices` to determine the feature-value count
775           // for this feature, because the preceding switch statement pushes
776           // the length of the appropriate feature list to that vector.
777           // TODO(b/111553342): If desirable, we could add support for counting
778           // elements in the features that aren't parsed, but this could add
779           // considerable runtime cost.
780           const size_t out_examples_count = out.example_end_indices.size();
781           if (out_examples_count == 1) {
782             output_stats->feature_values_count += out.example_end_indices[0];
783           } else {
784             output_stats->feature_values_count +=
785                 out.example_end_indices[out_examples_count - 1] -
786                 out.example_end_indices[out_examples_count - 2];
787           }
788         }
789       }
790     } else {
791       // Feature is sparse or ragged.
792       auto& last_example =
793           is_ragged ? ragged_feature_last_example : sparse_feature_last_example;
794 
795       // If feature was already visited, skip.
796       // Compare comment at the beginning of the loop.
797       if (last_example[d] == example_index) {
798         LogSparseFeatureDataLoss(feature_name);
799         continue;
800       }
801       last_example[d] = example_index;
802 
803       // Handle sparse features.
804       SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d];
805       DataType feature_dtype =
806           is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype;
807       if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
808         return example_error(
809             strings::StrCat("Data types don't match. ",
810                             "Expected type: ", DataTypeString(feature_dtype),
811                             ", Actual type: ", DataTypeString(example_dtype)));
812       }
813 
814       switch (feature_dtype) {
815         case DT_INT64: {
816           if (example_dtype != DT_INVALID) {
817             if (!feature.ParseInt64List(&out.int64_list)) {
818               return parse_error();
819             }
820           }
821           out.example_end_indices.push_back(out.int64_list.size());
822           break;
823         }
824         case DT_FLOAT: {
825           if (example_dtype != DT_INVALID) {
826             if (!feature.ParseFloatList(&out.float_list)) {
827               return parse_error();
828             }
829           }
830           out.example_end_indices.push_back(out.float_list.size());
831           break;
832         }
833         case DT_STRING: {
834           if (example_dtype != DT_INVALID) {
835             if (!feature.ParseBytesList(&out.bytes_list)) {
836               return parse_error();
837             }
838           }
839           out.example_end_indices.push_back(out.bytes_list.size());
840           break;
841         }
842         default:
843           LOG(FATAL) << "Should not happen.";
844       }
845 
846       if (output_stats) {
847         // Use `out.example_end_indices` to determine the feature-value count
848         // for this feature, because the preceding switch statement pushes
849         // the length of the appropriate feature list to that vector.
850         // TODO(b/111553342): If desirable, we could add support for counting
851         // elements in the features that aren't parsed, but this could add
852         // considerable runtime cost.
853         const size_t out_examples_count = out.example_end_indices.size();
854         if (out_examples_count == 1) {
855           output_stats->feature_values_count += out.example_end_indices[0];
856         } else {
857           output_stats->feature_values_count +=
858               out.example_end_indices[out_examples_count - 1] -
859               out.example_end_indices[out_examples_count - 2];
860         }
861       }
862     }
863   }
864 
865   // Handle missing dense features for fixed strides.
866   for (size_t d = 0; d < config.dense.size(); ++d) {
867     if (config.dense[d].variable_length) continue;
868     if (dense_feature_last_example[d] == example_index) continue;
869     if (config.dense[d].default_value.NumElements() == 0) {
870       return errors::InvalidArgument(
871           "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
872           " (data type: ", DataTypeString(config.dense[d].dtype), ")",
873           " is required but could not be found.");
874     }
875     const Tensor& in = config.dense[d].default_value;
876     Tensor& out = (*output_dense)[d];
877     const std::size_t num_elements = in.shape().num_elements();
878     const std::size_t offset = example_index * num_elements;
879 
880     switch (config.dense[d].dtype) {
881       case DT_INT64: {
882         std::copy_n(in.flat<int64>().data(), num_elements,
883                     out.flat<int64>().data() + offset);
884         break;
885       }
886       case DT_FLOAT: {
887         std::copy_n(in.flat<float>().data(), num_elements,
888                     out.flat<float>().data() + offset);
889         break;
890       }
891       case DT_STRING: {
892         std::copy_n(in.flat<tstring>().data(), num_elements,
893                     out.flat<tstring>().data() + offset);
894         break;
895       }
896       default:
897         LOG(FATAL) << "Should not happen.";
898     }
899   }
900 
901   // Handle missing varlen dense features.
902   for (size_t d = 0; d < config.dense.size(); ++d) {
903     if (!config.dense[d].variable_length) continue;
904     if (dense_feature_last_example[d] == example_index) continue;
905     SparseBuffer& out = (*output_varlen_dense)[d];
906     size_t prev_example_end_index =
907         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
908     out.example_end_indices.push_back(prev_example_end_index);
909   }
910 
911   // Handle missing sparse features.
912   for (size_t d = 0; d < config.sparse.size(); ++d) {
913     if (sparse_feature_last_example[d] == example_index) continue;
914     SparseBuffer& out = (*output_sparse)[d];
915     size_t prev_example_end_index =
916         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
917     out.example_end_indices.push_back(prev_example_end_index);
918   }
919 
920   // Handle missing ragged features.
921   for (size_t d = 0; d < config.ragged.size(); ++d) {
922     if (ragged_feature_last_example[d] == example_index) continue;
923     SparseBuffer& out = (*output_ragged)[d];
924     size_t prev_example_end_index =
925         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
926     out.example_end_indices.push_back(prev_example_end_index);
927   }
928 
929   return Status::OK();
930 }
931 
CheckConfigDataType(DataType dtype)932 Status CheckConfigDataType(DataType dtype) {
933   switch (dtype) {
934     case DT_INT64:
935     case DT_FLOAT:
936     case DT_STRING:
937       return Status::OK();
938     default:
939       return errors::InvalidArgument("Invalid config dtype: ",
940                                      DataTypeString(dtype));
941   }
942 }
943 
944 // Use this in the "default" clause of switch statements when dispatching
945 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)946 inline void ReportUnexpectedDataType(DataType dtype) {
947   DCHECK(false)
948       << "Encountered unexpected DataType " << DataTypeString(dtype)
949       << "in variable that should have been checked by CheckConfigDataType().";
950 }
951 
CheckConfigDataTypes(const Config & config)952 Status CheckConfigDataTypes(const Config& config) {
953   // Check config so we can safely CHECK(false) in switches on config.*.dtype
954   for (auto& c : config.sparse) {
955     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
956   }
957   for (auto& c : config.dense) {
958     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
959   }
960   for (auto& c : config.ragged) {
961     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
962     if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) {
963       return errors::InvalidArgument("Invalid ragged_split_type: ",
964                                      DataTypeString(c.splits_dtype));
965     }
966   }
967   return Status::OK();
968 }
969 
970 template <typename T>
971 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
972 
973 template <>
GetListFromBuffer(const SparseBuffer & buffer)974 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
975   return buffer.int64_list;
976 }
977 template <>
GetListFromBuffer(const SparseBuffer & buffer)978 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
979   return buffer.float_list;
980 }
981 template <>
GetListFromBuffer(const SparseBuffer & buffer)982 const SmallVector<tstring>& GetListFromBuffer<tstring>(
983     const SparseBuffer& buffer) {
984   return buffer.bytes_list;
985 }
986 
987 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)988 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
989   std::copy(b, e, t);
990 }
991 template <>
CopyOrMoveBlock(const tstring * b,const tstring * e,tstring * t)992 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
993   std::move(b, e, t);
994 }
995 
996 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)997 void FillAndCopyVarLen(
998     const int d, const size_t num_elements,
999     const size_t num_elements_per_minibatch, const Config& config,
1000     const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
1001     Tensor* values) {
1002   const Tensor& default_value = config.dense[d].default_value;
1003 
1004   // Copy-fill the tensors (creating the zero/fill-padding)
1005   std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
1006             default_value.flat<T>()(0));
1007 
1008   // Data is [batch_size, max_num_elements, data_stride_size]
1009   //   and num_elements_per_minibatch = max_num_elements * data_stride_size
1010   auto data = values->flat<T>().data();
1011 
1012   // Iterate over minibatch elements
1013   for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
1014     const SparseBuffer& buffer = varlen_dense_buffers[i][d];
1015     // Number of examples being stored in this buffer
1016     const auto& end_indices = buffer.example_end_indices;
1017     const size_t examples_in_buffer = end_indices.size();
1018     // const size_t stride_size = config.dense[d].elements_per_stride;
1019 
1020     const auto& list = GetListFromBuffer<T>(buffer);
1021     auto list_ptr = list.begin();
1022 
1023     size_t elements_tally = 0;
1024     // Iterate through all the examples stored in this buffer.
1025     for (size_t j = 0; j < examples_in_buffer; ++j) {
1026       // Number of elements stored for this example.
1027       const size_t num_elems = end_indices[j] - elements_tally;
1028       CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
1029       // Move forward this many elements in the varlen buffer.
1030       list_ptr += num_elems;
1031       // Move forward to the next minibatch entry in the values output.
1032       data += num_elements_per_minibatch;
1033       elements_tally = end_indices[j];
1034     }
1035     DCHECK(elements_tally == list.size());
1036   }
1037 }
1038 
1039 // Thin vector like interface wrapper around a Tensor. This enable us to
1040 // directly populate a tensor during parsing instead of having to first create a
1041 // vactor and then copy the data over.
1042 template <typename T>
1043 class TensorVector {
1044  public:
1045   using value_type = T;
1046 
tensor()1047   const Tensor& tensor() {
1048     if (!tensor_.has_value()) {
1049       resize(0);
1050     }
1051     return *tensor_;
1052   }
1053 
size() const1054   int64 size() const {
1055     return tensor_.has_value() ? tensor_->NumElements() : 0;
1056   }
resize(int64_t new_size)1057   void resize(int64_t new_size) {
1058     DCHECK(!tensor_.has_value());
1059     tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size}));
1060     data_ = tensor_->flat<T>().data();
1061   }
data()1062   T* data() { return data_; }
data() const1063   const T* data() const { return data_; }
1064 
1065  private:
1066   // Use absl::optional to avoid calling the default constructor of Tensor
1067   // unnecessarily.
1068   absl::optional<Tensor> tensor_;
1069 
1070   // Cached pointer to the raw data inside the tensor.
1071   T* data_ = nullptr;
1072 };
1073 
CountSparseFeatures(const std::vector<std::vector<SparseBuffer>> & sparse_buffers,size_t d,size_t * total_num_features,size_t * max_num_features)1074 void CountSparseFeatures(
1075     const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
1076     size_t* total_num_features, size_t* max_num_features) {
1077   for (auto& sparse_values_tmp : sparse_buffers) {
1078     const std::vector<size_t>& end_indices =
1079         sparse_values_tmp[d].example_end_indices;
1080     *total_num_features += end_indices.back();
1081     *max_num_features = std::max(*max_num_features, end_indices[0]);
1082     for (size_t i = 1; i < end_indices.size(); ++i) {
1083       size_t example_size = end_indices[i] - end_indices[i - 1];
1084       *max_num_features = std::max(*max_num_features, example_size);
1085     }
1086   }
1087 }
1088 
CopySparseBufferToTensor(DataType dtype,size_t offset,SparseBuffer * src,Tensor * dst)1089 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
1090                               Tensor* dst) {
1091   switch (dtype) {
1092     case DT_INT64: {
1093       std::copy(src->int64_list.begin(), src->int64_list.end(),
1094                 dst->flat<int64>().data() + offset);
1095       break;
1096     }
1097     case DT_FLOAT: {
1098       std::copy(src->float_list.begin(), src->float_list.end(),
1099                 dst->flat<float>().data() + offset);
1100       break;
1101     }
1102     case DT_STRING: {
1103       std::move(src->bytes_list.begin(), src->bytes_list.end(),
1104                 dst->flat<tstring>().data() + offset);
1105       break;
1106     }
1107     default:
1108       ReportUnexpectedDataType(dtype);
1109   }
1110 }
1111 
1112 }  // namespace
1113 
FastParseExample(const Config & config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * result)1114 Status FastParseExample(const Config& config,
1115                         gtl::ArraySlice<tstring> serialized,
1116                         gtl::ArraySlice<tstring> example_names,
1117                         thread::ThreadPool* thread_pool, Result* result) {
1118   DCHECK(result != nullptr);
1119   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1120   TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1121 
1122   if (config.collect_feature_stats) {
1123     result->feature_stats.resize(serialized.size());
1124   }
1125 
1126   size_t config_size =
1127       config.dense.size() + config.sparse.size() + config.ragged.size();
1128   SeededHasher hasher;
1129   // Build config index.
1130   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1131   bool ok = true;
1132   for (size_t i = 0; i < 1000; ++i) {
1133     for (size_t d = 0; d < config.dense.size(); ++d) {
1134       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1135                                       {d, Type::Dense});
1136     }
1137     for (size_t d = 0; d < config.sparse.size(); ++d) {
1138       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1139                                       {d, Type::Sparse});
1140     }
1141     for (size_t d = 0; d < config.ragged.size(); ++d) {
1142       ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1143                                       {d, Type::Ragged});
1144     }
1145     if (ok) break;
1146     LOG(WARNING) << "Collision found. This should happen only if you have "
1147                     "around 2^32 entries in your config.";
1148     hasher.seed++;
1149     config_index.Clear(config_size);
1150     ok = true;
1151   }
1152   if (!ok) {
1153     return errors::Internal(
1154         "Could not avoid collision. This should not happen.");
1155   }
1156 
1157   // Allocate dense output for fixed length dense values
1158   // (variable-length dense and sparse and ragged have to be buffered).
1159   std::vector<Tensor> fixed_dense_values(config.dense.size());
1160   for (size_t d = 0; d < config.dense.size(); ++d) {
1161     if (config.dense[d].variable_length) continue;
1162     TensorShape out_shape;
1163     out_shape.AddDim(serialized.size());
1164     for (const int64_t dim : config.dense[d].shape.dim_sizes()) {
1165       out_shape.AddDim(dim);
1166     }
1167     fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1168   }
1169 
1170   // This parameter affects performance in a big and data-dependent way.
1171   const size_t kMiniBatchSizeBytes = 50000;
1172 
1173   // Calculate number of minibatches.
1174   // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1175   // Apply 'special logic' below for small and big regimes.
1176   const size_t num_minibatches = [&] {
1177     size_t result = 0;
1178     size_t minibatch_bytes = 0;
1179     for (size_t i = 0; i < serialized.size(); i++) {
1180       if (minibatch_bytes == 0) {  // start minibatch
1181         result++;
1182       }
1183       minibatch_bytes += serialized[i].size() + 1;
1184       if (minibatch_bytes > kMiniBatchSizeBytes) {
1185         minibatch_bytes = 0;
1186       }
1187     }
1188     // 'special logic'
1189     const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1190     const size_t max_minibatches = 64;
1191     return std::max<size_t>(min_minibatches,
1192                             std::min<size_t>(max_minibatches, result));
1193   }();
1194 
1195   auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1196     return (serialized.size() * minibatch) / num_minibatches;
1197   };
1198 
1199   // TODO(lew): A big performance low-hanging fruit here is to improve
1200   //   num_minibatches calculation to take into account actual amount of work
1201   //   needed, as the size in bytes is not perfect. Linear combination of
1202   //   size in bytes and average number of features per example is promising.
1203   //   Even better: measure time instead of estimating, but this is too costly
1204   //   in small batches.
1205   //   Maybe accept outside parameter #num_minibatches?
1206 
1207   // Do minibatches in parallel.
1208   std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1209   std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1210   std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches);
1211   std::vector<Status> status_of_minibatch(num_minibatches);
1212   auto ProcessMiniBatch = [&](size_t minibatch) {
1213     sparse_buffers[minibatch].resize(config.sparse.size());
1214     varlen_dense_buffers[minibatch].resize(config.dense.size());
1215     ragged_buffers[minibatch].resize(config.ragged.size());
1216     size_t start = first_example_of_minibatch(minibatch);
1217     size_t end = first_example_of_minibatch(minibatch + 1);
1218     for (size_t e = start; e < end; ++e) {
1219       PerExampleFeatureStats* stats = nullptr;
1220       if (config.collect_feature_stats) {
1221         stats = &result->feature_stats[e];
1222       }
1223       status_of_minibatch[minibatch] = FastParseSerializedExample(
1224           serialized[e],
1225           (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1226           config_index, hasher, &fixed_dense_values,
1227           &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch],
1228           &ragged_buffers[minibatch], stats);
1229       if (!status_of_minibatch[minibatch].ok()) break;
1230     }
1231   };
1232 
1233   ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1234 
1235   for (Status& status : status_of_minibatch) {
1236     TF_RETURN_IF_ERROR(status);
1237   }
1238 
1239   result->sparse_indices.reserve(config.sparse.size());
1240   result->sparse_values.reserve(config.sparse.size());
1241   result->sparse_shapes.reserve(config.sparse.size());
1242   result->dense_values.reserve(config.dense.size());
1243   result->ragged_values.reserve(config.ragged.size());
1244   result->ragged_splits.reserve(config.ragged.size());
1245 
1246   for (size_t d = 0; d < config.dense.size(); ++d) {
1247     result->dense_values.push_back(std::move(fixed_dense_values[d]));
1248   }
1249 
1250   // Merge SparseBuffers from all minibatches for every config.sparse.
1251   auto MergeSparseMinibatches = [&](size_t d) {
1252     // Loop over minibatches
1253     size_t total_num_features = 0;
1254     size_t max_num_features = 0;
1255     CountSparseFeatures(sparse_buffers, d, &total_num_features,
1256                         &max_num_features);
1257 
1258     TensorShape indices_shape;
1259     indices_shape.AddDim(total_num_features);
1260     indices_shape.AddDim(2);
1261     result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1262     Tensor* indices = &result->sparse_indices.back();
1263 
1264     TensorShape values_shape;
1265     values_shape.AddDim(total_num_features);
1266     result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1267     Tensor* values = &result->sparse_values.back();
1268 
1269     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1270     auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
1271     shapes_shape_t(0) = serialized.size();
1272     shapes_shape_t(1) = max_num_features;
1273 
1274     size_t offset = 0;
1275     for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1276       SparseBuffer& buffer = sparse_buffers[i][d];
1277 
1278       // Update indices.
1279       size_t delta = 0;
1280 
1281       if (indices->NumElements() > 0) {
1282         int64* ix_p = &indices->matrix<int64>()(offset, 0);
1283         size_t example_index = first_example_of_minibatch(i);
1284         for (size_t example_end_index : buffer.example_end_indices) {
1285           size_t feature_index = 0;
1286           for (; delta < example_end_index; ++delta) {
1287             // Column 0: example index
1288             *ix_p = example_index;
1289             // Column 1: the feature index buffer example
1290             *(ix_p + 1) = feature_index;
1291             ix_p += 2;
1292             ++feature_index;
1293           }
1294           ++example_index;
1295         }
1296       }
1297 
1298       CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values);
1299       offset += delta;
1300     }
1301   };
1302 
1303   // Merge SparseBuffers from all minibatches for every config.ragged.
1304   auto MergeRaggedMinibatches = [&](size_t d) {
1305     // Loop over minibatches
1306     size_t total_num_features = 0;
1307     size_t max_num_features = 0;
1308     CountSparseFeatures(ragged_buffers, d, &total_num_features,
1309                         &max_num_features);
1310 
1311     TensorShape row_splits_shape;
1312     row_splits_shape.AddDim(serialized.size() + 1);
1313     result->ragged_splits.emplace_back(config.ragged[d].splits_dtype,
1314                                        row_splits_shape);
1315     Tensor* row_splits = &result->ragged_splits.back();
1316     if (config.ragged[d].splits_dtype == DT_INT64) {
1317       row_splits->flat<int64>()(0) = 0;
1318     } else {
1319       row_splits->flat<int32>()(0) = 0;
1320     }
1321 
1322     TensorShape values_shape;
1323     values_shape.AddDim(total_num_features);
1324     result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape);
1325     Tensor* values = &result->ragged_values.back();
1326 
1327     size_t values_offset = 0;
1328     size_t splits_offset = 0;
1329     for (size_t i = 0; i < ragged_buffers.size(); ++i) {
1330       SparseBuffer& buffer = ragged_buffers[i][d];
1331       if (buffer.example_end_indices.empty()) continue;
1332 
1333       // Update row_splits.  row_splits are formed by concatenating the example
1334       // end_indices (adjusting each to start after the previous one ends).
1335       if (config.ragged[d].splits_dtype == DT_INT64) {
1336         int64* row_splits_out = &row_splits->flat<int64>()(splits_offset);
1337         int64_t start = *row_splits_out;
1338         for (size_t example_end_index : buffer.example_end_indices) {
1339           *++row_splits_out = start + example_end_index;
1340         }
1341       } else {
1342         int32* row_splits_out = &row_splits->flat<int32>()(splits_offset);
1343         int32_t start = *row_splits_out;
1344         for (size_t example_end_index : buffer.example_end_indices) {
1345           *++row_splits_out = start + example_end_index;
1346         }
1347       }
1348 
1349       CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer,
1350                                values);
1351       values_offset += buffer.example_end_indices.back();
1352       splits_offset += buffer.example_end_indices.size();
1353     }
1354   };
1355 
1356   // Merge SparseBuffers from all minibatches for every config.dense having
1357   // variable_length.
1358   auto MergeDenseVarLenMinibatches = [&](size_t d) {
1359     if (!config.dense[d].variable_length) return;
1360 
1361     // Loop over minibatches
1362     size_t max_num_features = 0;
1363     for (auto& dense_values_tmp : varlen_dense_buffers) {
1364       std::vector<size_t>& end_indices =
1365           dense_values_tmp[d].example_end_indices;
1366       max_num_features = std::max(max_num_features, end_indices[0]);
1367       for (size_t i = 1; i < end_indices.size(); ++i) {
1368         size_t example_size = end_indices[i] - end_indices[i - 1];
1369         max_num_features = std::max(max_num_features, example_size);
1370       }
1371     }
1372 
1373     const size_t stride_size = config.dense[d].elements_per_stride;
1374     const size_t max_num_elements = max_num_features / stride_size;
1375     TensorShape values_shape;
1376     DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1377     const size_t batch_size = serialized.size();
1378     values_shape.AddDim(batch_size);
1379     values_shape.AddDim(max_num_elements);
1380     for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1381       values_shape.AddDim(config.dense[d].shape.dim_size(i));
1382     }
1383     Tensor values(config.dense[d].dtype, values_shape);
1384     result->dense_values[d] = values;
1385     const size_t num_elements = values.NumElements();
1386 
1387     // Nothing to write, exit early.
1388     if (num_elements == 0) return;
1389 
1390     const size_t num_elements_per_minibatch = num_elements / batch_size;
1391 
1392     switch (config.dense[d].dtype) {
1393       case DT_INT64: {
1394         FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
1395                                  config, varlen_dense_buffers, &values);
1396         break;
1397       }
1398       case DT_FLOAT: {
1399         FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1400                                  config, varlen_dense_buffers, &values);
1401         break;
1402       }
1403       case DT_STRING: {
1404         FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
1405                                    config, varlen_dense_buffers, &values);
1406         break;
1407       }
1408       default:
1409         ReportUnexpectedDataType(config.dense[d].dtype);
1410     }
1411   };
1412 
1413   for (size_t d = 0; d < config.dense.size(); ++d) {
1414     MergeDenseVarLenMinibatches(d);
1415   }
1416 
1417   for (size_t d = 0; d < config.sparse.size(); ++d) {
1418     MergeSparseMinibatches(d);
1419   }
1420 
1421   for (size_t d = 0; d < config.ragged.size(); ++d) {
1422     MergeRaggedMinibatches(d);
1423   }
1424 
1425   return Status::OK();
1426 }
1427 
FastParseSingleExample(const Config & config,StringPiece serialized,Result * result)1428 Status FastParseSingleExample(const Config& config, StringPiece serialized,
1429                               Result* result) {
1430   DCHECK(result != nullptr);
1431   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1432   TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1433 
1434   PerExampleFeatureStats* stats = nullptr;
1435   if (config.collect_feature_stats) {
1436     result->feature_stats.emplace_back();
1437     stats = &result->feature_stats.back();
1438   }
1439 
1440   // TODO(mrry): Cache the construction of this map at Op construction time.
1441   size_t config_size =
1442       config.dense.size() + config.sparse.size() + config.ragged.size();
1443   SeededHasher hasher;
1444   // Build config index.
1445   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1446   bool ok = true;
1447   for (size_t i = 0; i < 1000; ++i) {
1448     for (size_t d = 0; d < config.dense.size(); ++d) {
1449       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1450                                       {d, Type::Dense});
1451     }
1452     for (size_t d = 0; d < config.sparse.size(); ++d) {
1453       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1454                                       {d, Type::Sparse});
1455     }
1456     for (size_t d = 0; d < config.ragged.size(); ++d) {
1457       ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1458                                       {d, Type::Ragged});
1459     }
1460     if (ok) break;
1461     LOG(WARNING) << "Collision found. This should happen only if you have "
1462                     "around 2^32 entries in your config.";
1463     hasher.seed++;
1464     config_index.Clear(config_size);
1465     ok = true;
1466   }
1467   if (!ok) {
1468     return errors::Internal(
1469         "Could not avoid collision. This should not happen.");
1470   }
1471 
1472   result->sparse_indices.reserve(config.sparse.size());
1473   result->sparse_values.reserve(config.sparse.size());
1474   result->sparse_shapes.reserve(config.sparse.size());
1475   result->dense_values.reserve(config.dense.size());
1476   result->ragged_values.reserve(config.ragged.size());
1477   result->ragged_splits.reserve(config.ragged.size());
1478 
1479   // Allocate dense output tensors.
1480   for (size_t d = 0; d < config.dense.size(); ++d) {
1481     if (!config.dense[d].variable_length) {
1482       TensorShape values_shape;
1483       if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1484         return errors::Internal(
1485             "Fixed-length shape was not a statically defined shape.");
1486       }
1487       result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1488     } else {
1489       // Variable-length tensor will be allocated later.
1490       result->dense_values.emplace_back();
1491     }
1492   }
1493 
1494   // Allocate sparse output tensors.
1495   for (size_t d = 0; d < config.sparse.size(); ++d) {
1496     // The dense_shape is always a vector of length 1.
1497     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1498     // Variable-length tensors will be allocated later.
1499     result->sparse_indices.emplace_back();
1500     result->sparse_values.emplace_back();
1501   }
1502 
1503   // Allocate ragged output tensors.
1504   for (size_t d = 0; d < config.ragged.size(); ++d) {
1505     // Variable-length values tensors will be allocated later.
1506     result->ragged_values.emplace_back();
1507     // Splits tensors are empty (unused) for single (scalar) inputs.
1508     const auto splits_dtype = config.ragged[d].splits_dtype;
1509     result->ragged_splits.emplace_back(splits_dtype, TensorShape({0}));
1510   }
1511 
1512   parsed::Example parsed_example;
1513   if (!ParseExample(serialized, &parsed_example)) {
1514     return errors::InvalidArgument("Could not parse example input, value: '",
1515                                    serialized, "'");
1516   }
1517   std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1518   std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1519   std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false);
1520 
1521   if (stats) {
1522     // TODO(b/111553342): This may over-count the number of features if there
1523     // are duplicate keys in the feature map. Consider deduplicating the keys
1524     // before computing the count.
1525     stats->features_count = parsed_example.size();
1526   }
1527 
1528   // Handle features present in the example.
1529   const size_t parsed_example_size = parsed_example.size();
1530   for (size_t i = 0; i < parsed_example_size; ++i) {
1531     // This is a logic that standard protobuf parsing is implementing.
1532     // I.e. last entry in the map overwrites all the previous ones.
1533     parsed::FeatureMapEntry& name_and_feature =
1534         parsed_example[parsed_example_size - i - 1];
1535 
1536     const StringPiece feature_name = name_and_feature.first;
1537     parsed::Feature& feature = name_and_feature.second;
1538 
1539     std::pair<size_t, Type> d_and_type;
1540     uint64 h = hasher(feature_name);
1541     if (!config_index.Find(h, &d_and_type)) continue;
1542 
1543     size_t d = d_and_type.first;
1544     bool is_dense = d_and_type.second == Type::Dense;
1545     bool is_sparse = d_and_type.second == Type::Sparse;
1546 
1547     {
1548       // Testing for PresizedCuckooMap collision.
1549       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1550       const tstring& config_feature_name =
1551           is_dense ? config.dense[d].feature_name
1552                    : (is_sparse ? config.sparse[d].feature_name
1553                                 : config.ragged[d].feature_name);
1554       if (feature_name != config_feature_name) continue;
1555     }
1556 
1557     auto example_error = [feature_name](StringPiece suffix) {
1558       return errors::InvalidArgument("Key: ", feature_name, ".  ", suffix);
1559     };
1560 
1561     auto parse_error = [feature_name] {
1562       return errors::InvalidArgument("Key: ", feature_name,
1563                                      ".  Can't parse serialized Example.");
1564     };
1565 
1566     DataType example_dtype;
1567     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1568     if (example_dtype == DT_INVALID) continue;
1569 
1570     if (is_dense && !config.dense[d].variable_length) {
1571       // If feature was already visited, skip.
1572       // Compare comment at the beginning of the loop.
1573       if (dense_feature_already_seen[d]) {
1574         LogDenseFeatureDataLoss(feature_name);
1575         continue;
1576       }
1577       dense_feature_already_seen[d] = true;
1578 
1579       if (example_dtype != config.dense[d].dtype) {
1580         return example_error(strings::StrCat(
1581             "Data types don't match. Data type: ",
1582             DataTypeString(example_dtype),
1583             " but expected type: ", DataTypeString(config.dense[d].dtype)));
1584       }
1585 
1586       Tensor* out = &result->dense_values[d];
1587       const std::size_t num_elements = config.dense[d].elements_per_stride;
1588       if (stats) {
1589         // TODO(b/111553342): If desirable, we could add support for counting
1590         // elements in the features that aren't parsed, but this could add
1591         // considerable runtime cost.
1592         stats->feature_values_count += num_elements;
1593       }
1594       switch (example_dtype) {
1595         case DT_INT64: {
1596           auto out_p = out->flat<int64>().data();
1597           LimitedArraySlice<int64> slice(out_p, num_elements);
1598           if (!feature.ParseInt64List(&slice)) return parse_error();
1599           if (slice.EndDistance() != 0) {
1600             return parse_error();
1601           }
1602           break;
1603         }
1604         case DT_FLOAT: {
1605           auto out_p = out->flat<float>().data();
1606           LimitedArraySlice<float> slice(out_p, num_elements);
1607           if (!feature.ParseFloatList(&slice)) return parse_error();
1608           if (slice.EndDistance() != 0) {
1609             return parse_error();
1610           }
1611           break;
1612         }
1613         case DT_STRING: {
1614           auto out_p = out->flat<tstring>().data();
1615           LimitedArraySlice<tstring> slice(out_p, num_elements);
1616           if (!feature.ParseBytesList(&slice)) return parse_error();
1617           if (slice.EndDistance() != 0) {
1618             return parse_error();
1619           }
1620           break;
1621         }
1622         default:
1623           ReportUnexpectedDataType(example_dtype);
1624       }
1625 
1626     } else {  // if variable length
1627       SmallVector<tstring> bytes_list;
1628       TensorVector<float> float_list;
1629       SmallVector<int64> int64_list;
1630 
1631       const size_t num_elements_divisor =
1632           is_dense ? config.dense[d].elements_per_stride : 1;
1633       size_t num_elements;
1634 
1635       if (is_dense) {
1636         // If feature was already visited, skip.
1637         // Compare comment at the beginning of the loop.
1638         if (dense_feature_already_seen[d]) {
1639           LogDenseFeatureDataLoss(feature_name);
1640           continue;
1641         }
1642         dense_feature_already_seen[d] = true;
1643         if (example_dtype != config.dense[d].dtype) {
1644           return example_error(strings::StrCat(
1645               "Data types don't match. Data type: ",
1646               DataTypeString(example_dtype),
1647               " but expected type: ", DataTypeString(config.dense[d].dtype)));
1648         }
1649       } else {
1650         // Feature is sparse or ragged.
1651         auto& feature_already_seen = is_sparse ? sparse_feature_already_seen
1652                                                : ragged_feature_already_seen;
1653         auto& feature_dtype =
1654             is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype;
1655         // If feature was already visited, skip.
1656         // Compare comment at the beginning of the loop.
1657         if (feature_already_seen[d]) {
1658           LogSparseFeatureDataLoss(feature_name);
1659           continue;
1660         }
1661         feature_already_seen[d] = true;
1662 
1663         // Handle sparse features.
1664         if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
1665           return example_error(strings::StrCat(
1666               "Data types don't match. ",
1667               "Expected type: ", DataTypeString(feature_dtype),
1668               ", Actual type: ", DataTypeString(example_dtype)));
1669         }
1670       }
1671 
1672       switch (example_dtype) {
1673         case DT_INT64: {
1674           // TODO(mrry): Use the fact that the `int64_list` is packed to read
1675           // out the length and pre-allocate the output tensor.
1676           if (!feature.ParseInt64List(&int64_list)) return parse_error();
1677           num_elements = int64_list.size();
1678           break;
1679         }
1680         case DT_FLOAT: {
1681           if (!feature.ParseFloatList(&float_list)) return parse_error();
1682           num_elements = float_list.size();
1683           break;
1684         }
1685         case DT_STRING: {
1686           int actual_num_elements = 0;
1687           if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1688             return parse_error();
1689           }
1690           bytes_list.reserve(actual_num_elements);
1691           if (!feature.ParseBytesList(&bytes_list)) return parse_error();
1692           num_elements = bytes_list.size();
1693           break;
1694         }
1695         default:
1696           num_elements = 0;
1697           ReportUnexpectedDataType(example_dtype);
1698       }
1699 
1700       if (num_elements % num_elements_divisor != 0) {
1701         return parse_error();
1702       }
1703 
1704       if (stats) {
1705         stats->feature_values_count += num_elements;
1706       }
1707 
1708       Tensor* out;
1709       DataType out_dtype;
1710       TensorShape out_shape;
1711       if (is_dense) {
1712         out_shape.AddDim(num_elements / num_elements_divisor);
1713         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1714           out_shape.AddDim(config.dense[d].shape.dim_size(i));
1715         }
1716 
1717         out = &result->dense_values[d];
1718         out_dtype = config.dense[d].dtype;
1719       } else if (is_sparse) {
1720         Tensor* out_indices = &result->sparse_indices[d];
1721         Tensor* out_dense_shape = &result->sparse_shapes[d];
1722 
1723         // TODO(mrry): Investigate the possibility of not materializing
1724         // the indices (and perhaps dense_shape) until they are needed.
1725         *out_indices = Tensor(
1726             DT_INT64, TensorShape({static_cast<int64>(num_elements), 1}));
1727         auto indices_flat = out_indices->flat<int64>();
1728         for (size_t i = 0; i < num_elements; ++i) {
1729           indices_flat(i) = static_cast<int64>(i);
1730         }
1731 
1732         *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1733         auto shapes_shape_t = out_dense_shape->vec<int64>();
1734         shapes_shape_t(0) = num_elements;
1735 
1736         out = &result->sparse_values[d];
1737         out_dtype = config.sparse[d].dtype;
1738         out_shape.AddDim(num_elements);
1739       } else {
1740         out = &result->ragged_values[d];
1741         out_dtype = config.ragged[d].dtype;
1742         out_shape.AddDim(num_elements);
1743       }
1744 
1745       switch (example_dtype) {
1746         case DT_INT64: {
1747           *out = Tensor(out_dtype, out_shape);
1748           CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1749                           out->flat<int64>().data());
1750           break;
1751         }
1752         case DT_FLOAT: {
1753           if (!out->CopyFrom(float_list.tensor(), out_shape)) {
1754             return parse_error();
1755           }
1756           break;
1757         }
1758         case DT_STRING: {
1759           *out = Tensor(out_dtype, out_shape);
1760           CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
1761                           out->flat<tstring>().data());
1762           break;
1763         }
1764         default:
1765           ReportUnexpectedDataType(example_dtype);
1766       }
1767     }
1768   }
1769 
1770   // Handle missing dense features.
1771   for (size_t d = 0; d < config.dense.size(); ++d) {
1772     if (!dense_feature_already_seen[d]) {
1773       if (!config.dense[d].variable_length) {
1774         // Handle missing fixed-length dense feature.
1775         if (config.dense[d].default_value.NumElements() == 0) {
1776           return errors::InvalidArgument(
1777               "Feature: ", config.dense[d].feature_name,
1778               " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1779               " is required but could not be found.");
1780         }
1781         result->dense_values[d] = config.dense[d].default_value;
1782       } else {
1783         // Handle missing varlen dense feature.
1784         TensorShape empty_shape;
1785         empty_shape.AddDim(0);
1786         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1787           empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1788         }
1789         result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1790       }
1791     }
1792   }
1793 
1794   // Handle missing sparse features.
1795   for (size_t d = 0; d < config.sparse.size(); ++d) {
1796     if (!sparse_feature_already_seen[d]) {
1797       result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1798       result->sparse_values[d] =
1799           Tensor(config.sparse[d].dtype, TensorShape({0}));
1800       result->sparse_shapes[d].vec<int64>()(0) = 0;
1801     }
1802   }
1803 
1804   // Handle missing ragged features.
1805   for (size_t d = 0; d < config.ragged.size(); ++d) {
1806     if (!ragged_feature_already_seen[d]) {
1807       result->ragged_values[d] =
1808           Tensor(config.ragged[d].dtype, TensorShape({0}));
1809     }
1810   }
1811 
1812   return Status::OK();
1813 }
1814 
1815 // Private helper functions for FastParseSequenceExample.
1816 namespace {
1817 
1818 // A struct used by FastParseSequenceExample to hold the serialized proto
1819 // substrings for a single feature, plus some auxiliary information derived
1820 // from those protos (such as the total value length).
1821 struct FeatureProtos {
1822   // Proto substrings from each serialized SequenceExample that correspond
1823   // with this feature.  `protos_present` records whether the proto had a
1824   // value defined (even if that value is empty).
1825   std::vector<StringPiece> protos;
1826   std::vector<bool> protos_present;
1827 
1828   // Information derived from protos:
1829   size_t length;    // total length for ragged/sparse, max row length for dense.
1830   size_t num_rows;  // only populated for ragged sequence features.
1831 
1832   // Information from the config:
1833   Type type;  // Whether this feature is sparse, ragged, or dense.
1834   DataType dtype;
1835 };
1836 
1837 // Map from feature name to FeatureProtos for that feature.
1838 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
1839 
ExampleName(const gtl::ArraySlice<tstring> example_names,int n)1840 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
1841   return example_names.empty() ? "<unknown>" : example_names[n];
1842 }
1843 
1844 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1845 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)1846 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1847                              tstring* out) {
1848   int num_elements = 0;
1849   uint32 length;
1850   if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1851     return -1;
1852   }
1853   if (length > 0) {
1854     auto limit = stream->PushLimit(length);
1855     while (!stream->ExpectAtEnd()) {
1856       uint32 bytes_length;
1857       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1858           !stream->ReadVarint32(&bytes_length)) {
1859         return -1;
1860       }
1861       if (out == nullptr) {
1862         stream->Skip(bytes_length);
1863       } else {
1864         out->resize_uninitialized(bytes_length);
1865         if (!stream->ReadRaw(out->data(), bytes_length)) {
1866           return -1;
1867         }
1868         out++;
1869       }
1870       num_elements++;
1871     }
1872     stream->PopLimit(limit);
1873   }
1874   return num_elements;
1875 }
1876 
PadFloatFeature(int num_to_pad,float * out)1877 inline void PadFloatFeature(int num_to_pad, float* out) {
1878   for (int i = 0; i < num_to_pad; i++) {
1879     *out++ = 0.0;
1880   }
1881 }
1882 
PadInt64Feature(int num_to_pad,int64 * out)1883 inline void PadInt64Feature(int num_to_pad, int64* out) {
1884   for (int i = 0; i < num_to_pad; i++) {
1885     *out++ = 0;
1886   }
1887 }
1888 
1889 // Return the number of float elements parsed, or -1 on error. If out is null,
1890 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1891 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1892                              float* out) {
1893   int num_elements = 0;
1894   uint32 length;
1895   if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1896     return -1;
1897   }
1898   if (length > 0) {
1899     auto limit = stream->PushLimit(length);
1900     uint8 peek_tag = PeekTag(stream);
1901     if (peek_tag == kDelimitedTag(1)) {  // packed
1902       uint32 packed_length;
1903       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1904           !stream->ReadVarint32(&packed_length)) {
1905         return -1;
1906       }
1907       auto packed_limit = stream->PushLimit(packed_length);
1908       while (!stream->ExpectAtEnd()) {
1909         uint32 buffer32;
1910         if (!stream->ReadLittleEndian32(&buffer32)) {
1911           return -1;
1912         }
1913         if (out != nullptr) {
1914           *out++ = absl::bit_cast<float>(buffer32);
1915         }
1916         num_elements++;
1917       }
1918       stream->PopLimit(packed_limit);
1919     } else if (peek_tag == kFixed32Tag(1)) {
1920       while (!stream->ExpectAtEnd()) {
1921         uint32 buffer32;
1922         if (!stream->ExpectTag(kFixed32Tag(1)) ||
1923             !stream->ReadLittleEndian32(&buffer32)) {
1924           return -1;
1925         }
1926         if (out != nullptr) {
1927           *out++ = absl::bit_cast<float>(buffer32);
1928         }
1929         num_elements++;
1930       }
1931     } else {
1932       // Unknown tag.
1933       return -1;
1934     }
1935     stream->PopLimit(limit);
1936   }
1937   return num_elements;
1938 }
1939 
1940 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1941 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)1942 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1943                              int64* out) {
1944   int num_elements = 0;
1945   uint32 length;
1946   if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1947     return -1;
1948   }
1949   if (length > 0) {
1950     auto limit = stream->PushLimit(length);
1951     uint8 peek_tag = PeekTag(stream);
1952     if (peek_tag == kDelimitedTag(1)) {  // packed
1953       uint32 packed_length;
1954       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1955           !stream->ReadVarint32(&packed_length)) {
1956         return -1;
1957       }
1958       auto packed_limit = stream->PushLimit(packed_length);
1959       while (!stream->ExpectAtEnd()) {
1960         protobuf_uint64 n;  // There is no API for int64
1961         if (!stream->ReadVarint64(&n)) {
1962           return -1;
1963         }
1964         if (out != nullptr) {
1965           *out++ = n;
1966         }
1967         num_elements++;
1968       }
1969       stream->PopLimit(packed_limit);
1970     } else if (peek_tag == kVarintTag(1)) {
1971       while (!stream->ExpectAtEnd()) {
1972         protobuf_uint64 n;  // There is no API for int64
1973         if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1974           return -1;
1975         }
1976         if (out != nullptr) {
1977           *out++ = n;
1978         }
1979         num_elements++;
1980       }
1981     } else {
1982       // Unknown tag.
1983       return -1;
1984     }
1985     stream->PopLimit(limit);
1986   }
1987   return num_elements;
1988 }
1989 
1990 // Parses the next feature on `stream` into `out` starting at `out_offset`.
1991 // Updates `out_offset`, and returns the number of values added.
1992 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)1993 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
1994                         Tensor* out, size_t* out_offset) {
1995   int delta;
1996   switch (dtype) {
1997     case DT_STRING:
1998       delta =
1999           ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
2000       break;
2001     case DT_FLOAT:
2002       delta =
2003           ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
2004       break;
2005     case DT_INT64:
2006       delta =
2007           ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
2008       break;
2009     default:
2010       ReportUnexpectedDataType(dtype);
2011       delta = 0;
2012   }
2013   if (delta > 0) {
2014     *out_offset += delta;
2015   }
2016   return delta;
2017 }
2018 
2019 // Returns the length of the next feature on `stream`.
2020 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)2021 inline int GetFeatureLength(DataType dtype,
2022                             protobuf::io::CodedInputStream* stream) {
2023   switch (dtype) {
2024     case DT_STRING:
2025       return ParseBytesFeature(stream, nullptr);
2026     case DT_FLOAT:
2027       return ParseFloatFeature(stream, nullptr);
2028     case DT_INT64:
2029       return ParseInt64Feature(stream, nullptr);
2030     default:
2031       ReportUnexpectedDataType(dtype);
2032       return -1;
2033   }
2034 }
2035 
ParseDataType(protobuf::io::CodedInputStream * stream)2036 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
2037   uint8 peek_tag = PeekTag(stream);
2038   switch (peek_tag) {
2039     case kDelimitedTag(1):
2040       return DT_STRING;
2041     case kDelimitedTag(2):
2042       return DT_FLOAT;
2043     case kDelimitedTag(3):
2044       return DT_INT64;
2045     default:
2046       return DT_INVALID;
2047   }
2048 }
2049 
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)2050 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
2051                              DataType dtype) {
2052   switch (dtype) {
2053     case DT_STRING:
2054       if (!stream->ExpectTag(kDelimitedTag(1))) {
2055         return false;
2056       }
2057       break;
2058     case DT_FLOAT:
2059       if (!stream->ExpectTag(kDelimitedTag(2))) {
2060         return false;
2061       }
2062       break;
2063     case DT_INT64:
2064       if (!stream->ExpectTag(kDelimitedTag(3))) {
2065         return false;
2066       }
2067       break;
2068     default:
2069       return false;
2070   }
2071   uint32 length;
2072   return stream->ReadVarint32(&length) && length == 0;
2073 }
2074 
2075 // Reads an example proto, and extracts a StringPiece pointer to each feature.
ExtractFeaturesFromSequenceExamples(const gtl::ArraySlice<tstring> examples,const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features,FeatureProtosMap * sequence_features)2076 Status ExtractFeaturesFromSequenceExamples(
2077     const gtl::ArraySlice<tstring> examples,
2078     const gtl::ArraySlice<tstring> example_names,
2079     FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) {
2080   for (int d = 0; d < examples.size(); d++) {
2081     const tstring& example = examples[d];
2082     protobuf::io::CodedInputStream stream(
2083         reinterpret_cast<const uint8*>(example.data()), example.size());
2084     // Not clear what this does. Why not stream.EnableAliasing()?
2085     EnableAliasing(&stream);
2086 
2087     // Extract pointers to all features within this serialized example.
2088     while (!stream.ExpectAtEnd()) {
2089       FeatureProtosMap* features = nullptr;
2090       if (stream.ExpectTag(kDelimitedTag(1))) {
2091         // Context
2092         features = context_features;
2093       } else if (stream.ExpectTag(kDelimitedTag(2))) {
2094         // Sequence
2095         features = sequence_features;
2096       } else if (!SkipExtraneousTag(&stream)) {
2097         return errors::InvalidArgument(
2098             "Invalid protocol message input, example id: ",
2099             ExampleName(example_names, d));
2100       }
2101       if (features != nullptr) {
2102         uint32 length;
2103         if (!stream.ReadVarint32(&length)) {
2104           return errors::InvalidArgument(
2105               "Invalid protocol message input, example id: ",
2106               ExampleName(example_names, d));
2107         }
2108         auto limit = stream.PushLimit(length);
2109         while (!stream.ExpectAtEnd()) {
2110           StringPiece key, value;
2111           uint32 length;
2112           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2113               !stream.ReadVarint32(&length)) {
2114             return errors::InvalidArgument(
2115                 "Invalid protocol message input, example id: ",
2116                 ExampleName(example_names, d));
2117           }
2118           auto limit = stream.PushLimit(length);
2119           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2120               !ParseString(&stream, &key) ||
2121               !stream.ExpectTag(kDelimitedTag(2)) ||
2122               !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
2123             return errors::InvalidArgument(
2124                 "Invalid protocol message input, example id: ",
2125                 ExampleName(example_names, d));
2126           }
2127           stream.PopLimit(limit);
2128           // Only save if this feature was requested.
2129           auto feature_iter = features->find(key);
2130           if (feature_iter != features->end()) {
2131             auto& feature = feature_iter->second;
2132             feature.protos[d] = value;
2133             feature.protos_present[d] = true;
2134           }
2135         }
2136         stream.PopLimit(limit);
2137       }
2138     }
2139   }
2140   return Status::OK();
2141 }
2142 
2143 // Populates context_features[k].length based on context_features[k].protos
2144 // (for all k).
GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features)2145 Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2146                                 FeatureProtosMap* context_features) {
2147   for (auto& c : *context_features) {
2148     FeatureProtos& feature = c.second;
2149     for (int d = 0; d < feature.protos.size(); ++d) {
2150       const auto& proto = feature.protos[d];
2151       if (proto.empty()) continue;
2152       protobuf::io::CodedInputStream stream(
2153           reinterpret_cast<const uint8*>(proto.data()), proto.size());
2154       EnableAliasing(&stream);
2155       int num_elements = GetFeatureLength(feature.dtype, &stream);
2156       if (num_elements < 0) {
2157         return errors::InvalidArgument(
2158             "Name: ", ExampleName(example_names, d),
2159             ", Context feature: ", c.first,
2160             ".  Data types don't match. Expected type: ",
2161             DataTypeString(feature.dtype));
2162       }
2163       switch (feature.type) {
2164         case Type::Sparse:  // intentional fall-through
2165         case Type::Ragged:
2166           feature.length += num_elements;
2167           break;
2168         case Type::Dense:
2169           feature.length =
2170               std::max(feature.length, static_cast<size_t>(num_elements));
2171           break;
2172       }
2173     }
2174   }
2175   return Status::OK();
2176 }
2177 
2178 // Populates sequence_features[k].length and sequence_features[k].num_rows based
2179 // on sequence_features[k].protos (for all k).
GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * sequence_features)2180 Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2181                                  FeatureProtosMap* sequence_features) {
2182   for (auto& c : *sequence_features) {
2183     FeatureProtos& feature = c.second;
2184     for (int d = 0; d < feature.protos.size(); ++d) {
2185       const auto& proto = feature.protos[d];
2186       if (proto.empty()) continue;
2187 
2188       size_t num_rows = 0;
2189       size_t num_elements = 0;
2190       protobuf::io::CodedInputStream stream(
2191           reinterpret_cast<const uint8*>(proto.data()), proto.size());
2192       EnableAliasing(&stream);
2193       while (!stream.ExpectAtEnd()) {
2194         uint32 feature_bytes;
2195         if (!stream.ExpectTag(kDelimitedTag(1)) ||
2196             !stream.ReadVarint32(&feature_bytes)) {
2197           return errors::InvalidArgument("Error in sequence feature ", c.first,
2198                                          " in example ",
2199                                          ExampleName(example_names, d));
2200         }
2201         if (feature_bytes > 2) {
2202           auto limit = stream.PushLimit(feature_bytes);
2203           int delta = GetFeatureLength(feature.dtype, &stream);
2204           if (delta < 0) {
2205             return errors::InvalidArgument(
2206                 "Name: ", ExampleName(example_names, d),
2207                 ", Feature list: ", c.first, ", Index: ", num_rows,
2208                 ".  Data types don't match. Expected type: ",
2209                 DataTypeString(feature.dtype));
2210           }
2211           num_elements += delta;
2212           stream.PopLimit(limit);
2213         } else if (feature_bytes == 2) {
2214           if (!SkipEmptyFeature(&stream, feature.dtype)) {
2215             return errors::InvalidArgument(
2216                 "Name: ", ExampleName(example_names, d),
2217                 ", Feature list: ", c.first, ", Index: ", num_rows,
2218                 ".  Data types don't match. Expected type: ",
2219                 DataTypeString(feature.dtype));
2220           }
2221         } else if (feature_bytes != 0) {
2222           return errors::InvalidArgument("Error in sequence feature ", c.first,
2223                                          " in example ",
2224                                          ExampleName(example_names, d));
2225         }
2226         ++num_rows;
2227       }
2228       switch (feature.type) {
2229         case Type::Sparse:
2230           feature.length += num_elements;
2231           break;
2232         case Type::Ragged:
2233           feature.length += num_elements;
2234           feature.num_rows += num_rows;
2235           break;
2236         case Type::Dense:
2237           feature.length = std::max(feature.length, num_elements);
2238           break;
2239       }
2240     }
2241   }
2242   return Status::OK();
2243 }
2244 
2245 // Copies src into dst[dst_offset:dst_offset+src.size], and then increments
2246 // dst_offset by src.size.
CopyTensorIntoTensor(DataType dtype,const Tensor & src,Tensor * dst,size_t * dst_offset)2247 void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst,
2248                           size_t* dst_offset) {
2249   size_t src_size = src.NumElements();
2250   switch (dtype) {
2251     case DT_INT64: {
2252       auto src_t = src.flat<int64>().data();
2253       std::copy(src_t, src_t + src_size,
2254                 dst->flat<int64>().data() + *dst_offset);
2255       break;
2256     }
2257     case DT_FLOAT: {
2258       auto src_t = src.flat<float>().data();
2259       std::copy(src_t, src_t + src_size,
2260                 dst->flat<float>().data() + *dst_offset);
2261       break;
2262     }
2263     case DT_STRING: {
2264       auto src_t = src.flat<tstring>().data();
2265       std::copy(src_t, src_t + src_size,
2266                 dst->flat<tstring>().data() + *dst_offset);
2267       break;
2268     }
2269     default:
2270       ReportUnexpectedDataType(dtype);
2271   }
2272   *dst_offset += src_size;
2273 }
2274 
2275 // Parses dense features in `context_features`, and writes their parsed
2276 // values to `context_results`.
ParseContextDenseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2277 Status ParseContextDenseFeatures(const FeatureProtosMap& context_features,
2278                                  const FastParseExampleConfig& context_config,
2279                                  gtl::ArraySlice<tstring> example_names,
2280                                  bool is_batch, int num_examples,
2281                                  Allocator* allocator, Result* context_result) {
2282   for (int t = 0; t < context_config.dense.size(); ++t) {
2283     const auto& c = context_config.dense[t];
2284     const FeatureProtos& feature =
2285         context_features.find(c.feature_name)->second;
2286     TensorShape dense_shape, example_shape;
2287     DataType dtype = c.dtype;
2288     const size_t data_max_elements = feature.length;
2289     if (!c.shape.AsTensorShape(&example_shape) ||
2290         data_max_elements != example_shape.num_elements()) {
2291       return errors::InvalidArgument(
2292           "Inconsistent max number of elements for feature ", c.feature_name,
2293           ": expected ", example_shape.num_elements(), ", but found ",
2294           data_max_elements);
2295     }
2296     if (is_batch) {
2297       dense_shape.AddDim(num_examples);
2298     }
2299     for (const int dim : c.shape.dim_sizes()) {
2300       dense_shape.AddDim(dim);
2301     }
2302     context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2303 
2304     Tensor& out = context_result->dense_values[t];
2305     size_t out_offset = 0;
2306 
2307     // Fill in the values.
2308     for (int e = 0; e < num_examples; e++) {
2309       size_t num_elements = 0;
2310       const auto& feature_proto = feature.protos[e];
2311       if (!feature.protos_present[e]) {
2312         // Copy the default value, if present. If not, return an error.
2313         if (c.default_value.NumElements() == 0) {
2314           return errors::InvalidArgument(
2315               "Feature: ", c.feature_name,
2316               " (data type: ", DataTypeString(c.dtype), ")",
2317               " is required but could not be found.");
2318         }
2319         CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset);
2320         num_elements += c.default_value.NumElements();
2321       } else if (!feature_proto.empty()) {
2322         protobuf::io::CodedInputStream stream(
2323             reinterpret_cast<const uint8*>(feature_proto.data()),
2324             feature_proto.size());
2325         EnableAliasing(&stream);
2326         num_elements += ParseFeature(dtype, &stream, &out, &out_offset);
2327       }
2328       if (num_elements != data_max_elements) {
2329         return errors::InvalidArgument(
2330             "Unexpected number of elements in example ",
2331             ExampleName(example_names, e));
2332       }
2333     }
2334   }
2335   return Status::OK();
2336 }
2337 
2338 // Parses sparse features in `context_features`, and writes their parsed
2339 // values to `context_results`.
ParseContextSparseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2340 Status ParseContextSparseFeatures(const FeatureProtosMap& context_features,
2341                                   const FastParseExampleConfig& context_config,
2342                                   gtl::ArraySlice<tstring> example_names,
2343                                   bool is_batch, int num_examples,
2344                                   Allocator* allocator,
2345                                   Result* context_result) {
2346   for (int t = 0; t < context_config.sparse.size(); ++t) {
2347     const auto& c = context_config.sparse[t];
2348     const FeatureProtos& feature =
2349         context_features.find(c.feature_name)->second;
2350     TensorShape indices_shape, values_shape;
2351     DataType dtype = c.dtype;
2352     size_t expected_num_elements = feature.length;
2353     indices_shape.AddDim(expected_num_elements);
2354     indices_shape.AddDim(is_batch ? 2 : 1);
2355     values_shape.AddDim(expected_num_elements);
2356     context_result->sparse_indices[t] =
2357         Tensor(allocator, DT_INT64, indices_shape);
2358     context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2359     context_result->sparse_shapes[t] =
2360         Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1}));
2361     Tensor& out_values = context_result->sparse_values[t];
2362     size_t out_values_offset = 0;
2363     int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
2364     auto out_shape = context_result->sparse_shapes[t].vec<int64>();
2365 
2366     // Fill in the values.
2367     size_t num_elements = 0;
2368     size_t max_num_cols = 0;
2369     for (int e = 0; e < num_examples; e++) {
2370       const auto& feature_proto = feature.protos[e];
2371       if (feature_proto.empty()) continue;
2372       protobuf::io::CodedInputStream stream(
2373           reinterpret_cast<const uint8*>(feature_proto.data()),
2374           feature_proto.size());
2375       EnableAliasing(&stream);
2376       size_t num_added =
2377           ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2378       num_elements += num_added;
2379       max_num_cols = std::max(max_num_cols, num_added);
2380       for (int i = 0; i < num_added; i++) {
2381         if (is_batch) *out_indices++ = e;
2382         *out_indices++ = i;
2383       }
2384     }
2385     if (num_elements != expected_num_elements) {
2386       return errors::InvalidArgument(
2387           "Unexpected total number of elements in feature ", c.feature_name);
2388     }
2389     if (is_batch) {
2390       out_shape(0) = num_examples;
2391       out_shape(1) = max_num_cols;
2392     } else {
2393       out_shape(0) = max_num_cols;
2394     }
2395   }
2396   return Status::OK();
2397 }
2398 
2399 // Parses ragged features in `context_features`, and writes their parsed
2400 // values to `context_results`.
ParseContextRaggedFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2401 Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features,
2402                                   const FastParseExampleConfig& context_config,
2403                                   gtl::ArraySlice<tstring> example_names,
2404                                   bool is_batch, int num_examples,
2405                                   Allocator* allocator,
2406                                   Result* context_result) {
2407   for (int t = 0; t < context_config.ragged.size(); ++t) {
2408     const auto& c = context_config.ragged[t];
2409     const FeatureProtos& feature =
2410         context_features.find(c.feature_name)->second;
2411     TensorShape values_shape, splits_shape;
2412     DataType dtype = c.dtype;
2413     DataType splits_dtype = c.splits_dtype;
2414     size_t expected_num_elements = feature.length;
2415     values_shape.AddDim(expected_num_elements);
2416     if (is_batch) {
2417       splits_shape.AddDim(num_examples + 1);
2418     }
2419     context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2420     context_result->ragged_splits[t] =
2421         Tensor(allocator, splits_dtype, splits_shape);
2422     Tensor& out_values = context_result->ragged_values[t];
2423     size_t out_values_offset = 0;
2424     int32* int32_splits =
2425         is_batch && splits_dtype == DT_INT32
2426             ? context_result->ragged_splits[t].vec<int32>().data()
2427             : nullptr;
2428     int64* int64_splits =
2429         is_batch && splits_dtype == DT_INT64
2430             ? context_result->ragged_splits[t].vec<int64>().data()
2431             : nullptr;
2432     if (int32_splits) {
2433       *int32_splits++ = 0;
2434     } else if (int64_splits) {
2435       *int64_splits++ = 0;
2436     }
2437 
2438     // Fill in the values.
2439     size_t split = 0;  // = total number of elements we've seen so far
2440     for (int e = 0; e < num_examples; e++) {
2441       const auto& feature_proto = feature.protos[e];
2442       if (!feature_proto.empty()) {
2443         protobuf::io::CodedInputStream stream(
2444             reinterpret_cast<const uint8*>(feature_proto.data()),
2445             feature_proto.size());
2446         EnableAliasing(&stream);
2447         size_t num_added =
2448             ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2449         split += num_added;
2450       }
2451       if (int32_splits) {
2452         *int32_splits++ = split;
2453       } else if (int64_splits) {
2454         *int64_splits++ = split;
2455       }
2456     }
2457     if (split != expected_num_elements) {
2458       return errors::InvalidArgument(
2459           "Unexpected total number of elements in feature ", c.feature_name);
2460     }
2461     if (int32_splits || int64_splits) {
2462       int actual_splits =
2463           int32_splits
2464               ? int32_splits -
2465                     context_result->ragged_splits[t].vec<int32>().data()
2466               : int64_splits -
2467                     context_result->ragged_splits[t].vec<int64>().data();
2468       if (actual_splits != num_examples + 1) {
2469         return errors::InvalidArgument(
2470             "Unexpected number of examples for feature ", c.feature_name);
2471       }
2472     }
2473   }
2474   return Status::OK();
2475 }
2476 
2477 // Parses dense features in `sequence_features`, and writes their parsed
2478 // values to `sequence_result`.
ParseSequenceDenseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths)2479 Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features,
2480                                   const FastParseExampleConfig& sequence_config,
2481                                   gtl::ArraySlice<tstring> example_names,
2482                                   bool is_batch, int num_examples,
2483                                   Allocator* allocator, Result* sequence_result,
2484                                   std::vector<Tensor>* dense_feature_lengths) {
2485   TensorShape dense_length_shape;
2486   if (is_batch) {
2487     dense_length_shape.AddDim(num_examples);
2488   }
2489   for (int t = 0; t < sequence_config.dense.size(); ++t) {
2490     const auto& c = sequence_config.dense[t];
2491     const FeatureProtos& feature =
2492         sequence_features.find(c.feature_name)->second;
2493     TensorShape dense_shape, row_shape;
2494     DataType dtype = c.dtype;
2495     const size_t expected_max_elements = feature.length;
2496     if (!c.shape.AsTensorShape(&row_shape) ||
2497         expected_max_elements !=
2498             (expected_max_elements / row_shape.num_elements()) *
2499                 row_shape.num_elements()) {
2500       PartialTensorShape total_shape = row_shape;
2501       total_shape.InsertDim(0, -1);
2502       return errors::InvalidArgument(
2503           "Feature list '", c.feature_name,
2504           "' has an unexpected number of values.  Total values size: ",
2505           expected_max_elements,
2506           " is not consistent with output shape: ", total_shape.DebugString());
2507     }
2508     int64_t expected_max_rows =
2509         expected_max_elements / row_shape.num_elements();
2510     if (is_batch) {
2511       dense_shape.AddDim(num_examples);
2512     }
2513     dense_shape.AddDim(expected_max_rows);
2514     for (const int dim : sequence_config.dense[t].shape.dim_sizes()) {
2515       dense_shape.AddDim(dim);
2516     }
2517     sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2518     (*dense_feature_lengths)[t] =
2519         Tensor(allocator, DT_INT64, dense_length_shape);
2520     int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
2521 
2522     tstring* out_bytes = nullptr;
2523     float* out_float = nullptr;
2524     int64* out_int64 = nullptr;
2525     switch (dtype) {
2526       case DT_STRING:
2527         out_bytes = sequence_result->dense_values[t].flat<tstring>().data();
2528         break;
2529       case DT_FLOAT:
2530         out_float = sequence_result->dense_values[t].flat<float>().data();
2531         break;
2532       case DT_INT64:
2533         out_int64 = sequence_result->dense_values[t].flat<int64>().data();
2534         break;
2535       default:
2536         ReportUnexpectedDataType(dtype);
2537     }
2538 
2539     // Fill in the values.
2540     for (int e = 0; e < num_examples; e++) {
2541       size_t num_elements = 0, num_rows = 0;
2542       const auto& feature_proto = feature.protos[e];
2543       if (!feature.protos_present[e]) {
2544         // Return an error if this feature was not allowed to be missing.
2545         // Otherwise, we'll pad as needed below.
2546         if (!c.variable_length) {
2547           return errors::InvalidArgument(
2548               "Name: ", ExampleName(example_names, e), ", Feature list '",
2549               c.feature_name,
2550               "' is required but could not be found.  "
2551               "Did you mean to include it in "
2552               "feature_list_dense_missing_assumed_empty or "
2553               "feature_list_dense_defaults?");
2554         }
2555       } else if (!feature_proto.empty()) {
2556         protobuf::io::CodedInputStream stream(
2557             reinterpret_cast<const uint8*>(feature_proto.data()),
2558             feature_proto.size());
2559         EnableAliasing(&stream);
2560         while (!stream.ExpectAtEnd()) {
2561           uint32 feature_length;
2562           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2563               !stream.ReadVarint32(&feature_length)) {
2564             return errors::InvalidArgument("Error in sequence feature ",
2565                                            c.feature_name, " in example ",
2566                                            ExampleName(example_names, e));
2567           }
2568           auto limit = stream.PushLimit(feature_length);
2569           int num_added = 0;
2570           if (feature_length > 2) {
2571             switch (dtype) {
2572               case DT_STRING:
2573                 num_added = ParseBytesFeature(&stream, out_bytes);
2574                 out_bytes += num_added;
2575                 break;
2576               case DT_FLOAT:
2577                 num_added = ParseFloatFeature(&stream, out_float);
2578                 out_float += num_added;
2579                 break;
2580               case DT_INT64:
2581                 num_added = ParseInt64Feature(&stream, out_int64);
2582                 out_int64 += num_added;
2583                 break;
2584               default:
2585                 ReportUnexpectedDataType(dtype);
2586                 num_added = 0;
2587             }
2588             if (num_added < 0) {
2589               // This should be unreachable -- we already scanned the feature in
2590               // GetSequenceFeatureLengths, and it hasn't changed since then.
2591               return errors::InvalidArgument("Error in sequence feature ",
2592                                              c.feature_name, " in example ",
2593                                              ExampleName(example_names, e));
2594             }
2595           }
2596           if (num_added != row_shape.num_elements()) {
2597             return errors::InvalidArgument(
2598                 "Name: ", ExampleName(example_names, e),
2599                 ", Key: ", c.feature_name, ", Index: ", num_rows,
2600                 ".  Number of values != expected.  values size: ", num_added,
2601                 " but output shape: ", row_shape.DebugString());
2602           }
2603           num_elements += num_added;
2604           num_rows++;
2605           stream.PopLimit(limit);
2606         }
2607       }
2608       *out_lengths++ = num_rows;
2609       // Pad as necessary.
2610       int num_to_pad = expected_max_elements - num_elements;
2611       switch (dtype) {
2612         case DT_STRING:
2613           out_bytes += num_to_pad;
2614           break;
2615         case DT_FLOAT:
2616           PadFloatFeature(num_to_pad, out_float);
2617           out_float += num_to_pad;
2618           break;
2619         case DT_INT64:
2620           PadInt64Feature(num_to_pad, out_int64);
2621           out_int64 += num_to_pad;
2622           break;
2623         default:
2624           ReportUnexpectedDataType(dtype);
2625       }
2626     }
2627   }
2628   return Status::OK();
2629 }
2630 
2631 // Parses sparse features in `sequence_features`, and writes their parsed
2632 // values to `sequence_result`.
ParseSequenceSparseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2633 Status ParseSequenceSparseFeatures(
2634     const FeatureProtosMap& sequence_features,
2635     const FastParseExampleConfig& sequence_config,
2636     gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2637     Allocator* allocator, Result* sequence_result) {
2638   for (int t = 0; t < sequence_config.sparse.size(); ++t) {
2639     const auto& c = sequence_config.sparse[t];
2640     const FeatureProtos& feature =
2641         sequence_features.find(c.feature_name)->second;
2642     TensorShape indices_shape, values_shape;
2643     DataType dtype = c.dtype;
2644     size_t expected_num_elements = feature.length;
2645     indices_shape.AddDim(expected_num_elements);
2646     indices_shape.AddDim(is_batch ? 3 : 2);
2647     values_shape.AddDim(expected_num_elements);
2648     sequence_result->sparse_indices[t] =
2649         Tensor(allocator, DT_INT64, indices_shape);
2650     sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2651     sequence_result->sparse_shapes[t] =
2652         Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2}));
2653 
2654     tstring* out_bytes = nullptr;
2655     float* out_float = nullptr;
2656     int64* out_int64 = nullptr;
2657     switch (dtype) {
2658       case DT_STRING:
2659         out_bytes = sequence_result->sparse_values[t].flat<tstring>().data();
2660         break;
2661       case DT_FLOAT:
2662         out_float = sequence_result->sparse_values[t].flat<float>().data();
2663         break;
2664       case DT_INT64:
2665         out_int64 = sequence_result->sparse_values[t].flat<int64>().data();
2666         break;
2667       default:
2668         ReportUnexpectedDataType(dtype);
2669     }
2670     int64* out_indices =
2671         sequence_result->sparse_indices[t].flat<int64>().data();
2672     auto out_shape = sequence_result->sparse_shapes[t].vec<int64>();
2673 
2674     // Fill in the values.
2675     size_t num_elements = 0;
2676     size_t max_num_rows = 0;
2677     size_t max_num_cols = 0;
2678     for (int e = 0; e < num_examples; e++) {
2679       const auto& feature_proto = feature.protos[e];
2680       if (feature_proto.empty()) continue;
2681       protobuf::io::CodedInputStream stream(
2682           reinterpret_cast<const uint8*>(feature_proto.data()),
2683           feature_proto.size());
2684       EnableAliasing(&stream);
2685       size_t num_rows = 0;
2686       while (!stream.ExpectAtEnd()) {
2687         uint32 feature_length;
2688         if (!stream.ExpectTag(kDelimitedTag(1)) ||
2689             !stream.ReadVarint32(&feature_length)) {
2690           // This should be unreachable -- we already scanned the feature in
2691           // GetSequenceFeatureLengths, and it hasn't changed since then.
2692           return errors::InvalidArgument("Error in sequence feature ",
2693                                          c.feature_name, " in example ",
2694                                          ExampleName(example_names, e));
2695         }
2696         if (feature_length > 2) {
2697           auto limit = stream.PushLimit(feature_length);
2698           size_t num_added;
2699           switch (dtype) {
2700             case DT_STRING:
2701               num_added = ParseBytesFeature(&stream, out_bytes);
2702               out_bytes += num_added;
2703               break;
2704             case DT_FLOAT:
2705               num_added = ParseFloatFeature(&stream, out_float);
2706               out_float += num_added;
2707               break;
2708             case DT_INT64:
2709               num_added = ParseInt64Feature(&stream, out_int64);
2710               out_int64 += num_added;
2711               break;
2712             default:
2713               ReportUnexpectedDataType(dtype);
2714               num_added = 0;
2715           }
2716           num_elements += num_added;
2717           max_num_cols = std::max(max_num_cols, num_added);
2718           for (int i = 0; i < num_added; i++) {
2719             if (is_batch) *out_indices++ = e;
2720             *out_indices++ = num_rows;
2721             *out_indices++ = i;
2722           }
2723           stream.PopLimit(limit);
2724         } else if (feature_length == 2) {
2725           if (!SkipEmptyFeature(&stream, dtype)) {
2726             // This should be unreachable -- we already scanned the feature in
2727             // GetSequenceFeatureLengths, and it hasn't changed since then.
2728             return errors::InvalidArgument("Error in sequence feature ",
2729                                            c.feature_name, " in example ",
2730                                            ExampleName(example_names, e));
2731           }
2732         } else if (feature_length != 0) {
2733           // This should be unreachable -- we already scanned the feature in
2734           // GetSequenceFeatureLengths, and it hasn't changed since then.
2735           return errors::InvalidArgument("Error in sequence feature ",
2736                                          c.feature_name, " in example ",
2737                                          ExampleName(example_names, e));
2738         }
2739         num_rows++;
2740       }
2741       max_num_rows = std::max(max_num_rows, num_rows);
2742     }
2743     if (num_elements != expected_num_elements) {
2744       return errors::InvalidArgument(
2745           "Unexpected number of elements in feature ", c.feature_name);
2746     }
2747     if (is_batch) {
2748       out_shape(0) = num_examples;
2749       out_shape(1) = max_num_rows;
2750       out_shape(2) = max_num_cols;
2751     } else {
2752       out_shape(0) = max_num_rows;
2753       out_shape(1) = max_num_cols;
2754     }
2755   }
2756   return Status::OK();
2757 }
2758 
2759 // Parses ragged features in `sequence_features`, and writes their parsed
2760 // values to `sequence_result`.
ParseSequenceRaggedFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2761 Status ParseSequenceRaggedFeatures(
2762     const FeatureProtosMap& sequence_features,
2763     const FastParseExampleConfig& sequence_config,
2764     gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2765     Allocator* allocator, Result* sequence_result) {
2766   for (int t = 0; t < sequence_config.ragged.size(); ++t) {
2767     const auto& c = sequence_config.ragged[t];
2768     const FeatureProtos& feature =
2769         sequence_features.find(c.feature_name)->second;
2770     TensorShape values_shape, inner_splits_shape, outer_splits_shape;
2771     DataType dtype = c.dtype;
2772     DataType splits_dtype = c.splits_dtype;
2773     size_t expected_num_elements = feature.length;
2774     size_t expected_num_rows = feature.num_rows;
2775     values_shape.AddDim(expected_num_elements);
2776     inner_splits_shape.AddDim(expected_num_rows + 1);
2777     if (is_batch) {
2778       outer_splits_shape.AddDim(num_examples + 1);
2779     }
2780     sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2781     sequence_result->ragged_splits[t] =
2782         Tensor(allocator, splits_dtype, inner_splits_shape);
2783     sequence_result->ragged_outer_splits[t] =
2784         Tensor(allocator, splits_dtype, outer_splits_shape);
2785     Tensor& out_values = sequence_result->ragged_values[t];
2786     size_t out_values_offset = 0;
2787     int32* int32_inner_splits =
2788         splits_dtype == DT_INT32
2789             ? sequence_result->ragged_splits[t].vec<int32>().data()
2790             : nullptr;
2791     int64* int64_inner_splits =
2792         splits_dtype == DT_INT64
2793             ? sequence_result->ragged_splits[t].vec<int64>().data()
2794             : nullptr;
2795     int32* int32_outer_splits =
2796         is_batch && splits_dtype == DT_INT32
2797             ? sequence_result->ragged_outer_splits[t].vec<int32>().data()
2798             : nullptr;
2799     int64* int64_outer_splits =
2800         is_batch && splits_dtype == DT_INT64
2801             ? sequence_result->ragged_outer_splits[t].vec<int64>().data()
2802             : nullptr;
2803     if (int32_inner_splits) {
2804       *int32_inner_splits++ = 0;
2805     } else if (int64_inner_splits) {
2806       *int64_inner_splits++ = 0;
2807     }
2808     if (int32_outer_splits) {
2809       *int32_outer_splits++ = 0;
2810     } else if (int64_outer_splits) {
2811       *int64_outer_splits++ = 0;
2812     }
2813 
2814     // Fill in the values.
2815     size_t inner_split = 0;  // total number of elements we've seen so far
2816     size_t outer_split = 0;  // total number of rows we've seen so far
2817     for (int e = 0; e < num_examples; e++) {
2818       const auto& feature_proto = feature.protos[e];
2819       if (!feature_proto.empty()) {
2820         protobuf::io::CodedInputStream stream(
2821             reinterpret_cast<const uint8*>(feature_proto.data()),
2822             feature_proto.size());
2823         EnableAliasing(&stream);
2824         while (!stream.ExpectAtEnd()) {
2825           uint32 feature_length;
2826           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2827               !stream.ReadVarint32(&feature_length)) {
2828             // This should be unreachable -- we already scanned the feature in
2829             // GetSequenceFeatureLengths, and it hasn't changed since then.
2830             return errors::InvalidArgument("Error in sequence feature ",
2831                                            c.feature_name, " in example ",
2832                                            ExampleName(example_names, e));
2833           }
2834           if (feature_length > 2) {
2835             auto limit = stream.PushLimit(feature_length);
2836             size_t num_added =
2837                 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2838             inner_split += num_added;
2839             stream.PopLimit(limit);
2840           } else if (feature_length == 2) {
2841             if (!SkipEmptyFeature(&stream, dtype)) {
2842               // This should be unreachable -- we already scanned the feature in
2843               // GetSequenceFeatureLengths, and it hasn't changed since then.
2844               return errors::InvalidArgument("Error in sequence feature ",
2845                                              c.feature_name, " in example ",
2846                                              ExampleName(example_names, e));
2847             }
2848           } else if (feature_length != 0) {
2849             // This should be unreachable -- we already scanned the feature in
2850             // GetSequenceFeatureLengths, and it hasn't changed since then.
2851             return errors::InvalidArgument("Error in sequence feature ",
2852                                            c.feature_name, " in example ",
2853                                            ExampleName(example_names, e));
2854           }
2855           if (int32_inner_splits) {
2856             *int32_inner_splits++ = inner_split;
2857           } else if (int64_inner_splits) {
2858             *int64_inner_splits++ = inner_split;
2859           }
2860           outer_split++;
2861         }
2862       }
2863       if (int32_outer_splits) {
2864         *int32_outer_splits++ = outer_split;
2865       } else if (int64_outer_splits) {
2866         *int64_outer_splits++ = outer_split;
2867       }
2868     }
2869     if (outer_split != expected_num_rows) {
2870       return errors::InvalidArgument("Unexpected number of rows for feature ",
2871                                      c.feature_name);
2872     }
2873     if (inner_split != expected_num_elements) {
2874       return errors::InvalidArgument(
2875           "Unexpected number of elements for feature ", c.feature_name);
2876     }
2877 
2878     if (int32_inner_splits || int64_inner_splits) {
2879       const auto& inner_splits = sequence_result->ragged_splits[t];
2880       int num_inner_splits =
2881           int32_inner_splits
2882               ? int32_inner_splits - inner_splits.vec<int32>().data()
2883               : int64_inner_splits - inner_splits.vec<int64>().data();
2884       if (num_inner_splits != expected_num_rows + 1) {
2885         return errors::InvalidArgument("Unexpected number of rows for feature ",
2886                                        c.feature_name);
2887       }
2888     }
2889     if (int32_outer_splits || int64_outer_splits) {
2890       const auto& outer_splits = sequence_result->ragged_outer_splits[t];
2891       int num_outer_splits =
2892           int32_outer_splits
2893               ? int32_outer_splits - outer_splits.vec<int32>().data()
2894               : int64_outer_splits - outer_splits.vec<int64>().data();
2895       if (num_outer_splits != num_examples + 1) {
2896         return errors::InvalidArgument(
2897             "Unexpected number of examples for feature ", c.feature_name);
2898       }
2899     }
2900   }
2901   return Status::OK();
2902 }
2903 
2904 }  // namespace
2905 
2906 // TODO(sundberg): Use the threadpool to parallelize example parsing.
2907 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths,bool is_batch)2908 Status FastParseSequenceExample(const FastParseExampleConfig& context_config,
2909                                 const FastParseExampleConfig& sequence_config,
2910                                 gtl::ArraySlice<tstring> serialized,
2911                                 gtl::ArraySlice<tstring> example_names,
2912                                 thread::ThreadPool* thread_pool,
2913                                 Result* context_result, Result* sequence_result,
2914                                 std::vector<Tensor>* dense_feature_lengths,
2915                                 bool is_batch) {
2916   int num_examples = serialized.size();
2917   DCHECK(context_result != nullptr);
2918   DCHECK(sequence_result != nullptr);
2919   DCHECK(dense_feature_lengths != nullptr);
2920   size_t num_context_features = context_config.sparse.size() +
2921                                 context_config.dense.size() +
2922                                 context_config.ragged.size();
2923   FeatureProtosMap context_features;
2924   context_features.reserve(num_context_features);
2925 
2926   if (!example_names.empty() && example_names.size() != num_examples) {
2927     return errors::InvalidArgument(
2928         "example_names must be empty or have the correct number of elements");
2929   }
2930   for (auto& c : context_config.sparse) {
2931     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2932     FeatureProtos& feature = context_features[c.feature_name];
2933     feature.dtype = c.dtype;
2934     feature.length = 0;
2935     feature.type = Type::Sparse;
2936     feature.protos.resize(num_examples);
2937     feature.protos_present.resize(num_examples);
2938   }
2939   for (auto& c : context_config.ragged) {
2940     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2941     FeatureProtos& feature = context_features[c.feature_name];
2942     if (feature.type == Type::Sparse) {
2943       return errors::InvalidArgument("Context feature " + c.feature_name +
2944                                      " cannot be both ragged and sparse");
2945     }
2946     feature.dtype = c.dtype;
2947     feature.length = 0;
2948     feature.type = Type::Ragged;
2949     feature.protos.resize(num_examples);
2950     feature.protos_present.resize(num_examples);
2951   }
2952   for (auto& c : context_config.dense) {
2953     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2954     FeatureProtos& feature = context_features[c.feature_name];
2955     if (feature.type != Type::Dense) {
2956       return errors::InvalidArgument("Context feature " + c.feature_name +
2957                                      " cannot be both dense and sparse");
2958     }
2959     if (c.default_value.NumElements() > 0) {
2960       if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
2961         return errors::InvalidArgument("Default value for context feature ",
2962                                        c.feature_name,
2963                                        " has an incorrect shape: saw ",
2964                                        c.default_value.shape().DebugString(),
2965                                        " but expected ", c.shape.DebugString());
2966       }
2967     }
2968     feature.dtype = c.dtype;
2969     feature.length = c.default_value.NumElements();
2970     feature.protos.resize(num_examples);
2971     feature.protos_present.resize(num_examples);
2972   }
2973   size_t num_sequence_features = sequence_config.sparse.size() +
2974                                  sequence_config.dense.size() +
2975                                  sequence_config.ragged.size();
2976   FeatureProtosMap sequence_features;
2977   sequence_features.reserve(num_sequence_features);
2978   for (auto& c : sequence_config.sparse) {
2979     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2980     FeatureProtos& feature = sequence_features[c.feature_name];
2981     feature.dtype = c.dtype;
2982     feature.length = 0;
2983     feature.type = Type::Sparse;
2984     feature.protos.resize(num_examples);
2985     feature.protos_present.resize(num_examples);
2986   }
2987   for (auto& c : sequence_config.ragged) {
2988     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2989     FeatureProtos& feature = sequence_features[c.feature_name];
2990     if (feature.type == Type::Sparse) {
2991       return errors::InvalidArgument("Sequence feature " + c.feature_name +
2992                                      " cannot be both ragged and sparse");
2993     }
2994     feature.dtype = c.dtype;
2995     feature.length = 0;
2996     feature.type = Type::Ragged;
2997     feature.protos.resize(num_examples);
2998     feature.protos_present.resize(num_examples);
2999   }
3000   for (auto& c : sequence_config.dense) {
3001     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3002     FeatureProtos& feature = sequence_features[c.feature_name];
3003     if (feature.type != Type::Dense) {
3004       return errors::InvalidArgument("Sequence feature " + c.feature_name +
3005                                      " cannot be both dense and sparse");
3006     }
3007     feature.dtype = c.dtype;
3008     feature.length = 0;
3009     feature.protos.resize(num_examples);
3010     feature.protos_present.resize(num_examples);
3011   }
3012 
3013   // Find the serialized proto substrings for each feature.
3014   TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples(
3015       serialized, example_names, &context_features, &sequence_features));
3016 
3017   // Scan through the protos to determine how much memory we need to allocate.
3018   TF_RETURN_IF_ERROR(
3019       GetContextFeatureLengths(example_names, &context_features));
3020   TF_RETURN_IF_ERROR(
3021       GetSequenceFeatureLengths(example_names, &sequence_features));
3022 
3023   // Allocate memory.
3024   context_result->sparse_values.resize(context_config.sparse.size());
3025   context_result->sparse_indices.resize(context_config.sparse.size());
3026   context_result->sparse_shapes.resize(context_config.sparse.size());
3027   context_result->dense_values.resize(context_config.dense.size());
3028   context_result->ragged_values.resize(context_config.ragged.size());
3029   context_result->ragged_splits.resize(context_config.ragged.size());
3030   context_result->ragged_outer_splits.resize(context_config.ragged.size());
3031   sequence_result->sparse_values.resize(sequence_config.sparse.size());
3032   sequence_result->sparse_indices.resize(sequence_config.sparse.size());
3033   sequence_result->sparse_shapes.resize(sequence_config.sparse.size());
3034   sequence_result->dense_values.resize(sequence_config.dense.size());
3035   sequence_result->ragged_values.resize(sequence_config.ragged.size());
3036   sequence_result->ragged_splits.resize(sequence_config.ragged.size());
3037   sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size());
3038   dense_feature_lengths->resize(sequence_config.dense.size());
3039 
3040   // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction,
3041   // to avoid lock contention in `tensorflow::cpu_allocator()`.
3042   Allocator* allocator = tensorflow::cpu_allocator();
3043 
3044   TF_RETURN_IF_ERROR(ParseContextDenseFeatures(
3045       context_features, context_config, example_names, is_batch, num_examples,
3046       allocator, context_result));
3047   TF_RETURN_IF_ERROR(ParseContextSparseFeatures(
3048       context_features, context_config, example_names, is_batch, num_examples,
3049       allocator, context_result));
3050   TF_RETURN_IF_ERROR(ParseContextRaggedFeatures(
3051       context_features, context_config, example_names, is_batch, num_examples,
3052       allocator, context_result));
3053   TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures(
3054       sequence_features, sequence_config, example_names, is_batch, num_examples,
3055       allocator, sequence_result, dense_feature_lengths));
3056   TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures(
3057       sequence_features, sequence_config, example_names, is_batch, num_examples,
3058       allocator, sequence_result));
3059   TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures(
3060       sequence_features, sequence_config, example_names, is_batch, num_examples,
3061       allocator, sequence_result));
3062 
3063   return Status::OK();
3064 }
3065 
3066 }  // namespace example
3067 }  // namespace tensorflow
3068