1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
16 #define TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
17 #include "tensorflow/core/util/example_proto_fast_parsing.h"
18
19 #include <vector>
20
21 #include "absl/base/casts.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/core/example/example.pb.h"
24 #include "tensorflow/core/example/feature.pb.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/numeric_op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/lib/monitoring/counter.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/protobuf.h"
38 #include "tensorflow/core/util/presized_cuckoo_map.h"
39 #include "tensorflow/core/util/sparse/sparse_tensor.h"
40
41 namespace tensorflow {
42 namespace example {
43
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46
47 template <typename T>
48 class LimitedArraySlice {
49 public:
50 using value_type = T;
51
LimitedArraySlice(T * begin,size_t num_elements)52 LimitedArraySlice(T* begin, size_t num_elements)
53 : current_(begin), begin_(begin), end_(begin + num_elements) {}
54
55 // May return negative if there were push_back calls after slice was filled.
EndDistance()56 int64 EndDistance() const { return end_ - current_; }
57
58 // Attempts to push value to the back of this. If the slice has
59 // already been filled, this method has no effect on the underlying data, but
60 // it changes the number returned by EndDistance into negative values.
push_back(T && value)61 void push_back(T&& value) {
62 if (EndDistance() > 0) *current_ = std::move(value);
63 ++current_;
64 }
65
66 // "Constructs" an element at the back of this by resizing the slice, and
67 // returns a mutable reference to the new last element.
68 // REQUIRES: EndDistance() > 0.
construct_at_end()69 T& construct_at_end() {
70 DCHECK_GT(EndDistance(), 0);
71 return *(current_++);
72 }
73
74 // Returns a mutable reference to the last element in the slice.
75 // REQUIRES: size() > 0.
back()76 T& back() { return *(current_ - 1); }
77
78 // Returns the number of elements in the slice.
size()79 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80
81 // Attempts to resize the vector to the given size. It does so by advancing
82 // the pointer to the current element, possibly beyond the end of the slice.
83 // As a consequence, calling `size()` after `resize(x)` was called might
84 // return a value less than `x`.
resize(size_t size)85 void resize(size_t size) { current_ = begin_ + size; }
86
87 // Returns the pointer to the underlying data buffer.
data()88 T* data() { return begin_; }
89
90 private:
91 T* current_;
92 T* begin_;
93 T* end_;
94 };
95
96 template <typename A>
97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98 a->EnableAliasing(true);
99 }
100
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103
104 uint8 PeekTag(protobuf::io::CodedInputStream* stream);
105
kVarintTag(uint32 tag)106 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)107 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)108 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
109
110 namespace parsed {
111
112 // ParseDataType has to be called first, then appropriate ParseZzzzList.
113 class Feature {
114 public:
Feature()115 Feature() {}
Feature(StringPiece serialized)116 explicit Feature(StringPiece serialized) : serialized_(serialized) {}
117
ParseDataType(DataType * dtype)118 Status ParseDataType(DataType* dtype) {
119 DCHECK(dtype != nullptr);
120 if (serialized_.empty()) {
121 *dtype = DT_INVALID;
122 return Status::OK();
123 }
124 uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
125 serialized_.remove_prefix(1);
126 switch (oneof_tag) {
127 case kDelimitedTag(1):
128 *dtype = DT_STRING;
129 break;
130 case kDelimitedTag(2):
131 *dtype = DT_FLOAT;
132 break;
133 case kDelimitedTag(3):
134 *dtype = DT_INT64;
135 break;
136 default:
137 // Initialize variable to avoid compiler warning
138 *dtype = DT_INVALID;
139 return errors::InvalidArgument("Unsupported datatype.");
140 }
141 return Status::OK();
142 }
143
GetNumElementsInBytesList(int * num_elements)144 bool GetNumElementsInBytesList(int* num_elements) {
145 protobuf::io::CodedInputStream stream(
146 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
147 EnableAliasing(&stream);
148 uint32 length = 0;
149 if (!stream.ReadVarint32(&length)) return false;
150 auto limit = stream.PushLimit(length);
151 *num_elements = 0;
152 while (!stream.ExpectAtEnd()) {
153 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
154 uint32 bytes_length = 0;
155 if (!stream.ReadVarint32(&bytes_length)) return false;
156 if (!stream.Skip(bytes_length)) return false;
157 ++*num_elements;
158 }
159 stream.PopLimit(limit);
160 return true;
161 }
162
163 // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)164 tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
165 if (bytes_list->EndDistance() <= 0) {
166 return nullptr;
167 }
168 return &bytes_list->construct_at_end();
169 }
construct_at_end(SmallVector<tstring> * bytes_list)170 tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
171 return &bytes_list->emplace_back();
172 }
173
174 template <typename Result>
ParseBytesList(Result * bytes_list)175 bool ParseBytesList(Result* bytes_list) {
176 DCHECK(bytes_list != nullptr);
177
178 protobuf::io::CodedInputStream stream(
179 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
180
181 EnableAliasing(&stream);
182
183 uint32 length;
184 if (!stream.ReadVarint32(&length)) return false;
185 auto limit = stream.PushLimit(length);
186
187 while (!stream.ExpectAtEnd()) {
188 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
189 // parse string
190 uint32 bytes_length;
191 if (!stream.ReadVarint32(&bytes_length)) return false;
192 tstring* bytes = construct_at_end(bytes_list);
193 if (bytes == nullptr) return false;
194 bytes->resize_uninitialized(bytes_length);
195 if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
196 }
197 stream.PopLimit(limit);
198 return true;
199 }
200
201 template <typename Result>
ParseFloatList(Result * float_list)202 bool ParseFloatList(Result* float_list) {
203 DCHECK(float_list != nullptr);
204 protobuf::io::CodedInputStream stream(
205 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
206 EnableAliasing(&stream);
207 uint32 length;
208 if (!stream.ReadVarint32(&length)) return false;
209 auto limit = stream.PushLimit(length);
210
211 if (!stream.ExpectAtEnd()) {
212 uint8 peek_tag = PeekTag(&stream);
213 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
214 return false;
215 }
216
217 constexpr int32 kNumFloatBytes = 4;
218 if (peek_tag == kDelimitedTag(1)) { // packed
219 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
220 uint32 packed_length;
221 if (!stream.ReadVarint32(&packed_length)) return false;
222 auto packed_limit = stream.PushLimit(packed_length);
223
224 // Store the initial size to know the offset we have to start writing
225 // data from before resizing the output "vector".
226 const size_t initial_size = float_list->size();
227 float_list->resize(initial_size + packed_length / kNumFloatBytes);
228
229 // If the result data type is float and we are on a little endian
230 // machine then we can simply memcpy the data from the proto into the
231 // result vector.
232 if (port::kLittleEndian &&
233 sizeof(typename Result::value_type) == kNumFloatBytes) {
234 // Calculate the length of the buffer available what can be less than
235 // what we requested in resize in case of a LimitedArraySlice.
236 const uint32 bytes_to_copy =
237 std::min(static_cast<uint32>((float_list->size() - initial_size) *
238 kNumFloatBytes),
239 packed_length);
240 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
241 return false;
242 } else {
243 int64 index = initial_size;
244 while (!stream.ExpectAtEnd()) {
245 uint32 buffer32;
246 if (!stream.ReadLittleEndian32(&buffer32)) return false;
247 if (index < float_list->size()) {
248 float_list->data()[index] = absl::bit_cast<float>(buffer32);
249 ++index;
250 }
251 }
252 }
253
254 stream.PopLimit(packed_limit);
255 } else { // non-packed
256 const size_t initial_size = float_list->size();
257 // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
258 // the value.
259 const int64 num_elements =
260 stream.BytesUntilLimit() / (1 + kNumFloatBytes);
261 float_list->resize(initial_size + num_elements);
262 int64 index = initial_size;
263 while (!stream.ExpectAtEnd()) {
264 if (!stream.ExpectTag(kFixed32Tag(1))) return false;
265 uint32 buffer32;
266 if (!stream.ReadLittleEndian32(&buffer32)) return false;
267 float_list->data()[index] = absl::bit_cast<float>(buffer32);
268 ++index;
269 }
270 }
271 }
272
273 stream.PopLimit(limit);
274 return true;
275 }
276
277 template <typename Result>
ParseInt64List(Result * int64_list)278 bool ParseInt64List(Result* int64_list) {
279 DCHECK(int64_list != nullptr);
280 protobuf::io::CodedInputStream stream(
281 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
282 EnableAliasing(&stream);
283 uint32 length;
284 if (!stream.ReadVarint32(&length)) return false;
285 auto limit = stream.PushLimit(length);
286
287 if (!stream.ExpectAtEnd()) {
288 uint8 peek_tag = PeekTag(&stream);
289 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
290 return false;
291 }
292 if (peek_tag == kDelimitedTag(1)) { // packed
293 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
294 uint32 packed_length;
295 if (!stream.ReadVarint32(&packed_length)) return false;
296 auto packed_limit = stream.PushLimit(packed_length);
297
298 while (!stream.ExpectAtEnd()) {
299 protobuf_uint64 n; // There is no API for int64
300 if (!stream.ReadVarint64(&n)) return false;
301 int64_list->push_back(static_cast<int64>(n));
302 }
303
304 stream.PopLimit(packed_limit);
305 } else { // non-packed
306 while (!stream.ExpectAtEnd()) {
307 if (!stream.ExpectTag(kVarintTag(1))) return false;
308 protobuf_uint64 n; // There is no API for int64
309 if (!stream.ReadVarint64(&n)) return false;
310 int64_list->push_back(static_cast<int64>(n));
311 }
312 }
313 }
314 stream.PopLimit(limit);
315 return true;
316 }
317
GetSerialized()318 StringPiece GetSerialized() const { return serialized_; }
319
320 private:
321 StringPiece serialized_;
322 };
323
324 using FeatureMapEntry = std::pair<StringPiece, Feature>;
325 using Example = std::vector<FeatureMapEntry>;
326
327 } // namespace parsed
328
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)329 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
330 uint32 data;
331 protobuf_uint64 dummy;
332 switch (stream->ReadTag() & 0x7) {
333 case 0: // varint
334 if (!stream->ReadVarint32(&data)) return false;
335 return true;
336 case 1: // fixed64
337 if (!stream->ReadLittleEndian64(&dummy)) return false;
338 return true;
339 case 2: // length delimited
340 if (!stream->ReadVarint32(&data)) return false;
341 stream->Skip(data);
342 return true;
343 case 3: // group begin
344 return false; // groups not supported.
345 case 4: // group end
346 return false; // groups not supported.
347 case 5: // fixed32
348 if (!stream->ReadLittleEndian32(&data)) return false;
349 return true;
350 }
351 return false; // unrecognized tag type
352 }
353
354 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result);
355
356 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
357 parsed::FeatureMapEntry* feature_map_entry);
358
359 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
360 parsed::Example* example);
361
362 bool ParseExample(protobuf::io::CodedInputStream* stream,
363 parsed::Example* example);
364
365 bool ParseExample(StringPiece serialized, parsed::Example* example);
366
367 using Config = FastParseExampleConfig;
368
369 // Enumeration for distinguishing feature types.
370 // Note: FastParseSequenceExample constructs a map that includes Type values,
371 // and relies on the fact that they are default-initialized to Dense.
372 enum class Type { Dense, Sparse, Ragged };
373
374 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
375 struct SparseBuffer {
376 // Features are in one of the 3 vectors below depending on config's dtype.
377 // Other 2 vectors remain empty.
378 SmallVector<tstring> bytes_list;
379 SmallVector<float> float_list;
380 SmallVector<int64> int64_list;
381
382 // Features of example i are elements with indices
383 // from example_end_indices[i-1] to example_end_indices[i]-1 on the
384 // appropriate xxxxx_list
385 std::vector<size_t> example_end_indices;
386 };
387
388 struct SeededHasher {
operatorSeededHasher389 uint64 operator()(StringPiece s) const {
390 return Hash64(s.data(), s.size(), seed);
391 }
392 uint64 seed{0xDECAFCAFFE};
393 };
394
395 // Use this in the "default" clause of switch statements when dispatching
396 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)397 inline void ReportUnexpectedDataType(DataType dtype) {
398 DCHECK(false)
399 << "Encountered unexpected DataType " << DataTypeString(dtype)
400 << "in variable that should have been checked by CheckConfigDataType().";
401 }
402
403 template <typename T>
404 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
405
406 template <>
407 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer);
408
409 template <>
410 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer);
411
412 template <>
413 const SmallVector<tstring>& GetListFromBuffer<tstring>(
414 const SparseBuffer& buffer);
415
416 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)417 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
418 std::copy(b, e, t);
419 }
420 template <>
421 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t);
422
423 void CountSparseFeatures(
424 const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
425 size_t* total_num_features, size_t* max_num_features);
426
427 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
428 Tensor* dst);
429
430 // A struct used by FastParseSequenceExample to hold the serialized proto
431 // substrings for a single feature, plus some auxiliary information derived
432 // from those protos (such as the total value length).
433 struct FeatureProtos {
434 // Proto substrings from each serialized SequenceExample that correspond
435 // with this feature. `protos_present` records whether the proto had a
436 // value defined (even if that value is empty).
437 std::vector<StringPiece> protos;
438 std::vector<bool> protos_present;
439
440 // Information derived from protos:
441 size_t length; // total length for ragged/sparse, max row length for dense.
442 size_t num_rows; // only populated for ragged sequence features.
443
444 // Information from the config:
445 Type type; // Whether this feature is sparse, ragged, or dense.
446 DataType dtype;
447 };
448
449 // Map from feature name to FeatureProtos for that feature.
450 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
451
452 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n);
453
454 // Return the number of bytes elements parsed, or -1 on error. If out is null,
455 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)456 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
457 tstring* out) {
458 int num_elements = 0;
459 uint32 length;
460 if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
461 return -1;
462 }
463 if (length > 0) {
464 auto limit = stream->PushLimit(length);
465 while (!stream->ExpectAtEnd()) {
466 uint32 bytes_length;
467 if (!stream->ExpectTag(kDelimitedTag(1)) ||
468 !stream->ReadVarint32(&bytes_length)) {
469 return -1;
470 }
471 if (out == nullptr) {
472 stream->Skip(bytes_length);
473 } else {
474 out->resize_uninitialized(bytes_length);
475 if (!stream->ReadRaw(out->data(), bytes_length)) {
476 return -1;
477 }
478 out++;
479 }
480 num_elements++;
481 }
482 stream->PopLimit(limit);
483 }
484 return num_elements;
485 }
486
PadFloatFeature(int num_to_pad,float * out)487 inline void PadFloatFeature(int num_to_pad, float* out) {
488 for (int i = 0; i < num_to_pad; i++) {
489 *out++ = 0.0;
490 }
491 }
492
PadInt64Feature(int num_to_pad,int64 * out)493 inline void PadInt64Feature(int num_to_pad, int64* out) {
494 for (int i = 0; i < num_to_pad; i++) {
495 *out++ = 0;
496 }
497 }
498
499 // Return the number of float elements parsed, or -1 on error. If out is null,
500 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)501 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
502 float* out) {
503 int num_elements = 0;
504 uint32 length;
505 if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
506 return -1;
507 }
508 if (length > 0) {
509 auto limit = stream->PushLimit(length);
510 uint8 peek_tag = PeekTag(stream);
511 if (peek_tag == kDelimitedTag(1)) { // packed
512 uint32 packed_length;
513 if (!stream->ExpectTag(kDelimitedTag(1)) ||
514 !stream->ReadVarint32(&packed_length)) {
515 return -1;
516 }
517 auto packed_limit = stream->PushLimit(packed_length);
518 while (!stream->ExpectAtEnd()) {
519 uint32 buffer32;
520 if (!stream->ReadLittleEndian32(&buffer32)) {
521 return -1;
522 }
523 if (out != nullptr) {
524 *out++ = absl::bit_cast<float>(buffer32);
525 }
526 num_elements++;
527 }
528 stream->PopLimit(packed_limit);
529 } else if (peek_tag == kFixed32Tag(1)) {
530 while (!stream->ExpectAtEnd()) {
531 uint32 buffer32;
532 if (!stream->ExpectTag(kFixed32Tag(1)) ||
533 !stream->ReadLittleEndian32(&buffer32)) {
534 return -1;
535 }
536 if (out != nullptr) {
537 *out++ = absl::bit_cast<float>(buffer32);
538 }
539 num_elements++;
540 }
541 } else {
542 // Unknown tag.
543 return -1;
544 }
545 stream->PopLimit(limit);
546 }
547 return num_elements;
548 }
549
550 // Return the number of int64 elements parsed, or -1 on error. If out is null,
551 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)552 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
553 int64* out) {
554 int num_elements = 0;
555 uint32 length;
556 if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
557 return -1;
558 }
559 if (length > 0) {
560 auto limit = stream->PushLimit(length);
561 uint8 peek_tag = PeekTag(stream);
562 if (peek_tag == kDelimitedTag(1)) { // packed
563 uint32 packed_length;
564 if (!stream->ExpectTag(kDelimitedTag(1)) ||
565 !stream->ReadVarint32(&packed_length)) {
566 return -1;
567 }
568 auto packed_limit = stream->PushLimit(packed_length);
569 while (!stream->ExpectAtEnd()) {
570 protobuf_uint64 n; // There is no API for int64
571 if (!stream->ReadVarint64(&n)) {
572 return -1;
573 }
574 if (out != nullptr) {
575 *out++ = n;
576 }
577 num_elements++;
578 }
579 stream->PopLimit(packed_limit);
580 } else if (peek_tag == kVarintTag(1)) {
581 while (!stream->ExpectAtEnd()) {
582 protobuf_uint64 n; // There is no API for int64
583 if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
584 return -1;
585 }
586 if (out != nullptr) {
587 *out++ = n;
588 }
589 num_elements++;
590 }
591 } else {
592 // Unknown tag.
593 return -1;
594 }
595 stream->PopLimit(limit);
596 }
597 return num_elements;
598 }
599
600 // Parses the next feature on `stream` into `out` starting at `out_offset`.
601 // Updates `out_offset`, and returns the number of values added.
602 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)603 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
604 Tensor* out, size_t* out_offset) {
605 int delta;
606 switch (dtype) {
607 case DT_STRING:
608 delta =
609 ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
610 break;
611 case DT_FLOAT:
612 delta =
613 ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
614 break;
615 case DT_INT64:
616 delta =
617 ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
618 break;
619 default:
620 ReportUnexpectedDataType(dtype);
621 delta = 0;
622 }
623 if (delta > 0) {
624 *out_offset += delta;
625 }
626 return delta;
627 }
628
629 // Returns the length of the next feature on `stream`.
630 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)631 inline int GetFeatureLength(DataType dtype,
632 protobuf::io::CodedInputStream* stream) {
633 switch (dtype) {
634 case DT_STRING:
635 return ParseBytesFeature(stream, nullptr);
636 case DT_FLOAT:
637 return ParseFloatFeature(stream, nullptr);
638 case DT_INT64:
639 return ParseInt64Feature(stream, nullptr);
640 default:
641 ReportUnexpectedDataType(dtype);
642 return -1;
643 }
644 }
645
ParseDataType(protobuf::io::CodedInputStream * stream)646 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
647 uint8 peek_tag = PeekTag(stream);
648 switch (peek_tag) {
649 case kDelimitedTag(1):
650 return DT_STRING;
651 case kDelimitedTag(2):
652 return DT_FLOAT;
653 case kDelimitedTag(3):
654 return DT_INT64;
655 default:
656 return DT_INVALID;
657 }
658 }
659
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)660 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
661 DataType dtype) {
662 switch (dtype) {
663 case DT_STRING:
664 if (!stream->ExpectTag(kDelimitedTag(1))) {
665 return false;
666 }
667 break;
668 case DT_FLOAT:
669 if (!stream->ExpectTag(kDelimitedTag(2))) {
670 return false;
671 }
672 break;
673 case DT_INT64:
674 if (!stream->ExpectTag(kDelimitedTag(3))) {
675 return false;
676 }
677 break;
678 default:
679 return false;
680 }
681 uint32 length;
682 return stream->ReadVarint32(&length) && length == 0;
683 }
684
685 } // namespace example
686 } // namespace tensorflow
687
688 #endif // TENSORFLOW_LITE_KERNELS_PARSE_EXAMPLE_EXAMPLE_PROTO_FAST_PARSING_H_
689