• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
16 #define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
17 #include <algorithm>
18 #include <vector>
19 
20 #include "absl/base/casts.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/numeric_op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/monitoring/counter.h"
34 #include "tensorflow/core/platform/byte_order.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/core/util/example_proto_fast_parsing.h"
38 #include "tensorflow/core/util/presized_cuckoo_map.h"
39 #include "tensorflow/core/util/sparse/sparse_tensor.h"
40 
41 namespace tensorflow {
42 namespace example {
43 
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46 
47 template <typename T>
48 class LimitedArraySlice {
49  public:
50   using value_type = T;
51 
LimitedArraySlice(T * begin,size_t num_elements)52   LimitedArraySlice(T* begin, size_t num_elements)
53       : current_(begin), begin_(begin), end_(begin + num_elements) {}
54 
55   // May return negative if there were push_back calls after slice was filled.
EndDistance()56   int64_t EndDistance() const { return end_ - current_; }
57 
58   // Attempts to push value to the back of this. If the slice has
59   // already been filled, this method has no effect on the underlying data, but
60   // it changes the number returned by EndDistance into negative values.
push_back(T && value)61   void push_back(T&& value) {
62     if (EndDistance() > 0) *current_ = std::move(value);
63     ++current_;
64   }
65 
66   // "Constructs" an element at the back of this by resizing the slice, and
67   // returns a mutable reference to the new last element.
68   // REQUIRES: EndDistance() > 0.
construct_at_end()69   T& construct_at_end() {
70     DCHECK_GT(EndDistance(), 0);
71     return *(current_++);
72   }
73 
74   // Returns a mutable reference to the last element in the slice.
75   // REQUIRES: size() > 0.
back()76   T& back() { return *(current_ - 1); }
77 
78   // Returns the number of elements in the slice.
size()79   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80 
81   // Attempts to resize the vector to the given size. It does so by advancing
82   // the pointer to the current element, possibly beyond the end of the slice.
83   // As a consequence, calling `size()` after `resize(x)` was called might
84   // return a value less than `x`.
resize(size_t size)85   void resize(size_t size) { current_ = begin_ + size; }
86 
87   // Returns the pointer to the underlying data buffer.
data()88   T* data() { return begin_; }
89 
90  private:
91   T* current_;
92   T* begin_;
93   T* end_;
94 };
95 
96 template <typename A>
97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98   a->EnableAliasing(true);
99 }
100 
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103 
104 uint8 PeekTag(protobuf::io::CodedInputStream* stream);
105 
kVarintTag(uint32 tag)106 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)107 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)108 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
109 
110 namespace parsed {
111 
112 // ParseDataType has to be called first, then appropriate ParseZzzzList.
113 class Feature {
114  public:
Feature()115   Feature() {}
Feature(StringPiece serialized)116   explicit Feature(StringPiece serialized) : serialized_(serialized) {}
117 
ParseDataType(DataType * dtype)118   Status ParseDataType(DataType* dtype) {
119     DCHECK(dtype != nullptr);
120     if (serialized_.empty()) {
121       *dtype = DT_INVALID;
122       return OkStatus();
123     }
124     uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
125     serialized_.remove_prefix(1);
126     switch (oneof_tag) {
127       case kDelimitedTag(1):
128         *dtype = DT_STRING;
129         break;
130       case kDelimitedTag(2):
131         *dtype = DT_FLOAT;
132         break;
133       case kDelimitedTag(3):
134         *dtype = DT_INT64;
135         break;
136       default:
137         // Initialize variable to avoid compiler warning
138         *dtype = DT_INVALID;
139         return errors::InvalidArgument("Unsupported datatype.");
140     }
141     return OkStatus();
142   }
143 
GetNumElementsInBytesList(int * num_elements)144   bool GetNumElementsInBytesList(int* num_elements) {
145     protobuf::io::CodedInputStream stream(
146         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
147     EnableAliasing(&stream);
148     uint32 length = 0;
149     if (!stream.ReadVarint32(&length)) return false;
150     auto limit = stream.PushLimit(length);
151     *num_elements = 0;
152     while (!stream.ExpectAtEnd()) {
153       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
154       uint32 bytes_length = 0;
155       if (!stream.ReadVarint32(&bytes_length)) return false;
156       if (!stream.Skip(bytes_length)) return false;
157       ++*num_elements;
158     }
159     stream.PopLimit(limit);
160     return true;
161   }
162 
163   // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)164   tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
165     if (bytes_list->EndDistance() <= 0) {
166       return nullptr;
167     }
168     return &bytes_list->construct_at_end();
169   }
construct_at_end(SmallVector<tstring> * bytes_list)170   tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
171     return &bytes_list->emplace_back();
172   }
173 
174   template <typename Result>
ParseBytesList(Result * bytes_list)175   bool ParseBytesList(Result* bytes_list) {
176     DCHECK(bytes_list != nullptr);
177 
178     protobuf::io::CodedInputStream stream(
179         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
180 
181     EnableAliasing(&stream);
182 
183     uint32 length;
184     if (!stream.ReadVarint32(&length)) return false;
185     auto limit = stream.PushLimit(length);
186 
187     while (!stream.ExpectAtEnd()) {
188       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
189       // parse string
190       uint32 bytes_length;
191       if (!stream.ReadVarint32(&bytes_length)) return false;
192       tstring* bytes = construct_at_end(bytes_list);
193       if (bytes == nullptr) return false;
194       bytes->resize_uninitialized(bytes_length);
195       if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
196     }
197     stream.PopLimit(limit);
198     return true;
199   }
200 
201   template <typename Result>
ParseFloatList(Result * float_list)202   bool ParseFloatList(Result* float_list) {
203     DCHECK(float_list != nullptr);
204     protobuf::io::CodedInputStream stream(
205         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
206     EnableAliasing(&stream);
207     uint32 length;
208     if (!stream.ReadVarint32(&length)) return false;
209     auto limit = stream.PushLimit(length);
210 
211     if (!stream.ExpectAtEnd()) {
212       uint8 peek_tag = PeekTag(&stream);
213       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
214         return false;
215       }
216 
217       constexpr int32_t kNumFloatBytes = 4;
218       if (peek_tag == kDelimitedTag(1)) {                       // packed
219         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
220         uint32 packed_length;
221         if (!stream.ReadVarint32(&packed_length)) return false;
222         auto packed_limit = stream.PushLimit(packed_length);
223 
224         // Store the initial size to know the offset we have to start writing
225         // data from before resizing the output "vector".
226         const size_t initial_size = float_list->size();
227         float_list->resize(initial_size + packed_length / kNumFloatBytes);
228 
229         // If the result data type is float and we are on a little endian
230         // machine then we can simply memcpy the data from the proto into the
231         // result vector.
232         if (port::kLittleEndian &&
233             sizeof(typename Result::value_type) == kNumFloatBytes) {
234           // Calculate the length of the buffer available what can be less than
235           // what we requested in resize in case of a LimitedArraySlice.
236           const uint32 bytes_to_copy =
237               std::min(static_cast<uint32>((float_list->size() - initial_size) *
238                                            kNumFloatBytes),
239                        packed_length);
240           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
241             return false;
242         } else {
243           int64_t index = initial_size;
244           while (!stream.ExpectAtEnd()) {
245             uint32 buffer32;
246             if (!stream.ReadLittleEndian32(&buffer32)) return false;
247             if (index < float_list->size()) {
248               float_list->data()[index] = absl::bit_cast<float>(buffer32);
249               ++index;
250             }
251           }
252         }
253 
254         stream.PopLimit(packed_limit);
255       } else {  // non-packed
256         const size_t initial_size = float_list->size();
257         // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
258         // the value.
259         const int64_t num_elements =
260             stream.BytesUntilLimit() / (1 + kNumFloatBytes);
261         float_list->resize(initial_size + num_elements);
262         int64_t index = initial_size;
263         while (!stream.ExpectAtEnd()) {
264           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
265           uint32 buffer32;
266           if (!stream.ReadLittleEndian32(&buffer32)) return false;
267           float_list->data()[index] = absl::bit_cast<float>(buffer32);
268           ++index;
269         }
270       }
271     }
272 
273     stream.PopLimit(limit);
274     return true;
275   }
276 
277   template <typename Result>
ParseInt64List(Result * int64_list)278   bool ParseInt64List(Result* int64_list) {
279     DCHECK(int64_list != nullptr);
280     protobuf::io::CodedInputStream stream(
281         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
282     EnableAliasing(&stream);
283     uint32 length;
284     if (!stream.ReadVarint32(&length)) return false;
285     auto limit = stream.PushLimit(length);
286 
287     if (!stream.ExpectAtEnd()) {
288       uint8 peek_tag = PeekTag(&stream);
289       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
290         return false;
291       }
292       if (peek_tag == kDelimitedTag(1)) {                       // packed
293         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
294         uint32 packed_length;
295         if (!stream.ReadVarint32(&packed_length)) return false;
296         auto packed_limit = stream.PushLimit(packed_length);
297 
298         while (!stream.ExpectAtEnd()) {
299           protobuf_uint64 n;  // There is no API for int64
300           if (!stream.ReadVarint64(&n)) return false;
301           int64_list->push_back(static_cast<int64_t>(n));
302         }
303 
304         stream.PopLimit(packed_limit);
305       } else {  // non-packed
306         while (!stream.ExpectAtEnd()) {
307           if (!stream.ExpectTag(kVarintTag(1))) return false;
308           protobuf_uint64 n;  // There is no API for int64
309           if (!stream.ReadVarint64(&n)) return false;
310           int64_list->push_back(static_cast<int64_t>(n));
311         }
312       }
313     }
314     stream.PopLimit(limit);
315     return true;
316   }
317 
GetSerialized()318   StringPiece GetSerialized() const { return serialized_; }
319 
320  private:
321   StringPiece serialized_;
322 };
323 
324 using FeatureMapEntry = std::pair<StringPiece, Feature>;
325 using Example = std::vector<FeatureMapEntry>;
326 
327 }  // namespace parsed
328 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)329 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
330   uint32 data;
331   protobuf_uint64 dummy;
332   switch (stream->ReadTag() & 0x7) {
333     case 0:  // varint
334       if (!stream->ReadVarint32(&data)) return false;
335       return true;
336     case 1:  // fixed64
337       if (!stream->ReadLittleEndian64(&dummy)) return false;
338       return true;
339     case 2:  // length delimited
340       if (!stream->ReadVarint32(&data)) return false;
341       stream->Skip(data);
342       return true;
343     case 3:          // group begin
344       return false;  // groups not supported.
345     case 4:          // group end
346       return false;  // groups not supported.
347     case 5:          // fixed32
348       if (!stream->ReadLittleEndian32(&data)) return false;
349       return true;
350   }
351   return false;  // unrecognized tag type
352 }
353 
354 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
355 
356 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
357                           parsed::FeatureMapEntry* feature_map_entry);
358 
359 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
360                    parsed::Example* example);
361 
362 bool ParseExample(protobuf::io::CodedInputStream* stream,
363                   parsed::Example* example);
364 
365 bool ParseExample(StringPiece serialized, parsed::Example* example);
366 
367 using Config = FastParseExampleConfig;
368 
369 // Enumeration for distinguishing feature types.
370 // Note: FastParseSequenceExample constructs a map that includes Type values,
371 // and relies on the fact that they are default-initialized to Dense.
372 enum class Type { Dense, Sparse, Ragged };
373 
374 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
375 struct SparseBuffer {
376   // Features are in one of the 3 vectors below depending on config's dtype.
377   // Other 2 vectors remain empty.
378   SmallVector<tstring> bytes_list;
379   SmallVector<float> float_list;
380   SmallVector<int64_t> int64_list;
381 
382   // Features of example i are elements with indices
383   // from example_end_indices[i-1] to example_end_indices[i]-1 on the
384   // appropriate xxxxx_list
385   std::vector<size_t> example_end_indices;
386 };
387 
388 struct SeededHasher {
operatorSeededHasher389   uint64 operator()(StringPiece s) const {
390     return Hash64(s.data(), s.size(), seed);
391   }
392   uint64 seed{0xDECAFCAFFE};
393 };
394 
395 // Use this in the "default" clause of switch statements when dispatching
396 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)397 inline void ReportUnexpectedDataType(DataType dtype) {
398   DCHECK(false)
399       << "Encountered unexpected DataType " << DataTypeString(dtype)
400       << "in variable that should have been checked by CheckConfigDataType().";
401 }
402 
403 template <typename T>
404 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
405 
406 template <>
407 const SmallVector<int64_t>& GetListFromBuffer<int64_t>(
408     const SparseBuffer& buffer);
409 
410 template <>
411 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
412 
413 template <>
414 const SmallVector<tstring>& GetListFromBuffer<tstring>(
415     const SparseBuffer& buffer);
416 
417 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)418 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
419   std::copy(b, e, t);
420 }
421 template <>
422 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
423 
424 void CountSparseFeatures(
425     const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
426     size_t* total_num_features, size_t* max_num_features);
427 
428 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
429                               Tensor* dst);
430 
431 // A struct used by FastParseSequenceExample to hold the serialized proto
432 // substrings for a single feature, plus some auxiliary information derived
433 // from those protos (such as the total value length).
434 struct FeatureProtos {
435   // Proto substrings from each serialized SequenceExample that correspond
436   // with this feature.  `protos_present` records whether the proto had a
437   // value defined (even if that value is empty).
438   std::vector<StringPiece> protos;
439   std::vector<bool> protos_present;
440 
441   // Information derived from protos:
442   size_t length;    // total length for ragged/sparse, max row length for dense.
443   size_t num_rows;  // only populated for ragged sequence features.
444 
445   // Information from the config:
446   Type type;  // Whether this feature is sparse, ragged, or dense.
447   DataType dtype;
448 };
449 
450 // Map from feature name to FeatureProtos for that feature.
451 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
452 
453 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
454 
455 // Return the number of bytes elements parsed, or -1 on error. If out is null,
456 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)457 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
458                              tstring* out) {
459   int num_elements = 0;
460   uint32 length;
461   if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
462     return -1;
463   }
464   if (length > 0) {
465     auto limit = stream->PushLimit(length);
466     while (!stream->ExpectAtEnd()) {
467       uint32 bytes_length;
468       if (!stream->ExpectTag(kDelimitedTag(1)) ||
469           !stream->ReadVarint32(&bytes_length)) {
470         return -1;
471       }
472       if (out == nullptr) {
473         stream->Skip(bytes_length);
474       } else {
475         out->resize_uninitialized(bytes_length);
476         if (!stream->ReadRaw(out->data(), bytes_length)) {
477           return -1;
478         }
479         out++;
480       }
481       num_elements++;
482     }
483     stream->PopLimit(limit);
484   }
485   return num_elements;
486 }
487 
PadFloatFeature(int num_to_pad,float * out)488 inline void PadFloatFeature(int num_to_pad, float* out) {
489   for (int i = 0; i < num_to_pad; i++) {
490     *out++ = 0.0;
491   }
492 }
493 
PadInt64Feature(int num_to_pad,int64_t * out)494 inline void PadInt64Feature(int num_to_pad, int64_t* out) {
495   for (int i = 0; i < num_to_pad; i++) {
496     *out++ = 0;
497   }
498 }
499 
500 // Return the number of float elements parsed, or -1 on error. If out is null,
501 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)502 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
503                              float* out) {
504   int num_elements = 0;
505   uint32 length;
506   if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
507     return -1;
508   }
509   if (length > 0) {
510     auto limit = stream->PushLimit(length);
511     uint8 peek_tag = PeekTag(stream);
512     if (peek_tag == kDelimitedTag(1)) {  // packed
513       uint32 packed_length;
514       if (!stream->ExpectTag(kDelimitedTag(1)) ||
515           !stream->ReadVarint32(&packed_length)) {
516         return -1;
517       }
518       auto packed_limit = stream->PushLimit(packed_length);
519       while (!stream->ExpectAtEnd()) {
520         uint32 buffer32;
521         if (!stream->ReadLittleEndian32(&buffer32)) {
522           return -1;
523         }
524         if (out != nullptr) {
525           *out++ = absl::bit_cast<float>(buffer32);
526         }
527         num_elements++;
528       }
529       stream->PopLimit(packed_limit);
530     } else if (peek_tag == kFixed32Tag(1)) {
531       while (!stream->ExpectAtEnd()) {
532         uint32 buffer32;
533         if (!stream->ExpectTag(kFixed32Tag(1)) ||
534             !stream->ReadLittleEndian32(&buffer32)) {
535           return -1;
536         }
537         if (out != nullptr) {
538           *out++ = absl::bit_cast<float>(buffer32);
539         }
540         num_elements++;
541       }
542     } else {
543       // Unknown tag.
544       return -1;
545     }
546     stream->PopLimit(limit);
547   }
548   return num_elements;
549 }
550 
551 // Return the number of int64 elements parsed, or -1 on error. If out is null,
552 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64_t * out)553 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
554                              int64_t* out) {
555   int num_elements = 0;
556   uint32 length;
557   if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
558     return -1;
559   }
560   if (length > 0) {
561     auto limit = stream->PushLimit(length);
562     uint8 peek_tag = PeekTag(stream);
563     if (peek_tag == kDelimitedTag(1)) {  // packed
564       uint32 packed_length;
565       if (!stream->ExpectTag(kDelimitedTag(1)) ||
566           !stream->ReadVarint32(&packed_length)) {
567         return -1;
568       }
569       auto packed_limit = stream->PushLimit(packed_length);
570       while (!stream->ExpectAtEnd()) {
571         protobuf_uint64 n;  // There is no API for int64
572         if (!stream->ReadVarint64(&n)) {
573           return -1;
574         }
575         if (out != nullptr) {
576           *out++ = n;
577         }
578         num_elements++;
579       }
580       stream->PopLimit(packed_limit);
581     } else if (peek_tag == kVarintTag(1)) {
582       while (!stream->ExpectAtEnd()) {
583         protobuf_uint64 n;  // There is no API for int64
584         if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
585           return -1;
586         }
587         if (out != nullptr) {
588           *out++ = n;
589         }
590         num_elements++;
591       }
592     } else {
593       // Unknown tag.
594       return -1;
595     }
596     stream->PopLimit(limit);
597   }
598   return num_elements;
599 }
600 
601 // Parses the next feature on `stream` into `out` starting at `out_offset`.
602 // Updates `out_offset`, and returns the number of values added.
603 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)604 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
605                         Tensor* out, size_t* out_offset) {
606   int delta;
607   switch (dtype) {
608     case DT_STRING:
609       delta =
610           ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
611       break;
612     case DT_FLOAT:
613       delta =
614           ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
615       break;
616     case DT_INT64:
617       delta =
618           ParseInt64Feature(stream, out->flat<int64_t>().data() + *out_offset);
619       break;
620     default:
621       ReportUnexpectedDataType(dtype);
622       delta = 0;
623   }
624   if (delta > 0) {
625     *out_offset += delta;
626   }
627   return delta;
628 }
629 
630 // Returns the length of the next feature on `stream`.
631 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)632 inline int GetFeatureLength(DataType dtype,
633                             protobuf::io::CodedInputStream* stream) {
634   switch (dtype) {
635     case DT_STRING:
636       return ParseBytesFeature(stream, nullptr);
637     case DT_FLOAT:
638       return ParseFloatFeature(stream, nullptr);
639     case DT_INT64:
640       return ParseInt64Feature(stream, nullptr);
641     default:
642       ReportUnexpectedDataType(dtype);
643       return -1;
644   }
645 }
646 
ParseDataType(protobuf::io::CodedInputStream * stream)647 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
648   uint8 peek_tag = PeekTag(stream);
649   switch (peek_tag) {
650     case kDelimitedTag(1):
651       return DT_STRING;
652     case kDelimitedTag(2):
653       return DT_FLOAT;
654     case kDelimitedTag(3):
655       return DT_INT64;
656     default:
657       return DT_INVALID;
658   }
659 }
660 
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)661 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
662                              DataType dtype) {
663   switch (dtype) {
664     case DT_STRING:
665       if (!stream->ExpectTag(kDelimitedTag(1))) {
666         return false;
667       }
668       break;
669     case DT_FLOAT:
670       if (!stream->ExpectTag(kDelimitedTag(2))) {
671         return false;
672       }
673       break;
674     case DT_INT64:
675       if (!stream->ExpectTag(kDelimitedTag(3))) {
676         return false;
677       }
678       break;
679     default:
680       return false;
681   }
682   uint32 length;
683   return stream->ReadVarint32(&length) && length == 0;
684 }
685 
686 }  // namespace example
687 }  // namespace tensorflow
688 
689 #endif  // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
690