1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16
17 #include <vector>
18
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/util/presized_cuckoo_map.h"
37 #include "tensorflow/core/util/sparse/sparse_tensor.h"
38
39 namespace tensorflow {
40 namespace example {
41
42 namespace {
43
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46
47 template <typename T>
48 class LimitedArraySlice {
49 public:
50 using value_type = T;
51
LimitedArraySlice(T * begin,size_t num_elements)52 LimitedArraySlice(T* begin, size_t num_elements)
53 : current_(begin), begin_(begin), end_(begin + num_elements) {}
54
55 // May return negative if there were push_back calls after slice was filled.
EndDistance() const56 int64_t EndDistance() const { return end_ - current_; }
57
58 // Attempts to push value to the back of this. If the slice has
59 // already been filled, this method has no effect on the underlying data, but
60 // it changes the number returned by EndDistance into negative values.
push_back(T && value)61 void push_back(T&& value) {
62 if (EndDistance() > 0) *current_ = std::move(value);
63 ++current_;
64 }
65
66 // "Constructs" an element at the back of this by resizing the slice, and
67 // returns a mutable reference to the new last element.
68 // REQUIRES: EndDistance() > 0.
construct_at_end()69 T& construct_at_end() {
70 DCHECK_GT(EndDistance(), 0);
71 return *(current_++);
72 }
73
74 // Returns a mutable reference to the last element in the slice.
75 // REQUIRES: size() > 0.
back()76 T& back() { return *(current_ - 1); }
77
78 // Returns the number of elements in the slice.
size() const79 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80
81 // Attempts to resize the vector to the given size. It does so by advancing
82 // the pointer to the current element, possibly beyond the end of the slice.
83 // As a consequence, calling `size()` after `resize(x)` was called might
84 // return a value less than `x`.
resize(size_t size)85 void resize(size_t size) { current_ = begin_ + size; }
86
87 // Returns the pointer to the underlying data buffer.
data()88 T* data() { return begin_; }
89
90 private:
91 T* current_;
92 T* begin_;
93 T* end_;
94 };
95
96 template <typename A>
EnableAliasing(A * a)97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98 a->EnableAliasing(true);
99 }
100
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103
PeekTag(protobuf::io::CodedInputStream * stream)104 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
105 DCHECK(stream != nullptr);
106 const void* ptr;
107 int size;
108 if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
109 return *static_cast<const uint8*>(ptr);
110 }
111
kVarintTag(uint32 tag)112 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)113 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)114 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
115
116 namespace parsed {
117
118 // ParseDataType has to be called first, then appropriate ParseZzzzList.
119 class Feature {
120 public:
Feature()121 Feature() {}
Feature(StringPiece serialized)122 explicit Feature(StringPiece serialized) : serialized_(serialized) {}
123
ParseDataType(DataType * dtype)124 Status ParseDataType(DataType* dtype) {
125 DCHECK(dtype != nullptr);
126 if (serialized_.empty()) {
127 *dtype = DT_INVALID;
128 return OkStatus();
129 }
130 uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
131 serialized_.remove_prefix(1);
132 switch (oneof_tag) {
133 case kDelimitedTag(1):
134 *dtype = DT_STRING;
135 break;
136 case kDelimitedTag(2):
137 *dtype = DT_FLOAT;
138 break;
139 case kDelimitedTag(3):
140 *dtype = DT_INT64;
141 break;
142 default:
143 // Initialize variable to avoid compiler warning
144 *dtype = DT_INVALID;
145 return errors::InvalidArgument("Unsupported datatype.");
146 }
147 return OkStatus();
148 }
149
GetNumElementsInBytesList(int * num_elements)150 bool GetNumElementsInBytesList(int* num_elements) {
151 protobuf::io::CodedInputStream stream(
152 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
153 EnableAliasing(&stream);
154 uint32 length = 0;
155 if (!stream.ReadVarint32(&length)) return false;
156 auto limit = stream.PushLimit(length);
157 *num_elements = 0;
158 while (!stream.ExpectAtEnd()) {
159 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
160 uint32 bytes_length = 0;
161 if (!stream.ReadVarint32(&bytes_length)) return false;
162 if (!stream.Skip(bytes_length)) return false;
163 ++*num_elements;
164 }
165 stream.PopLimit(limit);
166 return true;
167 }
168
169 // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)170 tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
171 if (bytes_list->EndDistance() <= 0) {
172 return nullptr;
173 }
174 return &bytes_list->construct_at_end();
175 }
construct_at_end(SmallVector<tstring> * bytes_list)176 tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
177 return &bytes_list->emplace_back();
178 }
179
180 template <typename Result>
ParseBytesList(Result * bytes_list)181 bool ParseBytesList(Result* bytes_list) {
182 DCHECK(bytes_list != nullptr);
183
184 protobuf::io::CodedInputStream stream(
185 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
186
187 EnableAliasing(&stream);
188
189 uint32 length;
190 if (!stream.ReadVarint32(&length)) return false;
191 auto limit = stream.PushLimit(length);
192
193 while (!stream.ExpectAtEnd()) {
194 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
195 // parse string
196 uint32 bytes_length;
197 if (!stream.ReadVarint32(&bytes_length)) return false;
198 tstring* bytes = construct_at_end(bytes_list);
199 if (bytes == nullptr) return false;
200 bytes->resize_uninitialized(bytes_length);
201 if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
202 }
203 stream.PopLimit(limit);
204 return true;
205 }
206
207 template <typename Result>
ParseFloatList(Result * float_list)208 bool ParseFloatList(Result* float_list) {
209 DCHECK(float_list != nullptr);
210 protobuf::io::CodedInputStream stream(
211 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212 EnableAliasing(&stream);
213 uint32 length;
214 if (!stream.ReadVarint32(&length)) return false;
215 auto limit = stream.PushLimit(length);
216
217 if (!stream.ExpectAtEnd()) {
218 uint8 peek_tag = PeekTag(&stream);
219 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
220 return false;
221 }
222
223 constexpr int32_t kNumFloatBytes = 4;
224 if (peek_tag == kDelimitedTag(1)) { // packed
225 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
226 uint32 packed_length;
227 if (!stream.ReadVarint32(&packed_length)) return false;
228 auto packed_limit = stream.PushLimit(packed_length);
229
230 // Store the initial size to know the offset we have to start writing
231 // data from before resizing the output "vector".
232 const size_t initial_size = float_list->size();
233 float_list->resize(initial_size + packed_length / kNumFloatBytes);
234
235 // If the result data type is float and we are on a little endian
236 // machine then we can simply memcpy the data from the proto into the
237 // result vector.
238 if (port::kLittleEndian &&
239 sizeof(typename Result::value_type) == kNumFloatBytes) {
240 // Calculate the length of the buffer available what can be less than
241 // what we requested in resize in case of a LimitedArraySlice.
242 const uint32 bytes_to_copy =
243 std::min(static_cast<uint32>((float_list->size() - initial_size) *
244 kNumFloatBytes),
245 packed_length);
246 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
247 return false;
248 } else {
249 int64_t index = initial_size;
250 while (!stream.ExpectAtEnd()) {
251 uint32 buffer32;
252 if (!stream.ReadLittleEndian32(&buffer32)) return false;
253 if (index < float_list->size()) {
254 float_list->data()[index] = absl::bit_cast<float>(buffer32);
255 ++index;
256 }
257 }
258 }
259
260 stream.PopLimit(packed_limit);
261 } else { // non-packed
262 const size_t initial_size = float_list->size();
263 // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
264 // the value.
265 const int64_t num_elements =
266 stream.BytesUntilLimit() / (1 + kNumFloatBytes);
267 float_list->resize(initial_size + num_elements);
268 int64_t index = initial_size;
269 while (!stream.ExpectAtEnd()) {
270 if (!stream.ExpectTag(kFixed32Tag(1))) return false;
271 uint32 buffer32;
272 if (!stream.ReadLittleEndian32(&buffer32)) return false;
273 float_list->data()[index] = absl::bit_cast<float>(buffer32);
274 ++index;
275 }
276 }
277 }
278
279 stream.PopLimit(limit);
280 return true;
281 }
282
283 template <typename Result>
ParseInt64List(Result * int64_list)284 bool ParseInt64List(Result* int64_list) {
285 DCHECK(int64_list != nullptr);
286 protobuf::io::CodedInputStream stream(
287 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
288 EnableAliasing(&stream);
289 uint32 length;
290 if (!stream.ReadVarint32(&length)) return false;
291 auto limit = stream.PushLimit(length);
292
293 if (!stream.ExpectAtEnd()) {
294 uint8 peek_tag = PeekTag(&stream);
295 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
296 return false;
297 }
298 if (peek_tag == kDelimitedTag(1)) { // packed
299 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
300 uint32 packed_length;
301 if (!stream.ReadVarint32(&packed_length)) return false;
302 auto packed_limit = stream.PushLimit(packed_length);
303
304 while (!stream.ExpectAtEnd()) {
305 protobuf_uint64 n; // There is no API for int64
306 if (!stream.ReadVarint64(&n)) return false;
307 int64_list->push_back(static_cast<int64_t>(n));
308 }
309
310 stream.PopLimit(packed_limit);
311 } else { // non-packed
312 while (!stream.ExpectAtEnd()) {
313 if (!stream.ExpectTag(kVarintTag(1))) return false;
314 protobuf_uint64 n; // There is no API for int64
315 if (!stream.ReadVarint64(&n)) return false;
316 int64_list->push_back(static_cast<int64_t>(n));
317 }
318 }
319 }
320 stream.PopLimit(limit);
321 return true;
322 }
323
GetSerialized() const324 StringPiece GetSerialized() const { return serialized_; }
325
326 private:
327 // TODO(lew): Pair of uint8* would be more natural.
328 StringPiece serialized_;
329 };
330
331 using FeatureMapEntry = std::pair<StringPiece, Feature>;
332 using Example = std::vector<FeatureMapEntry>;
333
334 } // namespace parsed
335
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)336 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
337 uint32 data;
338 protobuf_uint64 dummy;
339 switch (stream->ReadTag() & 0x7) {
340 case 0: // varint
341 if (!stream->ReadVarint32(&data)) return false;
342 return true;
343 case 1: // fixed64
344 if (!stream->ReadLittleEndian64(&dummy)) return false;
345 return true;
346 case 2: // length delimited
347 if (!stream->ReadVarint32(&data)) return false;
348 stream->Skip(data);
349 return true;
350 case 3: // group begin
351 return false; // groups not supported.
352 case 4: // group end
353 return false; // groups not supported.
354 case 5: // fixed32
355 if (!stream->ReadLittleEndian32(&data)) return false;
356 return true;
357 }
358 return false; // unrecognized tag type
359 }
360
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)361 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
362 DCHECK(stream != nullptr);
363 DCHECK(result != nullptr);
364 uint32 length;
365 if (!stream->ReadVarint32(&length)) return false;
366 if (length == 0) {
367 *result = StringPiece(nullptr, 0);
368 return true;
369 }
370 const void* stream_alias;
371 int stream_size;
372 if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
373 return false;
374 }
375 if (static_cast<uint32>(stream_size) < length) return false;
376 *result = StringPiece(static_cast<const char*>(stream_alias), length);
377 stream->Skip(length);
378 return true;
379 }
380
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)381 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
382 parsed::FeatureMapEntry* feature_map_entry) {
383 DCHECK(stream != nullptr);
384 DCHECK(feature_map_entry != nullptr);
385 uint32 length;
386 if (!stream->ReadVarint32(&length)) return false;
387 auto limit = stream->PushLimit(length);
388
389 // Protobufs allow an arbitrary order for the key and value fields.
390 for (int n = 0; n < 2; ++n) {
391 const uint32_t tag = stream->ReadTag();
392 switch (tag) {
393 case kDelimitedTag(1):
394 if (!ParseString(stream, &feature_map_entry->first)) return false;
395 break;
396
397 case kDelimitedTag(2): {
398 StringPiece feature_string_piece;
399 if (!ParseString(stream, &feature_string_piece)) return false;
400 feature_map_entry->second = parsed::Feature(feature_string_piece);
401 break;
402 }
403
404 default:
405 return false;
406 }
407 }
408
409 if (!stream->ExpectAtEnd()) return false;
410 stream->PopLimit(limit);
411 return true;
412 }
413
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)414 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
415 parsed::Example* example) {
416 DCHECK(stream != nullptr);
417 DCHECK(example != nullptr);
418 uint32 length;
419 if (!stream->ReadVarint32(&length)) return false;
420 auto limit = stream->PushLimit(length);
421 while (!stream->ExpectAtEnd()) {
422 parsed::FeatureMapEntry feature_map_entry;
423 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
424 if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
425 example->push_back(std::move(feature_map_entry));
426 }
427 stream->PopLimit(limit);
428 return true;
429 }
430
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)431 bool ParseExample(protobuf::io::CodedInputStream* stream,
432 parsed::Example* example) {
433 DCHECK(stream != nullptr);
434 DCHECK(example != nullptr);
435 // Loop over the input stream which may contain multiple serialized Example
436 // protos merged together as strings. This behavior is consistent with Proto's
437 // ParseFromString when string representations are concatenated.
438 while (!stream->ExpectAtEnd()) {
439 if (!stream->ExpectTag(kDelimitedTag(1))) {
440 if (!SkipExtraneousTag(stream)) return false;
441 } else {
442 if (!ParseFeatures(stream, example)) return false;
443 }
444 }
445 return true;
446 }
447
ParseExample(StringPiece serialized,parsed::Example * example)448 bool ParseExample(StringPiece serialized, parsed::Example* example) {
449 DCHECK(example != nullptr);
450 protobuf::io::CodedInputStream stream(
451 reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
452 EnableAliasing(&stream);
453 return ParseExample(&stream, example);
454 }
455
456 } // namespace
457
TestFastParse(const string & serialized,Example * example)458 bool TestFastParse(const string& serialized, Example* example) {
459 DCHECK(example != nullptr);
460 parsed::Example parsed_example;
461 if (!ParseExample(serialized, &parsed_example)) return false;
462 auto& features = *example->mutable_features();
463 size_t parsed_example_size = parsed_example.size();
464 for (size_t i = 0; i < parsed_example_size; ++i) {
465 // This is a logic that standard protobuf parsing is implementing.
466 // I.e. last entry in the map overwrites all the previous ones.
467 parsed::FeatureMapEntry& name_and_feature =
468 parsed_example[parsed_example_size - i - 1];
469 string name(name_and_feature.first);
470 if ((*features.mutable_feature()).count(name) > 0) continue;
471
472 auto& value = (*features.mutable_feature())[name];
473 DataType dtype;
474 if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
475 switch (dtype) {
476 case DT_INVALID:
477 break;
478 case DT_STRING: {
479 SmallVector<tstring> list;
480 if (!name_and_feature.second.ParseBytesList(&list)) return false;
481 auto* result_list = value.mutable_bytes_list();
482 for (auto& bytes : list) {
483 result_list->add_value(bytes.data(), bytes.size());
484 }
485 break;
486 }
487 case DT_FLOAT: {
488 SmallVector<float> list;
489 if (!name_and_feature.second.ParseFloatList(&list)) return false;
490 auto* result_list = value.mutable_float_list();
491 for (float f : list) {
492 result_list->add_value(f);
493 }
494 break;
495 }
496 case DT_INT64: {
497 SmallVector<int64_t> list;
498 if (!name_and_feature.second.ParseInt64List(&list)) return false;
499 auto* result_list = value.mutable_int64_list();
500 for (int64_t i : list) {
501 result_list->add_value(i);
502 }
503 break;
504 }
505 default:
506 LOG(FATAL) << "Should not happen.";
507 }
508 }
509 return true;
510 }
511
512 // -----------------------------------------------------------------------------
513
514 namespace {
515
516 using Config = FastParseExampleConfig;
517
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)518 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
519 thread::ThreadPool* thread_pool) {
520 if (n == 0) return;
521 if (thread_pool == nullptr) {
522 for (size_t i = 0; i < n; ++i) {
523 f(i);
524 }
525 } else {
526 BlockingCounter counter(n - 1);
527 for (size_t i = 1; i < n; ++i) {
528 thread_pool->Schedule([i, &f, &counter] {
529 f(i);
530 counter.DecrementCount();
531 });
532 }
533 f(0);
534 counter.Wait();
535 }
536 }
537
538 // Enumeration for distinguishing feature types.
539 // Note: FastParseSequenceExample constructs a map that includes Type values,
540 // and relies on the fact that they are default-initialized to Dense.
541 enum class Type { Dense, Sparse, Ragged };
542
543 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
544 struct SparseBuffer {
545 // Features are in one of the 3 vectors below depending on config's dtype.
546 // Other 2 vectors remain empty.
547 SmallVector<tstring> bytes_list;
548 SmallVector<float> float_list;
549 SmallVector<int64_t> int64_list;
550
551 // Features of example i are elements with indices
552 // from example_end_indices[i-1] to example_end_indices[i]-1 on the
553 // appropriate xxxxx_list
554 std::vector<size_t> example_end_indices;
555 };
556
557 struct SeededHasher {
operator ()tensorflow::example::__anon158740920211::SeededHasher558 uint64 operator()(StringPiece s) const {
559 return Hash64(s.data(), s.size(), seed);
560 }
561 uint64 seed{0xDECAFCAFFE};
562 };
563
LogDenseFeatureDataLoss(StringPiece feature_name)564 void LogDenseFeatureDataLoss(StringPiece feature_name) {
565 LOG(WARNING) << "Data loss! Feature '" << feature_name
566 << "' is present in multiple concatenated "
567 "tf.Examples. Ignoring all but last one.";
568 static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
569 "/tensorflow/core/util/example_proto_fast_parsing/"
570 "duplicated_dense_feature",
571 "Dense feature appears twice in a tf.Example");
572 duplicated_dense_feature->GetCell()->IncrementBy(1);
573 }
574
LogSparseFeatureDataLoss(StringPiece feature_name)575 void LogSparseFeatureDataLoss(StringPiece feature_name) {
576 LOG(WARNING) << "Data loss! Feature '" << feature_name
577 << "' is present in multiple concatenated "
578 "tf.Examples. Ignoring all but last one.";
579 static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
580 "/tensorflow/core/util/example_proto_fast_parsing/"
581 "duplicated_sparse_feature",
582 "Sparse feature appears twice in a tf.Example");
583 duplicated_sparse_feature->GetCell()->IncrementBy(1);
584 }
585
FastParseSerializedExample(const tstring & serialized_example,const tstring & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::vector<SparseBuffer> * output_ragged,PerExampleFeatureStats * output_stats)586 Status FastParseSerializedExample(
587 const tstring& serialized_example, const tstring& example_name,
588 const size_t example_index, const Config& config,
589 const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
590 SeededHasher hasher, std::vector<Tensor>* output_dense,
591 std::vector<SparseBuffer>* output_varlen_dense,
592 std::vector<SparseBuffer>* output_sparse,
593 std::vector<SparseBuffer>* output_ragged,
594 PerExampleFeatureStats* output_stats) {
595 DCHECK(output_dense != nullptr);
596 DCHECK(output_sparse != nullptr);
597 DCHECK(output_ragged != nullptr);
598 parsed::Example parsed_example;
599 if (!ParseExample(serialized_example, &parsed_example)) {
600 return errors::InvalidArgument("Could not parse example input, value: '",
601 serialized_example, "'");
602 }
603 std::vector<int64_t> sparse_feature_last_example(config.sparse.size(), -1);
604 std::vector<int64_t> dense_feature_last_example(config.dense.size(), -1);
605 std::vector<int64_t> ragged_feature_last_example(config.ragged.size(), -1);
606
607 // Handle features present in the example.
608 const size_t parsed_example_size = parsed_example.size();
609
610 if (output_stats) {
611 // TODO(b/111553342): This may over-count the number of features if there
612 // are duplicate keys in the feature map. Consider deduplicating the keys
613 // before computing the count.
614 output_stats->features_count = parsed_example_size;
615 }
616
617 for (size_t i = 0; i < parsed_example_size; ++i) {
618 // This is a logic that standard protobuf parsing is implementing.
619 // I.e. last entry in the map overwrites all the previous ones.
620 parsed::FeatureMapEntry& name_and_feature =
621 parsed_example[parsed_example_size - i - 1];
622
623 const StringPiece feature_name = name_and_feature.first;
624 parsed::Feature& feature = name_and_feature.second;
625
626 std::pair<size_t, Type> d_and_type;
627 uint64 h = hasher(feature_name);
628 if (!config_index.Find(h, &d_and_type)) continue;
629
630 size_t d = d_and_type.first;
631 bool is_dense = d_and_type.second == Type::Dense;
632 bool is_ragged = d_and_type.second == Type::Ragged;
633
634 {
635 // Testing for PresizedCuckooMap collision.
636 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
637 const tstring& config_feature_name =
638 is_dense ? config.dense[d].feature_name
639 : (is_ragged ? config.ragged[d].feature_name
640 : config.sparse[d].feature_name);
641 if (feature_name != config_feature_name) continue;
642 }
643
644 auto example_error = [&](StringPiece suffix) {
645 return errors::InvalidArgument("Name: ", example_name,
646 ", Key: ", feature_name,
647 ", Index: ", example_index, ". ", suffix);
648 };
649
650 auto parse_error = [&] {
651 return example_error("Can't parse serialized Example.");
652 };
653
654 DataType example_dtype;
655 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
656
657 if (is_dense) {
658 if (example_dtype == DT_INVALID) continue;
659
660 // If feature was already visited, skip.
661 // Compare comment at the beginning of the loop.
662 if (dense_feature_last_example[d] == example_index) {
663 LogDenseFeatureDataLoss(feature_name);
664 continue;
665 }
666 dense_feature_last_example[d] = example_index;
667
668 if (example_dtype != config.dense[d].dtype) {
669 return example_error(strings::StrCat(
670 "Data types don't match. Data type: ",
671 DataTypeString(example_dtype),
672 " but expected type: ", DataTypeString(config.dense[d].dtype)));
673 }
674 if (!config.dense[d].variable_length) {
675 Tensor& out = (*output_dense)[d];
676
677 const std::size_t num_elements = config.dense[d].elements_per_stride;
678 if (output_stats) {
679 // TODO(b/111553342): If desirable, we could add support for counting
680 // elements in the features that aren't parsed, but this could add
681 // considerable runtime cost.
682 output_stats->feature_values_count += num_elements;
683 }
684
685 const std::size_t offset = example_index * num_elements;
686
687 auto shape_error = [&](size_t size, StringPiece type_str) {
688 return example_error(strings::StrCat(
689 "Number of ", type_str,
690 " values != expected. "
691 "Values size: ",
692 size,
693 " but output shape: ", config.dense[d].shape.DebugString()));
694 };
695
696 switch (config.dense[d].dtype) {
697 case DT_INT64: {
698 auto out_p = out.flat<int64_t>().data() + offset;
699 LimitedArraySlice<int64_t> slice(out_p, num_elements);
700 if (!feature.ParseInt64List(&slice)) return parse_error();
701 if (slice.EndDistance() != 0) {
702 return shape_error(num_elements - slice.EndDistance(), "int64");
703 }
704 break;
705 }
706 case DT_FLOAT: {
707 auto out_p = out.flat<float>().data() + offset;
708 LimitedArraySlice<float> slice(out_p, num_elements);
709 if (!feature.ParseFloatList(&slice)) return parse_error();
710 if (slice.EndDistance() != 0) {
711 return shape_error(num_elements - slice.EndDistance(), "float");
712 }
713 break;
714 }
715 case DT_STRING: {
716 auto out_p = out.flat<tstring>().data() + offset;
717 LimitedArraySlice<tstring> slice(out_p, num_elements);
718 if (!feature.ParseBytesList(&slice)) return parse_error();
719 if (slice.EndDistance() != 0) {
720 return shape_error(num_elements - slice.EndDistance(), "bytes");
721 }
722 break;
723 }
724 default:
725 LOG(FATAL) << "Should not happen.";
726 }
727 } else { // if variable length
728 SparseBuffer& out = (*output_varlen_dense)[d];
729
730 const std::size_t num_elements = config.dense[d].elements_per_stride;
731
732 if (example_dtype != DT_INVALID &&
733 example_dtype != config.dense[d].dtype) {
734 return example_error(strings::StrCat(
735 "Data types don't match. ",
736 "Expected type: ", DataTypeString(config.dense[d].dtype)));
737 }
738
739 auto shape_error = [&](size_t size, StringPiece type_str) {
740 return example_error(strings::StrCat(
741 "Number of ", type_str,
742 " values is not a multiple of stride length. Saw ", size,
743 " values but output shape is: ",
744 config.dense[d].shape.DebugString()));
745 };
746
747 switch (config.dense[d].dtype) {
748 case DT_INT64: {
749 if (example_dtype != DT_INVALID) {
750 if (!feature.ParseInt64List(&out.int64_list)) {
751 return parse_error();
752 }
753 if (out.int64_list.size() % num_elements != 0) {
754 return shape_error(out.int64_list.size(), "int64");
755 }
756 }
757 out.example_end_indices.push_back(out.int64_list.size());
758 break;
759 }
760 case DT_FLOAT: {
761 if (example_dtype != DT_INVALID) {
762 if (!feature.ParseFloatList(&out.float_list)) {
763 return parse_error();
764 }
765 if (out.float_list.size() % num_elements != 0) {
766 return shape_error(out.float_list.size(), "float");
767 }
768 }
769 out.example_end_indices.push_back(out.float_list.size());
770 break;
771 }
772 case DT_STRING: {
773 if (example_dtype != DT_INVALID) {
774 if (!feature.ParseBytesList(&out.bytes_list)) {
775 return parse_error();
776 }
777 if (out.bytes_list.size() % num_elements != 0) {
778 return shape_error(out.bytes_list.size(), "bytes");
779 }
780 }
781 out.example_end_indices.push_back(out.bytes_list.size());
782 break;
783 }
784 default:
785 LOG(FATAL) << "Should not happen.";
786 }
787
788 if (output_stats) {
789 // Use `out.example_end_indices` to determine the feature-value count
790 // for this feature, because the preceding switch statement pushes
791 // the length of the appropriate feature list to that vector.
792 // TODO(b/111553342): If desirable, we could add support for counting
793 // elements in the features that aren't parsed, but this could add
794 // considerable runtime cost.
795 const size_t out_examples_count = out.example_end_indices.size();
796 if (out_examples_count == 1) {
797 output_stats->feature_values_count += out.example_end_indices[0];
798 } else {
799 output_stats->feature_values_count +=
800 out.example_end_indices[out_examples_count - 1] -
801 out.example_end_indices[out_examples_count - 2];
802 }
803 }
804 }
805 } else {
806 // Feature is sparse or ragged.
807 auto& last_example =
808 is_ragged ? ragged_feature_last_example : sparse_feature_last_example;
809
810 // If feature was already visited, skip.
811 // Compare comment at the beginning of the loop.
812 if (last_example[d] == example_index) {
813 LogSparseFeatureDataLoss(feature_name);
814 continue;
815 }
816 last_example[d] = example_index;
817
818 // Handle sparse features.
819 SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d];
820 DataType feature_dtype =
821 is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype;
822 if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
823 return example_error(
824 strings::StrCat("Data types don't match. ",
825 "Expected type: ", DataTypeString(feature_dtype),
826 ", Actual type: ", DataTypeString(example_dtype)));
827 }
828
829 switch (feature_dtype) {
830 case DT_INT64: {
831 if (example_dtype != DT_INVALID) {
832 if (!feature.ParseInt64List(&out.int64_list)) {
833 return parse_error();
834 }
835 }
836 out.example_end_indices.push_back(out.int64_list.size());
837 break;
838 }
839 case DT_FLOAT: {
840 if (example_dtype != DT_INVALID) {
841 if (!feature.ParseFloatList(&out.float_list)) {
842 return parse_error();
843 }
844 }
845 out.example_end_indices.push_back(out.float_list.size());
846 break;
847 }
848 case DT_STRING: {
849 if (example_dtype != DT_INVALID) {
850 if (!feature.ParseBytesList(&out.bytes_list)) {
851 return parse_error();
852 }
853 }
854 out.example_end_indices.push_back(out.bytes_list.size());
855 break;
856 }
857 default:
858 LOG(FATAL) << "Should not happen.";
859 }
860
861 if (output_stats) {
862 // Use `out.example_end_indices` to determine the feature-value count
863 // for this feature, because the preceding switch statement pushes
864 // the length of the appropriate feature list to that vector.
865 // TODO(b/111553342): If desirable, we could add support for counting
866 // elements in the features that aren't parsed, but this could add
867 // considerable runtime cost.
868 const size_t out_examples_count = out.example_end_indices.size();
869 if (out_examples_count == 1) {
870 output_stats->feature_values_count += out.example_end_indices[0];
871 } else {
872 output_stats->feature_values_count +=
873 out.example_end_indices[out_examples_count - 1] -
874 out.example_end_indices[out_examples_count - 2];
875 }
876 }
877 }
878 }
879
880 // Handle missing dense features for fixed strides.
881 for (size_t d = 0; d < config.dense.size(); ++d) {
882 if (config.dense[d].variable_length) continue;
883 if (dense_feature_last_example[d] == example_index) continue;
884 if (config.dense[d].default_value.NumElements() == 0) {
885 return errors::InvalidArgument(
886 "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
887 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
888 " is required but could not be found.");
889 }
890 const Tensor& in = config.dense[d].default_value;
891 Tensor& out = (*output_dense)[d];
892 const std::size_t num_elements = in.shape().num_elements();
893 const std::size_t offset = example_index * num_elements;
894
895 switch (config.dense[d].dtype) {
896 case DT_INT64: {
897 std::copy_n(in.flat<int64_t>().data(), num_elements,
898 out.flat<int64_t>().data() + offset);
899 break;
900 }
901 case DT_FLOAT: {
902 std::copy_n(in.flat<float>().data(), num_elements,
903 out.flat<float>().data() + offset);
904 break;
905 }
906 case DT_STRING: {
907 std::copy_n(in.flat<tstring>().data(), num_elements,
908 out.flat<tstring>().data() + offset);
909 break;
910 }
911 default:
912 LOG(FATAL) << "Should not happen.";
913 }
914 }
915
916 // Handle missing varlen dense features.
917 for (size_t d = 0; d < config.dense.size(); ++d) {
918 if (!config.dense[d].variable_length) continue;
919 if (dense_feature_last_example[d] == example_index) continue;
920 SparseBuffer& out = (*output_varlen_dense)[d];
921 size_t prev_example_end_index =
922 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
923 out.example_end_indices.push_back(prev_example_end_index);
924 }
925
926 // Handle missing sparse features.
927 for (size_t d = 0; d < config.sparse.size(); ++d) {
928 if (sparse_feature_last_example[d] == example_index) continue;
929 SparseBuffer& out = (*output_sparse)[d];
930 size_t prev_example_end_index =
931 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
932 out.example_end_indices.push_back(prev_example_end_index);
933 }
934
935 // Handle missing ragged features.
936 for (size_t d = 0; d < config.ragged.size(); ++d) {
937 if (ragged_feature_last_example[d] == example_index) continue;
938 SparseBuffer& out = (*output_ragged)[d];
939 size_t prev_example_end_index =
940 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
941 out.example_end_indices.push_back(prev_example_end_index);
942 }
943
944 return OkStatus();
945 }
946
CheckConfigDataType(DataType dtype)947 Status CheckConfigDataType(DataType dtype) {
948 switch (dtype) {
949 case DT_INT64:
950 case DT_FLOAT:
951 case DT_STRING:
952 return OkStatus();
953 default:
954 return errors::InvalidArgument("Invalid config dtype: ",
955 DataTypeString(dtype));
956 }
957 }
958
959 // Use this in the "default" clause of switch statements when dispatching
960 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)961 inline void ReportUnexpectedDataType(DataType dtype) {
962 DCHECK(false)
963 << "Encountered unexpected DataType " << DataTypeString(dtype)
964 << "in variable that should have been checked by CheckConfigDataType().";
965 }
966
CheckConfigDataTypes(const Config & config)967 Status CheckConfigDataTypes(const Config& config) {
968 // Check config so we can safely CHECK(false) in switches on config.*.dtype
969 for (auto& c : config.sparse) {
970 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
971 }
972 for (auto& c : config.dense) {
973 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
974 }
975 for (auto& c : config.ragged) {
976 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
977 if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) {
978 return errors::InvalidArgument("Invalid ragged_split_type: ",
979 DataTypeString(c.splits_dtype));
980 }
981 }
982 return OkStatus();
983 }
984
985 template <typename T>
986 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
987
988 template <>
GetListFromBuffer(const SparseBuffer & buffer)989 const SmallVector<int64_t>& GetListFromBuffer<int64_t>(
990 const SparseBuffer& buffer) {
991 return buffer.int64_list;
992 }
993 template <>
GetListFromBuffer(const SparseBuffer & buffer)994 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
995 return buffer.float_list;
996 }
997 template <>
GetListFromBuffer(const SparseBuffer & buffer)998 const SmallVector<tstring>& GetListFromBuffer<tstring>(
999 const SparseBuffer& buffer) {
1000 return buffer.bytes_list;
1001 }
1002
1003 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)1004 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
1005 std::copy(b, e, t);
1006 }
1007 template <>
CopyOrMoveBlock(const tstring * b,const tstring * e,tstring * t)1008 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
1009 std::move(b, e, t);
1010 }
1011
1012 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)1013 void FillAndCopyVarLen(
1014 const int d, const size_t num_elements,
1015 const size_t num_elements_per_minibatch, const Config& config,
1016 const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
1017 Tensor* values) {
1018 const Tensor& default_value = config.dense[d].default_value;
1019
1020 // Copy-fill the tensors (creating the zero/fill-padding)
1021 std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
1022 default_value.flat<T>()(0));
1023
1024 // Data is [batch_size, max_num_elements, data_stride_size]
1025 // and num_elements_per_minibatch = max_num_elements * data_stride_size
1026 auto data = values->flat<T>().data();
1027
1028 // Iterate over minibatch elements
1029 for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
1030 const SparseBuffer& buffer = varlen_dense_buffers[i][d];
1031 // Number of examples being stored in this buffer
1032 const auto& end_indices = buffer.example_end_indices;
1033 const size_t examples_in_buffer = end_indices.size();
1034 // const size_t stride_size = config.dense[d].elements_per_stride;
1035
1036 const auto& list = GetListFromBuffer<T>(buffer);
1037 auto list_ptr = list.begin();
1038
1039 size_t elements_tally = 0;
1040 // Iterate through all the examples stored in this buffer.
1041 for (size_t j = 0; j < examples_in_buffer; ++j) {
1042 // Number of elements stored for this example.
1043 const size_t num_elems = end_indices[j] - elements_tally;
1044 CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
1045 // Move forward this many elements in the varlen buffer.
1046 list_ptr += num_elems;
1047 // Move forward to the next minibatch entry in the values output.
1048 data += num_elements_per_minibatch;
1049 elements_tally = end_indices[j];
1050 }
1051 DCHECK(elements_tally == list.size());
1052 }
1053 }
1054
1055 // Thin vector like interface wrapper around a Tensor. This enable us to
1056 // directly populate a tensor during parsing instead of having to first create a
1057 // vactor and then copy the data over.
1058 template <typename T>
1059 class TensorVector {
1060 public:
1061 using value_type = T;
1062
tensor()1063 const Tensor& tensor() {
1064 if (!tensor_.has_value()) {
1065 resize(0);
1066 }
1067 return *tensor_;
1068 }
1069
size() const1070 int64_t size() const {
1071 return tensor_.has_value() ? tensor_->NumElements() : 0;
1072 }
resize(int64_t new_size)1073 void resize(int64_t new_size) {
1074 DCHECK(!tensor_.has_value());
1075 tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size}));
1076 data_ = tensor_->flat<T>().data();
1077 }
data()1078 T* data() { return data_; }
data() const1079 const T* data() const { return data_; }
1080
1081 private:
1082 // Use absl::optional to avoid calling the default constructor of Tensor
1083 // unnecessarily.
1084 absl::optional<Tensor> tensor_;
1085
1086 // Cached pointer to the raw data inside the tensor.
1087 T* data_ = nullptr;
1088 };
1089
CountSparseFeatures(const std::vector<std::vector<SparseBuffer>> & sparse_buffers,size_t d,size_t * total_num_features,size_t * max_num_features)1090 void CountSparseFeatures(
1091 const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
1092 size_t* total_num_features, size_t* max_num_features) {
1093 for (auto& sparse_values_tmp : sparse_buffers) {
1094 const std::vector<size_t>& end_indices =
1095 sparse_values_tmp[d].example_end_indices;
1096 *total_num_features += end_indices.back();
1097 *max_num_features = std::max(*max_num_features, end_indices[0]);
1098 for (size_t i = 1; i < end_indices.size(); ++i) {
1099 size_t example_size = end_indices[i] - end_indices[i - 1];
1100 *max_num_features = std::max(*max_num_features, example_size);
1101 }
1102 }
1103 }
1104
CopySparseBufferToTensor(DataType dtype,size_t offset,SparseBuffer * src,Tensor * dst)1105 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
1106 Tensor* dst) {
1107 switch (dtype) {
1108 case DT_INT64: {
1109 std::copy(src->int64_list.begin(), src->int64_list.end(),
1110 dst->flat<int64_t>().data() + offset);
1111 break;
1112 }
1113 case DT_FLOAT: {
1114 std::copy(src->float_list.begin(), src->float_list.end(),
1115 dst->flat<float>().data() + offset);
1116 break;
1117 }
1118 case DT_STRING: {
1119 std::move(src->bytes_list.begin(), src->bytes_list.end(),
1120 dst->flat<tstring>().data() + offset);
1121 break;
1122 }
1123 default:
1124 ReportUnexpectedDataType(dtype);
1125 }
1126 }
1127
1128 } // namespace
1129
FastParseExample(const Config & config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * result)1130 Status FastParseExample(const Config& config,
1131 gtl::ArraySlice<tstring> serialized,
1132 gtl::ArraySlice<tstring> example_names,
1133 thread::ThreadPool* thread_pool, Result* result) {
1134 DCHECK(result != nullptr);
1135 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1136 TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1137
1138 if (config.collect_feature_stats) {
1139 result->feature_stats.resize(serialized.size());
1140 }
1141
1142 size_t config_size =
1143 config.dense.size() + config.sparse.size() + config.ragged.size();
1144 SeededHasher hasher;
1145 // Build config index.
1146 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1147 bool ok = true;
1148 for (size_t i = 0; i < 1000; ++i) {
1149 for (size_t d = 0; d < config.dense.size(); ++d) {
1150 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1151 {d, Type::Dense});
1152 }
1153 for (size_t d = 0; d < config.sparse.size(); ++d) {
1154 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1155 {d, Type::Sparse});
1156 }
1157 for (size_t d = 0; d < config.ragged.size(); ++d) {
1158 ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1159 {d, Type::Ragged});
1160 }
1161 if (ok) break;
1162 LOG(WARNING) << "Collision found. This should happen only if you have "
1163 "around 2^32 entries in your config.";
1164 hasher.seed++;
1165 config_index.Clear(config_size);
1166 ok = true;
1167 }
1168 if (!ok) {
1169 return errors::Internal(
1170 "Could not avoid collision. This should not happen.");
1171 }
1172
1173 // Allocate dense output for fixed length dense values
1174 // (variable-length dense and sparse and ragged have to be buffered).
1175 std::vector<Tensor> fixed_dense_values(config.dense.size());
1176 for (size_t d = 0; d < config.dense.size(); ++d) {
1177 if (config.dense[d].variable_length) continue;
1178 TensorShape out_shape;
1179 out_shape.AddDim(serialized.size());
1180 for (const int64_t dim : config.dense[d].shape.dim_sizes()) {
1181 out_shape.AddDim(dim);
1182 }
1183 fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1184 }
1185
1186 // This parameter affects performance in a big and data-dependent way.
1187 const size_t kMiniBatchSizeBytes = 50000;
1188
1189 // Calculate number of minibatches.
1190 // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1191 // Apply 'special logic' below for small and big regimes.
1192 const size_t num_minibatches = [&] {
1193 size_t result = 0;
1194 size_t minibatch_bytes = 0;
1195 for (size_t i = 0; i < serialized.size(); i++) {
1196 if (minibatch_bytes == 0) { // start minibatch
1197 result++;
1198 }
1199 minibatch_bytes += serialized[i].size() + 1;
1200 if (minibatch_bytes > kMiniBatchSizeBytes) {
1201 minibatch_bytes = 0;
1202 }
1203 }
1204 // 'special logic'
1205 const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1206 const size_t max_minibatches = 64;
1207 return std::max<size_t>(min_minibatches,
1208 std::min<size_t>(max_minibatches, result));
1209 }();
1210
1211 auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1212 return (serialized.size() * minibatch) / num_minibatches;
1213 };
1214
1215 // TODO(lew): A big performance low-hanging fruit here is to improve
1216 // num_minibatches calculation to take into account actual amount of work
1217 // needed, as the size in bytes is not perfect. Linear combination of
1218 // size in bytes and average number of features per example is promising.
1219 // Even better: measure time instead of estimating, but this is too costly
1220 // in small batches.
1221 // Maybe accept outside parameter #num_minibatches?
1222
1223 // Do minibatches in parallel.
1224 std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1225 std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1226 std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches);
1227 std::vector<Status> status_of_minibatch(num_minibatches);
1228 auto ProcessMiniBatch = [&](size_t minibatch) {
1229 sparse_buffers[minibatch].resize(config.sparse.size());
1230 varlen_dense_buffers[minibatch].resize(config.dense.size());
1231 ragged_buffers[minibatch].resize(config.ragged.size());
1232 size_t start = first_example_of_minibatch(minibatch);
1233 size_t end = first_example_of_minibatch(minibatch + 1);
1234 for (size_t e = start; e < end; ++e) {
1235 PerExampleFeatureStats* stats = nullptr;
1236 if (config.collect_feature_stats) {
1237 stats = &result->feature_stats[e];
1238 }
1239 status_of_minibatch[minibatch] = FastParseSerializedExample(
1240 serialized[e],
1241 (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1242 config_index, hasher, &fixed_dense_values,
1243 &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch],
1244 &ragged_buffers[minibatch], stats);
1245 if (!status_of_minibatch[minibatch].ok()) break;
1246 }
1247 };
1248
1249 ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1250
1251 for (Status& status : status_of_minibatch) {
1252 TF_RETURN_IF_ERROR(status);
1253 }
1254
1255 result->sparse_indices.reserve(config.sparse.size());
1256 result->sparse_values.reserve(config.sparse.size());
1257 result->sparse_shapes.reserve(config.sparse.size());
1258 result->dense_values.reserve(config.dense.size());
1259 result->ragged_values.reserve(config.ragged.size());
1260 result->ragged_splits.reserve(config.ragged.size());
1261
1262 for (size_t d = 0; d < config.dense.size(); ++d) {
1263 result->dense_values.push_back(std::move(fixed_dense_values[d]));
1264 }
1265
1266 // Merge SparseBuffers from all minibatches for every config.sparse.
1267 auto MergeSparseMinibatches = [&](size_t d) {
1268 // Loop over minibatches
1269 size_t total_num_features = 0;
1270 size_t max_num_features = 0;
1271 CountSparseFeatures(sparse_buffers, d, &total_num_features,
1272 &max_num_features);
1273
1274 TensorShape indices_shape;
1275 indices_shape.AddDim(total_num_features);
1276 indices_shape.AddDim(2);
1277 result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1278 Tensor* indices = &result->sparse_indices.back();
1279
1280 TensorShape values_shape;
1281 values_shape.AddDim(total_num_features);
1282 result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1283 Tensor* values = &result->sparse_values.back();
1284
1285 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1286 auto shapes_shape_t = result->sparse_shapes.back().vec<int64_t>();
1287 shapes_shape_t(0) = serialized.size();
1288 shapes_shape_t(1) = max_num_features;
1289
1290 size_t offset = 0;
1291 for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1292 SparseBuffer& buffer = sparse_buffers[i][d];
1293
1294 // Update indices.
1295 size_t delta = 0;
1296
1297 if (indices->NumElements() > 0) {
1298 int64* ix_p = &indices->matrix<int64_t>()(offset, 0);
1299 size_t example_index = first_example_of_minibatch(i);
1300 for (size_t example_end_index : buffer.example_end_indices) {
1301 size_t feature_index = 0;
1302 for (; delta < example_end_index; ++delta) {
1303 // Column 0: example index
1304 *ix_p = example_index;
1305 // Column 1: the feature index buffer example
1306 *(ix_p + 1) = feature_index;
1307 ix_p += 2;
1308 ++feature_index;
1309 }
1310 ++example_index;
1311 }
1312 }
1313
1314 CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values);
1315 offset += delta;
1316 }
1317 };
1318
1319 // Merge SparseBuffers from all minibatches for every config.ragged.
1320 auto MergeRaggedMinibatches = [&](size_t d) {
1321 // Loop over minibatches
1322 size_t total_num_features = 0;
1323 size_t max_num_features = 0;
1324 CountSparseFeatures(ragged_buffers, d, &total_num_features,
1325 &max_num_features);
1326
1327 TensorShape row_splits_shape;
1328 row_splits_shape.AddDim(serialized.size() + 1);
1329 result->ragged_splits.emplace_back(config.ragged[d].splits_dtype,
1330 row_splits_shape);
1331 Tensor* row_splits = &result->ragged_splits.back();
1332 if (config.ragged[d].splits_dtype == DT_INT64) {
1333 row_splits->flat<int64_t>()(0) = 0;
1334 } else {
1335 row_splits->flat<int32>()(0) = 0;
1336 }
1337
1338 TensorShape values_shape;
1339 values_shape.AddDim(total_num_features);
1340 result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape);
1341 Tensor* values = &result->ragged_values.back();
1342
1343 size_t values_offset = 0;
1344 size_t splits_offset = 0;
1345 for (size_t i = 0; i < ragged_buffers.size(); ++i) {
1346 SparseBuffer& buffer = ragged_buffers[i][d];
1347 if (buffer.example_end_indices.empty()) continue;
1348
1349 // Update row_splits. row_splits are formed by concatenating the example
1350 // end_indices (adjusting each to start after the previous one ends).
1351 if (config.ragged[d].splits_dtype == DT_INT64) {
1352 int64* row_splits_out = &row_splits->flat<int64_t>()(splits_offset);
1353 int64_t start = *row_splits_out;
1354 for (size_t example_end_index : buffer.example_end_indices) {
1355 *++row_splits_out = start + example_end_index;
1356 }
1357 } else {
1358 int32* row_splits_out = &row_splits->flat<int32>()(splits_offset);
1359 int32_t start = *row_splits_out;
1360 for (size_t example_end_index : buffer.example_end_indices) {
1361 *++row_splits_out = start + example_end_index;
1362 }
1363 }
1364
1365 CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer,
1366 values);
1367 values_offset += buffer.example_end_indices.back();
1368 splits_offset += buffer.example_end_indices.size();
1369 }
1370 };
1371
1372 // Merge SparseBuffers from all minibatches for every config.dense having
1373 // variable_length.
1374 auto MergeDenseVarLenMinibatches = [&](size_t d) {
1375 if (!config.dense[d].variable_length) return;
1376
1377 // Loop over minibatches
1378 size_t max_num_features = 0;
1379 for (auto& dense_values_tmp : varlen_dense_buffers) {
1380 std::vector<size_t>& end_indices =
1381 dense_values_tmp[d].example_end_indices;
1382 max_num_features = std::max(max_num_features, end_indices[0]);
1383 for (size_t i = 1; i < end_indices.size(); ++i) {
1384 size_t example_size = end_indices[i] - end_indices[i - 1];
1385 max_num_features = std::max(max_num_features, example_size);
1386 }
1387 }
1388
1389 const size_t stride_size = config.dense[d].elements_per_stride;
1390 const size_t max_num_elements = max_num_features / stride_size;
1391 TensorShape values_shape;
1392 DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1393 const size_t batch_size = serialized.size();
1394 values_shape.AddDim(batch_size);
1395 values_shape.AddDim(max_num_elements);
1396 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1397 values_shape.AddDim(config.dense[d].shape.dim_size(i));
1398 }
1399 Tensor values(config.dense[d].dtype, values_shape);
1400 result->dense_values[d] = values;
1401 const size_t num_elements = values.NumElements();
1402
1403 // Nothing to write, exit early.
1404 if (num_elements == 0) return;
1405
1406 const size_t num_elements_per_minibatch = num_elements / batch_size;
1407
1408 switch (config.dense[d].dtype) {
1409 case DT_INT64: {
1410 FillAndCopyVarLen<int64_t>(d, num_elements, num_elements_per_minibatch,
1411 config, varlen_dense_buffers, &values);
1412 break;
1413 }
1414 case DT_FLOAT: {
1415 FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1416 config, varlen_dense_buffers, &values);
1417 break;
1418 }
1419 case DT_STRING: {
1420 FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
1421 config, varlen_dense_buffers, &values);
1422 break;
1423 }
1424 default:
1425 ReportUnexpectedDataType(config.dense[d].dtype);
1426 }
1427 };
1428
1429 for (size_t d = 0; d < config.dense.size(); ++d) {
1430 MergeDenseVarLenMinibatches(d);
1431 }
1432
1433 for (size_t d = 0; d < config.sparse.size(); ++d) {
1434 MergeSparseMinibatches(d);
1435 }
1436
1437 for (size_t d = 0; d < config.ragged.size(); ++d) {
1438 MergeRaggedMinibatches(d);
1439 }
1440
1441 return OkStatus();
1442 }
1443
FastParseSingleExample(const Config & config,StringPiece serialized,Result * result)1444 Status FastParseSingleExample(const Config& config, StringPiece serialized,
1445 Result* result) {
1446 DCHECK(result != nullptr);
1447 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1448 TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1449
1450 PerExampleFeatureStats* stats = nullptr;
1451 if (config.collect_feature_stats) {
1452 result->feature_stats.emplace_back();
1453 stats = &result->feature_stats.back();
1454 }
1455
1456 // TODO(mrry): Cache the construction of this map at Op construction time.
1457 size_t config_size =
1458 config.dense.size() + config.sparse.size() + config.ragged.size();
1459 SeededHasher hasher;
1460 // Build config index.
1461 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1462 bool ok = true;
1463 for (size_t i = 0; i < 1000; ++i) {
1464 for (size_t d = 0; d < config.dense.size(); ++d) {
1465 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1466 {d, Type::Dense});
1467 }
1468 for (size_t d = 0; d < config.sparse.size(); ++d) {
1469 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1470 {d, Type::Sparse});
1471 }
1472 for (size_t d = 0; d < config.ragged.size(); ++d) {
1473 ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1474 {d, Type::Ragged});
1475 }
1476 if (ok) break;
1477 LOG(WARNING) << "Collision found. This should happen only if you have "
1478 "around 2^32 entries in your config.";
1479 hasher.seed++;
1480 config_index.Clear(config_size);
1481 ok = true;
1482 }
1483 if (!ok) {
1484 return errors::Internal(
1485 "Could not avoid collision. This should not happen.");
1486 }
1487
1488 result->sparse_indices.reserve(config.sparse.size());
1489 result->sparse_values.reserve(config.sparse.size());
1490 result->sparse_shapes.reserve(config.sparse.size());
1491 result->dense_values.reserve(config.dense.size());
1492 result->ragged_values.reserve(config.ragged.size());
1493 result->ragged_splits.reserve(config.ragged.size());
1494
1495 // Allocate dense output tensors.
1496 for (size_t d = 0; d < config.dense.size(); ++d) {
1497 if (!config.dense[d].variable_length) {
1498 TensorShape values_shape;
1499 if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1500 return errors::Internal(
1501 "Fixed-length shape was not a statically defined shape.");
1502 }
1503 result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1504 } else {
1505 // Variable-length tensor will be allocated later.
1506 result->dense_values.emplace_back();
1507 }
1508 }
1509
1510 // Allocate sparse output tensors.
1511 for (size_t d = 0; d < config.sparse.size(); ++d) {
1512 // The dense_shape is always a vector of length 1.
1513 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1514 // Variable-length tensors will be allocated later.
1515 result->sparse_indices.emplace_back();
1516 result->sparse_values.emplace_back();
1517 }
1518
1519 // Allocate ragged output tensors.
1520 for (size_t d = 0; d < config.ragged.size(); ++d) {
1521 // Variable-length values tensors will be allocated later.
1522 result->ragged_values.emplace_back();
1523 // Splits tensors are empty (unused) for single (scalar) inputs.
1524 const auto splits_dtype = config.ragged[d].splits_dtype;
1525 result->ragged_splits.emplace_back(splits_dtype, TensorShape({0}));
1526 }
1527
1528 parsed::Example parsed_example;
1529 if (!ParseExample(serialized, &parsed_example)) {
1530 return errors::InvalidArgument("Could not parse example input, value: '",
1531 serialized, "'");
1532 }
1533 std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1534 std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1535 std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false);
1536
1537 if (stats) {
1538 // TODO(b/111553342): This may over-count the number of features if there
1539 // are duplicate keys in the feature map. Consider deduplicating the keys
1540 // before computing the count.
1541 stats->features_count = parsed_example.size();
1542 }
1543
1544 // Handle features present in the example.
1545 const size_t parsed_example_size = parsed_example.size();
1546 for (size_t i = 0; i < parsed_example_size; ++i) {
1547 // This is a logic that standard protobuf parsing is implementing.
1548 // I.e. last entry in the map overwrites all the previous ones.
1549 parsed::FeatureMapEntry& name_and_feature =
1550 parsed_example[parsed_example_size - i - 1];
1551
1552 const StringPiece feature_name = name_and_feature.first;
1553 parsed::Feature& feature = name_and_feature.second;
1554
1555 std::pair<size_t, Type> d_and_type;
1556 uint64 h = hasher(feature_name);
1557 if (!config_index.Find(h, &d_and_type)) continue;
1558
1559 size_t d = d_and_type.first;
1560 bool is_dense = d_and_type.second == Type::Dense;
1561 bool is_sparse = d_and_type.second == Type::Sparse;
1562
1563 {
1564 // Testing for PresizedCuckooMap collision.
1565 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1566 const tstring& config_feature_name =
1567 is_dense ? config.dense[d].feature_name
1568 : (is_sparse ? config.sparse[d].feature_name
1569 : config.ragged[d].feature_name);
1570 if (feature_name != config_feature_name) continue;
1571 }
1572
1573 auto example_error = [feature_name](StringPiece suffix) {
1574 return errors::InvalidArgument("Key: ", feature_name, ". ", suffix);
1575 };
1576
1577 auto parse_error = [feature_name] {
1578 return errors::InvalidArgument("Key: ", feature_name,
1579 ". Can't parse serialized Example.");
1580 };
1581
1582 DataType example_dtype;
1583 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1584 if (example_dtype == DT_INVALID) continue;
1585
1586 if (is_dense && !config.dense[d].variable_length) {
1587 // If feature was already visited, skip.
1588 // Compare comment at the beginning of the loop.
1589 if (dense_feature_already_seen[d]) {
1590 LogDenseFeatureDataLoss(feature_name);
1591 continue;
1592 }
1593 dense_feature_already_seen[d] = true;
1594
1595 if (example_dtype != config.dense[d].dtype) {
1596 return example_error(strings::StrCat(
1597 "Data types don't match. Data type: ",
1598 DataTypeString(example_dtype),
1599 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1600 }
1601
1602 Tensor* out = &result->dense_values[d];
1603 const std::size_t num_elements = config.dense[d].elements_per_stride;
1604 if (stats) {
1605 // TODO(b/111553342): If desirable, we could add support for counting
1606 // elements in the features that aren't parsed, but this could add
1607 // considerable runtime cost.
1608 stats->feature_values_count += num_elements;
1609 }
1610 switch (example_dtype) {
1611 case DT_INT64: {
1612 auto out_p = out->flat<int64_t>().data();
1613 LimitedArraySlice<int64_t> slice(out_p, num_elements);
1614 if (!feature.ParseInt64List(&slice)) return parse_error();
1615 if (slice.EndDistance() != 0) {
1616 return parse_error();
1617 }
1618 break;
1619 }
1620 case DT_FLOAT: {
1621 auto out_p = out->flat<float>().data();
1622 LimitedArraySlice<float> slice(out_p, num_elements);
1623 if (!feature.ParseFloatList(&slice)) return parse_error();
1624 if (slice.EndDistance() != 0) {
1625 return parse_error();
1626 }
1627 break;
1628 }
1629 case DT_STRING: {
1630 auto out_p = out->flat<tstring>().data();
1631 LimitedArraySlice<tstring> slice(out_p, num_elements);
1632 if (!feature.ParseBytesList(&slice)) return parse_error();
1633 if (slice.EndDistance() != 0) {
1634 return parse_error();
1635 }
1636 break;
1637 }
1638 default:
1639 ReportUnexpectedDataType(example_dtype);
1640 }
1641
1642 } else { // if variable length
1643 SmallVector<tstring> bytes_list;
1644 TensorVector<float> float_list;
1645 SmallVector<int64_t> int64_list;
1646
1647 const size_t num_elements_divisor =
1648 is_dense ? config.dense[d].elements_per_stride : 1;
1649 size_t num_elements;
1650
1651 if (is_dense) {
1652 // If feature was already visited, skip.
1653 // Compare comment at the beginning of the loop.
1654 if (dense_feature_already_seen[d]) {
1655 LogDenseFeatureDataLoss(feature_name);
1656 continue;
1657 }
1658 dense_feature_already_seen[d] = true;
1659 if (example_dtype != config.dense[d].dtype) {
1660 return example_error(strings::StrCat(
1661 "Data types don't match. Data type: ",
1662 DataTypeString(example_dtype),
1663 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1664 }
1665 } else {
1666 // Feature is sparse or ragged.
1667 auto& feature_already_seen = is_sparse ? sparse_feature_already_seen
1668 : ragged_feature_already_seen;
1669 auto& feature_dtype =
1670 is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype;
1671 // If feature was already visited, skip.
1672 // Compare comment at the beginning of the loop.
1673 if (feature_already_seen[d]) {
1674 LogSparseFeatureDataLoss(feature_name);
1675 continue;
1676 }
1677 feature_already_seen[d] = true;
1678
1679 // Handle sparse features.
1680 if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
1681 return example_error(strings::StrCat(
1682 "Data types don't match. ",
1683 "Expected type: ", DataTypeString(feature_dtype),
1684 ", Actual type: ", DataTypeString(example_dtype)));
1685 }
1686 }
1687
1688 switch (example_dtype) {
1689 case DT_INT64: {
1690 // TODO(mrry): Use the fact that the `int64_list` is packed to read
1691 // out the length and pre-allocate the output tensor.
1692 if (!feature.ParseInt64List(&int64_list)) return parse_error();
1693 num_elements = int64_list.size();
1694 break;
1695 }
1696 case DT_FLOAT: {
1697 if (!feature.ParseFloatList(&float_list)) return parse_error();
1698 num_elements = float_list.size();
1699 break;
1700 }
1701 case DT_STRING: {
1702 int actual_num_elements = 0;
1703 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1704 return parse_error();
1705 }
1706 bytes_list.reserve(actual_num_elements);
1707 if (!feature.ParseBytesList(&bytes_list)) return parse_error();
1708 num_elements = bytes_list.size();
1709 break;
1710 }
1711 default:
1712 num_elements = 0;
1713 ReportUnexpectedDataType(example_dtype);
1714 }
1715
1716 if (num_elements % num_elements_divisor != 0) {
1717 return parse_error();
1718 }
1719
1720 if (stats) {
1721 stats->feature_values_count += num_elements;
1722 }
1723
1724 Tensor* out;
1725 DataType out_dtype;
1726 TensorShape out_shape;
1727 if (is_dense) {
1728 out_shape.AddDim(num_elements / num_elements_divisor);
1729 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1730 out_shape.AddDim(config.dense[d].shape.dim_size(i));
1731 }
1732
1733 out = &result->dense_values[d];
1734 out_dtype = config.dense[d].dtype;
1735 } else if (is_sparse) {
1736 Tensor* out_indices = &result->sparse_indices[d];
1737 Tensor* out_dense_shape = &result->sparse_shapes[d];
1738
1739 // TODO(mrry): Investigate the possibility of not materializing
1740 // the indices (and perhaps dense_shape) until they are needed.
1741 *out_indices = Tensor(
1742 DT_INT64, TensorShape({static_cast<int64_t>(num_elements), 1}));
1743 auto indices_flat = out_indices->flat<int64_t>();
1744 for (size_t i = 0; i < num_elements; ++i) {
1745 indices_flat(i) = static_cast<int64_t>(i);
1746 }
1747
1748 *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1749 auto shapes_shape_t = out_dense_shape->vec<int64_t>();
1750 shapes_shape_t(0) = num_elements;
1751
1752 out = &result->sparse_values[d];
1753 out_dtype = config.sparse[d].dtype;
1754 out_shape.AddDim(num_elements);
1755 } else {
1756 out = &result->ragged_values[d];
1757 out_dtype = config.ragged[d].dtype;
1758 out_shape.AddDim(num_elements);
1759 }
1760
1761 switch (example_dtype) {
1762 case DT_INT64: {
1763 *out = Tensor(out_dtype, out_shape);
1764 CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1765 out->flat<int64_t>().data());
1766 break;
1767 }
1768 case DT_FLOAT: {
1769 if (!out->CopyFrom(float_list.tensor(), out_shape)) {
1770 return parse_error();
1771 }
1772 break;
1773 }
1774 case DT_STRING: {
1775 *out = Tensor(out_dtype, out_shape);
1776 CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
1777 out->flat<tstring>().data());
1778 break;
1779 }
1780 default:
1781 ReportUnexpectedDataType(example_dtype);
1782 }
1783 }
1784 }
1785
1786 // Handle missing dense features.
1787 for (size_t d = 0; d < config.dense.size(); ++d) {
1788 if (!dense_feature_already_seen[d]) {
1789 if (!config.dense[d].variable_length) {
1790 // Handle missing fixed-length dense feature.
1791 if (config.dense[d].default_value.NumElements() == 0) {
1792 return errors::InvalidArgument(
1793 "Feature: ", config.dense[d].feature_name,
1794 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1795 " is required but could not be found.");
1796 }
1797 result->dense_values[d] = config.dense[d].default_value;
1798 } else {
1799 // Handle missing varlen dense feature.
1800 TensorShape empty_shape;
1801 empty_shape.AddDim(0);
1802 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1803 empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1804 }
1805 result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1806 }
1807 }
1808 }
1809
1810 // Handle missing sparse features.
1811 for (size_t d = 0; d < config.sparse.size(); ++d) {
1812 if (!sparse_feature_already_seen[d]) {
1813 result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1814 result->sparse_values[d] =
1815 Tensor(config.sparse[d].dtype, TensorShape({0}));
1816 result->sparse_shapes[d].vec<int64_t>()(0) = 0;
1817 }
1818 }
1819
1820 // Handle missing ragged features.
1821 for (size_t d = 0; d < config.ragged.size(); ++d) {
1822 if (!ragged_feature_already_seen[d]) {
1823 result->ragged_values[d] =
1824 Tensor(config.ragged[d].dtype, TensorShape({0}));
1825 }
1826 }
1827
1828 return OkStatus();
1829 }
1830
1831 // Private helper functions for FastParseSequenceExample.
1832 namespace {
1833
1834 // A struct used by FastParseSequenceExample to hold the serialized proto
1835 // substrings for a single feature, plus some auxiliary information derived
1836 // from those protos (such as the total value length).
1837 struct FeatureProtos {
1838 // Proto substrings from each serialized SequenceExample that correspond
1839 // with this feature. `protos_present` records whether the proto had a
1840 // value defined (even if that value is empty).
1841 std::vector<StringPiece> protos;
1842 std::vector<bool> protos_present;
1843
1844 // Information derived from protos:
1845 size_t length; // total length for ragged/sparse, max row length for dense.
1846 size_t num_rows; // only populated for ragged sequence features.
1847
1848 // Information from the config:
1849 Type type; // Whether this feature is sparse, ragged, or dense.
1850 DataType dtype;
1851 };
1852
1853 // Map from feature name to FeatureProtos for that feature.
1854 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
1855
ExampleName(const gtl::ArraySlice<tstring> example_names,int n)1856 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
1857 return example_names.empty() ? "<unknown>" : example_names[n];
1858 }
1859
1860 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1861 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)1862 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1863 tstring* out) {
1864 int num_elements = 0;
1865 uint32 length;
1866 if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1867 return -1;
1868 }
1869 if (length > 0) {
1870 auto limit = stream->PushLimit(length);
1871 while (!stream->ExpectAtEnd()) {
1872 uint32 bytes_length;
1873 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1874 !stream->ReadVarint32(&bytes_length)) {
1875 return -1;
1876 }
1877 if (out == nullptr) {
1878 stream->Skip(bytes_length);
1879 } else {
1880 out->resize_uninitialized(bytes_length);
1881 if (!stream->ReadRaw(out->data(), bytes_length)) {
1882 return -1;
1883 }
1884 out++;
1885 }
1886 num_elements++;
1887 }
1888 stream->PopLimit(limit);
1889 }
1890 return num_elements;
1891 }
1892
PadFloatFeature(int num_to_pad,float * out)1893 inline void PadFloatFeature(int num_to_pad, float* out) {
1894 for (int i = 0; i < num_to_pad; i++) {
1895 *out++ = 0.0;
1896 }
1897 }
1898
PadInt64Feature(int num_to_pad,int64_t * out)1899 inline void PadInt64Feature(int num_to_pad, int64_t* out) {
1900 for (int i = 0; i < num_to_pad; i++) {
1901 *out++ = 0;
1902 }
1903 }
1904
1905 // Return the number of float elements parsed, or -1 on error. If out is null,
1906 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1907 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1908 float* out) {
1909 int num_elements = 0;
1910 uint32 length;
1911 if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1912 return -1;
1913 }
1914 if (length > 0) {
1915 auto limit = stream->PushLimit(length);
1916 uint8 peek_tag = PeekTag(stream);
1917 if (peek_tag == kDelimitedTag(1)) { // packed
1918 uint32 packed_length;
1919 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1920 !stream->ReadVarint32(&packed_length)) {
1921 return -1;
1922 }
1923 auto packed_limit = stream->PushLimit(packed_length);
1924 while (!stream->ExpectAtEnd()) {
1925 uint32 buffer32;
1926 if (!stream->ReadLittleEndian32(&buffer32)) {
1927 return -1;
1928 }
1929 if (out != nullptr) {
1930 *out++ = absl::bit_cast<float>(buffer32);
1931 }
1932 num_elements++;
1933 }
1934 stream->PopLimit(packed_limit);
1935 } else if (peek_tag == kFixed32Tag(1)) {
1936 while (!stream->ExpectAtEnd()) {
1937 uint32 buffer32;
1938 if (!stream->ExpectTag(kFixed32Tag(1)) ||
1939 !stream->ReadLittleEndian32(&buffer32)) {
1940 return -1;
1941 }
1942 if (out != nullptr) {
1943 *out++ = absl::bit_cast<float>(buffer32);
1944 }
1945 num_elements++;
1946 }
1947 } else {
1948 // Unknown tag.
1949 return -1;
1950 }
1951 stream->PopLimit(limit);
1952 }
1953 return num_elements;
1954 }
1955
1956 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1957 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64_t * out)1958 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1959 int64_t* out) {
1960 int num_elements = 0;
1961 uint32 length;
1962 if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1963 return -1;
1964 }
1965 if (length > 0) {
1966 auto limit = stream->PushLimit(length);
1967 uint8 peek_tag = PeekTag(stream);
1968 if (peek_tag == kDelimitedTag(1)) { // packed
1969 uint32 packed_length;
1970 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1971 !stream->ReadVarint32(&packed_length)) {
1972 return -1;
1973 }
1974 auto packed_limit = stream->PushLimit(packed_length);
1975 while (!stream->ExpectAtEnd()) {
1976 protobuf_uint64 n; // There is no API for int64
1977 if (!stream->ReadVarint64(&n)) {
1978 return -1;
1979 }
1980 if (out != nullptr) {
1981 *out++ = n;
1982 }
1983 num_elements++;
1984 }
1985 stream->PopLimit(packed_limit);
1986 } else if (peek_tag == kVarintTag(1)) {
1987 while (!stream->ExpectAtEnd()) {
1988 protobuf_uint64 n; // There is no API for int64
1989 if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1990 return -1;
1991 }
1992 if (out != nullptr) {
1993 *out++ = n;
1994 }
1995 num_elements++;
1996 }
1997 } else {
1998 // Unknown tag.
1999 return -1;
2000 }
2001 stream->PopLimit(limit);
2002 }
2003 return num_elements;
2004 }
2005
2006 // Parses the next feature on `stream` into `out` starting at `out_offset`.
2007 // Updates `out_offset`, and returns the number of values added.
2008 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)2009 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
2010 Tensor* out, size_t* out_offset) {
2011 int delta;
2012 switch (dtype) {
2013 case DT_STRING:
2014 delta =
2015 ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
2016 break;
2017 case DT_FLOAT:
2018 delta =
2019 ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
2020 break;
2021 case DT_INT64:
2022 delta =
2023 ParseInt64Feature(stream, out->flat<int64_t>().data() + *out_offset);
2024 break;
2025 default:
2026 ReportUnexpectedDataType(dtype);
2027 delta = 0;
2028 }
2029 if (delta > 0) {
2030 *out_offset += delta;
2031 }
2032 return delta;
2033 }
2034
2035 // Returns the length of the next feature on `stream`.
2036 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)2037 inline int GetFeatureLength(DataType dtype,
2038 protobuf::io::CodedInputStream* stream) {
2039 switch (dtype) {
2040 case DT_STRING:
2041 return ParseBytesFeature(stream, nullptr);
2042 case DT_FLOAT:
2043 return ParseFloatFeature(stream, nullptr);
2044 case DT_INT64:
2045 return ParseInt64Feature(stream, nullptr);
2046 default:
2047 ReportUnexpectedDataType(dtype);
2048 return -1;
2049 }
2050 }
2051
ParseDataType(protobuf::io::CodedInputStream * stream)2052 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
2053 uint8 peek_tag = PeekTag(stream);
2054 switch (peek_tag) {
2055 case kDelimitedTag(1):
2056 return DT_STRING;
2057 case kDelimitedTag(2):
2058 return DT_FLOAT;
2059 case kDelimitedTag(3):
2060 return DT_INT64;
2061 default:
2062 return DT_INVALID;
2063 }
2064 }
2065
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)2066 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
2067 DataType dtype) {
2068 switch (dtype) {
2069 case DT_STRING:
2070 if (!stream->ExpectTag(kDelimitedTag(1))) {
2071 return false;
2072 }
2073 break;
2074 case DT_FLOAT:
2075 if (!stream->ExpectTag(kDelimitedTag(2))) {
2076 return false;
2077 }
2078 break;
2079 case DT_INT64:
2080 if (!stream->ExpectTag(kDelimitedTag(3))) {
2081 return false;
2082 }
2083 break;
2084 default:
2085 return false;
2086 }
2087 uint32 length;
2088 return stream->ReadVarint32(&length) && length == 0;
2089 }
2090
2091 // Reads an example proto, and extracts a StringPiece pointer to each feature.
ExtractFeaturesFromSequenceExamples(const gtl::ArraySlice<tstring> examples,const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features,FeatureProtosMap * sequence_features)2092 Status ExtractFeaturesFromSequenceExamples(
2093 const gtl::ArraySlice<tstring> examples,
2094 const gtl::ArraySlice<tstring> example_names,
2095 FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) {
2096 for (int d = 0; d < examples.size(); d++) {
2097 const tstring& example = examples[d];
2098 protobuf::io::CodedInputStream stream(
2099 reinterpret_cast<const uint8*>(example.data()), example.size());
2100 // Not clear what this does. Why not stream.EnableAliasing()?
2101 EnableAliasing(&stream);
2102
2103 // Extract pointers to all features within this serialized example.
2104 while (!stream.ExpectAtEnd()) {
2105 FeatureProtosMap* features = nullptr;
2106 if (stream.ExpectTag(kDelimitedTag(1))) {
2107 // Context
2108 features = context_features;
2109 } else if (stream.ExpectTag(kDelimitedTag(2))) {
2110 // Sequence
2111 features = sequence_features;
2112 } else if (!SkipExtraneousTag(&stream)) {
2113 return errors::InvalidArgument(
2114 "Invalid protocol message input, example id: ",
2115 ExampleName(example_names, d));
2116 }
2117 if (features != nullptr) {
2118 uint32 length;
2119 if (!stream.ReadVarint32(&length)) {
2120 return errors::InvalidArgument(
2121 "Invalid protocol message input, example id: ",
2122 ExampleName(example_names, d));
2123 }
2124 auto limit = stream.PushLimit(length);
2125 while (!stream.ExpectAtEnd()) {
2126 StringPiece key, value;
2127 uint32 length;
2128 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2129 !stream.ReadVarint32(&length)) {
2130 return errors::InvalidArgument(
2131 "Invalid protocol message input, example id: ",
2132 ExampleName(example_names, d));
2133 }
2134 auto limit = stream.PushLimit(length);
2135 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2136 !ParseString(&stream, &key) ||
2137 !stream.ExpectTag(kDelimitedTag(2)) ||
2138 !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
2139 return errors::InvalidArgument(
2140 "Invalid protocol message input, example id: ",
2141 ExampleName(example_names, d));
2142 }
2143 stream.PopLimit(limit);
2144 // Only save if this feature was requested.
2145 auto feature_iter = features->find(key);
2146 if (feature_iter != features->end()) {
2147 auto& feature = feature_iter->second;
2148 feature.protos[d] = value;
2149 feature.protos_present[d] = true;
2150 }
2151 }
2152 stream.PopLimit(limit);
2153 }
2154 }
2155 }
2156 return OkStatus();
2157 }
2158
2159 // Populates context_features[k].length based on context_features[k].protos
2160 // (for all k).
GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features)2161 Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2162 FeatureProtosMap* context_features) {
2163 for (auto& c : *context_features) {
2164 FeatureProtos& feature = c.second;
2165 for (int d = 0; d < feature.protos.size(); ++d) {
2166 const auto& proto = feature.protos[d];
2167 if (proto.empty()) continue;
2168 protobuf::io::CodedInputStream stream(
2169 reinterpret_cast<const uint8*>(proto.data()), proto.size());
2170 EnableAliasing(&stream);
2171 int num_elements = GetFeatureLength(feature.dtype, &stream);
2172 if (num_elements < 0) {
2173 return errors::InvalidArgument(
2174 "Name: ", ExampleName(example_names, d),
2175 ", Context feature: ", c.first,
2176 ". Data types don't match. Expected type: ",
2177 DataTypeString(feature.dtype));
2178 }
2179 switch (feature.type) {
2180 case Type::Sparse: // intentional fall-through
2181 case Type::Ragged:
2182 feature.length += num_elements;
2183 break;
2184 case Type::Dense:
2185 feature.length =
2186 std::max(feature.length, static_cast<size_t>(num_elements));
2187 break;
2188 }
2189 }
2190 }
2191 return OkStatus();
2192 }
2193
2194 // Populates sequence_features[k].length and sequence_features[k].num_rows based
2195 // on sequence_features[k].protos (for all k).
GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * sequence_features)2196 Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2197 FeatureProtosMap* sequence_features) {
2198 for (auto& c : *sequence_features) {
2199 FeatureProtos& feature = c.second;
2200 for (int d = 0; d < feature.protos.size(); ++d) {
2201 const auto& proto = feature.protos[d];
2202 if (proto.empty()) continue;
2203
2204 size_t num_rows = 0;
2205 size_t num_elements = 0;
2206 protobuf::io::CodedInputStream stream(
2207 reinterpret_cast<const uint8*>(proto.data()), proto.size());
2208 EnableAliasing(&stream);
2209 while (!stream.ExpectAtEnd()) {
2210 uint32 feature_bytes;
2211 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2212 !stream.ReadVarint32(&feature_bytes)) {
2213 return errors::InvalidArgument("Error in sequence feature ", c.first,
2214 " in example ",
2215 ExampleName(example_names, d));
2216 }
2217 if (feature_bytes > 2) {
2218 auto limit = stream.PushLimit(feature_bytes);
2219 int delta = GetFeatureLength(feature.dtype, &stream);
2220 if (delta < 0) {
2221 return errors::InvalidArgument(
2222 "Name: ", ExampleName(example_names, d),
2223 ", Feature list: ", c.first, ", Index: ", num_rows,
2224 ". Data types don't match. Expected type: ",
2225 DataTypeString(feature.dtype));
2226 }
2227 num_elements += delta;
2228 stream.PopLimit(limit);
2229 } else if (feature_bytes == 2) {
2230 if (!SkipEmptyFeature(&stream, feature.dtype)) {
2231 return errors::InvalidArgument(
2232 "Name: ", ExampleName(example_names, d),
2233 ", Feature list: ", c.first, ", Index: ", num_rows,
2234 ". Data types don't match. Expected type: ",
2235 DataTypeString(feature.dtype));
2236 }
2237 } else if (feature_bytes != 0) {
2238 return errors::InvalidArgument("Error in sequence feature ", c.first,
2239 " in example ",
2240 ExampleName(example_names, d));
2241 }
2242 ++num_rows;
2243 }
2244 switch (feature.type) {
2245 case Type::Sparse:
2246 feature.length += num_elements;
2247 break;
2248 case Type::Ragged:
2249 feature.length += num_elements;
2250 feature.num_rows += num_rows;
2251 break;
2252 case Type::Dense:
2253 feature.length = std::max(feature.length, num_elements);
2254 break;
2255 }
2256 }
2257 }
2258 return OkStatus();
2259 }
2260
2261 // Copies src into dst[dst_offset:dst_offset+src.size], and then increments
2262 // dst_offset by src.size.
CopyTensorIntoTensor(DataType dtype,const Tensor & src,Tensor * dst,size_t * dst_offset)2263 void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst,
2264 size_t* dst_offset) {
2265 size_t src_size = src.NumElements();
2266 switch (dtype) {
2267 case DT_INT64: {
2268 auto src_t = src.flat<int64_t>().data();
2269 std::copy(src_t, src_t + src_size,
2270 dst->flat<int64_t>().data() + *dst_offset);
2271 break;
2272 }
2273 case DT_FLOAT: {
2274 auto src_t = src.flat<float>().data();
2275 std::copy(src_t, src_t + src_size,
2276 dst->flat<float>().data() + *dst_offset);
2277 break;
2278 }
2279 case DT_STRING: {
2280 auto src_t = src.flat<tstring>().data();
2281 std::copy(src_t, src_t + src_size,
2282 dst->flat<tstring>().data() + *dst_offset);
2283 break;
2284 }
2285 default:
2286 ReportUnexpectedDataType(dtype);
2287 }
2288 *dst_offset += src_size;
2289 }
2290
2291 // Parses dense features in `context_features`, and writes their parsed
2292 // values to `context_results`.
ParseContextDenseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2293 Status ParseContextDenseFeatures(const FeatureProtosMap& context_features,
2294 const FastParseExampleConfig& context_config,
2295 gtl::ArraySlice<tstring> example_names,
2296 bool is_batch, int num_examples,
2297 Allocator* allocator, Result* context_result) {
2298 for (int t = 0; t < context_config.dense.size(); ++t) {
2299 const auto& c = context_config.dense[t];
2300 const FeatureProtos& feature =
2301 context_features.find(c.feature_name)->second;
2302 TensorShape dense_shape, example_shape;
2303 DataType dtype = c.dtype;
2304 const size_t data_max_elements = feature.length;
2305 if (!c.shape.AsTensorShape(&example_shape) ||
2306 data_max_elements != example_shape.num_elements()) {
2307 return errors::InvalidArgument(
2308 "Inconsistent max number of elements for feature ", c.feature_name,
2309 ": expected ", example_shape.num_elements(), ", but found ",
2310 data_max_elements);
2311 }
2312 if (is_batch) {
2313 dense_shape.AddDim(num_examples);
2314 }
2315 for (const int dim : c.shape.dim_sizes()) {
2316 dense_shape.AddDim(dim);
2317 }
2318 context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2319
2320 Tensor& out = context_result->dense_values[t];
2321 size_t out_offset = 0;
2322
2323 // Fill in the values.
2324 for (int e = 0; e < num_examples; e++) {
2325 size_t num_elements = 0;
2326 const auto& feature_proto = feature.protos[e];
2327 if (!feature.protos_present[e]) {
2328 // Copy the default value, if present. If not, return an error.
2329 if (c.default_value.NumElements() == 0) {
2330 return errors::InvalidArgument(
2331 "Feature: ", c.feature_name,
2332 " (data type: ", DataTypeString(c.dtype), ")",
2333 " is required but could not be found.");
2334 }
2335 CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset);
2336 num_elements += c.default_value.NumElements();
2337 } else if (!feature_proto.empty()) {
2338 protobuf::io::CodedInputStream stream(
2339 reinterpret_cast<const uint8*>(feature_proto.data()),
2340 feature_proto.size());
2341 EnableAliasing(&stream);
2342 num_elements += ParseFeature(dtype, &stream, &out, &out_offset);
2343 }
2344 if (num_elements != data_max_elements) {
2345 return errors::InvalidArgument(
2346 "Unexpected number of elements in example ",
2347 ExampleName(example_names, e));
2348 }
2349 }
2350 }
2351 return OkStatus();
2352 }
2353
2354 // Parses sparse features in `context_features`, and writes their parsed
2355 // values to `context_results`.
ParseContextSparseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2356 Status ParseContextSparseFeatures(const FeatureProtosMap& context_features,
2357 const FastParseExampleConfig& context_config,
2358 gtl::ArraySlice<tstring> example_names,
2359 bool is_batch, int num_examples,
2360 Allocator* allocator,
2361 Result* context_result) {
2362 for (int t = 0; t < context_config.sparse.size(); ++t) {
2363 const auto& c = context_config.sparse[t];
2364 const FeatureProtos& feature =
2365 context_features.find(c.feature_name)->second;
2366 TensorShape indices_shape, values_shape;
2367 DataType dtype = c.dtype;
2368 size_t expected_num_elements = feature.length;
2369 indices_shape.AddDim(expected_num_elements);
2370 indices_shape.AddDim(is_batch ? 2 : 1);
2371 values_shape.AddDim(expected_num_elements);
2372 context_result->sparse_indices[t] =
2373 Tensor(allocator, DT_INT64, indices_shape);
2374 context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2375 context_result->sparse_shapes[t] =
2376 Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1}));
2377 Tensor& out_values = context_result->sparse_values[t];
2378 size_t out_values_offset = 0;
2379 int64_t* out_indices =
2380 context_result->sparse_indices[t].flat<int64_t>().data();
2381 auto out_shape = context_result->sparse_shapes[t].vec<int64_t>();
2382
2383 // Fill in the values.
2384 size_t num_elements = 0;
2385 size_t max_num_cols = 0;
2386 for (int e = 0; e < num_examples; e++) {
2387 const auto& feature_proto = feature.protos[e];
2388 if (feature_proto.empty()) continue;
2389 protobuf::io::CodedInputStream stream(
2390 reinterpret_cast<const uint8*>(feature_proto.data()),
2391 feature_proto.size());
2392 EnableAliasing(&stream);
2393 size_t num_added =
2394 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2395 num_elements += num_added;
2396 max_num_cols = std::max(max_num_cols, num_added);
2397 for (int i = 0; i < num_added; i++) {
2398 if (is_batch) *out_indices++ = e;
2399 *out_indices++ = i;
2400 }
2401 }
2402 if (num_elements != expected_num_elements) {
2403 return errors::InvalidArgument(
2404 "Unexpected total number of elements in feature ", c.feature_name);
2405 }
2406 if (is_batch) {
2407 out_shape(0) = num_examples;
2408 out_shape(1) = max_num_cols;
2409 } else {
2410 out_shape(0) = max_num_cols;
2411 }
2412 }
2413 return OkStatus();
2414 }
2415
2416 // Parses ragged features in `context_features`, and writes their parsed
2417 // values to `context_results`.
ParseContextRaggedFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2418 Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features,
2419 const FastParseExampleConfig& context_config,
2420 gtl::ArraySlice<tstring> example_names,
2421 bool is_batch, int num_examples,
2422 Allocator* allocator,
2423 Result* context_result) {
2424 for (int t = 0; t < context_config.ragged.size(); ++t) {
2425 const auto& c = context_config.ragged[t];
2426 const FeatureProtos& feature =
2427 context_features.find(c.feature_name)->second;
2428 TensorShape values_shape, splits_shape;
2429 DataType dtype = c.dtype;
2430 DataType splits_dtype = c.splits_dtype;
2431 size_t expected_num_elements = feature.length;
2432 values_shape.AddDim(expected_num_elements);
2433 if (is_batch) {
2434 splits_shape.AddDim(num_examples + 1);
2435 }
2436 context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2437 context_result->ragged_splits[t] =
2438 Tensor(allocator, splits_dtype, splits_shape);
2439 Tensor& out_values = context_result->ragged_values[t];
2440 size_t out_values_offset = 0;
2441 int32* int32_splits =
2442 is_batch && splits_dtype == DT_INT32
2443 ? context_result->ragged_splits[t].vec<int32>().data()
2444 : nullptr;
2445 int64_t* int64_splits =
2446 is_batch && splits_dtype == DT_INT64
2447 ? context_result->ragged_splits[t].vec<int64_t>().data()
2448 : nullptr;
2449 if (int32_splits) {
2450 *int32_splits++ = 0;
2451 } else if (int64_splits) {
2452 *int64_splits++ = 0;
2453 }
2454
2455 // Fill in the values.
2456 size_t split = 0; // = total number of elements we've seen so far
2457 for (int e = 0; e < num_examples; e++) {
2458 const auto& feature_proto = feature.protos[e];
2459 if (!feature_proto.empty()) {
2460 protobuf::io::CodedInputStream stream(
2461 reinterpret_cast<const uint8*>(feature_proto.data()),
2462 feature_proto.size());
2463 EnableAliasing(&stream);
2464 size_t num_added =
2465 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2466 split += num_added;
2467 }
2468 if (int32_splits) {
2469 *int32_splits++ = split;
2470 } else if (int64_splits) {
2471 *int64_splits++ = split;
2472 }
2473 }
2474 if (split != expected_num_elements) {
2475 return errors::InvalidArgument(
2476 "Unexpected total number of elements in feature ", c.feature_name);
2477 }
2478 if (int32_splits || int64_splits) {
2479 int actual_splits =
2480 int32_splits
2481 ? int32_splits -
2482 context_result->ragged_splits[t].vec<int32>().data()
2483 : int64_splits -
2484 context_result->ragged_splits[t].vec<int64_t>().data();
2485 if (actual_splits != num_examples + 1) {
2486 return errors::InvalidArgument(
2487 "Unexpected number of examples for feature ", c.feature_name);
2488 }
2489 }
2490 }
2491 return OkStatus();
2492 }
2493
2494 // Parses dense features in `sequence_features`, and writes their parsed
2495 // values to `sequence_result`.
ParseSequenceDenseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths)2496 Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features,
2497 const FastParseExampleConfig& sequence_config,
2498 gtl::ArraySlice<tstring> example_names,
2499 bool is_batch, int num_examples,
2500 Allocator* allocator, Result* sequence_result,
2501 std::vector<Tensor>* dense_feature_lengths) {
2502 TensorShape dense_length_shape;
2503 if (is_batch) {
2504 dense_length_shape.AddDim(num_examples);
2505 }
2506 for (int t = 0; t < sequence_config.dense.size(); ++t) {
2507 const auto& c = sequence_config.dense[t];
2508 const FeatureProtos& feature =
2509 sequence_features.find(c.feature_name)->second;
2510 TensorShape dense_shape, row_shape;
2511 DataType dtype = c.dtype;
2512 const size_t expected_max_elements = feature.length;
2513 if (!c.shape.AsTensorShape(&row_shape) ||
2514 expected_max_elements !=
2515 (expected_max_elements / row_shape.num_elements()) *
2516 row_shape.num_elements()) {
2517 PartialTensorShape total_shape = row_shape;
2518 total_shape.InsertDim(0, -1);
2519 return errors::InvalidArgument(
2520 "Feature list '", c.feature_name,
2521 "' has an unexpected number of values. Total values size: ",
2522 expected_max_elements,
2523 " is not consistent with output shape: ", total_shape.DebugString());
2524 }
2525 int64_t expected_max_rows =
2526 expected_max_elements / row_shape.num_elements();
2527 if (is_batch) {
2528 dense_shape.AddDim(num_examples);
2529 }
2530 dense_shape.AddDim(expected_max_rows);
2531 for (const int dim : sequence_config.dense[t].shape.dim_sizes()) {
2532 dense_shape.AddDim(dim);
2533 }
2534 sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2535 (*dense_feature_lengths)[t] =
2536 Tensor(allocator, DT_INT64, dense_length_shape);
2537 int64_t* out_lengths = (*dense_feature_lengths)[t].flat<int64_t>().data();
2538
2539 tstring* out_bytes = nullptr;
2540 float* out_float = nullptr;
2541 int64_t* out_int64 = nullptr;
2542 switch (dtype) {
2543 case DT_STRING:
2544 out_bytes = sequence_result->dense_values[t].flat<tstring>().data();
2545 break;
2546 case DT_FLOAT:
2547 out_float = sequence_result->dense_values[t].flat<float>().data();
2548 break;
2549 case DT_INT64:
2550 out_int64 = sequence_result->dense_values[t].flat<int64_t>().data();
2551 break;
2552 default:
2553 ReportUnexpectedDataType(dtype);
2554 }
2555
2556 // Fill in the values.
2557 for (int e = 0; e < num_examples; e++) {
2558 size_t num_elements = 0, num_rows = 0;
2559 const auto& feature_proto = feature.protos[e];
2560 if (!feature.protos_present[e]) {
2561 // Return an error if this feature was not allowed to be missing.
2562 // Otherwise, we'll pad as needed below.
2563 if (!c.variable_length) {
2564 return errors::InvalidArgument(
2565 "Name: ", ExampleName(example_names, e), ", Feature list '",
2566 c.feature_name,
2567 "' is required but could not be found. "
2568 "Did you mean to include it in "
2569 "feature_list_dense_missing_assumed_empty or "
2570 "feature_list_dense_defaults?");
2571 }
2572 } else if (!feature_proto.empty()) {
2573 protobuf::io::CodedInputStream stream(
2574 reinterpret_cast<const uint8*>(feature_proto.data()),
2575 feature_proto.size());
2576 EnableAliasing(&stream);
2577 while (!stream.ExpectAtEnd()) {
2578 uint32 feature_length;
2579 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2580 !stream.ReadVarint32(&feature_length)) {
2581 return errors::InvalidArgument("Error in sequence feature ",
2582 c.feature_name, " in example ",
2583 ExampleName(example_names, e));
2584 }
2585 auto limit = stream.PushLimit(feature_length);
2586 int num_added = 0;
2587 if (feature_length > 2) {
2588 switch (dtype) {
2589 case DT_STRING:
2590 num_added = ParseBytesFeature(&stream, out_bytes);
2591 out_bytes += num_added;
2592 break;
2593 case DT_FLOAT:
2594 num_added = ParseFloatFeature(&stream, out_float);
2595 out_float += num_added;
2596 break;
2597 case DT_INT64:
2598 num_added = ParseInt64Feature(&stream, out_int64);
2599 out_int64 += num_added;
2600 break;
2601 default:
2602 ReportUnexpectedDataType(dtype);
2603 num_added = 0;
2604 }
2605 if (num_added < 0) {
2606 // This should be unreachable -- we already scanned the feature in
2607 // GetSequenceFeatureLengths, and it hasn't changed since then.
2608 return errors::InvalidArgument("Error in sequence feature ",
2609 c.feature_name, " in example ",
2610 ExampleName(example_names, e));
2611 }
2612 }
2613 if (num_added != row_shape.num_elements()) {
2614 return errors::InvalidArgument(
2615 "Name: ", ExampleName(example_names, e),
2616 ", Key: ", c.feature_name, ", Index: ", num_rows,
2617 ". Number of values != expected. values size: ", num_added,
2618 " but output shape: ", row_shape.DebugString());
2619 }
2620 num_elements += num_added;
2621 num_rows++;
2622 stream.PopLimit(limit);
2623 }
2624 }
2625 *out_lengths++ = num_rows;
2626 // Pad as necessary.
2627 int num_to_pad = expected_max_elements - num_elements;
2628 switch (dtype) {
2629 case DT_STRING:
2630 out_bytes += num_to_pad;
2631 break;
2632 case DT_FLOAT:
2633 PadFloatFeature(num_to_pad, out_float);
2634 out_float += num_to_pad;
2635 break;
2636 case DT_INT64:
2637 PadInt64Feature(num_to_pad, out_int64);
2638 out_int64 += num_to_pad;
2639 break;
2640 default:
2641 ReportUnexpectedDataType(dtype);
2642 }
2643 }
2644 }
2645 return OkStatus();
2646 }
2647
2648 // Parses sparse features in `sequence_features`, and writes their parsed
2649 // values to `sequence_result`.
ParseSequenceSparseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2650 Status ParseSequenceSparseFeatures(
2651 const FeatureProtosMap& sequence_features,
2652 const FastParseExampleConfig& sequence_config,
2653 gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2654 Allocator* allocator, Result* sequence_result) {
2655 for (int t = 0; t < sequence_config.sparse.size(); ++t) {
2656 const auto& c = sequence_config.sparse[t];
2657 const FeatureProtos& feature =
2658 sequence_features.find(c.feature_name)->second;
2659 TensorShape indices_shape, values_shape;
2660 DataType dtype = c.dtype;
2661 size_t expected_num_elements = feature.length;
2662 indices_shape.AddDim(expected_num_elements);
2663 indices_shape.AddDim(is_batch ? 3 : 2);
2664 values_shape.AddDim(expected_num_elements);
2665 sequence_result->sparse_indices[t] =
2666 Tensor(allocator, DT_INT64, indices_shape);
2667 sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2668 sequence_result->sparse_shapes[t] =
2669 Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2}));
2670
2671 tstring* out_bytes = nullptr;
2672 float* out_float = nullptr;
2673 int64_t* out_int64 = nullptr;
2674 switch (dtype) {
2675 case DT_STRING:
2676 out_bytes = sequence_result->sparse_values[t].flat<tstring>().data();
2677 break;
2678 case DT_FLOAT:
2679 out_float = sequence_result->sparse_values[t].flat<float>().data();
2680 break;
2681 case DT_INT64:
2682 out_int64 = sequence_result->sparse_values[t].flat<int64_t>().data();
2683 break;
2684 default:
2685 ReportUnexpectedDataType(dtype);
2686 }
2687 int64_t* out_indices =
2688 sequence_result->sparse_indices[t].flat<int64_t>().data();
2689 auto out_shape = sequence_result->sparse_shapes[t].vec<int64_t>();
2690
2691 // Fill in the values.
2692 size_t num_elements = 0;
2693 size_t max_num_rows = 0;
2694 size_t max_num_cols = 0;
2695 for (int e = 0; e < num_examples; e++) {
2696 const auto& feature_proto = feature.protos[e];
2697 if (feature_proto.empty()) continue;
2698 protobuf::io::CodedInputStream stream(
2699 reinterpret_cast<const uint8*>(feature_proto.data()),
2700 feature_proto.size());
2701 EnableAliasing(&stream);
2702 size_t num_rows = 0;
2703 while (!stream.ExpectAtEnd()) {
2704 uint32 feature_length;
2705 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2706 !stream.ReadVarint32(&feature_length)) {
2707 // This should be unreachable -- we already scanned the feature in
2708 // GetSequenceFeatureLengths, and it hasn't changed since then.
2709 return errors::InvalidArgument("Error in sequence feature ",
2710 c.feature_name, " in example ",
2711 ExampleName(example_names, e));
2712 }
2713 if (feature_length > 2) {
2714 auto limit = stream.PushLimit(feature_length);
2715 size_t num_added;
2716 switch (dtype) {
2717 case DT_STRING:
2718 num_added = ParseBytesFeature(&stream, out_bytes);
2719 out_bytes += num_added;
2720 break;
2721 case DT_FLOAT:
2722 num_added = ParseFloatFeature(&stream, out_float);
2723 out_float += num_added;
2724 break;
2725 case DT_INT64:
2726 num_added = ParseInt64Feature(&stream, out_int64);
2727 out_int64 += num_added;
2728 break;
2729 default:
2730 ReportUnexpectedDataType(dtype);
2731 num_added = 0;
2732 }
2733 num_elements += num_added;
2734 max_num_cols = std::max(max_num_cols, num_added);
2735 for (int i = 0; i < num_added; i++) {
2736 if (is_batch) *out_indices++ = e;
2737 *out_indices++ = num_rows;
2738 *out_indices++ = i;
2739 }
2740 stream.PopLimit(limit);
2741 } else if (feature_length == 2) {
2742 if (!SkipEmptyFeature(&stream, dtype)) {
2743 // This should be unreachable -- we already scanned the feature in
2744 // GetSequenceFeatureLengths, and it hasn't changed since then.
2745 return errors::InvalidArgument("Error in sequence feature ",
2746 c.feature_name, " in example ",
2747 ExampleName(example_names, e));
2748 }
2749 } else if (feature_length != 0) {
2750 // This should be unreachable -- we already scanned the feature in
2751 // GetSequenceFeatureLengths, and it hasn't changed since then.
2752 return errors::InvalidArgument("Error in sequence feature ",
2753 c.feature_name, " in example ",
2754 ExampleName(example_names, e));
2755 }
2756 num_rows++;
2757 }
2758 max_num_rows = std::max(max_num_rows, num_rows);
2759 }
2760 if (num_elements != expected_num_elements) {
2761 return errors::InvalidArgument(
2762 "Unexpected number of elements in feature ", c.feature_name);
2763 }
2764 if (is_batch) {
2765 out_shape(0) = num_examples;
2766 out_shape(1) = max_num_rows;
2767 out_shape(2) = max_num_cols;
2768 } else {
2769 out_shape(0) = max_num_rows;
2770 out_shape(1) = max_num_cols;
2771 }
2772 }
2773 return OkStatus();
2774 }
2775
2776 // Parses ragged features in `sequence_features`, and writes their parsed
2777 // values to `sequence_result`.
ParseSequenceRaggedFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2778 Status ParseSequenceRaggedFeatures(
2779 const FeatureProtosMap& sequence_features,
2780 const FastParseExampleConfig& sequence_config,
2781 gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2782 Allocator* allocator, Result* sequence_result) {
2783 for (int t = 0; t < sequence_config.ragged.size(); ++t) {
2784 const auto& c = sequence_config.ragged[t];
2785 const FeatureProtos& feature =
2786 sequence_features.find(c.feature_name)->second;
2787 TensorShape values_shape, inner_splits_shape, outer_splits_shape;
2788 DataType dtype = c.dtype;
2789 DataType splits_dtype = c.splits_dtype;
2790 size_t expected_num_elements = feature.length;
2791 size_t expected_num_rows = feature.num_rows;
2792 values_shape.AddDim(expected_num_elements);
2793 inner_splits_shape.AddDim(expected_num_rows + 1);
2794 if (is_batch) {
2795 outer_splits_shape.AddDim(num_examples + 1);
2796 }
2797 sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2798 sequence_result->ragged_splits[t] =
2799 Tensor(allocator, splits_dtype, inner_splits_shape);
2800 sequence_result->ragged_outer_splits[t] =
2801 Tensor(allocator, splits_dtype, outer_splits_shape);
2802 Tensor& out_values = sequence_result->ragged_values[t];
2803 size_t out_values_offset = 0;
2804 int32* int32_inner_splits =
2805 splits_dtype == DT_INT32
2806 ? sequence_result->ragged_splits[t].vec<int32>().data()
2807 : nullptr;
2808 int64_t* int64_inner_splits =
2809 splits_dtype == DT_INT64
2810 ? sequence_result->ragged_splits[t].vec<int64_t>().data()
2811 : nullptr;
2812 int32* int32_outer_splits =
2813 is_batch && splits_dtype == DT_INT32
2814 ? sequence_result->ragged_outer_splits[t].vec<int32>().data()
2815 : nullptr;
2816 int64_t* int64_outer_splits =
2817 is_batch && splits_dtype == DT_INT64
2818 ? sequence_result->ragged_outer_splits[t].vec<int64_t>().data()
2819 : nullptr;
2820 if (int32_inner_splits) {
2821 *int32_inner_splits++ = 0;
2822 } else if (int64_inner_splits) {
2823 *int64_inner_splits++ = 0;
2824 }
2825 if (int32_outer_splits) {
2826 *int32_outer_splits++ = 0;
2827 } else if (int64_outer_splits) {
2828 *int64_outer_splits++ = 0;
2829 }
2830
2831 // Fill in the values.
2832 size_t inner_split = 0; // total number of elements we've seen so far
2833 size_t outer_split = 0; // total number of rows we've seen so far
2834 for (int e = 0; e < num_examples; e++) {
2835 const auto& feature_proto = feature.protos[e];
2836 if (!feature_proto.empty()) {
2837 protobuf::io::CodedInputStream stream(
2838 reinterpret_cast<const uint8*>(feature_proto.data()),
2839 feature_proto.size());
2840 EnableAliasing(&stream);
2841 while (!stream.ExpectAtEnd()) {
2842 uint32 feature_length;
2843 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2844 !stream.ReadVarint32(&feature_length)) {
2845 // This should be unreachable -- we already scanned the feature in
2846 // GetSequenceFeatureLengths, and it hasn't changed since then.
2847 return errors::InvalidArgument("Error in sequence feature ",
2848 c.feature_name, " in example ",
2849 ExampleName(example_names, e));
2850 }
2851 if (feature_length > 2) {
2852 auto limit = stream.PushLimit(feature_length);
2853 size_t num_added =
2854 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2855 inner_split += num_added;
2856 stream.PopLimit(limit);
2857 } else if (feature_length == 2) {
2858 if (!SkipEmptyFeature(&stream, dtype)) {
2859 // This should be unreachable -- we already scanned the feature in
2860 // GetSequenceFeatureLengths, and it hasn't changed since then.
2861 return errors::InvalidArgument("Error in sequence feature ",
2862 c.feature_name, " in example ",
2863 ExampleName(example_names, e));
2864 }
2865 } else if (feature_length != 0) {
2866 // This should be unreachable -- we already scanned the feature in
2867 // GetSequenceFeatureLengths, and it hasn't changed since then.
2868 return errors::InvalidArgument("Error in sequence feature ",
2869 c.feature_name, " in example ",
2870 ExampleName(example_names, e));
2871 }
2872 if (int32_inner_splits) {
2873 *int32_inner_splits++ = inner_split;
2874 } else if (int64_inner_splits) {
2875 *int64_inner_splits++ = inner_split;
2876 }
2877 outer_split++;
2878 }
2879 }
2880 if (int32_outer_splits) {
2881 *int32_outer_splits++ = outer_split;
2882 } else if (int64_outer_splits) {
2883 *int64_outer_splits++ = outer_split;
2884 }
2885 }
2886 if (outer_split != expected_num_rows) {
2887 return errors::InvalidArgument("Unexpected number of rows for feature ",
2888 c.feature_name);
2889 }
2890 if (inner_split != expected_num_elements) {
2891 return errors::InvalidArgument(
2892 "Unexpected number of elements for feature ", c.feature_name);
2893 }
2894
2895 if (int32_inner_splits || int64_inner_splits) {
2896 const auto& inner_splits = sequence_result->ragged_splits[t];
2897 int num_inner_splits =
2898 int32_inner_splits
2899 ? int32_inner_splits - inner_splits.vec<int32>().data()
2900 : int64_inner_splits - inner_splits.vec<int64_t>().data();
2901 if (num_inner_splits != expected_num_rows + 1) {
2902 return errors::InvalidArgument("Unexpected number of rows for feature ",
2903 c.feature_name);
2904 }
2905 }
2906 if (int32_outer_splits || int64_outer_splits) {
2907 const auto& outer_splits = sequence_result->ragged_outer_splits[t];
2908 int num_outer_splits =
2909 int32_outer_splits
2910 ? int32_outer_splits - outer_splits.vec<int32>().data()
2911 : int64_outer_splits - outer_splits.vec<int64_t>().data();
2912 if (num_outer_splits != num_examples + 1) {
2913 return errors::InvalidArgument(
2914 "Unexpected number of examples for feature ", c.feature_name);
2915 }
2916 }
2917 }
2918 return OkStatus();
2919 }
2920
2921 } // namespace
2922
2923 // TODO(sundberg): Use the threadpool to parallelize example parsing.
2924 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths,bool is_batch)2925 Status FastParseSequenceExample(const FastParseExampleConfig& context_config,
2926 const FastParseExampleConfig& sequence_config,
2927 gtl::ArraySlice<tstring> serialized,
2928 gtl::ArraySlice<tstring> example_names,
2929 thread::ThreadPool* thread_pool,
2930 Result* context_result, Result* sequence_result,
2931 std::vector<Tensor>* dense_feature_lengths,
2932 bool is_batch) {
2933 int num_examples = serialized.size();
2934 DCHECK(context_result != nullptr);
2935 DCHECK(sequence_result != nullptr);
2936 DCHECK(dense_feature_lengths != nullptr);
2937 size_t num_context_features = context_config.sparse.size() +
2938 context_config.dense.size() +
2939 context_config.ragged.size();
2940 FeatureProtosMap context_features;
2941 context_features.reserve(num_context_features);
2942
2943 if (!example_names.empty() && example_names.size() != num_examples) {
2944 return errors::InvalidArgument(
2945 "example_names must be empty or have the correct number of elements");
2946 }
2947 for (auto& c : context_config.sparse) {
2948 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2949 FeatureProtos& feature = context_features[c.feature_name];
2950 feature.dtype = c.dtype;
2951 feature.length = 0;
2952 feature.type = Type::Sparse;
2953 feature.protos.resize(num_examples);
2954 feature.protos_present.resize(num_examples);
2955 }
2956 for (auto& c : context_config.ragged) {
2957 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2958 FeatureProtos& feature = context_features[c.feature_name];
2959 if (feature.type == Type::Sparse) {
2960 return errors::InvalidArgument("Context feature " + c.feature_name +
2961 " cannot be both ragged and sparse");
2962 }
2963 feature.dtype = c.dtype;
2964 feature.length = 0;
2965 feature.type = Type::Ragged;
2966 feature.protos.resize(num_examples);
2967 feature.protos_present.resize(num_examples);
2968 }
2969 for (auto& c : context_config.dense) {
2970 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2971 FeatureProtos& feature = context_features[c.feature_name];
2972 if (feature.type != Type::Dense) {
2973 return errors::InvalidArgument("Context feature " + c.feature_name +
2974 " cannot be both dense and sparse");
2975 }
2976 if (c.default_value.NumElements() > 0) {
2977 if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
2978 return errors::InvalidArgument("Default value for context feature ",
2979 c.feature_name,
2980 " has an incorrect shape: saw ",
2981 c.default_value.shape().DebugString(),
2982 " but expected ", c.shape.DebugString());
2983 }
2984 }
2985 feature.dtype = c.dtype;
2986 feature.length = c.default_value.NumElements();
2987 feature.protos.resize(num_examples);
2988 feature.protos_present.resize(num_examples);
2989 }
2990 size_t num_sequence_features = sequence_config.sparse.size() +
2991 sequence_config.dense.size() +
2992 sequence_config.ragged.size();
2993 FeatureProtosMap sequence_features;
2994 sequence_features.reserve(num_sequence_features);
2995 for (auto& c : sequence_config.sparse) {
2996 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2997 FeatureProtos& feature = sequence_features[c.feature_name];
2998 feature.dtype = c.dtype;
2999 feature.length = 0;
3000 feature.type = Type::Sparse;
3001 feature.protos.resize(num_examples);
3002 feature.protos_present.resize(num_examples);
3003 }
3004 for (auto& c : sequence_config.ragged) {
3005 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3006 FeatureProtos& feature = sequence_features[c.feature_name];
3007 if (feature.type == Type::Sparse) {
3008 return errors::InvalidArgument("Sequence feature " + c.feature_name +
3009 " cannot be both ragged and sparse");
3010 }
3011 feature.dtype = c.dtype;
3012 feature.length = 0;
3013 feature.type = Type::Ragged;
3014 feature.protos.resize(num_examples);
3015 feature.protos_present.resize(num_examples);
3016 }
3017 for (auto& c : sequence_config.dense) {
3018 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3019 FeatureProtos& feature = sequence_features[c.feature_name];
3020 if (feature.type != Type::Dense) {
3021 return errors::InvalidArgument("Sequence feature " + c.feature_name +
3022 " cannot be both dense and sparse");
3023 }
3024 feature.dtype = c.dtype;
3025 feature.length = 0;
3026 feature.protos.resize(num_examples);
3027 feature.protos_present.resize(num_examples);
3028 }
3029
3030 // Find the serialized proto substrings for each feature.
3031 TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples(
3032 serialized, example_names, &context_features, &sequence_features));
3033
3034 // Scan through the protos to determine how much memory we need to allocate.
3035 TF_RETURN_IF_ERROR(
3036 GetContextFeatureLengths(example_names, &context_features));
3037 TF_RETURN_IF_ERROR(
3038 GetSequenceFeatureLengths(example_names, &sequence_features));
3039
3040 // Allocate memory.
3041 context_result->sparse_values.resize(context_config.sparse.size());
3042 context_result->sparse_indices.resize(context_config.sparse.size());
3043 context_result->sparse_shapes.resize(context_config.sparse.size());
3044 context_result->dense_values.resize(context_config.dense.size());
3045 context_result->ragged_values.resize(context_config.ragged.size());
3046 context_result->ragged_splits.resize(context_config.ragged.size());
3047 context_result->ragged_outer_splits.resize(context_config.ragged.size());
3048 sequence_result->sparse_values.resize(sequence_config.sparse.size());
3049 sequence_result->sparse_indices.resize(sequence_config.sparse.size());
3050 sequence_result->sparse_shapes.resize(sequence_config.sparse.size());
3051 sequence_result->dense_values.resize(sequence_config.dense.size());
3052 sequence_result->ragged_values.resize(sequence_config.ragged.size());
3053 sequence_result->ragged_splits.resize(sequence_config.ragged.size());
3054 sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size());
3055 dense_feature_lengths->resize(sequence_config.dense.size());
3056
3057 // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction,
3058 // to avoid lock contention in `tensorflow::cpu_allocator()`.
3059 Allocator* allocator = tensorflow::cpu_allocator();
3060
3061 TF_RETURN_IF_ERROR(ParseContextDenseFeatures(
3062 context_features, context_config, example_names, is_batch, num_examples,
3063 allocator, context_result));
3064 TF_RETURN_IF_ERROR(ParseContextSparseFeatures(
3065 context_features, context_config, example_names, is_batch, num_examples,
3066 allocator, context_result));
3067 TF_RETURN_IF_ERROR(ParseContextRaggedFeatures(
3068 context_features, context_config, example_names, is_batch, num_examples,
3069 allocator, context_result));
3070 TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures(
3071 sequence_features, sequence_config, example_names, is_batch, num_examples,
3072 allocator, sequence_result, dense_feature_lengths));
3073 TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures(
3074 sequence_features, sequence_config, example_names, is_batch, num_examples,
3075 allocator, sequence_result));
3076 TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures(
3077 sequence_features, sequence_config, example_names, is_batch, num_examples,
3078 allocator, sequence_result));
3079
3080 return OkStatus();
3081 }
3082
3083 } // namespace example
3084 } // namespace tensorflow
3085