• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16 
17 #include <vector>
18 
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/util/presized_cuckoo_map.h"
37 #include "tensorflow/core/util/sparse/sparse_tensor.h"
38 
39 namespace tensorflow {
40 namespace example {
41 
42 namespace {
43 
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46 
47 template <typename T>
48 class LimitedArraySlice {
49  public:
50   using value_type = T;
51 
LimitedArraySlice(T * begin,size_t num_elements)52   LimitedArraySlice(T* begin, size_t num_elements)
53       : current_(begin), begin_(begin), end_(begin + num_elements) {}
54 
55   // May return negative if there were push_back calls after slice was filled.
EndDistance() const56   int64_t EndDistance() const { return end_ - current_; }
57 
58   // Attempts to push value to the back of this. If the slice has
59   // already been filled, this method has no effect on the underlying data, but
60   // it changes the number returned by EndDistance into negative values.
push_back(T && value)61   void push_back(T&& value) {
62     if (EndDistance() > 0) *current_ = std::move(value);
63     ++current_;
64   }
65 
66   // "Constructs" an element at the back of this by resizing the slice, and
67   // returns a mutable reference to the new last element.
68   // REQUIRES: EndDistance() > 0.
construct_at_end()69   T& construct_at_end() {
70     DCHECK_GT(EndDistance(), 0);
71     return *(current_++);
72   }
73 
74   // Returns a mutable reference to the last element in the slice.
75   // REQUIRES: size() > 0.
back()76   T& back() { return *(current_ - 1); }
77 
78   // Returns the number of elements in the slice.
size() const79   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80 
81   // Attempts to resize the vector to the given size. It does so by advancing
82   // the pointer to the current element, possibly beyond the end of the slice.
83   // As a consequence, calling `size()` after `resize(x)` was called might
84   // return a value less than `x`.
resize(size_t size)85   void resize(size_t size) { current_ = begin_ + size; }
86 
87   // Returns the pointer to the underlying data buffer.
data()88   T* data() { return begin_; }
89 
90  private:
91   T* current_;
92   T* begin_;
93   T* end_;
94 };
95 
96 template <typename A>
EnableAliasing(A * a)97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98   a->EnableAliasing(true);
99 }
100 
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103 
PeekTag(protobuf::io::CodedInputStream * stream)104 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
105   DCHECK(stream != nullptr);
106   const void* ptr;
107   int size;
108   if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
109   return *static_cast<const uint8*>(ptr);
110 }
111 
kVarintTag(uint32 tag)112 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)113 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)114 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
115 
116 namespace parsed {
117 
118 // ParseDataType has to be called first, then appropriate ParseZzzzList.
119 class Feature {
120  public:
Feature()121   Feature() {}
Feature(StringPiece serialized)122   explicit Feature(StringPiece serialized) : serialized_(serialized) {}
123 
ParseDataType(DataType * dtype)124   Status ParseDataType(DataType* dtype) {
125     DCHECK(dtype != nullptr);
126     if (serialized_.empty()) {
127       *dtype = DT_INVALID;
128       return OkStatus();
129     }
130     uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
131     serialized_.remove_prefix(1);
132     switch (oneof_tag) {
133       case kDelimitedTag(1):
134         *dtype = DT_STRING;
135         break;
136       case kDelimitedTag(2):
137         *dtype = DT_FLOAT;
138         break;
139       case kDelimitedTag(3):
140         *dtype = DT_INT64;
141         break;
142       default:
143         // Initialize variable to avoid compiler warning
144         *dtype = DT_INVALID;
145         return errors::InvalidArgument("Unsupported datatype.");
146     }
147     return OkStatus();
148   }
149 
GetNumElementsInBytesList(int * num_elements)150   bool GetNumElementsInBytesList(int* num_elements) {
151     protobuf::io::CodedInputStream stream(
152         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
153     EnableAliasing(&stream);
154     uint32 length = 0;
155     if (!stream.ReadVarint32(&length)) return false;
156     auto limit = stream.PushLimit(length);
157     *num_elements = 0;
158     while (!stream.ExpectAtEnd()) {
159       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
160       uint32 bytes_length = 0;
161       if (!stream.ReadVarint32(&bytes_length)) return false;
162       if (!stream.Skip(bytes_length)) return false;
163       ++*num_elements;
164     }
165     stream.PopLimit(limit);
166     return true;
167   }
168 
169   // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)170   tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
171     if (bytes_list->EndDistance() <= 0) {
172       return nullptr;
173     }
174     return &bytes_list->construct_at_end();
175   }
construct_at_end(SmallVector<tstring> * bytes_list)176   tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
177     return &bytes_list->emplace_back();
178   }
179 
180   template <typename Result>
ParseBytesList(Result * bytes_list)181   bool ParseBytesList(Result* bytes_list) {
182     DCHECK(bytes_list != nullptr);
183 
184     protobuf::io::CodedInputStream stream(
185         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
186 
187     EnableAliasing(&stream);
188 
189     uint32 length;
190     if (!stream.ReadVarint32(&length)) return false;
191     auto limit = stream.PushLimit(length);
192 
193     while (!stream.ExpectAtEnd()) {
194       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
195       // parse string
196       uint32 bytes_length;
197       if (!stream.ReadVarint32(&bytes_length)) return false;
198       tstring* bytes = construct_at_end(bytes_list);
199       if (bytes == nullptr) return false;
200       bytes->resize_uninitialized(bytes_length);
201       if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
202     }
203     stream.PopLimit(limit);
204     return true;
205   }
206 
207   template <typename Result>
ParseFloatList(Result * float_list)208   bool ParseFloatList(Result* float_list) {
209     DCHECK(float_list != nullptr);
210     protobuf::io::CodedInputStream stream(
211         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212     EnableAliasing(&stream);
213     uint32 length;
214     if (!stream.ReadVarint32(&length)) return false;
215     auto limit = stream.PushLimit(length);
216 
217     if (!stream.ExpectAtEnd()) {
218       uint8 peek_tag = PeekTag(&stream);
219       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
220         return false;
221       }
222 
223       constexpr int32_t kNumFloatBytes = 4;
224       if (peek_tag == kDelimitedTag(1)) {                       // packed
225         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
226         uint32 packed_length;
227         if (!stream.ReadVarint32(&packed_length)) return false;
228         auto packed_limit = stream.PushLimit(packed_length);
229 
230         // Store the initial size to know the offset we have to start writing
231         // data from before resizing the output "vector".
232         const size_t initial_size = float_list->size();
233         float_list->resize(initial_size + packed_length / kNumFloatBytes);
234 
235         // If the result data type is float and we are on a little endian
236         // machine then we can simply memcpy the data from the proto into the
237         // result vector.
238         if (port::kLittleEndian &&
239             sizeof(typename Result::value_type) == kNumFloatBytes) {
240           // Calculate the length of the buffer available what can be less than
241           // what we requested in resize in case of a LimitedArraySlice.
242           const uint32 bytes_to_copy =
243               std::min(static_cast<uint32>((float_list->size() - initial_size) *
244                                            kNumFloatBytes),
245                        packed_length);
246           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
247             return false;
248         } else {
249           int64_t index = initial_size;
250           while (!stream.ExpectAtEnd()) {
251             uint32 buffer32;
252             if (!stream.ReadLittleEndian32(&buffer32)) return false;
253             if (index < float_list->size()) {
254               float_list->data()[index] = absl::bit_cast<float>(buffer32);
255               ++index;
256             }
257           }
258         }
259 
260         stream.PopLimit(packed_limit);
261       } else {  // non-packed
262         const size_t initial_size = float_list->size();
263         // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
264         // the value.
265         const int64_t num_elements =
266             stream.BytesUntilLimit() / (1 + kNumFloatBytes);
267         float_list->resize(initial_size + num_elements);
268         int64_t index = initial_size;
269         while (!stream.ExpectAtEnd()) {
270           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
271           uint32 buffer32;
272           if (!stream.ReadLittleEndian32(&buffer32)) return false;
273           float_list->data()[index] = absl::bit_cast<float>(buffer32);
274           ++index;
275         }
276       }
277     }
278 
279     stream.PopLimit(limit);
280     return true;
281   }
282 
283   template <typename Result>
ParseInt64List(Result * int64_list)284   bool ParseInt64List(Result* int64_list) {
285     DCHECK(int64_list != nullptr);
286     protobuf::io::CodedInputStream stream(
287         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
288     EnableAliasing(&stream);
289     uint32 length;
290     if (!stream.ReadVarint32(&length)) return false;
291     auto limit = stream.PushLimit(length);
292 
293     if (!stream.ExpectAtEnd()) {
294       uint8 peek_tag = PeekTag(&stream);
295       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
296         return false;
297       }
298       if (peek_tag == kDelimitedTag(1)) {                       // packed
299         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
300         uint32 packed_length;
301         if (!stream.ReadVarint32(&packed_length)) return false;
302         auto packed_limit = stream.PushLimit(packed_length);
303 
304         while (!stream.ExpectAtEnd()) {
305           protobuf_uint64 n;  // There is no API for int64
306           if (!stream.ReadVarint64(&n)) return false;
307           int64_list->push_back(static_cast<int64_t>(n));
308         }
309 
310         stream.PopLimit(packed_limit);
311       } else {  // non-packed
312         while (!stream.ExpectAtEnd()) {
313           if (!stream.ExpectTag(kVarintTag(1))) return false;
314           protobuf_uint64 n;  // There is no API for int64
315           if (!stream.ReadVarint64(&n)) return false;
316           int64_list->push_back(static_cast<int64_t>(n));
317         }
318       }
319     }
320     stream.PopLimit(limit);
321     return true;
322   }
323 
GetSerialized() const324   StringPiece GetSerialized() const { return serialized_; }
325 
326  private:
327   // TODO(lew): Pair of uint8* would be more natural.
328   StringPiece serialized_;
329 };
330 
331 using FeatureMapEntry = std::pair<StringPiece, Feature>;
332 using Example = std::vector<FeatureMapEntry>;
333 
334 }  // namespace parsed
335 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)336 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
337   uint32 data;
338   protobuf_uint64 dummy;
339   switch (stream->ReadTag() & 0x7) {
340     case 0:  // varint
341       if (!stream->ReadVarint32(&data)) return false;
342       return true;
343     case 1:  // fixed64
344       if (!stream->ReadLittleEndian64(&dummy)) return false;
345       return true;
346     case 2:  // length delimited
347       if (!stream->ReadVarint32(&data)) return false;
348       stream->Skip(data);
349       return true;
350     case 3:          // group begin
351       return false;  // groups not supported.
352     case 4:          // group end
353       return false;  // groups not supported.
354     case 5:          // fixed32
355       if (!stream->ReadLittleEndian32(&data)) return false;
356       return true;
357   }
358   return false;  // unrecognized tag type
359 }
360 
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)361 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
362   DCHECK(stream != nullptr);
363   DCHECK(result != nullptr);
364   uint32 length;
365   if (!stream->ReadVarint32(&length)) return false;
366   if (length == 0) {
367     *result = StringPiece(nullptr, 0);
368     return true;
369   }
370   const void* stream_alias;
371   int stream_size;
372   if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
373     return false;
374   }
375   if (static_cast<uint32>(stream_size) < length) return false;
376   *result = StringPiece(static_cast<const char*>(stream_alias), length);
377   stream->Skip(length);
378   return true;
379 }
380 
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)381 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
382                           parsed::FeatureMapEntry* feature_map_entry) {
383   DCHECK(stream != nullptr);
384   DCHECK(feature_map_entry != nullptr);
385   uint32 length;
386   if (!stream->ReadVarint32(&length)) return false;
387   auto limit = stream->PushLimit(length);
388 
389   // Protobufs allow an arbitrary order for the key and value fields.
390   for (int n = 0; n < 2; ++n) {
391     const uint32_t tag = stream->ReadTag();
392     switch (tag) {
393       case kDelimitedTag(1):
394         if (!ParseString(stream, &feature_map_entry->first)) return false;
395         break;
396 
397       case kDelimitedTag(2): {
398         StringPiece feature_string_piece;
399         if (!ParseString(stream, &feature_string_piece)) return false;
400         feature_map_entry->second = parsed::Feature(feature_string_piece);
401         break;
402       }
403 
404       default:
405         return false;
406     }
407   }
408 
409   if (!stream->ExpectAtEnd()) return false;
410   stream->PopLimit(limit);
411   return true;
412 }
413 
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)414 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
415                    parsed::Example* example) {
416   DCHECK(stream != nullptr);
417   DCHECK(example != nullptr);
418   uint32 length;
419   if (!stream->ReadVarint32(&length)) return false;
420   auto limit = stream->PushLimit(length);
421   while (!stream->ExpectAtEnd()) {
422     parsed::FeatureMapEntry feature_map_entry;
423     if (!stream->ExpectTag(kDelimitedTag(1))) return false;
424     if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
425     example->push_back(std::move(feature_map_entry));
426   }
427   stream->PopLimit(limit);
428   return true;
429 }
430 
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)431 bool ParseExample(protobuf::io::CodedInputStream* stream,
432                   parsed::Example* example) {
433   DCHECK(stream != nullptr);
434   DCHECK(example != nullptr);
435   // Loop over the input stream which may contain multiple serialized Example
436   // protos merged together as strings. This behavior is consistent with Proto's
437   // ParseFromString when string representations are concatenated.
438   while (!stream->ExpectAtEnd()) {
439     if (!stream->ExpectTag(kDelimitedTag(1))) {
440       if (!SkipExtraneousTag(stream)) return false;
441     } else {
442       if (!ParseFeatures(stream, example)) return false;
443     }
444   }
445   return true;
446 }
447 
ParseExample(StringPiece serialized,parsed::Example * example)448 bool ParseExample(StringPiece serialized, parsed::Example* example) {
449   DCHECK(example != nullptr);
450   protobuf::io::CodedInputStream stream(
451       reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
452   EnableAliasing(&stream);
453   return ParseExample(&stream, example);
454 }
455 
456 }  // namespace
457 
TestFastParse(const string & serialized,Example * example)458 bool TestFastParse(const string& serialized, Example* example) {
459   DCHECK(example != nullptr);
460   parsed::Example parsed_example;
461   if (!ParseExample(serialized, &parsed_example)) return false;
462   auto& features = *example->mutable_features();
463   size_t parsed_example_size = parsed_example.size();
464   for (size_t i = 0; i < parsed_example_size; ++i) {
465     // This is a logic that standard protobuf parsing is implementing.
466     // I.e. last entry in the map overwrites all the previous ones.
467     parsed::FeatureMapEntry& name_and_feature =
468         parsed_example[parsed_example_size - i - 1];
469     string name(name_and_feature.first);
470     if ((*features.mutable_feature()).count(name) > 0) continue;
471 
472     auto& value = (*features.mutable_feature())[name];
473     DataType dtype;
474     if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
475     switch (dtype) {
476       case DT_INVALID:
477         break;
478       case DT_STRING: {
479         SmallVector<tstring> list;
480         if (!name_and_feature.second.ParseBytesList(&list)) return false;
481         auto* result_list = value.mutable_bytes_list();
482         for (auto& bytes : list) {
483           result_list->add_value(bytes.data(), bytes.size());
484         }
485         break;
486       }
487       case DT_FLOAT: {
488         SmallVector<float> list;
489         if (!name_and_feature.second.ParseFloatList(&list)) return false;
490         auto* result_list = value.mutable_float_list();
491         for (float f : list) {
492           result_list->add_value(f);
493         }
494         break;
495       }
496       case DT_INT64: {
497         SmallVector<int64_t> list;
498         if (!name_and_feature.second.ParseInt64List(&list)) return false;
499         auto* result_list = value.mutable_int64_list();
500         for (int64_t i : list) {
501           result_list->add_value(i);
502         }
503         break;
504       }
505       default:
506         LOG(FATAL) << "Should not happen.";
507     }
508   }
509   return true;
510 }
511 
512 // -----------------------------------------------------------------------------
513 
514 namespace {
515 
516 using Config = FastParseExampleConfig;
517 
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)518 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
519                  thread::ThreadPool* thread_pool) {
520   if (n == 0) return;
521   if (thread_pool == nullptr) {
522     for (size_t i = 0; i < n; ++i) {
523       f(i);
524     }
525   } else {
526     BlockingCounter counter(n - 1);
527     for (size_t i = 1; i < n; ++i) {
528       thread_pool->Schedule([i, &f, &counter] {
529         f(i);
530         counter.DecrementCount();
531       });
532     }
533     f(0);
534     counter.Wait();
535   }
536 }
537 
538 // Enumeration for distinguishing feature types.
539 // Note: FastParseSequenceExample constructs a map that includes Type values,
540 // and relies on the fact that they are default-initialized to Dense.
541 enum class Type { Dense, Sparse, Ragged };
542 
543 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
544 struct SparseBuffer {
545   // Features are in one of the 3 vectors below depending on config's dtype.
546   // Other 2 vectors remain empty.
547   SmallVector<tstring> bytes_list;
548   SmallVector<float> float_list;
549   SmallVector<int64_t> int64_list;
550 
551   // Features of example i are elements with indices
552   // from example_end_indices[i-1] to example_end_indices[i]-1 on the
553   // appropriate xxxxx_list
554   std::vector<size_t> example_end_indices;
555 };
556 
557 struct SeededHasher {
operator ()tensorflow::example::__anon158740920211::SeededHasher558   uint64 operator()(StringPiece s) const {
559     return Hash64(s.data(), s.size(), seed);
560   }
561   uint64 seed{0xDECAFCAFFE};
562 };
563 
LogDenseFeatureDataLoss(StringPiece feature_name)564 void LogDenseFeatureDataLoss(StringPiece feature_name) {
565   LOG(WARNING) << "Data loss! Feature '" << feature_name
566                << "' is present in multiple concatenated "
567                   "tf.Examples. Ignoring all but last one.";
568   static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
569       "/tensorflow/core/util/example_proto_fast_parsing/"
570       "duplicated_dense_feature",
571       "Dense feature appears twice in a tf.Example");
572   duplicated_dense_feature->GetCell()->IncrementBy(1);
573 }
574 
LogSparseFeatureDataLoss(StringPiece feature_name)575 void LogSparseFeatureDataLoss(StringPiece feature_name) {
576   LOG(WARNING) << "Data loss! Feature '" << feature_name
577                << "' is present in multiple concatenated "
578                   "tf.Examples. Ignoring all but last one.";
579   static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
580       "/tensorflow/core/util/example_proto_fast_parsing/"
581       "duplicated_sparse_feature",
582       "Sparse feature appears twice in a tf.Example");
583   duplicated_sparse_feature->GetCell()->IncrementBy(1);
584 }
585 
FastParseSerializedExample(const tstring & serialized_example,const tstring & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::vector<SparseBuffer> * output_ragged,PerExampleFeatureStats * output_stats)586 Status FastParseSerializedExample(
587     const tstring& serialized_example, const tstring& example_name,
588     const size_t example_index, const Config& config,
589     const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
590     SeededHasher hasher, std::vector<Tensor>* output_dense,
591     std::vector<SparseBuffer>* output_varlen_dense,
592     std::vector<SparseBuffer>* output_sparse,
593     std::vector<SparseBuffer>* output_ragged,
594     PerExampleFeatureStats* output_stats) {
595   DCHECK(output_dense != nullptr);
596   DCHECK(output_sparse != nullptr);
597   DCHECK(output_ragged != nullptr);
598   parsed::Example parsed_example;
599   if (!ParseExample(serialized_example, &parsed_example)) {
600     return errors::InvalidArgument("Could not parse example input, value: '",
601                                    serialized_example, "'");
602   }
603   std::vector<int64_t> sparse_feature_last_example(config.sparse.size(), -1);
604   std::vector<int64_t> dense_feature_last_example(config.dense.size(), -1);
605   std::vector<int64_t> ragged_feature_last_example(config.ragged.size(), -1);
606 
607   // Handle features present in the example.
608   const size_t parsed_example_size = parsed_example.size();
609 
610   if (output_stats) {
611     // TODO(b/111553342): This may over-count the number of features if there
612     // are duplicate keys in the feature map. Consider deduplicating the keys
613     // before computing the count.
614     output_stats->features_count = parsed_example_size;
615   }
616 
617   for (size_t i = 0; i < parsed_example_size; ++i) {
618     // This is a logic that standard protobuf parsing is implementing.
619     // I.e. last entry in the map overwrites all the previous ones.
620     parsed::FeatureMapEntry& name_and_feature =
621         parsed_example[parsed_example_size - i - 1];
622 
623     const StringPiece feature_name = name_and_feature.first;
624     parsed::Feature& feature = name_and_feature.second;
625 
626     std::pair<size_t, Type> d_and_type;
627     uint64 h = hasher(feature_name);
628     if (!config_index.Find(h, &d_and_type)) continue;
629 
630     size_t d = d_and_type.first;
631     bool is_dense = d_and_type.second == Type::Dense;
632     bool is_ragged = d_and_type.second == Type::Ragged;
633 
634     {
635       // Testing for PresizedCuckooMap collision.
636       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
637       const tstring& config_feature_name =
638           is_dense ? config.dense[d].feature_name
639                    : (is_ragged ? config.ragged[d].feature_name
640                                 : config.sparse[d].feature_name);
641       if (feature_name != config_feature_name) continue;
642     }
643 
644     auto example_error = [&](StringPiece suffix) {
645       return errors::InvalidArgument("Name: ", example_name,
646                                      ", Key: ", feature_name,
647                                      ", Index: ", example_index, ".  ", suffix);
648     };
649 
650     auto parse_error = [&] {
651       return example_error("Can't parse serialized Example.");
652     };
653 
654     DataType example_dtype;
655     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
656 
657     if (is_dense) {
658       if (example_dtype == DT_INVALID) continue;
659 
660       // If feature was already visited, skip.
661       // Compare comment at the beginning of the loop.
662       if (dense_feature_last_example[d] == example_index) {
663         LogDenseFeatureDataLoss(feature_name);
664         continue;
665       }
666       dense_feature_last_example[d] = example_index;
667 
668       if (example_dtype != config.dense[d].dtype) {
669         return example_error(strings::StrCat(
670             "Data types don't match. Data type: ",
671             DataTypeString(example_dtype),
672             " but expected type: ", DataTypeString(config.dense[d].dtype)));
673       }
674       if (!config.dense[d].variable_length) {
675         Tensor& out = (*output_dense)[d];
676 
677         const std::size_t num_elements = config.dense[d].elements_per_stride;
678         if (output_stats) {
679           // TODO(b/111553342): If desirable, we could add support for counting
680           // elements in the features that aren't parsed, but this could add
681           // considerable runtime cost.
682           output_stats->feature_values_count += num_elements;
683         }
684 
685         const std::size_t offset = example_index * num_elements;
686 
687         auto shape_error = [&](size_t size, StringPiece type_str) {
688           return example_error(strings::StrCat(
689               "Number of ", type_str,
690               " values != expected.  "
691               "Values size: ",
692               size,
693               " but output shape: ", config.dense[d].shape.DebugString()));
694         };
695 
696         switch (config.dense[d].dtype) {
697           case DT_INT64: {
698             auto out_p = out.flat<int64_t>().data() + offset;
699             LimitedArraySlice<int64_t> slice(out_p, num_elements);
700             if (!feature.ParseInt64List(&slice)) return parse_error();
701             if (slice.EndDistance() != 0) {
702               return shape_error(num_elements - slice.EndDistance(), "int64");
703             }
704             break;
705           }
706           case DT_FLOAT: {
707             auto out_p = out.flat<float>().data() + offset;
708             LimitedArraySlice<float> slice(out_p, num_elements);
709             if (!feature.ParseFloatList(&slice)) return parse_error();
710             if (slice.EndDistance() != 0) {
711               return shape_error(num_elements - slice.EndDistance(), "float");
712             }
713             break;
714           }
715           case DT_STRING: {
716             auto out_p = out.flat<tstring>().data() + offset;
717             LimitedArraySlice<tstring> slice(out_p, num_elements);
718             if (!feature.ParseBytesList(&slice)) return parse_error();
719             if (slice.EndDistance() != 0) {
720               return shape_error(num_elements - slice.EndDistance(), "bytes");
721             }
722             break;
723           }
724           default:
725             LOG(FATAL) << "Should not happen.";
726         }
727       } else {  // if variable length
728         SparseBuffer& out = (*output_varlen_dense)[d];
729 
730         const std::size_t num_elements = config.dense[d].elements_per_stride;
731 
732         if (example_dtype != DT_INVALID &&
733             example_dtype != config.dense[d].dtype) {
734           return example_error(strings::StrCat(
735               "Data types don't match. ",
736               "Expected type: ", DataTypeString(config.dense[d].dtype)));
737         }
738 
739         auto shape_error = [&](size_t size, StringPiece type_str) {
740           return example_error(strings::StrCat(
741               "Number of ", type_str,
742               " values is not a multiple of stride length. Saw ", size,
743               " values but output shape is: ",
744               config.dense[d].shape.DebugString()));
745         };
746 
747         switch (config.dense[d].dtype) {
748           case DT_INT64: {
749             if (example_dtype != DT_INVALID) {
750               if (!feature.ParseInt64List(&out.int64_list)) {
751                 return parse_error();
752               }
753               if (out.int64_list.size() % num_elements != 0) {
754                 return shape_error(out.int64_list.size(), "int64");
755               }
756             }
757             out.example_end_indices.push_back(out.int64_list.size());
758             break;
759           }
760           case DT_FLOAT: {
761             if (example_dtype != DT_INVALID) {
762               if (!feature.ParseFloatList(&out.float_list)) {
763                 return parse_error();
764               }
765               if (out.float_list.size() % num_elements != 0) {
766                 return shape_error(out.float_list.size(), "float");
767               }
768             }
769             out.example_end_indices.push_back(out.float_list.size());
770             break;
771           }
772           case DT_STRING: {
773             if (example_dtype != DT_INVALID) {
774               if (!feature.ParseBytesList(&out.bytes_list)) {
775                 return parse_error();
776               }
777               if (out.bytes_list.size() % num_elements != 0) {
778                 return shape_error(out.bytes_list.size(), "bytes");
779               }
780             }
781             out.example_end_indices.push_back(out.bytes_list.size());
782             break;
783           }
784           default:
785             LOG(FATAL) << "Should not happen.";
786         }
787 
788         if (output_stats) {
789           // Use `out.example_end_indices` to determine the feature-value count
790           // for this feature, because the preceding switch statement pushes
791           // the length of the appropriate feature list to that vector.
792           // TODO(b/111553342): If desirable, we could add support for counting
793           // elements in the features that aren't parsed, but this could add
794           // considerable runtime cost.
795           const size_t out_examples_count = out.example_end_indices.size();
796           if (out_examples_count == 1) {
797             output_stats->feature_values_count += out.example_end_indices[0];
798           } else {
799             output_stats->feature_values_count +=
800                 out.example_end_indices[out_examples_count - 1] -
801                 out.example_end_indices[out_examples_count - 2];
802           }
803         }
804       }
805     } else {
806       // Feature is sparse or ragged.
807       auto& last_example =
808           is_ragged ? ragged_feature_last_example : sparse_feature_last_example;
809 
810       // If feature was already visited, skip.
811       // Compare comment at the beginning of the loop.
812       if (last_example[d] == example_index) {
813         LogSparseFeatureDataLoss(feature_name);
814         continue;
815       }
816       last_example[d] = example_index;
817 
818       // Handle sparse features.
819       SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d];
820       DataType feature_dtype =
821           is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype;
822       if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
823         return example_error(
824             strings::StrCat("Data types don't match. ",
825                             "Expected type: ", DataTypeString(feature_dtype),
826                             ", Actual type: ", DataTypeString(example_dtype)));
827       }
828 
829       switch (feature_dtype) {
830         case DT_INT64: {
831           if (example_dtype != DT_INVALID) {
832             if (!feature.ParseInt64List(&out.int64_list)) {
833               return parse_error();
834             }
835           }
836           out.example_end_indices.push_back(out.int64_list.size());
837           break;
838         }
839         case DT_FLOAT: {
840           if (example_dtype != DT_INVALID) {
841             if (!feature.ParseFloatList(&out.float_list)) {
842               return parse_error();
843             }
844           }
845           out.example_end_indices.push_back(out.float_list.size());
846           break;
847         }
848         case DT_STRING: {
849           if (example_dtype != DT_INVALID) {
850             if (!feature.ParseBytesList(&out.bytes_list)) {
851               return parse_error();
852             }
853           }
854           out.example_end_indices.push_back(out.bytes_list.size());
855           break;
856         }
857         default:
858           LOG(FATAL) << "Should not happen.";
859       }
860 
861       if (output_stats) {
862         // Use `out.example_end_indices` to determine the feature-value count
863         // for this feature, because the preceding switch statement pushes
864         // the length of the appropriate feature list to that vector.
865         // TODO(b/111553342): If desirable, we could add support for counting
866         // elements in the features that aren't parsed, but this could add
867         // considerable runtime cost.
868         const size_t out_examples_count = out.example_end_indices.size();
869         if (out_examples_count == 1) {
870           output_stats->feature_values_count += out.example_end_indices[0];
871         } else {
872           output_stats->feature_values_count +=
873               out.example_end_indices[out_examples_count - 1] -
874               out.example_end_indices[out_examples_count - 2];
875         }
876       }
877     }
878   }
879 
880   // Handle missing dense features for fixed strides.
881   for (size_t d = 0; d < config.dense.size(); ++d) {
882     if (config.dense[d].variable_length) continue;
883     if (dense_feature_last_example[d] == example_index) continue;
884     if (config.dense[d].default_value.NumElements() == 0) {
885       return errors::InvalidArgument(
886           "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
887           " (data type: ", DataTypeString(config.dense[d].dtype), ")",
888           " is required but could not be found.");
889     }
890     const Tensor& in = config.dense[d].default_value;
891     Tensor& out = (*output_dense)[d];
892     const std::size_t num_elements = in.shape().num_elements();
893     const std::size_t offset = example_index * num_elements;
894 
895     switch (config.dense[d].dtype) {
896       case DT_INT64: {
897         std::copy_n(in.flat<int64_t>().data(), num_elements,
898                     out.flat<int64_t>().data() + offset);
899         break;
900       }
901       case DT_FLOAT: {
902         std::copy_n(in.flat<float>().data(), num_elements,
903                     out.flat<float>().data() + offset);
904         break;
905       }
906       case DT_STRING: {
907         std::copy_n(in.flat<tstring>().data(), num_elements,
908                     out.flat<tstring>().data() + offset);
909         break;
910       }
911       default:
912         LOG(FATAL) << "Should not happen.";
913     }
914   }
915 
916   // Handle missing varlen dense features.
917   for (size_t d = 0; d < config.dense.size(); ++d) {
918     if (!config.dense[d].variable_length) continue;
919     if (dense_feature_last_example[d] == example_index) continue;
920     SparseBuffer& out = (*output_varlen_dense)[d];
921     size_t prev_example_end_index =
922         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
923     out.example_end_indices.push_back(prev_example_end_index);
924   }
925 
926   // Handle missing sparse features.
927   for (size_t d = 0; d < config.sparse.size(); ++d) {
928     if (sparse_feature_last_example[d] == example_index) continue;
929     SparseBuffer& out = (*output_sparse)[d];
930     size_t prev_example_end_index =
931         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
932     out.example_end_indices.push_back(prev_example_end_index);
933   }
934 
935   // Handle missing ragged features.
936   for (size_t d = 0; d < config.ragged.size(); ++d) {
937     if (ragged_feature_last_example[d] == example_index) continue;
938     SparseBuffer& out = (*output_ragged)[d];
939     size_t prev_example_end_index =
940         out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
941     out.example_end_indices.push_back(prev_example_end_index);
942   }
943 
944   return OkStatus();
945 }
946 
CheckConfigDataType(DataType dtype)947 Status CheckConfigDataType(DataType dtype) {
948   switch (dtype) {
949     case DT_INT64:
950     case DT_FLOAT:
951     case DT_STRING:
952       return OkStatus();
953     default:
954       return errors::InvalidArgument("Invalid config dtype: ",
955                                      DataTypeString(dtype));
956   }
957 }
958 
959 // Use this in the "default" clause of switch statements when dispatching
960 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)961 inline void ReportUnexpectedDataType(DataType dtype) {
962   DCHECK(false)
963       << "Encountered unexpected DataType " << DataTypeString(dtype)
964       << "in variable that should have been checked by CheckConfigDataType().";
965 }
966 
CheckConfigDataTypes(const Config & config)967 Status CheckConfigDataTypes(const Config& config) {
968   // Check config so we can safely CHECK(false) in switches on config.*.dtype
969   for (auto& c : config.sparse) {
970     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
971   }
972   for (auto& c : config.dense) {
973     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
974   }
975   for (auto& c : config.ragged) {
976     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
977     if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) {
978       return errors::InvalidArgument("Invalid ragged_split_type: ",
979                                      DataTypeString(c.splits_dtype));
980     }
981   }
982   return OkStatus();
983 }
984 
985 template <typename T>
986 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
987 
988 template <>
GetListFromBuffer(const SparseBuffer & buffer)989 const SmallVector<int64_t>& GetListFromBuffer<int64_t>(
990     const SparseBuffer& buffer) {
991   return buffer.int64_list;
992 }
993 template <>
GetListFromBuffer(const SparseBuffer & buffer)994 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
995   return buffer.float_list;
996 }
997 template <>
GetListFromBuffer(const SparseBuffer & buffer)998 const SmallVector<tstring>& GetListFromBuffer<tstring>(
999     const SparseBuffer& buffer) {
1000   return buffer.bytes_list;
1001 }
1002 
1003 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)1004 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
1005   std::copy(b, e, t);
1006 }
1007 template <>
CopyOrMoveBlock(const tstring * b,const tstring * e,tstring * t)1008 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
1009   std::move(b, e, t);
1010 }
1011 
1012 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)1013 void FillAndCopyVarLen(
1014     const int d, const size_t num_elements,
1015     const size_t num_elements_per_minibatch, const Config& config,
1016     const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
1017     Tensor* values) {
1018   const Tensor& default_value = config.dense[d].default_value;
1019 
1020   // Copy-fill the tensors (creating the zero/fill-padding)
1021   std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
1022             default_value.flat<T>()(0));
1023 
1024   // Data is [batch_size, max_num_elements, data_stride_size]
1025   //   and num_elements_per_minibatch = max_num_elements * data_stride_size
1026   auto data = values->flat<T>().data();
1027 
1028   // Iterate over minibatch elements
1029   for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
1030     const SparseBuffer& buffer = varlen_dense_buffers[i][d];
1031     // Number of examples being stored in this buffer
1032     const auto& end_indices = buffer.example_end_indices;
1033     const size_t examples_in_buffer = end_indices.size();
1034     // const size_t stride_size = config.dense[d].elements_per_stride;
1035 
1036     const auto& list = GetListFromBuffer<T>(buffer);
1037     auto list_ptr = list.begin();
1038 
1039     size_t elements_tally = 0;
1040     // Iterate through all the examples stored in this buffer.
1041     for (size_t j = 0; j < examples_in_buffer; ++j) {
1042       // Number of elements stored for this example.
1043       const size_t num_elems = end_indices[j] - elements_tally;
1044       CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
1045       // Move forward this many elements in the varlen buffer.
1046       list_ptr += num_elems;
1047       // Move forward to the next minibatch entry in the values output.
1048       data += num_elements_per_minibatch;
1049       elements_tally = end_indices[j];
1050     }
1051     DCHECK(elements_tally == list.size());
1052   }
1053 }
1054 
1055 // Thin vector like interface wrapper around a Tensor. This enable us to
1056 // directly populate a tensor during parsing instead of having to first create a
1057 // vactor and then copy the data over.
1058 template <typename T>
1059 class TensorVector {
1060  public:
1061   using value_type = T;
1062 
tensor()1063   const Tensor& tensor() {
1064     if (!tensor_.has_value()) {
1065       resize(0);
1066     }
1067     return *tensor_;
1068   }
1069 
size() const1070   int64_t size() const {
1071     return tensor_.has_value() ? tensor_->NumElements() : 0;
1072   }
resize(int64_t new_size)1073   void resize(int64_t new_size) {
1074     DCHECK(!tensor_.has_value());
1075     tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size}));
1076     data_ = tensor_->flat<T>().data();
1077   }
data()1078   T* data() { return data_; }
data() const1079   const T* data() const { return data_; }
1080 
1081  private:
1082   // Use absl::optional to avoid calling the default constructor of Tensor
1083   // unnecessarily.
1084   absl::optional<Tensor> tensor_;
1085 
1086   // Cached pointer to the raw data inside the tensor.
1087   T* data_ = nullptr;
1088 };
1089 
CountSparseFeatures(const std::vector<std::vector<SparseBuffer>> & sparse_buffers,size_t d,size_t * total_num_features,size_t * max_num_features)1090 void CountSparseFeatures(
1091     const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
1092     size_t* total_num_features, size_t* max_num_features) {
1093   for (auto& sparse_values_tmp : sparse_buffers) {
1094     const std::vector<size_t>& end_indices =
1095         sparse_values_tmp[d].example_end_indices;
1096     *total_num_features += end_indices.back();
1097     *max_num_features = std::max(*max_num_features, end_indices[0]);
1098     for (size_t i = 1; i < end_indices.size(); ++i) {
1099       size_t example_size = end_indices[i] - end_indices[i - 1];
1100       *max_num_features = std::max(*max_num_features, example_size);
1101     }
1102   }
1103 }
1104 
CopySparseBufferToTensor(DataType dtype,size_t offset,SparseBuffer * src,Tensor * dst)1105 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
1106                               Tensor* dst) {
1107   switch (dtype) {
1108     case DT_INT64: {
1109       std::copy(src->int64_list.begin(), src->int64_list.end(),
1110                 dst->flat<int64_t>().data() + offset);
1111       break;
1112     }
1113     case DT_FLOAT: {
1114       std::copy(src->float_list.begin(), src->float_list.end(),
1115                 dst->flat<float>().data() + offset);
1116       break;
1117     }
1118     case DT_STRING: {
1119       std::move(src->bytes_list.begin(), src->bytes_list.end(),
1120                 dst->flat<tstring>().data() + offset);
1121       break;
1122     }
1123     default:
1124       ReportUnexpectedDataType(dtype);
1125   }
1126 }
1127 
1128 }  // namespace
1129 
FastParseExample(const Config & config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * result)1130 Status FastParseExample(const Config& config,
1131                         gtl::ArraySlice<tstring> serialized,
1132                         gtl::ArraySlice<tstring> example_names,
1133                         thread::ThreadPool* thread_pool, Result* result) {
1134   DCHECK(result != nullptr);
1135   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1136   TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1137 
1138   if (config.collect_feature_stats) {
1139     result->feature_stats.resize(serialized.size());
1140   }
1141 
1142   size_t config_size =
1143       config.dense.size() + config.sparse.size() + config.ragged.size();
1144   SeededHasher hasher;
1145   // Build config index.
1146   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1147   bool ok = true;
1148   for (size_t i = 0; i < 1000; ++i) {
1149     for (size_t d = 0; d < config.dense.size(); ++d) {
1150       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1151                                       {d, Type::Dense});
1152     }
1153     for (size_t d = 0; d < config.sparse.size(); ++d) {
1154       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1155                                       {d, Type::Sparse});
1156     }
1157     for (size_t d = 0; d < config.ragged.size(); ++d) {
1158       ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1159                                       {d, Type::Ragged});
1160     }
1161     if (ok) break;
1162     LOG(WARNING) << "Collision found. This should happen only if you have "
1163                     "around 2^32 entries in your config.";
1164     hasher.seed++;
1165     config_index.Clear(config_size);
1166     ok = true;
1167   }
1168   if (!ok) {
1169     return errors::Internal(
1170         "Could not avoid collision. This should not happen.");
1171   }
1172 
1173   // Allocate dense output for fixed length dense values
1174   // (variable-length dense and sparse and ragged have to be buffered).
1175   std::vector<Tensor> fixed_dense_values(config.dense.size());
1176   for (size_t d = 0; d < config.dense.size(); ++d) {
1177     if (config.dense[d].variable_length) continue;
1178     TensorShape out_shape;
1179     out_shape.AddDim(serialized.size());
1180     for (const int64_t dim : config.dense[d].shape.dim_sizes()) {
1181       out_shape.AddDim(dim);
1182     }
1183     fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1184   }
1185 
1186   // This parameter affects performance in a big and data-dependent way.
1187   const size_t kMiniBatchSizeBytes = 50000;
1188 
1189   // Calculate number of minibatches.
1190   // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1191   // Apply 'special logic' below for small and big regimes.
1192   const size_t num_minibatches = [&] {
1193     size_t result = 0;
1194     size_t minibatch_bytes = 0;
1195     for (size_t i = 0; i < serialized.size(); i++) {
1196       if (minibatch_bytes == 0) {  // start minibatch
1197         result++;
1198       }
1199       minibatch_bytes += serialized[i].size() + 1;
1200       if (minibatch_bytes > kMiniBatchSizeBytes) {
1201         minibatch_bytes = 0;
1202       }
1203     }
1204     // 'special logic'
1205     const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1206     const size_t max_minibatches = 64;
1207     return std::max<size_t>(min_minibatches,
1208                             std::min<size_t>(max_minibatches, result));
1209   }();
1210 
1211   auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1212     return (serialized.size() * minibatch) / num_minibatches;
1213   };
1214 
1215   // TODO(lew): A big performance low-hanging fruit here is to improve
1216   //   num_minibatches calculation to take into account actual amount of work
1217   //   needed, as the size in bytes is not perfect. Linear combination of
1218   //   size in bytes and average number of features per example is promising.
1219   //   Even better: measure time instead of estimating, but this is too costly
1220   //   in small batches.
1221   //   Maybe accept outside parameter #num_minibatches?
1222 
1223   // Do minibatches in parallel.
1224   std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1225   std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1226   std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches);
1227   std::vector<Status> status_of_minibatch(num_minibatches);
1228   auto ProcessMiniBatch = [&](size_t minibatch) {
1229     sparse_buffers[minibatch].resize(config.sparse.size());
1230     varlen_dense_buffers[minibatch].resize(config.dense.size());
1231     ragged_buffers[minibatch].resize(config.ragged.size());
1232     size_t start = first_example_of_minibatch(minibatch);
1233     size_t end = first_example_of_minibatch(minibatch + 1);
1234     for (size_t e = start; e < end; ++e) {
1235       PerExampleFeatureStats* stats = nullptr;
1236       if (config.collect_feature_stats) {
1237         stats = &result->feature_stats[e];
1238       }
1239       status_of_minibatch[minibatch] = FastParseSerializedExample(
1240           serialized[e],
1241           (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1242           config_index, hasher, &fixed_dense_values,
1243           &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch],
1244           &ragged_buffers[minibatch], stats);
1245       if (!status_of_minibatch[minibatch].ok()) break;
1246     }
1247   };
1248 
1249   ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1250 
1251   for (Status& status : status_of_minibatch) {
1252     TF_RETURN_IF_ERROR(status);
1253   }
1254 
1255   result->sparse_indices.reserve(config.sparse.size());
1256   result->sparse_values.reserve(config.sparse.size());
1257   result->sparse_shapes.reserve(config.sparse.size());
1258   result->dense_values.reserve(config.dense.size());
1259   result->ragged_values.reserve(config.ragged.size());
1260   result->ragged_splits.reserve(config.ragged.size());
1261 
1262   for (size_t d = 0; d < config.dense.size(); ++d) {
1263     result->dense_values.push_back(std::move(fixed_dense_values[d]));
1264   }
1265 
1266   // Merge SparseBuffers from all minibatches for every config.sparse.
1267   auto MergeSparseMinibatches = [&](size_t d) {
1268     // Loop over minibatches
1269     size_t total_num_features = 0;
1270     size_t max_num_features = 0;
1271     CountSparseFeatures(sparse_buffers, d, &total_num_features,
1272                         &max_num_features);
1273 
1274     TensorShape indices_shape;
1275     indices_shape.AddDim(total_num_features);
1276     indices_shape.AddDim(2);
1277     result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1278     Tensor* indices = &result->sparse_indices.back();
1279 
1280     TensorShape values_shape;
1281     values_shape.AddDim(total_num_features);
1282     result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1283     Tensor* values = &result->sparse_values.back();
1284 
1285     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1286     auto shapes_shape_t = result->sparse_shapes.back().vec<int64_t>();
1287     shapes_shape_t(0) = serialized.size();
1288     shapes_shape_t(1) = max_num_features;
1289 
1290     size_t offset = 0;
1291     for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1292       SparseBuffer& buffer = sparse_buffers[i][d];
1293 
1294       // Update indices.
1295       size_t delta = 0;
1296 
1297       if (indices->NumElements() > 0) {
1298         int64* ix_p = &indices->matrix<int64_t>()(offset, 0);
1299         size_t example_index = first_example_of_minibatch(i);
1300         for (size_t example_end_index : buffer.example_end_indices) {
1301           size_t feature_index = 0;
1302           for (; delta < example_end_index; ++delta) {
1303             // Column 0: example index
1304             *ix_p = example_index;
1305             // Column 1: the feature index buffer example
1306             *(ix_p + 1) = feature_index;
1307             ix_p += 2;
1308             ++feature_index;
1309           }
1310           ++example_index;
1311         }
1312       }
1313 
1314       CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values);
1315       offset += delta;
1316     }
1317   };
1318 
1319   // Merge SparseBuffers from all minibatches for every config.ragged.
1320   auto MergeRaggedMinibatches = [&](size_t d) {
1321     // Loop over minibatches
1322     size_t total_num_features = 0;
1323     size_t max_num_features = 0;
1324     CountSparseFeatures(ragged_buffers, d, &total_num_features,
1325                         &max_num_features);
1326 
1327     TensorShape row_splits_shape;
1328     row_splits_shape.AddDim(serialized.size() + 1);
1329     result->ragged_splits.emplace_back(config.ragged[d].splits_dtype,
1330                                        row_splits_shape);
1331     Tensor* row_splits = &result->ragged_splits.back();
1332     if (config.ragged[d].splits_dtype == DT_INT64) {
1333       row_splits->flat<int64_t>()(0) = 0;
1334     } else {
1335       row_splits->flat<int32>()(0) = 0;
1336     }
1337 
1338     TensorShape values_shape;
1339     values_shape.AddDim(total_num_features);
1340     result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape);
1341     Tensor* values = &result->ragged_values.back();
1342 
1343     size_t values_offset = 0;
1344     size_t splits_offset = 0;
1345     for (size_t i = 0; i < ragged_buffers.size(); ++i) {
1346       SparseBuffer& buffer = ragged_buffers[i][d];
1347       if (buffer.example_end_indices.empty()) continue;
1348 
1349       // Update row_splits.  row_splits are formed by concatenating the example
1350       // end_indices (adjusting each to start after the previous one ends).
1351       if (config.ragged[d].splits_dtype == DT_INT64) {
1352         int64* row_splits_out = &row_splits->flat<int64_t>()(splits_offset);
1353         int64_t start = *row_splits_out;
1354         for (size_t example_end_index : buffer.example_end_indices) {
1355           *++row_splits_out = start + example_end_index;
1356         }
1357       } else {
1358         int32* row_splits_out = &row_splits->flat<int32>()(splits_offset);
1359         int32_t start = *row_splits_out;
1360         for (size_t example_end_index : buffer.example_end_indices) {
1361           *++row_splits_out = start + example_end_index;
1362         }
1363       }
1364 
1365       CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer,
1366                                values);
1367       values_offset += buffer.example_end_indices.back();
1368       splits_offset += buffer.example_end_indices.size();
1369     }
1370   };
1371 
1372   // Merge SparseBuffers from all minibatches for every config.dense having
1373   // variable_length.
1374   auto MergeDenseVarLenMinibatches = [&](size_t d) {
1375     if (!config.dense[d].variable_length) return;
1376 
1377     // Loop over minibatches
1378     size_t max_num_features = 0;
1379     for (auto& dense_values_tmp : varlen_dense_buffers) {
1380       std::vector<size_t>& end_indices =
1381           dense_values_tmp[d].example_end_indices;
1382       max_num_features = std::max(max_num_features, end_indices[0]);
1383       for (size_t i = 1; i < end_indices.size(); ++i) {
1384         size_t example_size = end_indices[i] - end_indices[i - 1];
1385         max_num_features = std::max(max_num_features, example_size);
1386       }
1387     }
1388 
1389     const size_t stride_size = config.dense[d].elements_per_stride;
1390     const size_t max_num_elements = max_num_features / stride_size;
1391     TensorShape values_shape;
1392     DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1393     const size_t batch_size = serialized.size();
1394     values_shape.AddDim(batch_size);
1395     values_shape.AddDim(max_num_elements);
1396     for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1397       values_shape.AddDim(config.dense[d].shape.dim_size(i));
1398     }
1399     Tensor values(config.dense[d].dtype, values_shape);
1400     result->dense_values[d] = values;
1401     const size_t num_elements = values.NumElements();
1402 
1403     // Nothing to write, exit early.
1404     if (num_elements == 0) return;
1405 
1406     const size_t num_elements_per_minibatch = num_elements / batch_size;
1407 
1408     switch (config.dense[d].dtype) {
1409       case DT_INT64: {
1410         FillAndCopyVarLen<int64_t>(d, num_elements, num_elements_per_minibatch,
1411                                    config, varlen_dense_buffers, &values);
1412         break;
1413       }
1414       case DT_FLOAT: {
1415         FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1416                                  config, varlen_dense_buffers, &values);
1417         break;
1418       }
1419       case DT_STRING: {
1420         FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
1421                                    config, varlen_dense_buffers, &values);
1422         break;
1423       }
1424       default:
1425         ReportUnexpectedDataType(config.dense[d].dtype);
1426     }
1427   };
1428 
1429   for (size_t d = 0; d < config.dense.size(); ++d) {
1430     MergeDenseVarLenMinibatches(d);
1431   }
1432 
1433   for (size_t d = 0; d < config.sparse.size(); ++d) {
1434     MergeSparseMinibatches(d);
1435   }
1436 
1437   for (size_t d = 0; d < config.ragged.size(); ++d) {
1438     MergeRaggedMinibatches(d);
1439   }
1440 
1441   return OkStatus();
1442 }
1443 
FastParseSingleExample(const Config & config,StringPiece serialized,Result * result)1444 Status FastParseSingleExample(const Config& config, StringPiece serialized,
1445                               Result* result) {
1446   DCHECK(result != nullptr);
1447   // Check config so we can safely CHECK(false) in switches on config.*.dtype
1448   TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1449 
1450   PerExampleFeatureStats* stats = nullptr;
1451   if (config.collect_feature_stats) {
1452     result->feature_stats.emplace_back();
1453     stats = &result->feature_stats.back();
1454   }
1455 
1456   // TODO(mrry): Cache the construction of this map at Op construction time.
1457   size_t config_size =
1458       config.dense.size() + config.sparse.size() + config.ragged.size();
1459   SeededHasher hasher;
1460   // Build config index.
1461   PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1462   bool ok = true;
1463   for (size_t i = 0; i < 1000; ++i) {
1464     for (size_t d = 0; d < config.dense.size(); ++d) {
1465       ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1466                                       {d, Type::Dense});
1467     }
1468     for (size_t d = 0; d < config.sparse.size(); ++d) {
1469       ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1470                                       {d, Type::Sparse});
1471     }
1472     for (size_t d = 0; d < config.ragged.size(); ++d) {
1473       ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1474                                       {d, Type::Ragged});
1475     }
1476     if (ok) break;
1477     LOG(WARNING) << "Collision found. This should happen only if you have "
1478                     "around 2^32 entries in your config.";
1479     hasher.seed++;
1480     config_index.Clear(config_size);
1481     ok = true;
1482   }
1483   if (!ok) {
1484     return errors::Internal(
1485         "Could not avoid collision. This should not happen.");
1486   }
1487 
1488   result->sparse_indices.reserve(config.sparse.size());
1489   result->sparse_values.reserve(config.sparse.size());
1490   result->sparse_shapes.reserve(config.sparse.size());
1491   result->dense_values.reserve(config.dense.size());
1492   result->ragged_values.reserve(config.ragged.size());
1493   result->ragged_splits.reserve(config.ragged.size());
1494 
1495   // Allocate dense output tensors.
1496   for (size_t d = 0; d < config.dense.size(); ++d) {
1497     if (!config.dense[d].variable_length) {
1498       TensorShape values_shape;
1499       if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1500         return errors::Internal(
1501             "Fixed-length shape was not a statically defined shape.");
1502       }
1503       result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1504     } else {
1505       // Variable-length tensor will be allocated later.
1506       result->dense_values.emplace_back();
1507     }
1508   }
1509 
1510   // Allocate sparse output tensors.
1511   for (size_t d = 0; d < config.sparse.size(); ++d) {
1512     // The dense_shape is always a vector of length 1.
1513     result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1514     // Variable-length tensors will be allocated later.
1515     result->sparse_indices.emplace_back();
1516     result->sparse_values.emplace_back();
1517   }
1518 
1519   // Allocate ragged output tensors.
1520   for (size_t d = 0; d < config.ragged.size(); ++d) {
1521     // Variable-length values tensors will be allocated later.
1522     result->ragged_values.emplace_back();
1523     // Splits tensors are empty (unused) for single (scalar) inputs.
1524     const auto splits_dtype = config.ragged[d].splits_dtype;
1525     result->ragged_splits.emplace_back(splits_dtype, TensorShape({0}));
1526   }
1527 
1528   parsed::Example parsed_example;
1529   if (!ParseExample(serialized, &parsed_example)) {
1530     return errors::InvalidArgument("Could not parse example input, value: '",
1531                                    serialized, "'");
1532   }
1533   std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1534   std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1535   std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false);
1536 
1537   if (stats) {
1538     // TODO(b/111553342): This may over-count the number of features if there
1539     // are duplicate keys in the feature map. Consider deduplicating the keys
1540     // before computing the count.
1541     stats->features_count = parsed_example.size();
1542   }
1543 
1544   // Handle features present in the example.
1545   const size_t parsed_example_size = parsed_example.size();
1546   for (size_t i = 0; i < parsed_example_size; ++i) {
1547     // This is a logic that standard protobuf parsing is implementing.
1548     // I.e. last entry in the map overwrites all the previous ones.
1549     parsed::FeatureMapEntry& name_and_feature =
1550         parsed_example[parsed_example_size - i - 1];
1551 
1552     const StringPiece feature_name = name_and_feature.first;
1553     parsed::Feature& feature = name_and_feature.second;
1554 
1555     std::pair<size_t, Type> d_and_type;
1556     uint64 h = hasher(feature_name);
1557     if (!config_index.Find(h, &d_and_type)) continue;
1558 
1559     size_t d = d_and_type.first;
1560     bool is_dense = d_and_type.second == Type::Dense;
1561     bool is_sparse = d_and_type.second == Type::Sparse;
1562 
1563     {
1564       // Testing for PresizedCuckooMap collision.
1565       // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1566       const tstring& config_feature_name =
1567           is_dense ? config.dense[d].feature_name
1568                    : (is_sparse ? config.sparse[d].feature_name
1569                                 : config.ragged[d].feature_name);
1570       if (feature_name != config_feature_name) continue;
1571     }
1572 
1573     auto example_error = [feature_name](StringPiece suffix) {
1574       return errors::InvalidArgument("Key: ", feature_name, ".  ", suffix);
1575     };
1576 
1577     auto parse_error = [feature_name] {
1578       return errors::InvalidArgument("Key: ", feature_name,
1579                                      ".  Can't parse serialized Example.");
1580     };
1581 
1582     DataType example_dtype;
1583     TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1584     if (example_dtype == DT_INVALID) continue;
1585 
1586     if (is_dense && !config.dense[d].variable_length) {
1587       // If feature was already visited, skip.
1588       // Compare comment at the beginning of the loop.
1589       if (dense_feature_already_seen[d]) {
1590         LogDenseFeatureDataLoss(feature_name);
1591         continue;
1592       }
1593       dense_feature_already_seen[d] = true;
1594 
1595       if (example_dtype != config.dense[d].dtype) {
1596         return example_error(strings::StrCat(
1597             "Data types don't match. Data type: ",
1598             DataTypeString(example_dtype),
1599             " but expected type: ", DataTypeString(config.dense[d].dtype)));
1600       }
1601 
1602       Tensor* out = &result->dense_values[d];
1603       const std::size_t num_elements = config.dense[d].elements_per_stride;
1604       if (stats) {
1605         // TODO(b/111553342): If desirable, we could add support for counting
1606         // elements in the features that aren't parsed, but this could add
1607         // considerable runtime cost.
1608         stats->feature_values_count += num_elements;
1609       }
1610       switch (example_dtype) {
1611         case DT_INT64: {
1612           auto out_p = out->flat<int64_t>().data();
1613           LimitedArraySlice<int64_t> slice(out_p, num_elements);
1614           if (!feature.ParseInt64List(&slice)) return parse_error();
1615           if (slice.EndDistance() != 0) {
1616             return parse_error();
1617           }
1618           break;
1619         }
1620         case DT_FLOAT: {
1621           auto out_p = out->flat<float>().data();
1622           LimitedArraySlice<float> slice(out_p, num_elements);
1623           if (!feature.ParseFloatList(&slice)) return parse_error();
1624           if (slice.EndDistance() != 0) {
1625             return parse_error();
1626           }
1627           break;
1628         }
1629         case DT_STRING: {
1630           auto out_p = out->flat<tstring>().data();
1631           LimitedArraySlice<tstring> slice(out_p, num_elements);
1632           if (!feature.ParseBytesList(&slice)) return parse_error();
1633           if (slice.EndDistance() != 0) {
1634             return parse_error();
1635           }
1636           break;
1637         }
1638         default:
1639           ReportUnexpectedDataType(example_dtype);
1640       }
1641 
1642     } else {  // if variable length
1643       SmallVector<tstring> bytes_list;
1644       TensorVector<float> float_list;
1645       SmallVector<int64_t> int64_list;
1646 
1647       const size_t num_elements_divisor =
1648           is_dense ? config.dense[d].elements_per_stride : 1;
1649       size_t num_elements;
1650 
1651       if (is_dense) {
1652         // If feature was already visited, skip.
1653         // Compare comment at the beginning of the loop.
1654         if (dense_feature_already_seen[d]) {
1655           LogDenseFeatureDataLoss(feature_name);
1656           continue;
1657         }
1658         dense_feature_already_seen[d] = true;
1659         if (example_dtype != config.dense[d].dtype) {
1660           return example_error(strings::StrCat(
1661               "Data types don't match. Data type: ",
1662               DataTypeString(example_dtype),
1663               " but expected type: ", DataTypeString(config.dense[d].dtype)));
1664         }
1665       } else {
1666         // Feature is sparse or ragged.
1667         auto& feature_already_seen = is_sparse ? sparse_feature_already_seen
1668                                                : ragged_feature_already_seen;
1669         auto& feature_dtype =
1670             is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype;
1671         // If feature was already visited, skip.
1672         // Compare comment at the beginning of the loop.
1673         if (feature_already_seen[d]) {
1674           LogSparseFeatureDataLoss(feature_name);
1675           continue;
1676         }
1677         feature_already_seen[d] = true;
1678 
1679         // Handle sparse features.
1680         if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
1681           return example_error(strings::StrCat(
1682               "Data types don't match. ",
1683               "Expected type: ", DataTypeString(feature_dtype),
1684               ", Actual type: ", DataTypeString(example_dtype)));
1685         }
1686       }
1687 
1688       switch (example_dtype) {
1689         case DT_INT64: {
1690           // TODO(mrry): Use the fact that the `int64_list` is packed to read
1691           // out the length and pre-allocate the output tensor.
1692           if (!feature.ParseInt64List(&int64_list)) return parse_error();
1693           num_elements = int64_list.size();
1694           break;
1695         }
1696         case DT_FLOAT: {
1697           if (!feature.ParseFloatList(&float_list)) return parse_error();
1698           num_elements = float_list.size();
1699           break;
1700         }
1701         case DT_STRING: {
1702           int actual_num_elements = 0;
1703           if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1704             return parse_error();
1705           }
1706           bytes_list.reserve(actual_num_elements);
1707           if (!feature.ParseBytesList(&bytes_list)) return parse_error();
1708           num_elements = bytes_list.size();
1709           break;
1710         }
1711         default:
1712           num_elements = 0;
1713           ReportUnexpectedDataType(example_dtype);
1714       }
1715 
1716       if (num_elements % num_elements_divisor != 0) {
1717         return parse_error();
1718       }
1719 
1720       if (stats) {
1721         stats->feature_values_count += num_elements;
1722       }
1723 
1724       Tensor* out;
1725       DataType out_dtype;
1726       TensorShape out_shape;
1727       if (is_dense) {
1728         out_shape.AddDim(num_elements / num_elements_divisor);
1729         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1730           out_shape.AddDim(config.dense[d].shape.dim_size(i));
1731         }
1732 
1733         out = &result->dense_values[d];
1734         out_dtype = config.dense[d].dtype;
1735       } else if (is_sparse) {
1736         Tensor* out_indices = &result->sparse_indices[d];
1737         Tensor* out_dense_shape = &result->sparse_shapes[d];
1738 
1739         // TODO(mrry): Investigate the possibility of not materializing
1740         // the indices (and perhaps dense_shape) until they are needed.
1741         *out_indices = Tensor(
1742             DT_INT64, TensorShape({static_cast<int64_t>(num_elements), 1}));
1743         auto indices_flat = out_indices->flat<int64_t>();
1744         for (size_t i = 0; i < num_elements; ++i) {
1745           indices_flat(i) = static_cast<int64_t>(i);
1746         }
1747 
1748         *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1749         auto shapes_shape_t = out_dense_shape->vec<int64_t>();
1750         shapes_shape_t(0) = num_elements;
1751 
1752         out = &result->sparse_values[d];
1753         out_dtype = config.sparse[d].dtype;
1754         out_shape.AddDim(num_elements);
1755       } else {
1756         out = &result->ragged_values[d];
1757         out_dtype = config.ragged[d].dtype;
1758         out_shape.AddDim(num_elements);
1759       }
1760 
1761       switch (example_dtype) {
1762         case DT_INT64: {
1763           *out = Tensor(out_dtype, out_shape);
1764           CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1765                           out->flat<int64_t>().data());
1766           break;
1767         }
1768         case DT_FLOAT: {
1769           if (!out->CopyFrom(float_list.tensor(), out_shape)) {
1770             return parse_error();
1771           }
1772           break;
1773         }
1774         case DT_STRING: {
1775           *out = Tensor(out_dtype, out_shape);
1776           CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
1777                           out->flat<tstring>().data());
1778           break;
1779         }
1780         default:
1781           ReportUnexpectedDataType(example_dtype);
1782       }
1783     }
1784   }
1785 
1786   // Handle missing dense features.
1787   for (size_t d = 0; d < config.dense.size(); ++d) {
1788     if (!dense_feature_already_seen[d]) {
1789       if (!config.dense[d].variable_length) {
1790         // Handle missing fixed-length dense feature.
1791         if (config.dense[d].default_value.NumElements() == 0) {
1792           return errors::InvalidArgument(
1793               "Feature: ", config.dense[d].feature_name,
1794               " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1795               " is required but could not be found.");
1796         }
1797         result->dense_values[d] = config.dense[d].default_value;
1798       } else {
1799         // Handle missing varlen dense feature.
1800         TensorShape empty_shape;
1801         empty_shape.AddDim(0);
1802         for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1803           empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1804         }
1805         result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1806       }
1807     }
1808   }
1809 
1810   // Handle missing sparse features.
1811   for (size_t d = 0; d < config.sparse.size(); ++d) {
1812     if (!sparse_feature_already_seen[d]) {
1813       result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1814       result->sparse_values[d] =
1815           Tensor(config.sparse[d].dtype, TensorShape({0}));
1816       result->sparse_shapes[d].vec<int64_t>()(0) = 0;
1817     }
1818   }
1819 
1820   // Handle missing ragged features.
1821   for (size_t d = 0; d < config.ragged.size(); ++d) {
1822     if (!ragged_feature_already_seen[d]) {
1823       result->ragged_values[d] =
1824           Tensor(config.ragged[d].dtype, TensorShape({0}));
1825     }
1826   }
1827 
1828   return OkStatus();
1829 }
1830 
1831 // Private helper functions for FastParseSequenceExample.
1832 namespace {
1833 
1834 // A struct used by FastParseSequenceExample to hold the serialized proto
1835 // substrings for a single feature, plus some auxiliary information derived
1836 // from those protos (such as the total value length).
1837 struct FeatureProtos {
1838   // Proto substrings from each serialized SequenceExample that correspond
1839   // with this feature.  `protos_present` records whether the proto had a
1840   // value defined (even if that value is empty).
1841   std::vector<StringPiece> protos;
1842   std::vector<bool> protos_present;
1843 
1844   // Information derived from protos:
1845   size_t length;    // total length for ragged/sparse, max row length for dense.
1846   size_t num_rows;  // only populated for ragged sequence features.
1847 
1848   // Information from the config:
1849   Type type;  // Whether this feature is sparse, ragged, or dense.
1850   DataType dtype;
1851 };
1852 
1853 // Map from feature name to FeatureProtos for that feature.
1854 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
1855 
ExampleName(const gtl::ArraySlice<tstring> example_names,int n)1856 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
1857   return example_names.empty() ? "<unknown>" : example_names[n];
1858 }
1859 
1860 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1861 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)1862 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1863                              tstring* out) {
1864   int num_elements = 0;
1865   uint32 length;
1866   if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1867     return -1;
1868   }
1869   if (length > 0) {
1870     auto limit = stream->PushLimit(length);
1871     while (!stream->ExpectAtEnd()) {
1872       uint32 bytes_length;
1873       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1874           !stream->ReadVarint32(&bytes_length)) {
1875         return -1;
1876       }
1877       if (out == nullptr) {
1878         stream->Skip(bytes_length);
1879       } else {
1880         out->resize_uninitialized(bytes_length);
1881         if (!stream->ReadRaw(out->data(), bytes_length)) {
1882           return -1;
1883         }
1884         out++;
1885       }
1886       num_elements++;
1887     }
1888     stream->PopLimit(limit);
1889   }
1890   return num_elements;
1891 }
1892 
PadFloatFeature(int num_to_pad,float * out)1893 inline void PadFloatFeature(int num_to_pad, float* out) {
1894   for (int i = 0; i < num_to_pad; i++) {
1895     *out++ = 0.0;
1896   }
1897 }
1898 
PadInt64Feature(int num_to_pad,int64_t * out)1899 inline void PadInt64Feature(int num_to_pad, int64_t* out) {
1900   for (int i = 0; i < num_to_pad; i++) {
1901     *out++ = 0;
1902   }
1903 }
1904 
1905 // Return the number of float elements parsed, or -1 on error. If out is null,
1906 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1907 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1908                              float* out) {
1909   int num_elements = 0;
1910   uint32 length;
1911   if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1912     return -1;
1913   }
1914   if (length > 0) {
1915     auto limit = stream->PushLimit(length);
1916     uint8 peek_tag = PeekTag(stream);
1917     if (peek_tag == kDelimitedTag(1)) {  // packed
1918       uint32 packed_length;
1919       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1920           !stream->ReadVarint32(&packed_length)) {
1921         return -1;
1922       }
1923       auto packed_limit = stream->PushLimit(packed_length);
1924       while (!stream->ExpectAtEnd()) {
1925         uint32 buffer32;
1926         if (!stream->ReadLittleEndian32(&buffer32)) {
1927           return -1;
1928         }
1929         if (out != nullptr) {
1930           *out++ = absl::bit_cast<float>(buffer32);
1931         }
1932         num_elements++;
1933       }
1934       stream->PopLimit(packed_limit);
1935     } else if (peek_tag == kFixed32Tag(1)) {
1936       while (!stream->ExpectAtEnd()) {
1937         uint32 buffer32;
1938         if (!stream->ExpectTag(kFixed32Tag(1)) ||
1939             !stream->ReadLittleEndian32(&buffer32)) {
1940           return -1;
1941         }
1942         if (out != nullptr) {
1943           *out++ = absl::bit_cast<float>(buffer32);
1944         }
1945         num_elements++;
1946       }
1947     } else {
1948       // Unknown tag.
1949       return -1;
1950     }
1951     stream->PopLimit(limit);
1952   }
1953   return num_elements;
1954 }
1955 
1956 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1957 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64_t * out)1958 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1959                              int64_t* out) {
1960   int num_elements = 0;
1961   uint32 length;
1962   if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1963     return -1;
1964   }
1965   if (length > 0) {
1966     auto limit = stream->PushLimit(length);
1967     uint8 peek_tag = PeekTag(stream);
1968     if (peek_tag == kDelimitedTag(1)) {  // packed
1969       uint32 packed_length;
1970       if (!stream->ExpectTag(kDelimitedTag(1)) ||
1971           !stream->ReadVarint32(&packed_length)) {
1972         return -1;
1973       }
1974       auto packed_limit = stream->PushLimit(packed_length);
1975       while (!stream->ExpectAtEnd()) {
1976         protobuf_uint64 n;  // There is no API for int64
1977         if (!stream->ReadVarint64(&n)) {
1978           return -1;
1979         }
1980         if (out != nullptr) {
1981           *out++ = n;
1982         }
1983         num_elements++;
1984       }
1985       stream->PopLimit(packed_limit);
1986     } else if (peek_tag == kVarintTag(1)) {
1987       while (!stream->ExpectAtEnd()) {
1988         protobuf_uint64 n;  // There is no API for int64
1989         if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1990           return -1;
1991         }
1992         if (out != nullptr) {
1993           *out++ = n;
1994         }
1995         num_elements++;
1996       }
1997     } else {
1998       // Unknown tag.
1999       return -1;
2000     }
2001     stream->PopLimit(limit);
2002   }
2003   return num_elements;
2004 }
2005 
2006 // Parses the next feature on `stream` into `out` starting at `out_offset`.
2007 // Updates `out_offset`, and returns the number of values added.
2008 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)2009 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
2010                         Tensor* out, size_t* out_offset) {
2011   int delta;
2012   switch (dtype) {
2013     case DT_STRING:
2014       delta =
2015           ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
2016       break;
2017     case DT_FLOAT:
2018       delta =
2019           ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
2020       break;
2021     case DT_INT64:
2022       delta =
2023           ParseInt64Feature(stream, out->flat<int64_t>().data() + *out_offset);
2024       break;
2025     default:
2026       ReportUnexpectedDataType(dtype);
2027       delta = 0;
2028   }
2029   if (delta > 0) {
2030     *out_offset += delta;
2031   }
2032   return delta;
2033 }
2034 
2035 // Returns the length of the next feature on `stream`.
2036 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)2037 inline int GetFeatureLength(DataType dtype,
2038                             protobuf::io::CodedInputStream* stream) {
2039   switch (dtype) {
2040     case DT_STRING:
2041       return ParseBytesFeature(stream, nullptr);
2042     case DT_FLOAT:
2043       return ParseFloatFeature(stream, nullptr);
2044     case DT_INT64:
2045       return ParseInt64Feature(stream, nullptr);
2046     default:
2047       ReportUnexpectedDataType(dtype);
2048       return -1;
2049   }
2050 }
2051 
ParseDataType(protobuf::io::CodedInputStream * stream)2052 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
2053   uint8 peek_tag = PeekTag(stream);
2054   switch (peek_tag) {
2055     case kDelimitedTag(1):
2056       return DT_STRING;
2057     case kDelimitedTag(2):
2058       return DT_FLOAT;
2059     case kDelimitedTag(3):
2060       return DT_INT64;
2061     default:
2062       return DT_INVALID;
2063   }
2064 }
2065 
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)2066 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
2067                              DataType dtype) {
2068   switch (dtype) {
2069     case DT_STRING:
2070       if (!stream->ExpectTag(kDelimitedTag(1))) {
2071         return false;
2072       }
2073       break;
2074     case DT_FLOAT:
2075       if (!stream->ExpectTag(kDelimitedTag(2))) {
2076         return false;
2077       }
2078       break;
2079     case DT_INT64:
2080       if (!stream->ExpectTag(kDelimitedTag(3))) {
2081         return false;
2082       }
2083       break;
2084     default:
2085       return false;
2086   }
2087   uint32 length;
2088   return stream->ReadVarint32(&length) && length == 0;
2089 }
2090 
2091 // Reads an example proto, and extracts a StringPiece pointer to each feature.
ExtractFeaturesFromSequenceExamples(const gtl::ArraySlice<tstring> examples,const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features,FeatureProtosMap * sequence_features)2092 Status ExtractFeaturesFromSequenceExamples(
2093     const gtl::ArraySlice<tstring> examples,
2094     const gtl::ArraySlice<tstring> example_names,
2095     FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) {
2096   for (int d = 0; d < examples.size(); d++) {
2097     const tstring& example = examples[d];
2098     protobuf::io::CodedInputStream stream(
2099         reinterpret_cast<const uint8*>(example.data()), example.size());
2100     // Not clear what this does. Why not stream.EnableAliasing()?
2101     EnableAliasing(&stream);
2102 
2103     // Extract pointers to all features within this serialized example.
2104     while (!stream.ExpectAtEnd()) {
2105       FeatureProtosMap* features = nullptr;
2106       if (stream.ExpectTag(kDelimitedTag(1))) {
2107         // Context
2108         features = context_features;
2109       } else if (stream.ExpectTag(kDelimitedTag(2))) {
2110         // Sequence
2111         features = sequence_features;
2112       } else if (!SkipExtraneousTag(&stream)) {
2113         return errors::InvalidArgument(
2114             "Invalid protocol message input, example id: ",
2115             ExampleName(example_names, d));
2116       }
2117       if (features != nullptr) {
2118         uint32 length;
2119         if (!stream.ReadVarint32(&length)) {
2120           return errors::InvalidArgument(
2121               "Invalid protocol message input, example id: ",
2122               ExampleName(example_names, d));
2123         }
2124         auto limit = stream.PushLimit(length);
2125         while (!stream.ExpectAtEnd()) {
2126           StringPiece key, value;
2127           uint32 length;
2128           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2129               !stream.ReadVarint32(&length)) {
2130             return errors::InvalidArgument(
2131                 "Invalid protocol message input, example id: ",
2132                 ExampleName(example_names, d));
2133           }
2134           auto limit = stream.PushLimit(length);
2135           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2136               !ParseString(&stream, &key) ||
2137               !stream.ExpectTag(kDelimitedTag(2)) ||
2138               !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
2139             return errors::InvalidArgument(
2140                 "Invalid protocol message input, example id: ",
2141                 ExampleName(example_names, d));
2142           }
2143           stream.PopLimit(limit);
2144           // Only save if this feature was requested.
2145           auto feature_iter = features->find(key);
2146           if (feature_iter != features->end()) {
2147             auto& feature = feature_iter->second;
2148             feature.protos[d] = value;
2149             feature.protos_present[d] = true;
2150           }
2151         }
2152         stream.PopLimit(limit);
2153       }
2154     }
2155   }
2156   return OkStatus();
2157 }
2158 
2159 // Populates context_features[k].length based on context_features[k].protos
2160 // (for all k).
GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features)2161 Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2162                                 FeatureProtosMap* context_features) {
2163   for (auto& c : *context_features) {
2164     FeatureProtos& feature = c.second;
2165     for (int d = 0; d < feature.protos.size(); ++d) {
2166       const auto& proto = feature.protos[d];
2167       if (proto.empty()) continue;
2168       protobuf::io::CodedInputStream stream(
2169           reinterpret_cast<const uint8*>(proto.data()), proto.size());
2170       EnableAliasing(&stream);
2171       int num_elements = GetFeatureLength(feature.dtype, &stream);
2172       if (num_elements < 0) {
2173         return errors::InvalidArgument(
2174             "Name: ", ExampleName(example_names, d),
2175             ", Context feature: ", c.first,
2176             ".  Data types don't match. Expected type: ",
2177             DataTypeString(feature.dtype));
2178       }
2179       switch (feature.type) {
2180         case Type::Sparse:  // intentional fall-through
2181         case Type::Ragged:
2182           feature.length += num_elements;
2183           break;
2184         case Type::Dense:
2185           feature.length =
2186               std::max(feature.length, static_cast<size_t>(num_elements));
2187           break;
2188       }
2189     }
2190   }
2191   return OkStatus();
2192 }
2193 
2194 // Populates sequence_features[k].length and sequence_features[k].num_rows based
2195 // on sequence_features[k].protos (for all k).
GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * sequence_features)2196 Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2197                                  FeatureProtosMap* sequence_features) {
2198   for (auto& c : *sequence_features) {
2199     FeatureProtos& feature = c.second;
2200     for (int d = 0; d < feature.protos.size(); ++d) {
2201       const auto& proto = feature.protos[d];
2202       if (proto.empty()) continue;
2203 
2204       size_t num_rows = 0;
2205       size_t num_elements = 0;
2206       protobuf::io::CodedInputStream stream(
2207           reinterpret_cast<const uint8*>(proto.data()), proto.size());
2208       EnableAliasing(&stream);
2209       while (!stream.ExpectAtEnd()) {
2210         uint32 feature_bytes;
2211         if (!stream.ExpectTag(kDelimitedTag(1)) ||
2212             !stream.ReadVarint32(&feature_bytes)) {
2213           return errors::InvalidArgument("Error in sequence feature ", c.first,
2214                                          " in example ",
2215                                          ExampleName(example_names, d));
2216         }
2217         if (feature_bytes > 2) {
2218           auto limit = stream.PushLimit(feature_bytes);
2219           int delta = GetFeatureLength(feature.dtype, &stream);
2220           if (delta < 0) {
2221             return errors::InvalidArgument(
2222                 "Name: ", ExampleName(example_names, d),
2223                 ", Feature list: ", c.first, ", Index: ", num_rows,
2224                 ".  Data types don't match. Expected type: ",
2225                 DataTypeString(feature.dtype));
2226           }
2227           num_elements += delta;
2228           stream.PopLimit(limit);
2229         } else if (feature_bytes == 2) {
2230           if (!SkipEmptyFeature(&stream, feature.dtype)) {
2231             return errors::InvalidArgument(
2232                 "Name: ", ExampleName(example_names, d),
2233                 ", Feature list: ", c.first, ", Index: ", num_rows,
2234                 ".  Data types don't match. Expected type: ",
2235                 DataTypeString(feature.dtype));
2236           }
2237         } else if (feature_bytes != 0) {
2238           return errors::InvalidArgument("Error in sequence feature ", c.first,
2239                                          " in example ",
2240                                          ExampleName(example_names, d));
2241         }
2242         ++num_rows;
2243       }
2244       switch (feature.type) {
2245         case Type::Sparse:
2246           feature.length += num_elements;
2247           break;
2248         case Type::Ragged:
2249           feature.length += num_elements;
2250           feature.num_rows += num_rows;
2251           break;
2252         case Type::Dense:
2253           feature.length = std::max(feature.length, num_elements);
2254           break;
2255       }
2256     }
2257   }
2258   return OkStatus();
2259 }
2260 
2261 // Copies src into dst[dst_offset:dst_offset+src.size], and then increments
2262 // dst_offset by src.size.
CopyTensorIntoTensor(DataType dtype,const Tensor & src,Tensor * dst,size_t * dst_offset)2263 void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst,
2264                           size_t* dst_offset) {
2265   size_t src_size = src.NumElements();
2266   switch (dtype) {
2267     case DT_INT64: {
2268       auto src_t = src.flat<int64_t>().data();
2269       std::copy(src_t, src_t + src_size,
2270                 dst->flat<int64_t>().data() + *dst_offset);
2271       break;
2272     }
2273     case DT_FLOAT: {
2274       auto src_t = src.flat<float>().data();
2275       std::copy(src_t, src_t + src_size,
2276                 dst->flat<float>().data() + *dst_offset);
2277       break;
2278     }
2279     case DT_STRING: {
2280       auto src_t = src.flat<tstring>().data();
2281       std::copy(src_t, src_t + src_size,
2282                 dst->flat<tstring>().data() + *dst_offset);
2283       break;
2284     }
2285     default:
2286       ReportUnexpectedDataType(dtype);
2287   }
2288   *dst_offset += src_size;
2289 }
2290 
2291 // Parses dense features in `context_features`, and writes their parsed
2292 // values to `context_results`.
ParseContextDenseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2293 Status ParseContextDenseFeatures(const FeatureProtosMap& context_features,
2294                                  const FastParseExampleConfig& context_config,
2295                                  gtl::ArraySlice<tstring> example_names,
2296                                  bool is_batch, int num_examples,
2297                                  Allocator* allocator, Result* context_result) {
2298   for (int t = 0; t < context_config.dense.size(); ++t) {
2299     const auto& c = context_config.dense[t];
2300     const FeatureProtos& feature =
2301         context_features.find(c.feature_name)->second;
2302     TensorShape dense_shape, example_shape;
2303     DataType dtype = c.dtype;
2304     const size_t data_max_elements = feature.length;
2305     if (!c.shape.AsTensorShape(&example_shape) ||
2306         data_max_elements != example_shape.num_elements()) {
2307       return errors::InvalidArgument(
2308           "Inconsistent max number of elements for feature ", c.feature_name,
2309           ": expected ", example_shape.num_elements(), ", but found ",
2310           data_max_elements);
2311     }
2312     if (is_batch) {
2313       dense_shape.AddDim(num_examples);
2314     }
2315     for (const int dim : c.shape.dim_sizes()) {
2316       dense_shape.AddDim(dim);
2317     }
2318     context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2319 
2320     Tensor& out = context_result->dense_values[t];
2321     size_t out_offset = 0;
2322 
2323     // Fill in the values.
2324     for (int e = 0; e < num_examples; e++) {
2325       size_t num_elements = 0;
2326       const auto& feature_proto = feature.protos[e];
2327       if (!feature.protos_present[e]) {
2328         // Copy the default value, if present. If not, return an error.
2329         if (c.default_value.NumElements() == 0) {
2330           return errors::InvalidArgument(
2331               "Feature: ", c.feature_name,
2332               " (data type: ", DataTypeString(c.dtype), ")",
2333               " is required but could not be found.");
2334         }
2335         CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset);
2336         num_elements += c.default_value.NumElements();
2337       } else if (!feature_proto.empty()) {
2338         protobuf::io::CodedInputStream stream(
2339             reinterpret_cast<const uint8*>(feature_proto.data()),
2340             feature_proto.size());
2341         EnableAliasing(&stream);
2342         num_elements += ParseFeature(dtype, &stream, &out, &out_offset);
2343       }
2344       if (num_elements != data_max_elements) {
2345         return errors::InvalidArgument(
2346             "Unexpected number of elements in example ",
2347             ExampleName(example_names, e));
2348       }
2349     }
2350   }
2351   return OkStatus();
2352 }
2353 
2354 // Parses sparse features in `context_features`, and writes their parsed
2355 // values to `context_results`.
ParseContextSparseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2356 Status ParseContextSparseFeatures(const FeatureProtosMap& context_features,
2357                                   const FastParseExampleConfig& context_config,
2358                                   gtl::ArraySlice<tstring> example_names,
2359                                   bool is_batch, int num_examples,
2360                                   Allocator* allocator,
2361                                   Result* context_result) {
2362   for (int t = 0; t < context_config.sparse.size(); ++t) {
2363     const auto& c = context_config.sparse[t];
2364     const FeatureProtos& feature =
2365         context_features.find(c.feature_name)->second;
2366     TensorShape indices_shape, values_shape;
2367     DataType dtype = c.dtype;
2368     size_t expected_num_elements = feature.length;
2369     indices_shape.AddDim(expected_num_elements);
2370     indices_shape.AddDim(is_batch ? 2 : 1);
2371     values_shape.AddDim(expected_num_elements);
2372     context_result->sparse_indices[t] =
2373         Tensor(allocator, DT_INT64, indices_shape);
2374     context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2375     context_result->sparse_shapes[t] =
2376         Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1}));
2377     Tensor& out_values = context_result->sparse_values[t];
2378     size_t out_values_offset = 0;
2379     int64_t* out_indices =
2380         context_result->sparse_indices[t].flat<int64_t>().data();
2381     auto out_shape = context_result->sparse_shapes[t].vec<int64_t>();
2382 
2383     // Fill in the values.
2384     size_t num_elements = 0;
2385     size_t max_num_cols = 0;
2386     for (int e = 0; e < num_examples; e++) {
2387       const auto& feature_proto = feature.protos[e];
2388       if (feature_proto.empty()) continue;
2389       protobuf::io::CodedInputStream stream(
2390           reinterpret_cast<const uint8*>(feature_proto.data()),
2391           feature_proto.size());
2392       EnableAliasing(&stream);
2393       size_t num_added =
2394           ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2395       num_elements += num_added;
2396       max_num_cols = std::max(max_num_cols, num_added);
2397       for (int i = 0; i < num_added; i++) {
2398         if (is_batch) *out_indices++ = e;
2399         *out_indices++ = i;
2400       }
2401     }
2402     if (num_elements != expected_num_elements) {
2403       return errors::InvalidArgument(
2404           "Unexpected total number of elements in feature ", c.feature_name);
2405     }
2406     if (is_batch) {
2407       out_shape(0) = num_examples;
2408       out_shape(1) = max_num_cols;
2409     } else {
2410       out_shape(0) = max_num_cols;
2411     }
2412   }
2413   return OkStatus();
2414 }
2415 
2416 // Parses ragged features in `context_features`, and writes their parsed
2417 // values to `context_results`.
ParseContextRaggedFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2418 Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features,
2419                                   const FastParseExampleConfig& context_config,
2420                                   gtl::ArraySlice<tstring> example_names,
2421                                   bool is_batch, int num_examples,
2422                                   Allocator* allocator,
2423                                   Result* context_result) {
2424   for (int t = 0; t < context_config.ragged.size(); ++t) {
2425     const auto& c = context_config.ragged[t];
2426     const FeatureProtos& feature =
2427         context_features.find(c.feature_name)->second;
2428     TensorShape values_shape, splits_shape;
2429     DataType dtype = c.dtype;
2430     DataType splits_dtype = c.splits_dtype;
2431     size_t expected_num_elements = feature.length;
2432     values_shape.AddDim(expected_num_elements);
2433     if (is_batch) {
2434       splits_shape.AddDim(num_examples + 1);
2435     }
2436     context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2437     context_result->ragged_splits[t] =
2438         Tensor(allocator, splits_dtype, splits_shape);
2439     Tensor& out_values = context_result->ragged_values[t];
2440     size_t out_values_offset = 0;
2441     int32* int32_splits =
2442         is_batch && splits_dtype == DT_INT32
2443             ? context_result->ragged_splits[t].vec<int32>().data()
2444             : nullptr;
2445     int64_t* int64_splits =
2446         is_batch && splits_dtype == DT_INT64
2447             ? context_result->ragged_splits[t].vec<int64_t>().data()
2448             : nullptr;
2449     if (int32_splits) {
2450       *int32_splits++ = 0;
2451     } else if (int64_splits) {
2452       *int64_splits++ = 0;
2453     }
2454 
2455     // Fill in the values.
2456     size_t split = 0;  // = total number of elements we've seen so far
2457     for (int e = 0; e < num_examples; e++) {
2458       const auto& feature_proto = feature.protos[e];
2459       if (!feature_proto.empty()) {
2460         protobuf::io::CodedInputStream stream(
2461             reinterpret_cast<const uint8*>(feature_proto.data()),
2462             feature_proto.size());
2463         EnableAliasing(&stream);
2464         size_t num_added =
2465             ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2466         split += num_added;
2467       }
2468       if (int32_splits) {
2469         *int32_splits++ = split;
2470       } else if (int64_splits) {
2471         *int64_splits++ = split;
2472       }
2473     }
2474     if (split != expected_num_elements) {
2475       return errors::InvalidArgument(
2476           "Unexpected total number of elements in feature ", c.feature_name);
2477     }
2478     if (int32_splits || int64_splits) {
2479       int actual_splits =
2480           int32_splits
2481               ? int32_splits -
2482                     context_result->ragged_splits[t].vec<int32>().data()
2483               : int64_splits -
2484                     context_result->ragged_splits[t].vec<int64_t>().data();
2485       if (actual_splits != num_examples + 1) {
2486         return errors::InvalidArgument(
2487             "Unexpected number of examples for feature ", c.feature_name);
2488       }
2489     }
2490   }
2491   return OkStatus();
2492 }
2493 
2494 // Parses dense features in `sequence_features`, and writes their parsed
2495 // values to `sequence_result`.
ParseSequenceDenseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths)2496 Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features,
2497                                   const FastParseExampleConfig& sequence_config,
2498                                   gtl::ArraySlice<tstring> example_names,
2499                                   bool is_batch, int num_examples,
2500                                   Allocator* allocator, Result* sequence_result,
2501                                   std::vector<Tensor>* dense_feature_lengths) {
2502   TensorShape dense_length_shape;
2503   if (is_batch) {
2504     dense_length_shape.AddDim(num_examples);
2505   }
2506   for (int t = 0; t < sequence_config.dense.size(); ++t) {
2507     const auto& c = sequence_config.dense[t];
2508     const FeatureProtos& feature =
2509         sequence_features.find(c.feature_name)->second;
2510     TensorShape dense_shape, row_shape;
2511     DataType dtype = c.dtype;
2512     const size_t expected_max_elements = feature.length;
2513     if (!c.shape.AsTensorShape(&row_shape) ||
2514         expected_max_elements !=
2515             (expected_max_elements / row_shape.num_elements()) *
2516                 row_shape.num_elements()) {
2517       PartialTensorShape total_shape = row_shape;
2518       total_shape.InsertDim(0, -1);
2519       return errors::InvalidArgument(
2520           "Feature list '", c.feature_name,
2521           "' has an unexpected number of values.  Total values size: ",
2522           expected_max_elements,
2523           " is not consistent with output shape: ", total_shape.DebugString());
2524     }
2525     int64_t expected_max_rows =
2526         expected_max_elements / row_shape.num_elements();
2527     if (is_batch) {
2528       dense_shape.AddDim(num_examples);
2529     }
2530     dense_shape.AddDim(expected_max_rows);
2531     for (const int dim : sequence_config.dense[t].shape.dim_sizes()) {
2532       dense_shape.AddDim(dim);
2533     }
2534     sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2535     (*dense_feature_lengths)[t] =
2536         Tensor(allocator, DT_INT64, dense_length_shape);
2537     int64_t* out_lengths = (*dense_feature_lengths)[t].flat<int64_t>().data();
2538 
2539     tstring* out_bytes = nullptr;
2540     float* out_float = nullptr;
2541     int64_t* out_int64 = nullptr;
2542     switch (dtype) {
2543       case DT_STRING:
2544         out_bytes = sequence_result->dense_values[t].flat<tstring>().data();
2545         break;
2546       case DT_FLOAT:
2547         out_float = sequence_result->dense_values[t].flat<float>().data();
2548         break;
2549       case DT_INT64:
2550         out_int64 = sequence_result->dense_values[t].flat<int64_t>().data();
2551         break;
2552       default:
2553         ReportUnexpectedDataType(dtype);
2554     }
2555 
2556     // Fill in the values.
2557     for (int e = 0; e < num_examples; e++) {
2558       size_t num_elements = 0, num_rows = 0;
2559       const auto& feature_proto = feature.protos[e];
2560       if (!feature.protos_present[e]) {
2561         // Return an error if this feature was not allowed to be missing.
2562         // Otherwise, we'll pad as needed below.
2563         if (!c.variable_length) {
2564           return errors::InvalidArgument(
2565               "Name: ", ExampleName(example_names, e), ", Feature list '",
2566               c.feature_name,
2567               "' is required but could not be found.  "
2568               "Did you mean to include it in "
2569               "feature_list_dense_missing_assumed_empty or "
2570               "feature_list_dense_defaults?");
2571         }
2572       } else if (!feature_proto.empty()) {
2573         protobuf::io::CodedInputStream stream(
2574             reinterpret_cast<const uint8*>(feature_proto.data()),
2575             feature_proto.size());
2576         EnableAliasing(&stream);
2577         while (!stream.ExpectAtEnd()) {
2578           uint32 feature_length;
2579           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2580               !stream.ReadVarint32(&feature_length)) {
2581             return errors::InvalidArgument("Error in sequence feature ",
2582                                            c.feature_name, " in example ",
2583                                            ExampleName(example_names, e));
2584           }
2585           auto limit = stream.PushLimit(feature_length);
2586           int num_added = 0;
2587           if (feature_length > 2) {
2588             switch (dtype) {
2589               case DT_STRING:
2590                 num_added = ParseBytesFeature(&stream, out_bytes);
2591                 out_bytes += num_added;
2592                 break;
2593               case DT_FLOAT:
2594                 num_added = ParseFloatFeature(&stream, out_float);
2595                 out_float += num_added;
2596                 break;
2597               case DT_INT64:
2598                 num_added = ParseInt64Feature(&stream, out_int64);
2599                 out_int64 += num_added;
2600                 break;
2601               default:
2602                 ReportUnexpectedDataType(dtype);
2603                 num_added = 0;
2604             }
2605             if (num_added < 0) {
2606               // This should be unreachable -- we already scanned the feature in
2607               // GetSequenceFeatureLengths, and it hasn't changed since then.
2608               return errors::InvalidArgument("Error in sequence feature ",
2609                                              c.feature_name, " in example ",
2610                                              ExampleName(example_names, e));
2611             }
2612           }
2613           if (num_added != row_shape.num_elements()) {
2614             return errors::InvalidArgument(
2615                 "Name: ", ExampleName(example_names, e),
2616                 ", Key: ", c.feature_name, ", Index: ", num_rows,
2617                 ".  Number of values != expected.  values size: ", num_added,
2618                 " but output shape: ", row_shape.DebugString());
2619           }
2620           num_elements += num_added;
2621           num_rows++;
2622           stream.PopLimit(limit);
2623         }
2624       }
2625       *out_lengths++ = num_rows;
2626       // Pad as necessary.
2627       int num_to_pad = expected_max_elements - num_elements;
2628       switch (dtype) {
2629         case DT_STRING:
2630           out_bytes += num_to_pad;
2631           break;
2632         case DT_FLOAT:
2633           PadFloatFeature(num_to_pad, out_float);
2634           out_float += num_to_pad;
2635           break;
2636         case DT_INT64:
2637           PadInt64Feature(num_to_pad, out_int64);
2638           out_int64 += num_to_pad;
2639           break;
2640         default:
2641           ReportUnexpectedDataType(dtype);
2642       }
2643     }
2644   }
2645   return OkStatus();
2646 }
2647 
2648 // Parses sparse features in `sequence_features`, and writes their parsed
2649 // values to `sequence_result`.
ParseSequenceSparseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2650 Status ParseSequenceSparseFeatures(
2651     const FeatureProtosMap& sequence_features,
2652     const FastParseExampleConfig& sequence_config,
2653     gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2654     Allocator* allocator, Result* sequence_result) {
2655   for (int t = 0; t < sequence_config.sparse.size(); ++t) {
2656     const auto& c = sequence_config.sparse[t];
2657     const FeatureProtos& feature =
2658         sequence_features.find(c.feature_name)->second;
2659     TensorShape indices_shape, values_shape;
2660     DataType dtype = c.dtype;
2661     size_t expected_num_elements = feature.length;
2662     indices_shape.AddDim(expected_num_elements);
2663     indices_shape.AddDim(is_batch ? 3 : 2);
2664     values_shape.AddDim(expected_num_elements);
2665     sequence_result->sparse_indices[t] =
2666         Tensor(allocator, DT_INT64, indices_shape);
2667     sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2668     sequence_result->sparse_shapes[t] =
2669         Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2}));
2670 
2671     tstring* out_bytes = nullptr;
2672     float* out_float = nullptr;
2673     int64_t* out_int64 = nullptr;
2674     switch (dtype) {
2675       case DT_STRING:
2676         out_bytes = sequence_result->sparse_values[t].flat<tstring>().data();
2677         break;
2678       case DT_FLOAT:
2679         out_float = sequence_result->sparse_values[t].flat<float>().data();
2680         break;
2681       case DT_INT64:
2682         out_int64 = sequence_result->sparse_values[t].flat<int64_t>().data();
2683         break;
2684       default:
2685         ReportUnexpectedDataType(dtype);
2686     }
2687     int64_t* out_indices =
2688         sequence_result->sparse_indices[t].flat<int64_t>().data();
2689     auto out_shape = sequence_result->sparse_shapes[t].vec<int64_t>();
2690 
2691     // Fill in the values.
2692     size_t num_elements = 0;
2693     size_t max_num_rows = 0;
2694     size_t max_num_cols = 0;
2695     for (int e = 0; e < num_examples; e++) {
2696       const auto& feature_proto = feature.protos[e];
2697       if (feature_proto.empty()) continue;
2698       protobuf::io::CodedInputStream stream(
2699           reinterpret_cast<const uint8*>(feature_proto.data()),
2700           feature_proto.size());
2701       EnableAliasing(&stream);
2702       size_t num_rows = 0;
2703       while (!stream.ExpectAtEnd()) {
2704         uint32 feature_length;
2705         if (!stream.ExpectTag(kDelimitedTag(1)) ||
2706             !stream.ReadVarint32(&feature_length)) {
2707           // This should be unreachable -- we already scanned the feature in
2708           // GetSequenceFeatureLengths, and it hasn't changed since then.
2709           return errors::InvalidArgument("Error in sequence feature ",
2710                                          c.feature_name, " in example ",
2711                                          ExampleName(example_names, e));
2712         }
2713         if (feature_length > 2) {
2714           auto limit = stream.PushLimit(feature_length);
2715           size_t num_added;
2716           switch (dtype) {
2717             case DT_STRING:
2718               num_added = ParseBytesFeature(&stream, out_bytes);
2719               out_bytes += num_added;
2720               break;
2721             case DT_FLOAT:
2722               num_added = ParseFloatFeature(&stream, out_float);
2723               out_float += num_added;
2724               break;
2725             case DT_INT64:
2726               num_added = ParseInt64Feature(&stream, out_int64);
2727               out_int64 += num_added;
2728               break;
2729             default:
2730               ReportUnexpectedDataType(dtype);
2731               num_added = 0;
2732           }
2733           num_elements += num_added;
2734           max_num_cols = std::max(max_num_cols, num_added);
2735           for (int i = 0; i < num_added; i++) {
2736             if (is_batch) *out_indices++ = e;
2737             *out_indices++ = num_rows;
2738             *out_indices++ = i;
2739           }
2740           stream.PopLimit(limit);
2741         } else if (feature_length == 2) {
2742           if (!SkipEmptyFeature(&stream, dtype)) {
2743             // This should be unreachable -- we already scanned the feature in
2744             // GetSequenceFeatureLengths, and it hasn't changed since then.
2745             return errors::InvalidArgument("Error in sequence feature ",
2746                                            c.feature_name, " in example ",
2747                                            ExampleName(example_names, e));
2748           }
2749         } else if (feature_length != 0) {
2750           // This should be unreachable -- we already scanned the feature in
2751           // GetSequenceFeatureLengths, and it hasn't changed since then.
2752           return errors::InvalidArgument("Error in sequence feature ",
2753                                          c.feature_name, " in example ",
2754                                          ExampleName(example_names, e));
2755         }
2756         num_rows++;
2757       }
2758       max_num_rows = std::max(max_num_rows, num_rows);
2759     }
2760     if (num_elements != expected_num_elements) {
2761       return errors::InvalidArgument(
2762           "Unexpected number of elements in feature ", c.feature_name);
2763     }
2764     if (is_batch) {
2765       out_shape(0) = num_examples;
2766       out_shape(1) = max_num_rows;
2767       out_shape(2) = max_num_cols;
2768     } else {
2769       out_shape(0) = max_num_rows;
2770       out_shape(1) = max_num_cols;
2771     }
2772   }
2773   return OkStatus();
2774 }
2775 
2776 // Parses ragged features in `sequence_features`, and writes their parsed
2777 // values to `sequence_result`.
ParseSequenceRaggedFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2778 Status ParseSequenceRaggedFeatures(
2779     const FeatureProtosMap& sequence_features,
2780     const FastParseExampleConfig& sequence_config,
2781     gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2782     Allocator* allocator, Result* sequence_result) {
2783   for (int t = 0; t < sequence_config.ragged.size(); ++t) {
2784     const auto& c = sequence_config.ragged[t];
2785     const FeatureProtos& feature =
2786         sequence_features.find(c.feature_name)->second;
2787     TensorShape values_shape, inner_splits_shape, outer_splits_shape;
2788     DataType dtype = c.dtype;
2789     DataType splits_dtype = c.splits_dtype;
2790     size_t expected_num_elements = feature.length;
2791     size_t expected_num_rows = feature.num_rows;
2792     values_shape.AddDim(expected_num_elements);
2793     inner_splits_shape.AddDim(expected_num_rows + 1);
2794     if (is_batch) {
2795       outer_splits_shape.AddDim(num_examples + 1);
2796     }
2797     sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2798     sequence_result->ragged_splits[t] =
2799         Tensor(allocator, splits_dtype, inner_splits_shape);
2800     sequence_result->ragged_outer_splits[t] =
2801         Tensor(allocator, splits_dtype, outer_splits_shape);
2802     Tensor& out_values = sequence_result->ragged_values[t];
2803     size_t out_values_offset = 0;
2804     int32* int32_inner_splits =
2805         splits_dtype == DT_INT32
2806             ? sequence_result->ragged_splits[t].vec<int32>().data()
2807             : nullptr;
2808     int64_t* int64_inner_splits =
2809         splits_dtype == DT_INT64
2810             ? sequence_result->ragged_splits[t].vec<int64_t>().data()
2811             : nullptr;
2812     int32* int32_outer_splits =
2813         is_batch && splits_dtype == DT_INT32
2814             ? sequence_result->ragged_outer_splits[t].vec<int32>().data()
2815             : nullptr;
2816     int64_t* int64_outer_splits =
2817         is_batch && splits_dtype == DT_INT64
2818             ? sequence_result->ragged_outer_splits[t].vec<int64_t>().data()
2819             : nullptr;
2820     if (int32_inner_splits) {
2821       *int32_inner_splits++ = 0;
2822     } else if (int64_inner_splits) {
2823       *int64_inner_splits++ = 0;
2824     }
2825     if (int32_outer_splits) {
2826       *int32_outer_splits++ = 0;
2827     } else if (int64_outer_splits) {
2828       *int64_outer_splits++ = 0;
2829     }
2830 
2831     // Fill in the values.
2832     size_t inner_split = 0;  // total number of elements we've seen so far
2833     size_t outer_split = 0;  // total number of rows we've seen so far
2834     for (int e = 0; e < num_examples; e++) {
2835       const auto& feature_proto = feature.protos[e];
2836       if (!feature_proto.empty()) {
2837         protobuf::io::CodedInputStream stream(
2838             reinterpret_cast<const uint8*>(feature_proto.data()),
2839             feature_proto.size());
2840         EnableAliasing(&stream);
2841         while (!stream.ExpectAtEnd()) {
2842           uint32 feature_length;
2843           if (!stream.ExpectTag(kDelimitedTag(1)) ||
2844               !stream.ReadVarint32(&feature_length)) {
2845             // This should be unreachable -- we already scanned the feature in
2846             // GetSequenceFeatureLengths, and it hasn't changed since then.
2847             return errors::InvalidArgument("Error in sequence feature ",
2848                                            c.feature_name, " in example ",
2849                                            ExampleName(example_names, e));
2850           }
2851           if (feature_length > 2) {
2852             auto limit = stream.PushLimit(feature_length);
2853             size_t num_added =
2854                 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2855             inner_split += num_added;
2856             stream.PopLimit(limit);
2857           } else if (feature_length == 2) {
2858             if (!SkipEmptyFeature(&stream, dtype)) {
2859               // This should be unreachable -- we already scanned the feature in
2860               // GetSequenceFeatureLengths, and it hasn't changed since then.
2861               return errors::InvalidArgument("Error in sequence feature ",
2862                                              c.feature_name, " in example ",
2863                                              ExampleName(example_names, e));
2864             }
2865           } else if (feature_length != 0) {
2866             // This should be unreachable -- we already scanned the feature in
2867             // GetSequenceFeatureLengths, and it hasn't changed since then.
2868             return errors::InvalidArgument("Error in sequence feature ",
2869                                            c.feature_name, " in example ",
2870                                            ExampleName(example_names, e));
2871           }
2872           if (int32_inner_splits) {
2873             *int32_inner_splits++ = inner_split;
2874           } else if (int64_inner_splits) {
2875             *int64_inner_splits++ = inner_split;
2876           }
2877           outer_split++;
2878         }
2879       }
2880       if (int32_outer_splits) {
2881         *int32_outer_splits++ = outer_split;
2882       } else if (int64_outer_splits) {
2883         *int64_outer_splits++ = outer_split;
2884       }
2885     }
2886     if (outer_split != expected_num_rows) {
2887       return errors::InvalidArgument("Unexpected number of rows for feature ",
2888                                      c.feature_name);
2889     }
2890     if (inner_split != expected_num_elements) {
2891       return errors::InvalidArgument(
2892           "Unexpected number of elements for feature ", c.feature_name);
2893     }
2894 
2895     if (int32_inner_splits || int64_inner_splits) {
2896       const auto& inner_splits = sequence_result->ragged_splits[t];
2897       int num_inner_splits =
2898           int32_inner_splits
2899               ? int32_inner_splits - inner_splits.vec<int32>().data()
2900               : int64_inner_splits - inner_splits.vec<int64_t>().data();
2901       if (num_inner_splits != expected_num_rows + 1) {
2902         return errors::InvalidArgument("Unexpected number of rows for feature ",
2903                                        c.feature_name);
2904       }
2905     }
2906     if (int32_outer_splits || int64_outer_splits) {
2907       const auto& outer_splits = sequence_result->ragged_outer_splits[t];
2908       int num_outer_splits =
2909           int32_outer_splits
2910               ? int32_outer_splits - outer_splits.vec<int32>().data()
2911               : int64_outer_splits - outer_splits.vec<int64_t>().data();
2912       if (num_outer_splits != num_examples + 1) {
2913         return errors::InvalidArgument(
2914             "Unexpected number of examples for feature ", c.feature_name);
2915       }
2916     }
2917   }
2918   return OkStatus();
2919 }
2920 
2921 }  // namespace
2922 
2923 // TODO(sundberg): Use the threadpool to parallelize example parsing.
2924 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths,bool is_batch)2925 Status FastParseSequenceExample(const FastParseExampleConfig& context_config,
2926                                 const FastParseExampleConfig& sequence_config,
2927                                 gtl::ArraySlice<tstring> serialized,
2928                                 gtl::ArraySlice<tstring> example_names,
2929                                 thread::ThreadPool* thread_pool,
2930                                 Result* context_result, Result* sequence_result,
2931                                 std::vector<Tensor>* dense_feature_lengths,
2932                                 bool is_batch) {
2933   int num_examples = serialized.size();
2934   DCHECK(context_result != nullptr);
2935   DCHECK(sequence_result != nullptr);
2936   DCHECK(dense_feature_lengths != nullptr);
2937   size_t num_context_features = context_config.sparse.size() +
2938                                 context_config.dense.size() +
2939                                 context_config.ragged.size();
2940   FeatureProtosMap context_features;
2941   context_features.reserve(num_context_features);
2942 
2943   if (!example_names.empty() && example_names.size() != num_examples) {
2944     return errors::InvalidArgument(
2945         "example_names must be empty or have the correct number of elements");
2946   }
2947   for (auto& c : context_config.sparse) {
2948     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2949     FeatureProtos& feature = context_features[c.feature_name];
2950     feature.dtype = c.dtype;
2951     feature.length = 0;
2952     feature.type = Type::Sparse;
2953     feature.protos.resize(num_examples);
2954     feature.protos_present.resize(num_examples);
2955   }
2956   for (auto& c : context_config.ragged) {
2957     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2958     FeatureProtos& feature = context_features[c.feature_name];
2959     if (feature.type == Type::Sparse) {
2960       return errors::InvalidArgument("Context feature " + c.feature_name +
2961                                      " cannot be both ragged and sparse");
2962     }
2963     feature.dtype = c.dtype;
2964     feature.length = 0;
2965     feature.type = Type::Ragged;
2966     feature.protos.resize(num_examples);
2967     feature.protos_present.resize(num_examples);
2968   }
2969   for (auto& c : context_config.dense) {
2970     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2971     FeatureProtos& feature = context_features[c.feature_name];
2972     if (feature.type != Type::Dense) {
2973       return errors::InvalidArgument("Context feature " + c.feature_name +
2974                                      " cannot be both dense and sparse");
2975     }
2976     if (c.default_value.NumElements() > 0) {
2977       if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
2978         return errors::InvalidArgument("Default value for context feature ",
2979                                        c.feature_name,
2980                                        " has an incorrect shape: saw ",
2981                                        c.default_value.shape().DebugString(),
2982                                        " but expected ", c.shape.DebugString());
2983       }
2984     }
2985     feature.dtype = c.dtype;
2986     feature.length = c.default_value.NumElements();
2987     feature.protos.resize(num_examples);
2988     feature.protos_present.resize(num_examples);
2989   }
2990   size_t num_sequence_features = sequence_config.sparse.size() +
2991                                  sequence_config.dense.size() +
2992                                  sequence_config.ragged.size();
2993   FeatureProtosMap sequence_features;
2994   sequence_features.reserve(num_sequence_features);
2995   for (auto& c : sequence_config.sparse) {
2996     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2997     FeatureProtos& feature = sequence_features[c.feature_name];
2998     feature.dtype = c.dtype;
2999     feature.length = 0;
3000     feature.type = Type::Sparse;
3001     feature.protos.resize(num_examples);
3002     feature.protos_present.resize(num_examples);
3003   }
3004   for (auto& c : sequence_config.ragged) {
3005     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3006     FeatureProtos& feature = sequence_features[c.feature_name];
3007     if (feature.type == Type::Sparse) {
3008       return errors::InvalidArgument("Sequence feature " + c.feature_name +
3009                                      " cannot be both ragged and sparse");
3010     }
3011     feature.dtype = c.dtype;
3012     feature.length = 0;
3013     feature.type = Type::Ragged;
3014     feature.protos.resize(num_examples);
3015     feature.protos_present.resize(num_examples);
3016   }
3017   for (auto& c : sequence_config.dense) {
3018     TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3019     FeatureProtos& feature = sequence_features[c.feature_name];
3020     if (feature.type != Type::Dense) {
3021       return errors::InvalidArgument("Sequence feature " + c.feature_name +
3022                                      " cannot be both dense and sparse");
3023     }
3024     feature.dtype = c.dtype;
3025     feature.length = 0;
3026     feature.protos.resize(num_examples);
3027     feature.protos_present.resize(num_examples);
3028   }
3029 
3030   // Find the serialized proto substrings for each feature.
3031   TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples(
3032       serialized, example_names, &context_features, &sequence_features));
3033 
3034   // Scan through the protos to determine how much memory we need to allocate.
3035   TF_RETURN_IF_ERROR(
3036       GetContextFeatureLengths(example_names, &context_features));
3037   TF_RETURN_IF_ERROR(
3038       GetSequenceFeatureLengths(example_names, &sequence_features));
3039 
3040   // Allocate memory.
3041   context_result->sparse_values.resize(context_config.sparse.size());
3042   context_result->sparse_indices.resize(context_config.sparse.size());
3043   context_result->sparse_shapes.resize(context_config.sparse.size());
3044   context_result->dense_values.resize(context_config.dense.size());
3045   context_result->ragged_values.resize(context_config.ragged.size());
3046   context_result->ragged_splits.resize(context_config.ragged.size());
3047   context_result->ragged_outer_splits.resize(context_config.ragged.size());
3048   sequence_result->sparse_values.resize(sequence_config.sparse.size());
3049   sequence_result->sparse_indices.resize(sequence_config.sparse.size());
3050   sequence_result->sparse_shapes.resize(sequence_config.sparse.size());
3051   sequence_result->dense_values.resize(sequence_config.dense.size());
3052   sequence_result->ragged_values.resize(sequence_config.ragged.size());
3053   sequence_result->ragged_splits.resize(sequence_config.ragged.size());
3054   sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size());
3055   dense_feature_lengths->resize(sequence_config.dense.size());
3056 
3057   // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction,
3058   // to avoid lock contention in `tensorflow::cpu_allocator()`.
3059   Allocator* allocator = tensorflow::cpu_allocator();
3060 
3061   TF_RETURN_IF_ERROR(ParseContextDenseFeatures(
3062       context_features, context_config, example_names, is_batch, num_examples,
3063       allocator, context_result));
3064   TF_RETURN_IF_ERROR(ParseContextSparseFeatures(
3065       context_features, context_config, example_names, is_batch, num_examples,
3066       allocator, context_result));
3067   TF_RETURN_IF_ERROR(ParseContextRaggedFeatures(
3068       context_features, context_config, example_names, is_batch, num_examples,
3069       allocator, context_result));
3070   TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures(
3071       sequence_features, sequence_config, example_names, is_batch, num_examples,
3072       allocator, sequence_result, dense_feature_lengths));
3073   TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures(
3074       sequence_features, sequence_config, example_names, is_batch, num_examples,
3075       allocator, sequence_result));
3076   TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures(
3077       sequence_features, sequence_config, example_names, is_batch, num_examples,
3078       allocator, sequence_result));
3079 
3080   return OkStatus();
3081 }
3082 
3083 }  // namespace example
3084 }  // namespace tensorflow
3085