• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
16 #define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
17 #include "tensorflow/core/util/example_proto_fast_parsing.h"
18 
19 #include <vector>
20 
21 #include "absl/base/casts.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/core/example/example.pb.h"
24 #include "tensorflow/core/example/feature.pb.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/numeric_op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/lib/monitoring/counter.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/protobuf.h"
38 #include "tensorflow/core/util/presized_cuckoo_map.h"
39 #include "tensorflow/core/util/sparse/sparse_tensor.h"
40 
41 namespace tensorflow {
42 namespace example {
43 
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46 
47 template <typename T>
48 class LimitedArraySlice {
49  public:
50   using value_type = T;
51 
LimitedArraySlice(T * begin,size_t num_elements)52   LimitedArraySlice(T* begin, size_t num_elements)
53       : current_(begin), begin_(begin), end_(begin + num_elements) {}
54 
55   // May return negative if there were push_back calls after slice was filled.
EndDistance()56   int64 EndDistance() const { return end_ - current_; }
57 
58   // Attempts to push value to the back of this. If the slice has
59   // already been filled, this method has no effect on the underlying data, but
60   // it changes the number returned by EndDistance into negative values.
push_back(T && value)61   void push_back(T&& value) {
62     if (EndDistance() > 0) *current_ = std::move(value);
63     ++current_;
64   }
65 
66   // "Constructs" an element at the back of this by resizing the slice, and
67   // returns a mutable reference to the new last element.
68   // REQUIRES: EndDistance() > 0.
construct_at_end()69   T& construct_at_end() {
70     DCHECK_GT(EndDistance(), 0);
71     return *(current_++);
72   }
73 
74   // Returns a mutable reference to the last element in the slice.
75   // REQUIRES: size() > 0.
back()76   T& back() { return *(current_ - 1); }
77 
78   // Returns the number of elements in the slice.
size()79   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80 
81   // Attempts to resize the vector to the given size. It does so by advancing
82   // the pointer to the current element, possibly beyond the end of the slice.
83   // As a consequence, calling `size()` after `resize(x)` was called might
84   // return a value less than `x`.
resize(size_t size)85   void resize(size_t size) { current_ = begin_ + size; }
86 
87   // Returns the pointer to the underlying data buffer.
data()88   T* data() { return begin_; }
89 
90  private:
91   T* current_;
92   T* begin_;
93   T* end_;
94 };
95 
96 template <typename A>
97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98   a->EnableAliasing(true);
99 }
100 
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103 
104 uint8 PeekTag(protobuf::io::CodedInputStream* stream);
105 
kVarintTag(uint32 tag)106 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)107 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)108 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
109 
110 namespace parsed {
111 
112 // ParseDataType has to be called first, then appropriate ParseZzzzList.
113 class Feature {
114  public:
Feature()115   Feature() {}
Feature(StringPiece serialized)116   explicit Feature(StringPiece serialized) : serialized_(serialized) {}
117 
ParseDataType(DataType * dtype)118   Status ParseDataType(DataType* dtype) {
119     DCHECK(dtype != nullptr);
120     if (serialized_.empty()) {
121       *dtype = DT_INVALID;
122       return Status::OK();
123     }
124     uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
125     serialized_.remove_prefix(1);
126     switch (oneof_tag) {
127       case kDelimitedTag(1):
128         *dtype = DT_STRING;
129         break;
130       case kDelimitedTag(2):
131         *dtype = DT_FLOAT;
132         break;
133       case kDelimitedTag(3):
134         *dtype = DT_INT64;
135         break;
136       default:
137         // Initialize variable to avoid compiler warning
138         *dtype = DT_INVALID;
139         return errors::InvalidArgument("Unsupported datatype.");
140     }
141     return Status::OK();
142   }
143 
GetNumElementsInBytesList(int * num_elements)144   bool GetNumElementsInBytesList(int* num_elements) {
145     protobuf::io::CodedInputStream stream(
146         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
147     EnableAliasing(&stream);
148     uint32 length = 0;
149     if (!stream.ReadVarint32(&length)) return false;
150     auto limit = stream.PushLimit(length);
151     *num_elements = 0;
152     while (!stream.ExpectAtEnd()) {
153       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
154       uint32 bytes_length = 0;
155       if (!stream.ReadVarint32(&bytes_length)) return false;
156       if (!stream.Skip(bytes_length)) return false;
157       ++*num_elements;
158     }
159     stream.PopLimit(limit);
160     return true;
161   }
162 
163   // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)164   tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
165     if (bytes_list->EndDistance() <= 0) {
166       return nullptr;
167     }
168     return &bytes_list->construct_at_end();
169   }
construct_at_end(SmallVector<tstring> * bytes_list)170   tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
171     return &bytes_list->emplace_back();
172   }
173 
174   template <typename Result>
ParseBytesList(Result * bytes_list)175   bool ParseBytesList(Result* bytes_list) {
176     DCHECK(bytes_list != nullptr);
177 
178     protobuf::io::CodedInputStream stream(
179         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
180 
181     EnableAliasing(&stream);
182 
183     uint32 length;
184     if (!stream.ReadVarint32(&length)) return false;
185     auto limit = stream.PushLimit(length);
186 
187     while (!stream.ExpectAtEnd()) {
188       if (!stream.ExpectTag(kDelimitedTag(1))) return false;
189       // parse string
190       uint32 bytes_length;
191       if (!stream.ReadVarint32(&bytes_length)) return false;
192       tstring* bytes = construct_at_end(bytes_list);
193       if (bytes == nullptr) return false;
194       bytes->resize_uninitialized(bytes_length);
195       if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
196     }
197     stream.PopLimit(limit);
198     return true;
199   }
200 
201   template <typename Result>
ParseFloatList(Result * float_list)202   bool ParseFloatList(Result* float_list) {
203     DCHECK(float_list != nullptr);
204     protobuf::io::CodedInputStream stream(
205         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
206     EnableAliasing(&stream);
207     uint32 length;
208     if (!stream.ReadVarint32(&length)) return false;
209     auto limit = stream.PushLimit(length);
210 
211     if (!stream.ExpectAtEnd()) {
212       uint8 peek_tag = PeekTag(&stream);
213       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
214         return false;
215       }
216 
217       constexpr int32 kNumFloatBytes = 4;
218       if (peek_tag == kDelimitedTag(1)) {                       // packed
219         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
220         uint32 packed_length;
221         if (!stream.ReadVarint32(&packed_length)) return false;
222         auto packed_limit = stream.PushLimit(packed_length);
223 
224         // Store the initial size to know the offset we have to start writing
225         // data from before resizing the output "vector".
226         const size_t initial_size = float_list->size();
227         float_list->resize(initial_size + packed_length / kNumFloatBytes);
228 
229         // If the result data type is float and we are on a little endian
230         // machine then we can simply memcpy the data from the proto into the
231         // result vector.
232         if (port::kLittleEndian &&
233             sizeof(typename Result::value_type) == kNumFloatBytes) {
234           // Calculate the length of the buffer available what can be less than
235           // what we requested in resize in case of a LimitedArraySlice.
236           const uint32 bytes_to_copy =
237               std::min(static_cast<uint32>((float_list->size() - initial_size) *
238                                            kNumFloatBytes),
239                        packed_length);
240           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
241             return false;
242         } else {
243           int64 index = initial_size;
244           while (!stream.ExpectAtEnd()) {
245             uint32 buffer32;
246             if (!stream.ReadLittleEndian32(&buffer32)) return false;
247             if (index < float_list->size()) {
248               float_list->data()[index] = absl::bit_cast<float>(buffer32);
249               ++index;
250             }
251           }
252         }
253 
254         stream.PopLimit(packed_limit);
255       } else {  // non-packed
256         const size_t initial_size = float_list->size();
257         // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
258         // the value.
259         const int64 num_elements =
260             stream.BytesUntilLimit() / (1 + kNumFloatBytes);
261         float_list->resize(initial_size + num_elements);
262         int64 index = initial_size;
263         while (!stream.ExpectAtEnd()) {
264           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
265           uint32 buffer32;
266           if (!stream.ReadLittleEndian32(&buffer32)) return false;
267           float_list->data()[index] = absl::bit_cast<float>(buffer32);
268           ++index;
269         }
270       }
271     }
272 
273     stream.PopLimit(limit);
274     return true;
275   }
276 
277   template <typename Result>
ParseInt64List(Result * int64_list)278   bool ParseInt64List(Result* int64_list) {
279     DCHECK(int64_list != nullptr);
280     protobuf::io::CodedInputStream stream(
281         reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
282     EnableAliasing(&stream);
283     uint32 length;
284     if (!stream.ReadVarint32(&length)) return false;
285     auto limit = stream.PushLimit(length);
286 
287     if (!stream.ExpectAtEnd()) {
288       uint8 peek_tag = PeekTag(&stream);
289       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
290         return false;
291       }
292       if (peek_tag == kDelimitedTag(1)) {                       // packed
293         if (!stream.ExpectTag(kDelimitedTag(1))) return false;  // packed tag
294         uint32 packed_length;
295         if (!stream.ReadVarint32(&packed_length)) return false;
296         auto packed_limit = stream.PushLimit(packed_length);
297 
298         while (!stream.ExpectAtEnd()) {
299           protobuf_uint64 n;  // There is no API for int64
300           if (!stream.ReadVarint64(&n)) return false;
301           int64_list->push_back(static_cast<int64>(n));
302         }
303 
304         stream.PopLimit(packed_limit);
305       } else {  // non-packed
306         while (!stream.ExpectAtEnd()) {
307           if (!stream.ExpectTag(kVarintTag(1))) return false;
308           protobuf_uint64 n;  // There is no API for int64
309           if (!stream.ReadVarint64(&n)) return false;
310           int64_list->push_back(static_cast<int64>(n));
311         }
312       }
313     }
314     stream.PopLimit(limit);
315     return true;
316   }
317 
GetSerialized()318   StringPiece GetSerialized() const { return serialized_; }
319 
320  private:
321   StringPiece serialized_;
322 };
323 
324 using FeatureMapEntry = std::pair<StringPiece, Feature>;
325 using Example = std::vector<FeatureMapEntry>;
326 
327 }  // namespace parsed
328 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)329 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
330   uint32 data;
331   protobuf_uint64 dummy;
332   switch (stream->ReadTag() & 0x7) {
333     case 0:  // varint
334       if (!stream->ReadVarint32(&data)) return false;
335       return true;
336     case 1:  // fixed64
337       if (!stream->ReadLittleEndian64(&dummy)) return false;
338       return true;
339     case 2:  // length delimited
340       if (!stream->ReadVarint32(&data)) return false;
341       stream->Skip(data);
342       return true;
343     case 3:          // group begin
344       return false;  // groups not supported.
345     case 4:          // group end
346       return false;  // groups not supported.
347     case 5:          // fixed32
348       if (!stream->ReadLittleEndian32(&data)) return false;
349       return true;
350   }
351   return false;  // unrecognized tag type
352 }
353 
354 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
355 
356 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
357                           parsed::FeatureMapEntry* feature_map_entry);
358 
359 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
360                    parsed::Example* example);
361 
362 bool ParseExample(protobuf::io::CodedInputStream* stream,
363                   parsed::Example* example);
364 
365 bool ParseExample(StringPiece serialized, parsed::Example* example);
366 
367 using Config = FastParseExampleConfig;
368 
369 // Enumeration for distinguishing feature types.
370 // Note: FastParseSequenceExample constructs a map that includes Type values,
371 // and relies on the fact that they are default-initialized to Dense.
372 enum class Type { Dense, Sparse, Ragged };
373 
374 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
375 struct SparseBuffer {
376   // Features are in one of the 3 vectors below depending on config's dtype.
377   // Other 2 vectors remain empty.
378   SmallVector<tstring> bytes_list;
379   SmallVector<float> float_list;
380   SmallVector<int64> int64_list;
381 
382   // Features of example i are elements with indices
383   // from example_end_indices[i-1] to example_end_indices[i]-1 on the
384   // appropriate xxxxx_list
385   std::vector<size_t> example_end_indices;
386 };
387 
388 struct SeededHasher {
operatorSeededHasher389   uint64 operator()(StringPiece s) const {
390     return Hash64(s.data(), s.size(), seed);
391   }
392   uint64 seed{0xDECAFCAFFE};
393 };
394 
395 // Use this in the "default" clause of switch statements when dispatching
396 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)397 inline void ReportUnexpectedDataType(DataType dtype) {
398   DCHECK(false)
399       << "Encountered unexpected DataType " << DataTypeString(dtype)
400       << "in variable that should have been checked by CheckConfigDataType().";
401 }
402 
403 template <typename T>
404 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
405 
406 template <>
407 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer);
408 
409 template <>
410 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
411 
412 template <>
413 const SmallVector<tstring>& GetListFromBuffer<tstring>(
414     const SparseBuffer& buffer);
415 
416 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)417 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
418   std::copy(b, e, t);
419 }
420 template <>
421 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
422 
423 void CountSparseFeatures(
424     const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
425     size_t* total_num_features, size_t* max_num_features);
426 
427 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
428                               Tensor* dst);
429 
430 // A struct used by FastParseSequenceExample to hold the serialized proto
431 // substrings for a single feature, plus some auxiliary information derived
432 // from those protos (such as the total value length).
433 struct FeatureProtos {
434   // Proto substrings from each serialized SequenceExample that correspond
435   // with this feature.  `protos_present` records whether the proto had a
436   // value defined (even if that value is empty).
437   std::vector<StringPiece> protos;
438   std::vector<bool> protos_present;
439 
440   // Information derived from protos:
441   size_t length;    // total length for ragged/sparse, max row length for dense.
442   size_t num_rows;  // only populated for ragged sequence features.
443 
444   // Information from the config:
445   Type type;  // Whether this feature is sparse, ragged, or dense.
446   DataType dtype;
447 };
448 
449 // Map from feature name to FeatureProtos for that feature.
450 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
451 
452 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
453 
454 // Return the number of bytes elements parsed, or -1 on error. If out is null,
455 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)456 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
457                              tstring* out) {
458   int num_elements = 0;
459   uint32 length;
460   if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
461     return -1;
462   }
463   if (length > 0) {
464     auto limit = stream->PushLimit(length);
465     while (!stream->ExpectAtEnd()) {
466       uint32 bytes_length;
467       if (!stream->ExpectTag(kDelimitedTag(1)) ||
468           !stream->ReadVarint32(&bytes_length)) {
469         return -1;
470       }
471       if (out == nullptr) {
472         stream->Skip(bytes_length);
473       } else {
474         out->resize_uninitialized(bytes_length);
475         if (!stream->ReadRaw(out->data(), bytes_length)) {
476           return -1;
477         }
478         out++;
479       }
480       num_elements++;
481     }
482     stream->PopLimit(limit);
483   }
484   return num_elements;
485 }
486 
PadFloatFeature(int num_to_pad,float * out)487 inline void PadFloatFeature(int num_to_pad, float* out) {
488   for (int i = 0; i < num_to_pad; i++) {
489     *out++ = 0.0;
490   }
491 }
492 
PadInt64Feature(int num_to_pad,int64 * out)493 inline void PadInt64Feature(int num_to_pad, int64* out) {
494   for (int i = 0; i < num_to_pad; i++) {
495     *out++ = 0;
496   }
497 }
498 
499 // Return the number of float elements parsed, or -1 on error. If out is null,
500 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)501 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
502                              float* out) {
503   int num_elements = 0;
504   uint32 length;
505   if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
506     return -1;
507   }
508   if (length > 0) {
509     auto limit = stream->PushLimit(length);
510     uint8 peek_tag = PeekTag(stream);
511     if (peek_tag == kDelimitedTag(1)) {  // packed
512       uint32 packed_length;
513       if (!stream->ExpectTag(kDelimitedTag(1)) ||
514           !stream->ReadVarint32(&packed_length)) {
515         return -1;
516       }
517       auto packed_limit = stream->PushLimit(packed_length);
518       while (!stream->ExpectAtEnd()) {
519         uint32 buffer32;
520         if (!stream->ReadLittleEndian32(&buffer32)) {
521           return -1;
522         }
523         if (out != nullptr) {
524           *out++ = absl::bit_cast<float>(buffer32);
525         }
526         num_elements++;
527       }
528       stream->PopLimit(packed_limit);
529     } else if (peek_tag == kFixed32Tag(1)) {
530       while (!stream->ExpectAtEnd()) {
531         uint32 buffer32;
532         if (!stream->ExpectTag(kFixed32Tag(1)) ||
533             !stream->ReadLittleEndian32(&buffer32)) {
534           return -1;
535         }
536         if (out != nullptr) {
537           *out++ = absl::bit_cast<float>(buffer32);
538         }
539         num_elements++;
540       }
541     } else {
542       // Unknown tag.
543       return -1;
544     }
545     stream->PopLimit(limit);
546   }
547   return num_elements;
548 }
549 
550 // Return the number of int64 elements parsed, or -1 on error. If out is null,
551 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)552 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
553                              int64* out) {
554   int num_elements = 0;
555   uint32 length;
556   if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
557     return -1;
558   }
559   if (length > 0) {
560     auto limit = stream->PushLimit(length);
561     uint8 peek_tag = PeekTag(stream);
562     if (peek_tag == kDelimitedTag(1)) {  // packed
563       uint32 packed_length;
564       if (!stream->ExpectTag(kDelimitedTag(1)) ||
565           !stream->ReadVarint32(&packed_length)) {
566         return -1;
567       }
568       auto packed_limit = stream->PushLimit(packed_length);
569       while (!stream->ExpectAtEnd()) {
570         protobuf_uint64 n;  // There is no API for int64
571         if (!stream->ReadVarint64(&n)) {
572           return -1;
573         }
574         if (out != nullptr) {
575           *out++ = n;
576         }
577         num_elements++;
578       }
579       stream->PopLimit(packed_limit);
580     } else if (peek_tag == kVarintTag(1)) {
581       while (!stream->ExpectAtEnd()) {
582         protobuf_uint64 n;  // There is no API for int64
583         if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
584           return -1;
585         }
586         if (out != nullptr) {
587           *out++ = n;
588         }
589         num_elements++;
590       }
591     } else {
592       // Unknown tag.
593       return -1;
594     }
595     stream->PopLimit(limit);
596   }
597   return num_elements;
598 }
599 
600 // Parses the next feature on `stream` into `out` starting at `out_offset`.
601 // Updates `out_offset`, and returns the number of values added.
602 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)603 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
604                         Tensor* out, size_t* out_offset) {
605   int delta;
606   switch (dtype) {
607     case DT_STRING:
608       delta =
609           ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
610       break;
611     case DT_FLOAT:
612       delta =
613           ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
614       break;
615     case DT_INT64:
616       delta =
617           ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
618       break;
619     default:
620       ReportUnexpectedDataType(dtype);
621       delta = 0;
622   }
623   if (delta > 0) {
624     *out_offset += delta;
625   }
626   return delta;
627 }
628 
629 // Returns the length of the next feature on `stream`.
630 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)631 inline int GetFeatureLength(DataType dtype,
632                             protobuf::io::CodedInputStream* stream) {
633   switch (dtype) {
634     case DT_STRING:
635       return ParseBytesFeature(stream, nullptr);
636     case DT_FLOAT:
637       return ParseFloatFeature(stream, nullptr);
638     case DT_INT64:
639       return ParseInt64Feature(stream, nullptr);
640     default:
641       ReportUnexpectedDataType(dtype);
642       return -1;
643   }
644 }
645 
ParseDataType(protobuf::io::CodedInputStream * stream)646 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
647   uint8 peek_tag = PeekTag(stream);
648   switch (peek_tag) {
649     case kDelimitedTag(1):
650       return DT_STRING;
651     case kDelimitedTag(2):
652       return DT_FLOAT;
653     case kDelimitedTag(3):
654       return DT_INT64;
655     default:
656       return DT_INVALID;
657   }
658 }
659 
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)660 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
661                              DataType dtype) {
662   switch (dtype) {
663     case DT_STRING:
664       if (!stream->ExpectTag(kDelimitedTag(1))) {
665         return false;
666       }
667       break;
668     case DT_FLOAT:
669       if (!stream->ExpectTag(kDelimitedTag(2))) {
670         return false;
671       }
672       break;
673     case DT_INT64:
674       if (!stream->ExpectTag(kDelimitedTag(3))) {
675         return false;
676       }
677       break;
678     default:
679       return false;
680   }
681   uint32 length;
682   return stream->ReadVarint32(&length) && length == 0;
683 }
684 
685 }  // namespace example
686 }  // namespace tensorflow
687 
688 #endif  // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
689