1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
16 #define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
17 #include <algorithm>
18 #include <vector>
19
20 #include "absl/base/casts.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/numeric_op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/threadpool.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/monitoring/counter.h"
34 #include "tensorflow/core/platform/byte_order.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/core/util/example_proto_fast_parsing.h"
38 #include "tensorflow/core/util/presized_cuckoo_map.h"
39 #include "tensorflow/core/util/sparse/sparse_tensor.h"
40
41 namespace tensorflow {
42 namespace example {
43
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46
47 template <typename T>
48 class LimitedArraySlice {
49 public:
50 using value_type = T;
51
LimitedArraySlice(T * begin,size_t num_elements)52 LimitedArraySlice(T* begin, size_t num_elements)
53 : current_(begin), begin_(begin), end_(begin + num_elements) {}
54
55 // May return negative if there were push_back calls after slice was filled.
EndDistance()56 int64_t EndDistance() const { return end_ - current_; }
57
58 // Attempts to push value to the back of this. If the slice has
59 // already been filled, this method has no effect on the underlying data, but
60 // it changes the number returned by EndDistance into negative values.
push_back(T && value)61 void push_back(T&& value) {
62 if (EndDistance() > 0) *current_ = std::move(value);
63 ++current_;
64 }
65
66 // "Constructs" an element at the back of this by resizing the slice, and
67 // returns a mutable reference to the new last element.
68 // REQUIRES: EndDistance() > 0.
construct_at_end()69 T& construct_at_end() {
70 DCHECK_GT(EndDistance(), 0);
71 return *(current_++);
72 }
73
74 // Returns a mutable reference to the last element in the slice.
75 // REQUIRES: size() > 0.
back()76 T& back() { return *(current_ - 1); }
77
78 // Returns the number of elements in the slice.
size()79 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80
81 // Attempts to resize the vector to the given size. It does so by advancing
82 // the pointer to the current element, possibly beyond the end of the slice.
83 // As a consequence, calling `size()` after `resize(x)` was called might
84 // return a value less than `x`.
resize(size_t size)85 void resize(size_t size) { current_ = begin_ + size; }
86
87 // Returns the pointer to the underlying data buffer.
data()88 T* data() { return begin_; }
89
90 private:
91 T* current_;
92 T* begin_;
93 T* end_;
94 };
95
96 template <typename A>
97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98 a->EnableAliasing(true);
99 }
100
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103
104 uint8 PeekTag(protobuf::io::CodedInputStream* stream);
105
kVarintTag(uint32 tag)106 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)107 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)108 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
109
110 namespace parsed {
111
112 // ParseDataType has to be called first, then appropriate ParseZzzzList.
113 class Feature {
114 public:
Feature()115 Feature() {}
Feature(StringPiece serialized)116 explicit Feature(StringPiece serialized) : serialized_(serialized) {}
117
ParseDataType(DataType * dtype)118 Status ParseDataType(DataType* dtype) {
119 DCHECK(dtype != nullptr);
120 if (serialized_.empty()) {
121 *dtype = DT_INVALID;
122 return OkStatus();
123 }
124 uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
125 serialized_.remove_prefix(1);
126 switch (oneof_tag) {
127 case kDelimitedTag(1):
128 *dtype = DT_STRING;
129 break;
130 case kDelimitedTag(2):
131 *dtype = DT_FLOAT;
132 break;
133 case kDelimitedTag(3):
134 *dtype = DT_INT64;
135 break;
136 default:
137 // Initialize variable to avoid compiler warning
138 *dtype = DT_INVALID;
139 return errors::InvalidArgument("Unsupported datatype.");
140 }
141 return OkStatus();
142 }
143
GetNumElementsInBytesList(int * num_elements)144 bool GetNumElementsInBytesList(int* num_elements) {
145 protobuf::io::CodedInputStream stream(
146 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
147 EnableAliasing(&stream);
148 uint32 length = 0;
149 if (!stream.ReadVarint32(&length)) return false;
150 auto limit = stream.PushLimit(length);
151 *num_elements = 0;
152 while (!stream.ExpectAtEnd()) {
153 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
154 uint32 bytes_length = 0;
155 if (!stream.ReadVarint32(&bytes_length)) return false;
156 if (!stream.Skip(bytes_length)) return false;
157 ++*num_elements;
158 }
159 stream.PopLimit(limit);
160 return true;
161 }
162
163 // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)164 tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
165 if (bytes_list->EndDistance() <= 0) {
166 return nullptr;
167 }
168 return &bytes_list->construct_at_end();
169 }
construct_at_end(SmallVector<tstring> * bytes_list)170 tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
171 return &bytes_list->emplace_back();
172 }
173
174 template <typename Result>
ParseBytesList(Result * bytes_list)175 bool ParseBytesList(Result* bytes_list) {
176 DCHECK(bytes_list != nullptr);
177
178 protobuf::io::CodedInputStream stream(
179 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
180
181 EnableAliasing(&stream);
182
183 uint32 length;
184 if (!stream.ReadVarint32(&length)) return false;
185 auto limit = stream.PushLimit(length);
186
187 while (!stream.ExpectAtEnd()) {
188 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
189 // parse string
190 uint32 bytes_length;
191 if (!stream.ReadVarint32(&bytes_length)) return false;
192 tstring* bytes = construct_at_end(bytes_list);
193 if (bytes == nullptr) return false;
194 bytes->resize_uninitialized(bytes_length);
195 if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
196 }
197 stream.PopLimit(limit);
198 return true;
199 }
200
201 template <typename Result>
ParseFloatList(Result * float_list)202 bool ParseFloatList(Result* float_list) {
203 DCHECK(float_list != nullptr);
204 protobuf::io::CodedInputStream stream(
205 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
206 EnableAliasing(&stream);
207 uint32 length;
208 if (!stream.ReadVarint32(&length)) return false;
209 auto limit = stream.PushLimit(length);
210
211 if (!stream.ExpectAtEnd()) {
212 uint8 peek_tag = PeekTag(&stream);
213 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
214 return false;
215 }
216
217 constexpr int32_t kNumFloatBytes = 4;
218 if (peek_tag == kDelimitedTag(1)) { // packed
219 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
220 uint32 packed_length;
221 if (!stream.ReadVarint32(&packed_length)) return false;
222 auto packed_limit = stream.PushLimit(packed_length);
223
224 // Store the initial size to know the offset we have to start writing
225 // data from before resizing the output "vector".
226 const size_t initial_size = float_list->size();
227 float_list->resize(initial_size + packed_length / kNumFloatBytes);
228
229 // If the result data type is float and we are on a little endian
230 // machine then we can simply memcpy the data from the proto into the
231 // result vector.
232 if (port::kLittleEndian &&
233 sizeof(typename Result::value_type) == kNumFloatBytes) {
234 // Calculate the length of the buffer available what can be less than
235 // what we requested in resize in case of a LimitedArraySlice.
236 const uint32 bytes_to_copy =
237 std::min(static_cast<uint32>((float_list->size() - initial_size) *
238 kNumFloatBytes),
239 packed_length);
240 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
241 return false;
242 } else {
243 int64_t index = initial_size;
244 while (!stream.ExpectAtEnd()) {
245 uint32 buffer32;
246 if (!stream.ReadLittleEndian32(&buffer32)) return false;
247 if (index < float_list->size()) {
248 float_list->data()[index] = absl::bit_cast<float>(buffer32);
249 ++index;
250 }
251 }
252 }
253
254 stream.PopLimit(packed_limit);
255 } else { // non-packed
256 const size_t initial_size = float_list->size();
257 // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
258 // the value.
259 const int64_t num_elements =
260 stream.BytesUntilLimit() / (1 + kNumFloatBytes);
261 float_list->resize(initial_size + num_elements);
262 int64_t index = initial_size;
263 while (!stream.ExpectAtEnd()) {
264 if (!stream.ExpectTag(kFixed32Tag(1))) return false;
265 uint32 buffer32;
266 if (!stream.ReadLittleEndian32(&buffer32)) return false;
267 float_list->data()[index] = absl::bit_cast<float>(buffer32);
268 ++index;
269 }
270 }
271 }
272
273 stream.PopLimit(limit);
274 return true;
275 }
276
277 template <typename Result>
ParseInt64List(Result * int64_list)278 bool ParseInt64List(Result* int64_list) {
279 DCHECK(int64_list != nullptr);
280 protobuf::io::CodedInputStream stream(
281 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
282 EnableAliasing(&stream);
283 uint32 length;
284 if (!stream.ReadVarint32(&length)) return false;
285 auto limit = stream.PushLimit(length);
286
287 if (!stream.ExpectAtEnd()) {
288 uint8 peek_tag = PeekTag(&stream);
289 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
290 return false;
291 }
292 if (peek_tag == kDelimitedTag(1)) { // packed
293 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
294 uint32 packed_length;
295 if (!stream.ReadVarint32(&packed_length)) return false;
296 auto packed_limit = stream.PushLimit(packed_length);
297
298 while (!stream.ExpectAtEnd()) {
299 protobuf_uint64 n; // There is no API for int64
300 if (!stream.ReadVarint64(&n)) return false;
301 int64_list->push_back(static_cast<int64_t>(n));
302 }
303
304 stream.PopLimit(packed_limit);
305 } else { // non-packed
306 while (!stream.ExpectAtEnd()) {
307 if (!stream.ExpectTag(kVarintTag(1))) return false;
308 protobuf_uint64 n; // There is no API for int64
309 if (!stream.ReadVarint64(&n)) return false;
310 int64_list->push_back(static_cast<int64_t>(n));
311 }
312 }
313 }
314 stream.PopLimit(limit);
315 return true;
316 }
317
GetSerialized()318 StringPiece GetSerialized() const { return serialized_; }
319
320 private:
321 StringPiece serialized_;
322 };
323
324 using FeatureMapEntry = std::pair<StringPiece, Feature>;
325 using Example = std::vector<FeatureMapEntry>;
326
327 } // namespace parsed
328
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)329 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
330 uint32 data;
331 protobuf_uint64 dummy;
332 switch (stream->ReadTag() & 0x7) {
333 case 0: // varint
334 if (!stream->ReadVarint32(&data)) return false;
335 return true;
336 case 1: // fixed64
337 if (!stream->ReadLittleEndian64(&dummy)) return false;
338 return true;
339 case 2: // length delimited
340 if (!stream->ReadVarint32(&data)) return false;
341 stream->Skip(data);
342 return true;
343 case 3: // group begin
344 return false; // groups not supported.
345 case 4: // group end
346 return false; // groups not supported.
347 case 5: // fixed32
348 if (!stream->ReadLittleEndian32(&data)) return false;
349 return true;
350 }
351 return false; // unrecognized tag type
352 }
353
354 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
355
356 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
357 parsed::FeatureMapEntry* feature_map_entry);
358
359 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
360 parsed::Example* example);
361
362 bool ParseExample(protobuf::io::CodedInputStream* stream,
363 parsed::Example* example);
364
365 bool ParseExample(StringPiece serialized, parsed::Example* example);
366
367 using Config = FastParseExampleConfig;
368
369 // Enumeration for distinguishing feature types.
370 // Note: FastParseSequenceExample constructs a map that includes Type values,
371 // and relies on the fact that they are default-initialized to Dense.
372 enum class Type { Dense, Sparse, Ragged };
373
374 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
375 struct SparseBuffer {
376 // Features are in one of the 3 vectors below depending on config's dtype.
377 // Other 2 vectors remain empty.
378 SmallVector<tstring> bytes_list;
379 SmallVector<float> float_list;
380 SmallVector<int64_t> int64_list;
381
382 // Features of example i are elements with indices
383 // from example_end_indices[i-1] to example_end_indices[i]-1 on the
384 // appropriate xxxxx_list
385 std::vector<size_t> example_end_indices;
386 };
387
388 struct SeededHasher {
operatorSeededHasher389 uint64 operator()(StringPiece s) const {
390 return Hash64(s.data(), s.size(), seed);
391 }
392 uint64 seed{0xDECAFCAFFE};
393 };
394
395 // Use this in the "default" clause of switch statements when dispatching
396 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)397 inline void ReportUnexpectedDataType(DataType dtype) {
398 DCHECK(false)
399 << "Encountered unexpected DataType " << DataTypeString(dtype)
400 << "in variable that should have been checked by CheckConfigDataType().";
401 }
402
403 template <typename T>
404 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
405
406 template <>
407 const SmallVector<int64_t>& GetListFromBuffer<int64_t>(
408 const SparseBuffer& buffer);
409
410 template <>
411 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
412
413 template <>
414 const SmallVector<tstring>& GetListFromBuffer<tstring>(
415 const SparseBuffer& buffer);
416
417 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)418 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
419 std::copy(b, e, t);
420 }
421 template <>
422 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
423
424 void CountSparseFeatures(
425 const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
426 size_t* total_num_features, size_t* max_num_features);
427
428 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
429 Tensor* dst);
430
431 // A struct used by FastParseSequenceExample to hold the serialized proto
432 // substrings for a single feature, plus some auxiliary information derived
433 // from those protos (such as the total value length).
434 struct FeatureProtos {
435 // Proto substrings from each serialized SequenceExample that correspond
436 // with this feature. `protos_present` records whether the proto had a
437 // value defined (even if that value is empty).
438 std::vector<StringPiece> protos;
439 std::vector<bool> protos_present;
440
441 // Information derived from protos:
442 size_t length; // total length for ragged/sparse, max row length for dense.
443 size_t num_rows; // only populated for ragged sequence features.
444
445 // Information from the config:
446 Type type; // Whether this feature is sparse, ragged, or dense.
447 DataType dtype;
448 };
449
450 // Map from feature name to FeatureProtos for that feature.
451 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
452
453 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
454
455 // Return the number of bytes elements parsed, or -1 on error. If out is null,
456 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)457 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
458 tstring* out) {
459 int num_elements = 0;
460 uint32 length;
461 if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
462 return -1;
463 }
464 if (length > 0) {
465 auto limit = stream->PushLimit(length);
466 while (!stream->ExpectAtEnd()) {
467 uint32 bytes_length;
468 if (!stream->ExpectTag(kDelimitedTag(1)) ||
469 !stream->ReadVarint32(&bytes_length)) {
470 return -1;
471 }
472 if (out == nullptr) {
473 stream->Skip(bytes_length);
474 } else {
475 out->resize_uninitialized(bytes_length);
476 if (!stream->ReadRaw(out->data(), bytes_length)) {
477 return -1;
478 }
479 out++;
480 }
481 num_elements++;
482 }
483 stream->PopLimit(limit);
484 }
485 return num_elements;
486 }
487
PadFloatFeature(int num_to_pad,float * out)488 inline void PadFloatFeature(int num_to_pad, float* out) {
489 for (int i = 0; i < num_to_pad; i++) {
490 *out++ = 0.0;
491 }
492 }
493
PadInt64Feature(int num_to_pad,int64_t * out)494 inline void PadInt64Feature(int num_to_pad, int64_t* out) {
495 for (int i = 0; i < num_to_pad; i++) {
496 *out++ = 0;
497 }
498 }
499
500 // Return the number of float elements parsed, or -1 on error. If out is null,
501 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)502 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
503 float* out) {
504 int num_elements = 0;
505 uint32 length;
506 if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
507 return -1;
508 }
509 if (length > 0) {
510 auto limit = stream->PushLimit(length);
511 uint8 peek_tag = PeekTag(stream);
512 if (peek_tag == kDelimitedTag(1)) { // packed
513 uint32 packed_length;
514 if (!stream->ExpectTag(kDelimitedTag(1)) ||
515 !stream->ReadVarint32(&packed_length)) {
516 return -1;
517 }
518 auto packed_limit = stream->PushLimit(packed_length);
519 while (!stream->ExpectAtEnd()) {
520 uint32 buffer32;
521 if (!stream->ReadLittleEndian32(&buffer32)) {
522 return -1;
523 }
524 if (out != nullptr) {
525 *out++ = absl::bit_cast<float>(buffer32);
526 }
527 num_elements++;
528 }
529 stream->PopLimit(packed_limit);
530 } else if (peek_tag == kFixed32Tag(1)) {
531 while (!stream->ExpectAtEnd()) {
532 uint32 buffer32;
533 if (!stream->ExpectTag(kFixed32Tag(1)) ||
534 !stream->ReadLittleEndian32(&buffer32)) {
535 return -1;
536 }
537 if (out != nullptr) {
538 *out++ = absl::bit_cast<float>(buffer32);
539 }
540 num_elements++;
541 }
542 } else {
543 // Unknown tag.
544 return -1;
545 }
546 stream->PopLimit(limit);
547 }
548 return num_elements;
549 }
550
551 // Return the number of int64 elements parsed, or -1 on error. If out is null,
552 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64_t * out)553 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
554 int64_t* out) {
555 int num_elements = 0;
556 uint32 length;
557 if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
558 return -1;
559 }
560 if (length > 0) {
561 auto limit = stream->PushLimit(length);
562 uint8 peek_tag = PeekTag(stream);
563 if (peek_tag == kDelimitedTag(1)) { // packed
564 uint32 packed_length;
565 if (!stream->ExpectTag(kDelimitedTag(1)) ||
566 !stream->ReadVarint32(&packed_length)) {
567 return -1;
568 }
569 auto packed_limit = stream->PushLimit(packed_length);
570 while (!stream->ExpectAtEnd()) {
571 protobuf_uint64 n; // There is no API for int64
572 if (!stream->ReadVarint64(&n)) {
573 return -1;
574 }
575 if (out != nullptr) {
576 *out++ = n;
577 }
578 num_elements++;
579 }
580 stream->PopLimit(packed_limit);
581 } else if (peek_tag == kVarintTag(1)) {
582 while (!stream->ExpectAtEnd()) {
583 protobuf_uint64 n; // There is no API for int64
584 if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
585 return -1;
586 }
587 if (out != nullptr) {
588 *out++ = n;
589 }
590 num_elements++;
591 }
592 } else {
593 // Unknown tag.
594 return -1;
595 }
596 stream->PopLimit(limit);
597 }
598 return num_elements;
599 }
600
601 // Parses the next feature on `stream` into `out` starting at `out_offset`.
602 // Updates `out_offset`, and returns the number of values added.
603 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)604 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
605 Tensor* out, size_t* out_offset) {
606 int delta;
607 switch (dtype) {
608 case DT_STRING:
609 delta =
610 ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
611 break;
612 case DT_FLOAT:
613 delta =
614 ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
615 break;
616 case DT_INT64:
617 delta =
618 ParseInt64Feature(stream, out->flat<int64_t>().data() + *out_offset);
619 break;
620 default:
621 ReportUnexpectedDataType(dtype);
622 delta = 0;
623 }
624 if (delta > 0) {
625 *out_offset += delta;
626 }
627 return delta;
628 }
629
630 // Returns the length of the next feature on `stream`.
631 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)632 inline int GetFeatureLength(DataType dtype,
633 protobuf::io::CodedInputStream* stream) {
634 switch (dtype) {
635 case DT_STRING:
636 return ParseBytesFeature(stream, nullptr);
637 case DT_FLOAT:
638 return ParseFloatFeature(stream, nullptr);
639 case DT_INT64:
640 return ParseInt64Feature(stream, nullptr);
641 default:
642 ReportUnexpectedDataType(dtype);
643 return -1;
644 }
645 }
646
ParseDataType(protobuf::io::CodedInputStream * stream)647 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
648 uint8 peek_tag = PeekTag(stream);
649 switch (peek_tag) {
650 case kDelimitedTag(1):
651 return DT_STRING;
652 case kDelimitedTag(2):
653 return DT_FLOAT;
654 case kDelimitedTag(3):
655 return DT_INT64;
656 default:
657 return DT_INVALID;
658 }
659 }
660
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)661 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
662 DataType dtype) {
663 switch (dtype) {
664 case DT_STRING:
665 if (!stream->ExpectTag(kDelimitedTag(1))) {
666 return false;
667 }
668 break;
669 case DT_FLOAT:
670 if (!stream->ExpectTag(kDelimitedTag(2))) {
671 return false;
672 }
673 break;
674 case DT_INT64:
675 if (!stream->ExpectTag(kDelimitedTag(3))) {
676 return false;
677 }
678 break;
679 default:
680 return false;
681 }
682 uint32 length;
683 return stream->ReadVarint32(&length) && length == 0;
684 }
685
686 } // namespace example
687 } // namespace tensorflow
688
689 #endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
690