1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16
17 #include <vector>
18
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/blocking_counter.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/monitoring/counter.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/util/presized_cuckoo_map.h"
37 #include "tensorflow/core/util/sparse/sparse_tensor.h"
38
39 namespace tensorflow {
40 namespace example {
41
42 namespace {
43
44 template <typename T>
45 using SmallVector = gtl::InlinedVector<T, 4>;
46
47 template <typename T>
48 class LimitedArraySlice {
49 public:
50 using value_type = T;
51
LimitedArraySlice(T * begin,size_t num_elements)52 LimitedArraySlice(T* begin, size_t num_elements)
53 : current_(begin), begin_(begin), end_(begin + num_elements) {}
54
55 // May return negative if there were push_back calls after slice was filled.
EndDistance() const56 int64 EndDistance() const { return end_ - current_; }
57
58 // Attempts to push value to the back of this. If the slice has
59 // already been filled, this method has no effect on the underlying data, but
60 // it changes the number returned by EndDistance into negative values.
push_back(T && value)61 void push_back(T&& value) {
62 if (EndDistance() > 0) *current_ = std::move(value);
63 ++current_;
64 }
65
66 // "Constructs" an element at the back of this by resizing the slice, and
67 // returns a mutable reference to the new last element.
68 // REQUIRES: EndDistance() > 0.
construct_at_end()69 T& construct_at_end() {
70 DCHECK_GT(EndDistance(), 0);
71 return *(current_++);
72 }
73
74 // Returns a mutable reference to the last element in the slice.
75 // REQUIRES: size() > 0.
back()76 T& back() { return *(current_ - 1); }
77
78 // Returns the number of elements in the slice.
size() const79 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
80
81 // Attempts to resize the vector to the given size. It does so by advancing
82 // the pointer to the current element, possibly beyond the end of the slice.
83 // As a consequence, calling `size()` after `resize(x)` was called might
84 // return a value less than `x`.
resize(size_t size)85 void resize(size_t size) { current_ = begin_ + size; }
86
87 // Returns the pointer to the underlying data buffer.
data()88 T* data() { return begin_; }
89
90 private:
91 T* current_;
92 T* begin_;
93 T* end_;
94 };
95
96 template <typename A>
EnableAliasing(A * a)97 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
98 a->EnableAliasing(true);
99 }
100
101 template <typename A>
EnableAliasing(A && a)102 void EnableAliasing(A&& a) {}
103
PeekTag(protobuf::io::CodedInputStream * stream)104 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
105 DCHECK(stream != nullptr);
106 const void* ptr;
107 int size;
108 if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
109 return *static_cast<const uint8*>(ptr);
110 }
111
kVarintTag(uint32 tag)112 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)113 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)114 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
115
116 namespace parsed {
117
118 // ParseDataType has to be called first, then appropriate ParseZzzzList.
119 class Feature {
120 public:
Feature()121 Feature() {}
Feature(StringPiece serialized)122 explicit Feature(StringPiece serialized) : serialized_(serialized) {}
123
ParseDataType(DataType * dtype)124 Status ParseDataType(DataType* dtype) {
125 DCHECK(dtype != nullptr);
126 if (serialized_.empty()) {
127 *dtype = DT_INVALID;
128 return Status::OK();
129 }
130 uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
131 serialized_.remove_prefix(1);
132 switch (oneof_tag) {
133 case kDelimitedTag(1):
134 *dtype = DT_STRING;
135 break;
136 case kDelimitedTag(2):
137 *dtype = DT_FLOAT;
138 break;
139 case kDelimitedTag(3):
140 *dtype = DT_INT64;
141 break;
142 default:
143 // Initialize variable to avoid compiler warning
144 *dtype = DT_INVALID;
145 return errors::InvalidArgument("Unsupported datatype.");
146 }
147 return Status::OK();
148 }
149
GetNumElementsInBytesList(int * num_elements)150 bool GetNumElementsInBytesList(int* num_elements) {
151 protobuf::io::CodedInputStream stream(
152 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
153 EnableAliasing(&stream);
154 uint32 length = 0;
155 if (!stream.ReadVarint32(&length)) return false;
156 auto limit = stream.PushLimit(length);
157 *num_elements = 0;
158 while (!stream.ExpectAtEnd()) {
159 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
160 uint32 bytes_length = 0;
161 if (!stream.ReadVarint32(&bytes_length)) return false;
162 if (!stream.Skip(bytes_length)) return false;
163 ++*num_elements;
164 }
165 stream.PopLimit(limit);
166 return true;
167 }
168
169 // Helper methods
construct_at_end(LimitedArraySlice<tstring> * bytes_list)170 tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) {
171 if (bytes_list->EndDistance() <= 0) {
172 return nullptr;
173 }
174 return &bytes_list->construct_at_end();
175 }
construct_at_end(SmallVector<tstring> * bytes_list)176 tstring* construct_at_end(SmallVector<tstring>* bytes_list) {
177 return &bytes_list->emplace_back();
178 }
179
180 template <typename Result>
ParseBytesList(Result * bytes_list)181 bool ParseBytesList(Result* bytes_list) {
182 DCHECK(bytes_list != nullptr);
183
184 protobuf::io::CodedInputStream stream(
185 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
186
187 EnableAliasing(&stream);
188
189 uint32 length;
190 if (!stream.ReadVarint32(&length)) return false;
191 auto limit = stream.PushLimit(length);
192
193 while (!stream.ExpectAtEnd()) {
194 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
195 // parse string
196 uint32 bytes_length;
197 if (!stream.ReadVarint32(&bytes_length)) return false;
198 tstring* bytes = construct_at_end(bytes_list);
199 if (bytes == nullptr) return false;
200 bytes->resize_uninitialized(bytes_length);
201 if (!stream.ReadRaw(bytes->data(), bytes_length)) return false;
202 }
203 stream.PopLimit(limit);
204 return true;
205 }
206
207 template <typename Result>
ParseFloatList(Result * float_list)208 bool ParseFloatList(Result* float_list) {
209 DCHECK(float_list != nullptr);
210 protobuf::io::CodedInputStream stream(
211 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212 EnableAliasing(&stream);
213 uint32 length;
214 if (!stream.ReadVarint32(&length)) return false;
215 auto limit = stream.PushLimit(length);
216
217 if (!stream.ExpectAtEnd()) {
218 uint8 peek_tag = PeekTag(&stream);
219 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
220 return false;
221 }
222
223 constexpr int32_t kNumFloatBytes = 4;
224 if (peek_tag == kDelimitedTag(1)) { // packed
225 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
226 uint32 packed_length;
227 if (!stream.ReadVarint32(&packed_length)) return false;
228 auto packed_limit = stream.PushLimit(packed_length);
229
230 // Store the initial size to know the offset we have to start writing
231 // data from before resizing the output "vector".
232 const size_t initial_size = float_list->size();
233 float_list->resize(initial_size + packed_length / kNumFloatBytes);
234
235 // If the result data type is float and we are on a little endian
236 // machine then we can simply memcpy the data from the proto into the
237 // result vector.
238 if (port::kLittleEndian &&
239 sizeof(typename Result::value_type) == kNumFloatBytes) {
240 // Calculate the length of the buffer available what can be less than
241 // what we requested in resize in case of a LimitedArraySlice.
242 const uint32 bytes_to_copy =
243 std::min(static_cast<uint32>((float_list->size() - initial_size) *
244 kNumFloatBytes),
245 packed_length);
246 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
247 return false;
248 } else {
249 int64_t index = initial_size;
250 while (!stream.ExpectAtEnd()) {
251 uint32 buffer32;
252 if (!stream.ReadLittleEndian32(&buffer32)) return false;
253 if (index < float_list->size()) {
254 float_list->data()[index] = absl::bit_cast<float>(buffer32);
255 ++index;
256 }
257 }
258 }
259
260 stream.PopLimit(packed_limit);
261 } else { // non-packed
262 const size_t initial_size = float_list->size();
263 // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
264 // the value.
265 const int64_t num_elements =
266 stream.BytesUntilLimit() / (1 + kNumFloatBytes);
267 float_list->resize(initial_size + num_elements);
268 int64_t index = initial_size;
269 while (!stream.ExpectAtEnd()) {
270 if (!stream.ExpectTag(kFixed32Tag(1))) return false;
271 uint32 buffer32;
272 if (!stream.ReadLittleEndian32(&buffer32)) return false;
273 float_list->data()[index] = absl::bit_cast<float>(buffer32);
274 ++index;
275 }
276 }
277 }
278
279 stream.PopLimit(limit);
280 return true;
281 }
282
283 template <typename Result>
ParseInt64List(Result * int64_list)284 bool ParseInt64List(Result* int64_list) {
285 DCHECK(int64_list != nullptr);
286 protobuf::io::CodedInputStream stream(
287 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
288 EnableAliasing(&stream);
289 uint32 length;
290 if (!stream.ReadVarint32(&length)) return false;
291 auto limit = stream.PushLimit(length);
292
293 if (!stream.ExpectAtEnd()) {
294 uint8 peek_tag = PeekTag(&stream);
295 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
296 return false;
297 }
298 if (peek_tag == kDelimitedTag(1)) { // packed
299 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
300 uint32 packed_length;
301 if (!stream.ReadVarint32(&packed_length)) return false;
302 auto packed_limit = stream.PushLimit(packed_length);
303
304 while (!stream.ExpectAtEnd()) {
305 protobuf_uint64 n; // There is no API for int64
306 if (!stream.ReadVarint64(&n)) return false;
307 int64_list->push_back(static_cast<int64>(n));
308 }
309
310 stream.PopLimit(packed_limit);
311 } else { // non-packed
312 while (!stream.ExpectAtEnd()) {
313 if (!stream.ExpectTag(kVarintTag(1))) return false;
314 protobuf_uint64 n; // There is no API for int64
315 if (!stream.ReadVarint64(&n)) return false;
316 int64_list->push_back(static_cast<int64>(n));
317 }
318 }
319 }
320 stream.PopLimit(limit);
321 return true;
322 }
323
GetSerialized() const324 StringPiece GetSerialized() const { return serialized_; }
325
326 private:
327 // TODO(lew): Pair of uint8* would be more natural.
328 StringPiece serialized_;
329 };
330
331 using FeatureMapEntry = std::pair<StringPiece, Feature>;
332 using Example = std::vector<FeatureMapEntry>;
333
334 } // namespace parsed
335
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)336 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
337 uint32 data;
338 protobuf_uint64 dummy;
339 switch (stream->ReadTag() & 0x7) {
340 case 0: // varint
341 if (!stream->ReadVarint32(&data)) return false;
342 return true;
343 case 1: // fixed64
344 if (!stream->ReadLittleEndian64(&dummy)) return false;
345 return true;
346 case 2: // length delimited
347 if (!stream->ReadVarint32(&data)) return false;
348 stream->Skip(data);
349 return true;
350 case 3: // group begin
351 return false; // groups not supported.
352 case 4: // group end
353 return false; // groups not supported.
354 case 5: // fixed32
355 if (!stream->ReadLittleEndian32(&data)) return false;
356 return true;
357 }
358 return false; // unrecognized tag type
359 }
360
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)361 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
362 DCHECK(stream != nullptr);
363 DCHECK(result != nullptr);
364 uint32 length;
365 if (!stream->ReadVarint32(&length)) return false;
366 if (length == 0) {
367 *result = StringPiece(nullptr, 0);
368 return true;
369 }
370 const void* stream_alias;
371 int stream_size;
372 if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
373 return false;
374 }
375 if (static_cast<uint32>(stream_size) < length) return false;
376 *result = StringPiece(static_cast<const char*>(stream_alias), length);
377 stream->Skip(length);
378 return true;
379 }
380
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)381 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
382 parsed::FeatureMapEntry* feature_map_entry) {
383 DCHECK(stream != nullptr);
384 DCHECK(feature_map_entry != nullptr);
385 uint32 length;
386 if (!stream->ReadVarint32(&length)) return false;
387 auto limit = stream->PushLimit(length);
388 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
389 if (!ParseString(stream, &feature_map_entry->first)) return false;
390 if (!stream->ExpectTag(kDelimitedTag(2))) return false;
391 StringPiece feature_string_piece;
392 if (!ParseString(stream, &feature_string_piece)) return false;
393 feature_map_entry->second = parsed::Feature(feature_string_piece);
394 if (!stream->ExpectAtEnd()) return false;
395 stream->PopLimit(limit);
396 return true;
397 }
398
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)399 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
400 parsed::Example* example) {
401 DCHECK(stream != nullptr);
402 DCHECK(example != nullptr);
403 uint32 length;
404 if (!stream->ReadVarint32(&length)) return false;
405 auto limit = stream->PushLimit(length);
406 while (!stream->ExpectAtEnd()) {
407 parsed::FeatureMapEntry feature_map_entry;
408 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
409 if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
410 example->push_back(std::move(feature_map_entry));
411 }
412 stream->PopLimit(limit);
413 return true;
414 }
415
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)416 bool ParseExample(protobuf::io::CodedInputStream* stream,
417 parsed::Example* example) {
418 DCHECK(stream != nullptr);
419 DCHECK(example != nullptr);
420 // Loop over the input stream which may contain multiple serialized Example
421 // protos merged together as strings. This behavior is consistent with Proto's
422 // ParseFromString when string representations are concatenated.
423 while (!stream->ExpectAtEnd()) {
424 if (!stream->ExpectTag(kDelimitedTag(1))) {
425 if (!SkipExtraneousTag(stream)) return false;
426 } else {
427 if (!ParseFeatures(stream, example)) return false;
428 }
429 }
430 return true;
431 }
432
ParseExample(StringPiece serialized,parsed::Example * example)433 bool ParseExample(StringPiece serialized, parsed::Example* example) {
434 DCHECK(example != nullptr);
435 protobuf::io::CodedInputStream stream(
436 reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
437 EnableAliasing(&stream);
438 return ParseExample(&stream, example);
439 }
440
441 } // namespace
442
TestFastParse(const string & serialized,Example * example)443 bool TestFastParse(const string& serialized, Example* example) {
444 DCHECK(example != nullptr);
445 parsed::Example parsed_example;
446 if (!ParseExample(serialized, &parsed_example)) return false;
447 auto& features = *example->mutable_features();
448 size_t parsed_example_size = parsed_example.size();
449 for (size_t i = 0; i < parsed_example_size; ++i) {
450 // This is a logic that standard protobuf parsing is implementing.
451 // I.e. last entry in the map overwrites all the previous ones.
452 parsed::FeatureMapEntry& name_and_feature =
453 parsed_example[parsed_example_size - i - 1];
454 string name(name_and_feature.first);
455 if ((*features.mutable_feature()).count(name) > 0) continue;
456
457 auto& value = (*features.mutable_feature())[name];
458 DataType dtype;
459 if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
460 switch (dtype) {
461 case DT_INVALID:
462 break;
463 case DT_STRING: {
464 SmallVector<tstring> list;
465 if (!name_and_feature.second.ParseBytesList(&list)) return false;
466 auto* result_list = value.mutable_bytes_list();
467 for (auto& bytes : list) {
468 result_list->add_value(bytes.data(), bytes.size());
469 }
470 break;
471 }
472 case DT_FLOAT: {
473 SmallVector<float> list;
474 if (!name_and_feature.second.ParseFloatList(&list)) return false;
475 auto* result_list = value.mutable_float_list();
476 for (float f : list) {
477 result_list->add_value(f);
478 }
479 break;
480 }
481 case DT_INT64: {
482 SmallVector<int64> list;
483 if (!name_and_feature.second.ParseInt64List(&list)) return false;
484 auto* result_list = value.mutable_int64_list();
485 for (int64_t i : list) {
486 result_list->add_value(i);
487 }
488 break;
489 }
490 default:
491 LOG(FATAL) << "Should not happen.";
492 }
493 }
494 return true;
495 }
496
497 // -----------------------------------------------------------------------------
498
499 namespace {
500
501 using Config = FastParseExampleConfig;
502
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)503 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
504 thread::ThreadPool* thread_pool) {
505 if (n == 0) return;
506 if (thread_pool == nullptr) {
507 for (size_t i = 0; i < n; ++i) {
508 f(i);
509 }
510 } else {
511 BlockingCounter counter(n - 1);
512 for (size_t i = 1; i < n; ++i) {
513 thread_pool->Schedule([i, &f, &counter] {
514 f(i);
515 counter.DecrementCount();
516 });
517 }
518 f(0);
519 counter.Wait();
520 }
521 }
522
523 // Enumeration for distinguishing feature types.
524 // Note: FastParseSequenceExample constructs a map that includes Type values,
525 // and relies on the fact that they are default-initialized to Dense.
526 enum class Type { Dense, Sparse, Ragged };
527
528 // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features.
529 struct SparseBuffer {
530 // Features are in one of the 3 vectors below depending on config's dtype.
531 // Other 2 vectors remain empty.
532 SmallVector<tstring> bytes_list;
533 SmallVector<float> float_list;
534 SmallVector<int64> int64_list;
535
536 // Features of example i are elements with indices
537 // from example_end_indices[i-1] to example_end_indices[i]-1 on the
538 // appropriate xxxxx_list
539 std::vector<size_t> example_end_indices;
540 };
541
542 struct SeededHasher {
operator ()tensorflow::example::__anon549c22790211::SeededHasher543 uint64 operator()(StringPiece s) const {
544 return Hash64(s.data(), s.size(), seed);
545 }
546 uint64 seed{0xDECAFCAFFE};
547 };
548
LogDenseFeatureDataLoss(StringPiece feature_name)549 void LogDenseFeatureDataLoss(StringPiece feature_name) {
550 LOG(WARNING) << "Data loss! Feature '" << feature_name
551 << "' is present in multiple concatenated "
552 "tf.Examples. Ignoring all but last one.";
553 static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
554 "/tensorflow/core/util/example_proto_fast_parsing/"
555 "duplicated_dense_feature",
556 "Dense feature appears twice in a tf.Example");
557 duplicated_dense_feature->GetCell()->IncrementBy(1);
558 }
559
LogSparseFeatureDataLoss(StringPiece feature_name)560 void LogSparseFeatureDataLoss(StringPiece feature_name) {
561 LOG(WARNING) << "Data loss! Feature '" << feature_name
562 << "' is present in multiple concatenated "
563 "tf.Examples. Ignoring all but last one.";
564 static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
565 "/tensorflow/core/util/example_proto_fast_parsing/"
566 "duplicated_sparse_feature",
567 "Sparse feature appears twice in a tf.Example");
568 duplicated_sparse_feature->GetCell()->IncrementBy(1);
569 }
570
FastParseSerializedExample(const tstring & serialized_example,const tstring & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,std::vector<SparseBuffer> * output_ragged,PerExampleFeatureStats * output_stats)571 Status FastParseSerializedExample(
572 const tstring& serialized_example, const tstring& example_name,
573 const size_t example_index, const Config& config,
574 const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
575 SeededHasher hasher, std::vector<Tensor>* output_dense,
576 std::vector<SparseBuffer>* output_varlen_dense,
577 std::vector<SparseBuffer>* output_sparse,
578 std::vector<SparseBuffer>* output_ragged,
579 PerExampleFeatureStats* output_stats) {
580 DCHECK(output_dense != nullptr);
581 DCHECK(output_sparse != nullptr);
582 DCHECK(output_ragged != nullptr);
583 parsed::Example parsed_example;
584 if (!ParseExample(serialized_example, &parsed_example)) {
585 return errors::InvalidArgument("Could not parse example input, value: '",
586 serialized_example, "'");
587 }
588 std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1);
589 std::vector<int64> dense_feature_last_example(config.dense.size(), -1);
590 std::vector<int64> ragged_feature_last_example(config.ragged.size(), -1);
591
592 // Handle features present in the example.
593 const size_t parsed_example_size = parsed_example.size();
594
595 if (output_stats) {
596 // TODO(b/111553342): This may over-count the number of features if there
597 // are duplicate keys in the feature map. Consider deduplicating the keys
598 // before computing the count.
599 output_stats->features_count = parsed_example_size;
600 }
601
602 for (size_t i = 0; i < parsed_example_size; ++i) {
603 // This is a logic that standard protobuf parsing is implementing.
604 // I.e. last entry in the map overwrites all the previous ones.
605 parsed::FeatureMapEntry& name_and_feature =
606 parsed_example[parsed_example_size - i - 1];
607
608 const StringPiece feature_name = name_and_feature.first;
609 parsed::Feature& feature = name_and_feature.second;
610
611 std::pair<size_t, Type> d_and_type;
612 uint64 h = hasher(feature_name);
613 if (!config_index.Find(h, &d_and_type)) continue;
614
615 size_t d = d_and_type.first;
616 bool is_dense = d_and_type.second == Type::Dense;
617 bool is_ragged = d_and_type.second == Type::Ragged;
618
619 {
620 // Testing for PresizedCuckooMap collision.
621 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
622 const tstring& config_feature_name =
623 is_dense ? config.dense[d].feature_name
624 : (is_ragged ? config.ragged[d].feature_name
625 : config.sparse[d].feature_name);
626 if (feature_name != config_feature_name) continue;
627 }
628
629 auto example_error = [&](StringPiece suffix) {
630 return errors::InvalidArgument("Name: ", example_name,
631 ", Key: ", feature_name,
632 ", Index: ", example_index, ". ", suffix);
633 };
634
635 auto parse_error = [&] {
636 return example_error("Can't parse serialized Example.");
637 };
638
639 DataType example_dtype;
640 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
641
642 if (is_dense) {
643 if (example_dtype == DT_INVALID) continue;
644
645 // If feature was already visited, skip.
646 // Compare comment at the beginning of the loop.
647 if (dense_feature_last_example[d] == example_index) {
648 LogDenseFeatureDataLoss(feature_name);
649 continue;
650 }
651 dense_feature_last_example[d] = example_index;
652
653 if (example_dtype != config.dense[d].dtype) {
654 return example_error(strings::StrCat(
655 "Data types don't match. Data type: ",
656 DataTypeString(example_dtype),
657 " but expected type: ", DataTypeString(config.dense[d].dtype)));
658 }
659 if (!config.dense[d].variable_length) {
660 Tensor& out = (*output_dense)[d];
661
662 const std::size_t num_elements = config.dense[d].elements_per_stride;
663 if (output_stats) {
664 // TODO(b/111553342): If desirable, we could add support for counting
665 // elements in the features that aren't parsed, but this could add
666 // considerable runtime cost.
667 output_stats->feature_values_count += num_elements;
668 }
669
670 const std::size_t offset = example_index * num_elements;
671
672 auto shape_error = [&](size_t size, StringPiece type_str) {
673 return example_error(strings::StrCat(
674 "Number of ", type_str,
675 " values != expected. "
676 "Values size: ",
677 size,
678 " but output shape: ", config.dense[d].shape.DebugString()));
679 };
680
681 switch (config.dense[d].dtype) {
682 case DT_INT64: {
683 auto out_p = out.flat<int64>().data() + offset;
684 LimitedArraySlice<int64> slice(out_p, num_elements);
685 if (!feature.ParseInt64List(&slice)) return parse_error();
686 if (slice.EndDistance() != 0) {
687 return shape_error(num_elements - slice.EndDistance(), "int64");
688 }
689 break;
690 }
691 case DT_FLOAT: {
692 auto out_p = out.flat<float>().data() + offset;
693 LimitedArraySlice<float> slice(out_p, num_elements);
694 if (!feature.ParseFloatList(&slice)) return parse_error();
695 if (slice.EndDistance() != 0) {
696 return shape_error(num_elements - slice.EndDistance(), "float");
697 }
698 break;
699 }
700 case DT_STRING: {
701 auto out_p = out.flat<tstring>().data() + offset;
702 LimitedArraySlice<tstring> slice(out_p, num_elements);
703 if (!feature.ParseBytesList(&slice)) return parse_error();
704 if (slice.EndDistance() != 0) {
705 return shape_error(num_elements - slice.EndDistance(), "bytes");
706 }
707 break;
708 }
709 default:
710 LOG(FATAL) << "Should not happen.";
711 }
712 } else { // if variable length
713 SparseBuffer& out = (*output_varlen_dense)[d];
714
715 const std::size_t num_elements = config.dense[d].elements_per_stride;
716
717 if (example_dtype != DT_INVALID &&
718 example_dtype != config.dense[d].dtype) {
719 return example_error(strings::StrCat(
720 "Data types don't match. ",
721 "Expected type: ", DataTypeString(config.dense[d].dtype)));
722 }
723
724 auto shape_error = [&](size_t size, StringPiece type_str) {
725 return example_error(strings::StrCat(
726 "Number of ", type_str,
727 " values is not a multiple of stride length. Saw ", size,
728 " values but output shape is: ",
729 config.dense[d].shape.DebugString()));
730 };
731
732 switch (config.dense[d].dtype) {
733 case DT_INT64: {
734 if (example_dtype != DT_INVALID) {
735 if (!feature.ParseInt64List(&out.int64_list)) {
736 return parse_error();
737 }
738 if (out.int64_list.size() % num_elements != 0) {
739 return shape_error(out.int64_list.size(), "int64");
740 }
741 }
742 out.example_end_indices.push_back(out.int64_list.size());
743 break;
744 }
745 case DT_FLOAT: {
746 if (example_dtype != DT_INVALID) {
747 if (!feature.ParseFloatList(&out.float_list)) {
748 return parse_error();
749 }
750 if (out.float_list.size() % num_elements != 0) {
751 return shape_error(out.float_list.size(), "float");
752 }
753 }
754 out.example_end_indices.push_back(out.float_list.size());
755 break;
756 }
757 case DT_STRING: {
758 if (example_dtype != DT_INVALID) {
759 if (!feature.ParseBytesList(&out.bytes_list)) {
760 return parse_error();
761 }
762 if (out.bytes_list.size() % num_elements != 0) {
763 return shape_error(out.bytes_list.size(), "bytes");
764 }
765 }
766 out.example_end_indices.push_back(out.bytes_list.size());
767 break;
768 }
769 default:
770 LOG(FATAL) << "Should not happen.";
771 }
772
773 if (output_stats) {
774 // Use `out.example_end_indices` to determine the feature-value count
775 // for this feature, because the preceding switch statement pushes
776 // the length of the appropriate feature list to that vector.
777 // TODO(b/111553342): If desirable, we could add support for counting
778 // elements in the features that aren't parsed, but this could add
779 // considerable runtime cost.
780 const size_t out_examples_count = out.example_end_indices.size();
781 if (out_examples_count == 1) {
782 output_stats->feature_values_count += out.example_end_indices[0];
783 } else {
784 output_stats->feature_values_count +=
785 out.example_end_indices[out_examples_count - 1] -
786 out.example_end_indices[out_examples_count - 2];
787 }
788 }
789 }
790 } else {
791 // Feature is sparse or ragged.
792 auto& last_example =
793 is_ragged ? ragged_feature_last_example : sparse_feature_last_example;
794
795 // If feature was already visited, skip.
796 // Compare comment at the beginning of the loop.
797 if (last_example[d] == example_index) {
798 LogSparseFeatureDataLoss(feature_name);
799 continue;
800 }
801 last_example[d] = example_index;
802
803 // Handle sparse features.
804 SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d];
805 DataType feature_dtype =
806 is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype;
807 if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
808 return example_error(
809 strings::StrCat("Data types don't match. ",
810 "Expected type: ", DataTypeString(feature_dtype),
811 ", Actual type: ", DataTypeString(example_dtype)));
812 }
813
814 switch (feature_dtype) {
815 case DT_INT64: {
816 if (example_dtype != DT_INVALID) {
817 if (!feature.ParseInt64List(&out.int64_list)) {
818 return parse_error();
819 }
820 }
821 out.example_end_indices.push_back(out.int64_list.size());
822 break;
823 }
824 case DT_FLOAT: {
825 if (example_dtype != DT_INVALID) {
826 if (!feature.ParseFloatList(&out.float_list)) {
827 return parse_error();
828 }
829 }
830 out.example_end_indices.push_back(out.float_list.size());
831 break;
832 }
833 case DT_STRING: {
834 if (example_dtype != DT_INVALID) {
835 if (!feature.ParseBytesList(&out.bytes_list)) {
836 return parse_error();
837 }
838 }
839 out.example_end_indices.push_back(out.bytes_list.size());
840 break;
841 }
842 default:
843 LOG(FATAL) << "Should not happen.";
844 }
845
846 if (output_stats) {
847 // Use `out.example_end_indices` to determine the feature-value count
848 // for this feature, because the preceding switch statement pushes
849 // the length of the appropriate feature list to that vector.
850 // TODO(b/111553342): If desirable, we could add support for counting
851 // elements in the features that aren't parsed, but this could add
852 // considerable runtime cost.
853 const size_t out_examples_count = out.example_end_indices.size();
854 if (out_examples_count == 1) {
855 output_stats->feature_values_count += out.example_end_indices[0];
856 } else {
857 output_stats->feature_values_count +=
858 out.example_end_indices[out_examples_count - 1] -
859 out.example_end_indices[out_examples_count - 2];
860 }
861 }
862 }
863 }
864
865 // Handle missing dense features for fixed strides.
866 for (size_t d = 0; d < config.dense.size(); ++d) {
867 if (config.dense[d].variable_length) continue;
868 if (dense_feature_last_example[d] == example_index) continue;
869 if (config.dense[d].default_value.NumElements() == 0) {
870 return errors::InvalidArgument(
871 "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
872 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
873 " is required but could not be found.");
874 }
875 const Tensor& in = config.dense[d].default_value;
876 Tensor& out = (*output_dense)[d];
877 const std::size_t num_elements = in.shape().num_elements();
878 const std::size_t offset = example_index * num_elements;
879
880 switch (config.dense[d].dtype) {
881 case DT_INT64: {
882 std::copy_n(in.flat<int64>().data(), num_elements,
883 out.flat<int64>().data() + offset);
884 break;
885 }
886 case DT_FLOAT: {
887 std::copy_n(in.flat<float>().data(), num_elements,
888 out.flat<float>().data() + offset);
889 break;
890 }
891 case DT_STRING: {
892 std::copy_n(in.flat<tstring>().data(), num_elements,
893 out.flat<tstring>().data() + offset);
894 break;
895 }
896 default:
897 LOG(FATAL) << "Should not happen.";
898 }
899 }
900
901 // Handle missing varlen dense features.
902 for (size_t d = 0; d < config.dense.size(); ++d) {
903 if (!config.dense[d].variable_length) continue;
904 if (dense_feature_last_example[d] == example_index) continue;
905 SparseBuffer& out = (*output_varlen_dense)[d];
906 size_t prev_example_end_index =
907 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
908 out.example_end_indices.push_back(prev_example_end_index);
909 }
910
911 // Handle missing sparse features.
912 for (size_t d = 0; d < config.sparse.size(); ++d) {
913 if (sparse_feature_last_example[d] == example_index) continue;
914 SparseBuffer& out = (*output_sparse)[d];
915 size_t prev_example_end_index =
916 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
917 out.example_end_indices.push_back(prev_example_end_index);
918 }
919
920 // Handle missing ragged features.
921 for (size_t d = 0; d < config.ragged.size(); ++d) {
922 if (ragged_feature_last_example[d] == example_index) continue;
923 SparseBuffer& out = (*output_ragged)[d];
924 size_t prev_example_end_index =
925 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
926 out.example_end_indices.push_back(prev_example_end_index);
927 }
928
929 return Status::OK();
930 }
931
CheckConfigDataType(DataType dtype)932 Status CheckConfigDataType(DataType dtype) {
933 switch (dtype) {
934 case DT_INT64:
935 case DT_FLOAT:
936 case DT_STRING:
937 return Status::OK();
938 default:
939 return errors::InvalidArgument("Invalid config dtype: ",
940 DataTypeString(dtype));
941 }
942 }
943
944 // Use this in the "default" clause of switch statements when dispatching
945 // on a dtype variable that was checked by CheckConfigDataType():
ReportUnexpectedDataType(DataType dtype)946 inline void ReportUnexpectedDataType(DataType dtype) {
947 DCHECK(false)
948 << "Encountered unexpected DataType " << DataTypeString(dtype)
949 << "in variable that should have been checked by CheckConfigDataType().";
950 }
951
CheckConfigDataTypes(const Config & config)952 Status CheckConfigDataTypes(const Config& config) {
953 // Check config so we can safely CHECK(false) in switches on config.*.dtype
954 for (auto& c : config.sparse) {
955 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
956 }
957 for (auto& c : config.dense) {
958 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
959 }
960 for (auto& c : config.ragged) {
961 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
962 if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) {
963 return errors::InvalidArgument("Invalid ragged_split_type: ",
964 DataTypeString(c.splits_dtype));
965 }
966 }
967 return Status::OK();
968 }
969
970 template <typename T>
971 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
972
973 template <>
GetListFromBuffer(const SparseBuffer & buffer)974 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
975 return buffer.int64_list;
976 }
977 template <>
GetListFromBuffer(const SparseBuffer & buffer)978 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
979 return buffer.float_list;
980 }
981 template <>
GetListFromBuffer(const SparseBuffer & buffer)982 const SmallVector<tstring>& GetListFromBuffer<tstring>(
983 const SparseBuffer& buffer) {
984 return buffer.bytes_list;
985 }
986
987 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)988 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
989 std::copy(b, e, t);
990 }
991 template <>
CopyOrMoveBlock(const tstring * b,const tstring * e,tstring * t)992 void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) {
993 std::move(b, e, t);
994 }
995
996 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)997 void FillAndCopyVarLen(
998 const int d, const size_t num_elements,
999 const size_t num_elements_per_minibatch, const Config& config,
1000 const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
1001 Tensor* values) {
1002 const Tensor& default_value = config.dense[d].default_value;
1003
1004 // Copy-fill the tensors (creating the zero/fill-padding)
1005 std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
1006 default_value.flat<T>()(0));
1007
1008 // Data is [batch_size, max_num_elements, data_stride_size]
1009 // and num_elements_per_minibatch = max_num_elements * data_stride_size
1010 auto data = values->flat<T>().data();
1011
1012 // Iterate over minibatch elements
1013 for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
1014 const SparseBuffer& buffer = varlen_dense_buffers[i][d];
1015 // Number of examples being stored in this buffer
1016 const auto& end_indices = buffer.example_end_indices;
1017 const size_t examples_in_buffer = end_indices.size();
1018 // const size_t stride_size = config.dense[d].elements_per_stride;
1019
1020 const auto& list = GetListFromBuffer<T>(buffer);
1021 auto list_ptr = list.begin();
1022
1023 size_t elements_tally = 0;
1024 // Iterate through all the examples stored in this buffer.
1025 for (size_t j = 0; j < examples_in_buffer; ++j) {
1026 // Number of elements stored for this example.
1027 const size_t num_elems = end_indices[j] - elements_tally;
1028 CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
1029 // Move forward this many elements in the varlen buffer.
1030 list_ptr += num_elems;
1031 // Move forward to the next minibatch entry in the values output.
1032 data += num_elements_per_minibatch;
1033 elements_tally = end_indices[j];
1034 }
1035 DCHECK(elements_tally == list.size());
1036 }
1037 }
1038
1039 // Thin vector like interface wrapper around a Tensor. This enable us to
1040 // directly populate a tensor during parsing instead of having to first create a
1041 // vactor and then copy the data over.
1042 template <typename T>
1043 class TensorVector {
1044 public:
1045 using value_type = T;
1046
tensor()1047 const Tensor& tensor() {
1048 if (!tensor_.has_value()) {
1049 resize(0);
1050 }
1051 return *tensor_;
1052 }
1053
size() const1054 int64 size() const {
1055 return tensor_.has_value() ? tensor_->NumElements() : 0;
1056 }
resize(int64_t new_size)1057 void resize(int64_t new_size) {
1058 DCHECK(!tensor_.has_value());
1059 tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size}));
1060 data_ = tensor_->flat<T>().data();
1061 }
data()1062 T* data() { return data_; }
data() const1063 const T* data() const { return data_; }
1064
1065 private:
1066 // Use absl::optional to avoid calling the default constructor of Tensor
1067 // unnecessarily.
1068 absl::optional<Tensor> tensor_;
1069
1070 // Cached pointer to the raw data inside the tensor.
1071 T* data_ = nullptr;
1072 };
1073
CountSparseFeatures(const std::vector<std::vector<SparseBuffer>> & sparse_buffers,size_t d,size_t * total_num_features,size_t * max_num_features)1074 void CountSparseFeatures(
1075 const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d,
1076 size_t* total_num_features, size_t* max_num_features) {
1077 for (auto& sparse_values_tmp : sparse_buffers) {
1078 const std::vector<size_t>& end_indices =
1079 sparse_values_tmp[d].example_end_indices;
1080 *total_num_features += end_indices.back();
1081 *max_num_features = std::max(*max_num_features, end_indices[0]);
1082 for (size_t i = 1; i < end_indices.size(); ++i) {
1083 size_t example_size = end_indices[i] - end_indices[i - 1];
1084 *max_num_features = std::max(*max_num_features, example_size);
1085 }
1086 }
1087 }
1088
CopySparseBufferToTensor(DataType dtype,size_t offset,SparseBuffer * src,Tensor * dst)1089 void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src,
1090 Tensor* dst) {
1091 switch (dtype) {
1092 case DT_INT64: {
1093 std::copy(src->int64_list.begin(), src->int64_list.end(),
1094 dst->flat<int64>().data() + offset);
1095 break;
1096 }
1097 case DT_FLOAT: {
1098 std::copy(src->float_list.begin(), src->float_list.end(),
1099 dst->flat<float>().data() + offset);
1100 break;
1101 }
1102 case DT_STRING: {
1103 std::move(src->bytes_list.begin(), src->bytes_list.end(),
1104 dst->flat<tstring>().data() + offset);
1105 break;
1106 }
1107 default:
1108 ReportUnexpectedDataType(dtype);
1109 }
1110 }
1111
1112 } // namespace
1113
FastParseExample(const Config & config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * result)1114 Status FastParseExample(const Config& config,
1115 gtl::ArraySlice<tstring> serialized,
1116 gtl::ArraySlice<tstring> example_names,
1117 thread::ThreadPool* thread_pool, Result* result) {
1118 DCHECK(result != nullptr);
1119 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1120 TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1121
1122 if (config.collect_feature_stats) {
1123 result->feature_stats.resize(serialized.size());
1124 }
1125
1126 size_t config_size =
1127 config.dense.size() + config.sparse.size() + config.ragged.size();
1128 SeededHasher hasher;
1129 // Build config index.
1130 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1131 bool ok = true;
1132 for (size_t i = 0; i < 1000; ++i) {
1133 for (size_t d = 0; d < config.dense.size(); ++d) {
1134 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1135 {d, Type::Dense});
1136 }
1137 for (size_t d = 0; d < config.sparse.size(); ++d) {
1138 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1139 {d, Type::Sparse});
1140 }
1141 for (size_t d = 0; d < config.ragged.size(); ++d) {
1142 ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1143 {d, Type::Ragged});
1144 }
1145 if (ok) break;
1146 LOG(WARNING) << "Collision found. This should happen only if you have "
1147 "around 2^32 entries in your config.";
1148 hasher.seed++;
1149 config_index.Clear(config_size);
1150 ok = true;
1151 }
1152 if (!ok) {
1153 return errors::Internal(
1154 "Could not avoid collision. This should not happen.");
1155 }
1156
1157 // Allocate dense output for fixed length dense values
1158 // (variable-length dense and sparse and ragged have to be buffered).
1159 std::vector<Tensor> fixed_dense_values(config.dense.size());
1160 for (size_t d = 0; d < config.dense.size(); ++d) {
1161 if (config.dense[d].variable_length) continue;
1162 TensorShape out_shape;
1163 out_shape.AddDim(serialized.size());
1164 for (const int64_t dim : config.dense[d].shape.dim_sizes()) {
1165 out_shape.AddDim(dim);
1166 }
1167 fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1168 }
1169
1170 // This parameter affects performance in a big and data-dependent way.
1171 const size_t kMiniBatchSizeBytes = 50000;
1172
1173 // Calculate number of minibatches.
1174 // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1175 // Apply 'special logic' below for small and big regimes.
1176 const size_t num_minibatches = [&] {
1177 size_t result = 0;
1178 size_t minibatch_bytes = 0;
1179 for (size_t i = 0; i < serialized.size(); i++) {
1180 if (minibatch_bytes == 0) { // start minibatch
1181 result++;
1182 }
1183 minibatch_bytes += serialized[i].size() + 1;
1184 if (minibatch_bytes > kMiniBatchSizeBytes) {
1185 minibatch_bytes = 0;
1186 }
1187 }
1188 // 'special logic'
1189 const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1190 const size_t max_minibatches = 64;
1191 return std::max<size_t>(min_minibatches,
1192 std::min<size_t>(max_minibatches, result));
1193 }();
1194
1195 auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1196 return (serialized.size() * minibatch) / num_minibatches;
1197 };
1198
1199 // TODO(lew): A big performance low-hanging fruit here is to improve
1200 // num_minibatches calculation to take into account actual amount of work
1201 // needed, as the size in bytes is not perfect. Linear combination of
1202 // size in bytes and average number of features per example is promising.
1203 // Even better: measure time instead of estimating, but this is too costly
1204 // in small batches.
1205 // Maybe accept outside parameter #num_minibatches?
1206
1207 // Do minibatches in parallel.
1208 std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1209 std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1210 std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches);
1211 std::vector<Status> status_of_minibatch(num_minibatches);
1212 auto ProcessMiniBatch = [&](size_t minibatch) {
1213 sparse_buffers[minibatch].resize(config.sparse.size());
1214 varlen_dense_buffers[minibatch].resize(config.dense.size());
1215 ragged_buffers[minibatch].resize(config.ragged.size());
1216 size_t start = first_example_of_minibatch(minibatch);
1217 size_t end = first_example_of_minibatch(minibatch + 1);
1218 for (size_t e = start; e < end; ++e) {
1219 PerExampleFeatureStats* stats = nullptr;
1220 if (config.collect_feature_stats) {
1221 stats = &result->feature_stats[e];
1222 }
1223 status_of_minibatch[minibatch] = FastParseSerializedExample(
1224 serialized[e],
1225 (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1226 config_index, hasher, &fixed_dense_values,
1227 &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch],
1228 &ragged_buffers[minibatch], stats);
1229 if (!status_of_minibatch[minibatch].ok()) break;
1230 }
1231 };
1232
1233 ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1234
1235 for (Status& status : status_of_minibatch) {
1236 TF_RETURN_IF_ERROR(status);
1237 }
1238
1239 result->sparse_indices.reserve(config.sparse.size());
1240 result->sparse_values.reserve(config.sparse.size());
1241 result->sparse_shapes.reserve(config.sparse.size());
1242 result->dense_values.reserve(config.dense.size());
1243 result->ragged_values.reserve(config.ragged.size());
1244 result->ragged_splits.reserve(config.ragged.size());
1245
1246 for (size_t d = 0; d < config.dense.size(); ++d) {
1247 result->dense_values.push_back(std::move(fixed_dense_values[d]));
1248 }
1249
1250 // Merge SparseBuffers from all minibatches for every config.sparse.
1251 auto MergeSparseMinibatches = [&](size_t d) {
1252 // Loop over minibatches
1253 size_t total_num_features = 0;
1254 size_t max_num_features = 0;
1255 CountSparseFeatures(sparse_buffers, d, &total_num_features,
1256 &max_num_features);
1257
1258 TensorShape indices_shape;
1259 indices_shape.AddDim(total_num_features);
1260 indices_shape.AddDim(2);
1261 result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1262 Tensor* indices = &result->sparse_indices.back();
1263
1264 TensorShape values_shape;
1265 values_shape.AddDim(total_num_features);
1266 result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1267 Tensor* values = &result->sparse_values.back();
1268
1269 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1270 auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
1271 shapes_shape_t(0) = serialized.size();
1272 shapes_shape_t(1) = max_num_features;
1273
1274 size_t offset = 0;
1275 for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1276 SparseBuffer& buffer = sparse_buffers[i][d];
1277
1278 // Update indices.
1279 size_t delta = 0;
1280
1281 if (indices->NumElements() > 0) {
1282 int64* ix_p = &indices->matrix<int64>()(offset, 0);
1283 size_t example_index = first_example_of_minibatch(i);
1284 for (size_t example_end_index : buffer.example_end_indices) {
1285 size_t feature_index = 0;
1286 for (; delta < example_end_index; ++delta) {
1287 // Column 0: example index
1288 *ix_p = example_index;
1289 // Column 1: the feature index buffer example
1290 *(ix_p + 1) = feature_index;
1291 ix_p += 2;
1292 ++feature_index;
1293 }
1294 ++example_index;
1295 }
1296 }
1297
1298 CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values);
1299 offset += delta;
1300 }
1301 };
1302
1303 // Merge SparseBuffers from all minibatches for every config.ragged.
1304 auto MergeRaggedMinibatches = [&](size_t d) {
1305 // Loop over minibatches
1306 size_t total_num_features = 0;
1307 size_t max_num_features = 0;
1308 CountSparseFeatures(ragged_buffers, d, &total_num_features,
1309 &max_num_features);
1310
1311 TensorShape row_splits_shape;
1312 row_splits_shape.AddDim(serialized.size() + 1);
1313 result->ragged_splits.emplace_back(config.ragged[d].splits_dtype,
1314 row_splits_shape);
1315 Tensor* row_splits = &result->ragged_splits.back();
1316 if (config.ragged[d].splits_dtype == DT_INT64) {
1317 row_splits->flat<int64>()(0) = 0;
1318 } else {
1319 row_splits->flat<int32>()(0) = 0;
1320 }
1321
1322 TensorShape values_shape;
1323 values_shape.AddDim(total_num_features);
1324 result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape);
1325 Tensor* values = &result->ragged_values.back();
1326
1327 size_t values_offset = 0;
1328 size_t splits_offset = 0;
1329 for (size_t i = 0; i < ragged_buffers.size(); ++i) {
1330 SparseBuffer& buffer = ragged_buffers[i][d];
1331 if (buffer.example_end_indices.empty()) continue;
1332
1333 // Update row_splits. row_splits are formed by concatenating the example
1334 // end_indices (adjusting each to start after the previous one ends).
1335 if (config.ragged[d].splits_dtype == DT_INT64) {
1336 int64* row_splits_out = &row_splits->flat<int64>()(splits_offset);
1337 int64_t start = *row_splits_out;
1338 for (size_t example_end_index : buffer.example_end_indices) {
1339 *++row_splits_out = start + example_end_index;
1340 }
1341 } else {
1342 int32* row_splits_out = &row_splits->flat<int32>()(splits_offset);
1343 int32_t start = *row_splits_out;
1344 for (size_t example_end_index : buffer.example_end_indices) {
1345 *++row_splits_out = start + example_end_index;
1346 }
1347 }
1348
1349 CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer,
1350 values);
1351 values_offset += buffer.example_end_indices.back();
1352 splits_offset += buffer.example_end_indices.size();
1353 }
1354 };
1355
1356 // Merge SparseBuffers from all minibatches for every config.dense having
1357 // variable_length.
1358 auto MergeDenseVarLenMinibatches = [&](size_t d) {
1359 if (!config.dense[d].variable_length) return;
1360
1361 // Loop over minibatches
1362 size_t max_num_features = 0;
1363 for (auto& dense_values_tmp : varlen_dense_buffers) {
1364 std::vector<size_t>& end_indices =
1365 dense_values_tmp[d].example_end_indices;
1366 max_num_features = std::max(max_num_features, end_indices[0]);
1367 for (size_t i = 1; i < end_indices.size(); ++i) {
1368 size_t example_size = end_indices[i] - end_indices[i - 1];
1369 max_num_features = std::max(max_num_features, example_size);
1370 }
1371 }
1372
1373 const size_t stride_size = config.dense[d].elements_per_stride;
1374 const size_t max_num_elements = max_num_features / stride_size;
1375 TensorShape values_shape;
1376 DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1377 const size_t batch_size = serialized.size();
1378 values_shape.AddDim(batch_size);
1379 values_shape.AddDim(max_num_elements);
1380 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1381 values_shape.AddDim(config.dense[d].shape.dim_size(i));
1382 }
1383 Tensor values(config.dense[d].dtype, values_shape);
1384 result->dense_values[d] = values;
1385 const size_t num_elements = values.NumElements();
1386
1387 // Nothing to write, exit early.
1388 if (num_elements == 0) return;
1389
1390 const size_t num_elements_per_minibatch = num_elements / batch_size;
1391
1392 switch (config.dense[d].dtype) {
1393 case DT_INT64: {
1394 FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
1395 config, varlen_dense_buffers, &values);
1396 break;
1397 }
1398 case DT_FLOAT: {
1399 FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1400 config, varlen_dense_buffers, &values);
1401 break;
1402 }
1403 case DT_STRING: {
1404 FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch,
1405 config, varlen_dense_buffers, &values);
1406 break;
1407 }
1408 default:
1409 ReportUnexpectedDataType(config.dense[d].dtype);
1410 }
1411 };
1412
1413 for (size_t d = 0; d < config.dense.size(); ++d) {
1414 MergeDenseVarLenMinibatches(d);
1415 }
1416
1417 for (size_t d = 0; d < config.sparse.size(); ++d) {
1418 MergeSparseMinibatches(d);
1419 }
1420
1421 for (size_t d = 0; d < config.ragged.size(); ++d) {
1422 MergeRaggedMinibatches(d);
1423 }
1424
1425 return Status::OK();
1426 }
1427
FastParseSingleExample(const Config & config,StringPiece serialized,Result * result)1428 Status FastParseSingleExample(const Config& config, StringPiece serialized,
1429 Result* result) {
1430 DCHECK(result != nullptr);
1431 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1432 TF_RETURN_IF_ERROR(CheckConfigDataTypes(config));
1433
1434 PerExampleFeatureStats* stats = nullptr;
1435 if (config.collect_feature_stats) {
1436 result->feature_stats.emplace_back();
1437 stats = &result->feature_stats.back();
1438 }
1439
1440 // TODO(mrry): Cache the construction of this map at Op construction time.
1441 size_t config_size =
1442 config.dense.size() + config.sparse.size() + config.ragged.size();
1443 SeededHasher hasher;
1444 // Build config index.
1445 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1446 bool ok = true;
1447 for (size_t i = 0; i < 1000; ++i) {
1448 for (size_t d = 0; d < config.dense.size(); ++d) {
1449 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1450 {d, Type::Dense});
1451 }
1452 for (size_t d = 0; d < config.sparse.size(); ++d) {
1453 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1454 {d, Type::Sparse});
1455 }
1456 for (size_t d = 0; d < config.ragged.size(); ++d) {
1457 ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name),
1458 {d, Type::Ragged});
1459 }
1460 if (ok) break;
1461 LOG(WARNING) << "Collision found. This should happen only if you have "
1462 "around 2^32 entries in your config.";
1463 hasher.seed++;
1464 config_index.Clear(config_size);
1465 ok = true;
1466 }
1467 if (!ok) {
1468 return errors::Internal(
1469 "Could not avoid collision. This should not happen.");
1470 }
1471
1472 result->sparse_indices.reserve(config.sparse.size());
1473 result->sparse_values.reserve(config.sparse.size());
1474 result->sparse_shapes.reserve(config.sparse.size());
1475 result->dense_values.reserve(config.dense.size());
1476 result->ragged_values.reserve(config.ragged.size());
1477 result->ragged_splits.reserve(config.ragged.size());
1478
1479 // Allocate dense output tensors.
1480 for (size_t d = 0; d < config.dense.size(); ++d) {
1481 if (!config.dense[d].variable_length) {
1482 TensorShape values_shape;
1483 if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1484 return errors::Internal(
1485 "Fixed-length shape was not a statically defined shape.");
1486 }
1487 result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1488 } else {
1489 // Variable-length tensor will be allocated later.
1490 result->dense_values.emplace_back();
1491 }
1492 }
1493
1494 // Allocate sparse output tensors.
1495 for (size_t d = 0; d < config.sparse.size(); ++d) {
1496 // The dense_shape is always a vector of length 1.
1497 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1498 // Variable-length tensors will be allocated later.
1499 result->sparse_indices.emplace_back();
1500 result->sparse_values.emplace_back();
1501 }
1502
1503 // Allocate ragged output tensors.
1504 for (size_t d = 0; d < config.ragged.size(); ++d) {
1505 // Variable-length values tensors will be allocated later.
1506 result->ragged_values.emplace_back();
1507 // Splits tensors are empty (unused) for single (scalar) inputs.
1508 const auto splits_dtype = config.ragged[d].splits_dtype;
1509 result->ragged_splits.emplace_back(splits_dtype, TensorShape({0}));
1510 }
1511
1512 parsed::Example parsed_example;
1513 if (!ParseExample(serialized, &parsed_example)) {
1514 return errors::InvalidArgument("Could not parse example input, value: '",
1515 serialized, "'");
1516 }
1517 std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1518 std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1519 std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false);
1520
1521 if (stats) {
1522 // TODO(b/111553342): This may over-count the number of features if there
1523 // are duplicate keys in the feature map. Consider deduplicating the keys
1524 // before computing the count.
1525 stats->features_count = parsed_example.size();
1526 }
1527
1528 // Handle features present in the example.
1529 const size_t parsed_example_size = parsed_example.size();
1530 for (size_t i = 0; i < parsed_example_size; ++i) {
1531 // This is a logic that standard protobuf parsing is implementing.
1532 // I.e. last entry in the map overwrites all the previous ones.
1533 parsed::FeatureMapEntry& name_and_feature =
1534 parsed_example[parsed_example_size - i - 1];
1535
1536 const StringPiece feature_name = name_and_feature.first;
1537 parsed::Feature& feature = name_and_feature.second;
1538
1539 std::pair<size_t, Type> d_and_type;
1540 uint64 h = hasher(feature_name);
1541 if (!config_index.Find(h, &d_and_type)) continue;
1542
1543 size_t d = d_and_type.first;
1544 bool is_dense = d_and_type.second == Type::Dense;
1545 bool is_sparse = d_and_type.second == Type::Sparse;
1546
1547 {
1548 // Testing for PresizedCuckooMap collision.
1549 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1550 const tstring& config_feature_name =
1551 is_dense ? config.dense[d].feature_name
1552 : (is_sparse ? config.sparse[d].feature_name
1553 : config.ragged[d].feature_name);
1554 if (feature_name != config_feature_name) continue;
1555 }
1556
1557 auto example_error = [feature_name](StringPiece suffix) {
1558 return errors::InvalidArgument("Key: ", feature_name, ". ", suffix);
1559 };
1560
1561 auto parse_error = [feature_name] {
1562 return errors::InvalidArgument("Key: ", feature_name,
1563 ". Can't parse serialized Example.");
1564 };
1565
1566 DataType example_dtype;
1567 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1568 if (example_dtype == DT_INVALID) continue;
1569
1570 if (is_dense && !config.dense[d].variable_length) {
1571 // If feature was already visited, skip.
1572 // Compare comment at the beginning of the loop.
1573 if (dense_feature_already_seen[d]) {
1574 LogDenseFeatureDataLoss(feature_name);
1575 continue;
1576 }
1577 dense_feature_already_seen[d] = true;
1578
1579 if (example_dtype != config.dense[d].dtype) {
1580 return example_error(strings::StrCat(
1581 "Data types don't match. Data type: ",
1582 DataTypeString(example_dtype),
1583 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1584 }
1585
1586 Tensor* out = &result->dense_values[d];
1587 const std::size_t num_elements = config.dense[d].elements_per_stride;
1588 if (stats) {
1589 // TODO(b/111553342): If desirable, we could add support for counting
1590 // elements in the features that aren't parsed, but this could add
1591 // considerable runtime cost.
1592 stats->feature_values_count += num_elements;
1593 }
1594 switch (example_dtype) {
1595 case DT_INT64: {
1596 auto out_p = out->flat<int64>().data();
1597 LimitedArraySlice<int64> slice(out_p, num_elements);
1598 if (!feature.ParseInt64List(&slice)) return parse_error();
1599 if (slice.EndDistance() != 0) {
1600 return parse_error();
1601 }
1602 break;
1603 }
1604 case DT_FLOAT: {
1605 auto out_p = out->flat<float>().data();
1606 LimitedArraySlice<float> slice(out_p, num_elements);
1607 if (!feature.ParseFloatList(&slice)) return parse_error();
1608 if (slice.EndDistance() != 0) {
1609 return parse_error();
1610 }
1611 break;
1612 }
1613 case DT_STRING: {
1614 auto out_p = out->flat<tstring>().data();
1615 LimitedArraySlice<tstring> slice(out_p, num_elements);
1616 if (!feature.ParseBytesList(&slice)) return parse_error();
1617 if (slice.EndDistance() != 0) {
1618 return parse_error();
1619 }
1620 break;
1621 }
1622 default:
1623 ReportUnexpectedDataType(example_dtype);
1624 }
1625
1626 } else { // if variable length
1627 SmallVector<tstring> bytes_list;
1628 TensorVector<float> float_list;
1629 SmallVector<int64> int64_list;
1630
1631 const size_t num_elements_divisor =
1632 is_dense ? config.dense[d].elements_per_stride : 1;
1633 size_t num_elements;
1634
1635 if (is_dense) {
1636 // If feature was already visited, skip.
1637 // Compare comment at the beginning of the loop.
1638 if (dense_feature_already_seen[d]) {
1639 LogDenseFeatureDataLoss(feature_name);
1640 continue;
1641 }
1642 dense_feature_already_seen[d] = true;
1643 if (example_dtype != config.dense[d].dtype) {
1644 return example_error(strings::StrCat(
1645 "Data types don't match. Data type: ",
1646 DataTypeString(example_dtype),
1647 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1648 }
1649 } else {
1650 // Feature is sparse or ragged.
1651 auto& feature_already_seen = is_sparse ? sparse_feature_already_seen
1652 : ragged_feature_already_seen;
1653 auto& feature_dtype =
1654 is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype;
1655 // If feature was already visited, skip.
1656 // Compare comment at the beginning of the loop.
1657 if (feature_already_seen[d]) {
1658 LogSparseFeatureDataLoss(feature_name);
1659 continue;
1660 }
1661 feature_already_seen[d] = true;
1662
1663 // Handle sparse features.
1664 if (example_dtype != DT_INVALID && example_dtype != feature_dtype) {
1665 return example_error(strings::StrCat(
1666 "Data types don't match. ",
1667 "Expected type: ", DataTypeString(feature_dtype),
1668 ", Actual type: ", DataTypeString(example_dtype)));
1669 }
1670 }
1671
1672 switch (example_dtype) {
1673 case DT_INT64: {
1674 // TODO(mrry): Use the fact that the `int64_list` is packed to read
1675 // out the length and pre-allocate the output tensor.
1676 if (!feature.ParseInt64List(&int64_list)) return parse_error();
1677 num_elements = int64_list.size();
1678 break;
1679 }
1680 case DT_FLOAT: {
1681 if (!feature.ParseFloatList(&float_list)) return parse_error();
1682 num_elements = float_list.size();
1683 break;
1684 }
1685 case DT_STRING: {
1686 int actual_num_elements = 0;
1687 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1688 return parse_error();
1689 }
1690 bytes_list.reserve(actual_num_elements);
1691 if (!feature.ParseBytesList(&bytes_list)) return parse_error();
1692 num_elements = bytes_list.size();
1693 break;
1694 }
1695 default:
1696 num_elements = 0;
1697 ReportUnexpectedDataType(example_dtype);
1698 }
1699
1700 if (num_elements % num_elements_divisor != 0) {
1701 return parse_error();
1702 }
1703
1704 if (stats) {
1705 stats->feature_values_count += num_elements;
1706 }
1707
1708 Tensor* out;
1709 DataType out_dtype;
1710 TensorShape out_shape;
1711 if (is_dense) {
1712 out_shape.AddDim(num_elements / num_elements_divisor);
1713 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1714 out_shape.AddDim(config.dense[d].shape.dim_size(i));
1715 }
1716
1717 out = &result->dense_values[d];
1718 out_dtype = config.dense[d].dtype;
1719 } else if (is_sparse) {
1720 Tensor* out_indices = &result->sparse_indices[d];
1721 Tensor* out_dense_shape = &result->sparse_shapes[d];
1722
1723 // TODO(mrry): Investigate the possibility of not materializing
1724 // the indices (and perhaps dense_shape) until they are needed.
1725 *out_indices = Tensor(
1726 DT_INT64, TensorShape({static_cast<int64>(num_elements), 1}));
1727 auto indices_flat = out_indices->flat<int64>();
1728 for (size_t i = 0; i < num_elements; ++i) {
1729 indices_flat(i) = static_cast<int64>(i);
1730 }
1731
1732 *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1733 auto shapes_shape_t = out_dense_shape->vec<int64>();
1734 shapes_shape_t(0) = num_elements;
1735
1736 out = &result->sparse_values[d];
1737 out_dtype = config.sparse[d].dtype;
1738 out_shape.AddDim(num_elements);
1739 } else {
1740 out = &result->ragged_values[d];
1741 out_dtype = config.ragged[d].dtype;
1742 out_shape.AddDim(num_elements);
1743 }
1744
1745 switch (example_dtype) {
1746 case DT_INT64: {
1747 *out = Tensor(out_dtype, out_shape);
1748 CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1749 out->flat<int64>().data());
1750 break;
1751 }
1752 case DT_FLOAT: {
1753 if (!out->CopyFrom(float_list.tensor(), out_shape)) {
1754 return parse_error();
1755 }
1756 break;
1757 }
1758 case DT_STRING: {
1759 *out = Tensor(out_dtype, out_shape);
1760 CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(),
1761 out->flat<tstring>().data());
1762 break;
1763 }
1764 default:
1765 ReportUnexpectedDataType(example_dtype);
1766 }
1767 }
1768 }
1769
1770 // Handle missing dense features.
1771 for (size_t d = 0; d < config.dense.size(); ++d) {
1772 if (!dense_feature_already_seen[d]) {
1773 if (!config.dense[d].variable_length) {
1774 // Handle missing fixed-length dense feature.
1775 if (config.dense[d].default_value.NumElements() == 0) {
1776 return errors::InvalidArgument(
1777 "Feature: ", config.dense[d].feature_name,
1778 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1779 " is required but could not be found.");
1780 }
1781 result->dense_values[d] = config.dense[d].default_value;
1782 } else {
1783 // Handle missing varlen dense feature.
1784 TensorShape empty_shape;
1785 empty_shape.AddDim(0);
1786 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1787 empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1788 }
1789 result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1790 }
1791 }
1792 }
1793
1794 // Handle missing sparse features.
1795 for (size_t d = 0; d < config.sparse.size(); ++d) {
1796 if (!sparse_feature_already_seen[d]) {
1797 result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1798 result->sparse_values[d] =
1799 Tensor(config.sparse[d].dtype, TensorShape({0}));
1800 result->sparse_shapes[d].vec<int64>()(0) = 0;
1801 }
1802 }
1803
1804 // Handle missing ragged features.
1805 for (size_t d = 0; d < config.ragged.size(); ++d) {
1806 if (!ragged_feature_already_seen[d]) {
1807 result->ragged_values[d] =
1808 Tensor(config.ragged[d].dtype, TensorShape({0}));
1809 }
1810 }
1811
1812 return Status::OK();
1813 }
1814
1815 // Private helper functions for FastParseSequenceExample.
1816 namespace {
1817
1818 // A struct used by FastParseSequenceExample to hold the serialized proto
1819 // substrings for a single feature, plus some auxiliary information derived
1820 // from those protos (such as the total value length).
1821 struct FeatureProtos {
1822 // Proto substrings from each serialized SequenceExample that correspond
1823 // with this feature. `protos_present` records whether the proto had a
1824 // value defined (even if that value is empty).
1825 std::vector<StringPiece> protos;
1826 std::vector<bool> protos_present;
1827
1828 // Information derived from protos:
1829 size_t length; // total length for ragged/sparse, max row length for dense.
1830 size_t num_rows; // only populated for ragged sequence features.
1831
1832 // Information from the config:
1833 Type type; // Whether this feature is sparse, ragged, or dense.
1834 DataType dtype;
1835 };
1836
1837 // Map from feature name to FeatureProtos for that feature.
1838 using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>;
1839
ExampleName(const gtl::ArraySlice<tstring> example_names,int n)1840 string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) {
1841 return example_names.empty() ? "<unknown>" : example_names[n];
1842 }
1843
1844 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1845 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,tstring * out)1846 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1847 tstring* out) {
1848 int num_elements = 0;
1849 uint32 length;
1850 if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1851 return -1;
1852 }
1853 if (length > 0) {
1854 auto limit = stream->PushLimit(length);
1855 while (!stream->ExpectAtEnd()) {
1856 uint32 bytes_length;
1857 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1858 !stream->ReadVarint32(&bytes_length)) {
1859 return -1;
1860 }
1861 if (out == nullptr) {
1862 stream->Skip(bytes_length);
1863 } else {
1864 out->resize_uninitialized(bytes_length);
1865 if (!stream->ReadRaw(out->data(), bytes_length)) {
1866 return -1;
1867 }
1868 out++;
1869 }
1870 num_elements++;
1871 }
1872 stream->PopLimit(limit);
1873 }
1874 return num_elements;
1875 }
1876
PadFloatFeature(int num_to_pad,float * out)1877 inline void PadFloatFeature(int num_to_pad, float* out) {
1878 for (int i = 0; i < num_to_pad; i++) {
1879 *out++ = 0.0;
1880 }
1881 }
1882
PadInt64Feature(int num_to_pad,int64 * out)1883 inline void PadInt64Feature(int num_to_pad, int64* out) {
1884 for (int i = 0; i < num_to_pad; i++) {
1885 *out++ = 0;
1886 }
1887 }
1888
1889 // Return the number of float elements parsed, or -1 on error. If out is null,
1890 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1891 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1892 float* out) {
1893 int num_elements = 0;
1894 uint32 length;
1895 if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1896 return -1;
1897 }
1898 if (length > 0) {
1899 auto limit = stream->PushLimit(length);
1900 uint8 peek_tag = PeekTag(stream);
1901 if (peek_tag == kDelimitedTag(1)) { // packed
1902 uint32 packed_length;
1903 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1904 !stream->ReadVarint32(&packed_length)) {
1905 return -1;
1906 }
1907 auto packed_limit = stream->PushLimit(packed_length);
1908 while (!stream->ExpectAtEnd()) {
1909 uint32 buffer32;
1910 if (!stream->ReadLittleEndian32(&buffer32)) {
1911 return -1;
1912 }
1913 if (out != nullptr) {
1914 *out++ = absl::bit_cast<float>(buffer32);
1915 }
1916 num_elements++;
1917 }
1918 stream->PopLimit(packed_limit);
1919 } else if (peek_tag == kFixed32Tag(1)) {
1920 while (!stream->ExpectAtEnd()) {
1921 uint32 buffer32;
1922 if (!stream->ExpectTag(kFixed32Tag(1)) ||
1923 !stream->ReadLittleEndian32(&buffer32)) {
1924 return -1;
1925 }
1926 if (out != nullptr) {
1927 *out++ = absl::bit_cast<float>(buffer32);
1928 }
1929 num_elements++;
1930 }
1931 } else {
1932 // Unknown tag.
1933 return -1;
1934 }
1935 stream->PopLimit(limit);
1936 }
1937 return num_elements;
1938 }
1939
1940 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1941 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)1942 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1943 int64* out) {
1944 int num_elements = 0;
1945 uint32 length;
1946 if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1947 return -1;
1948 }
1949 if (length > 0) {
1950 auto limit = stream->PushLimit(length);
1951 uint8 peek_tag = PeekTag(stream);
1952 if (peek_tag == kDelimitedTag(1)) { // packed
1953 uint32 packed_length;
1954 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1955 !stream->ReadVarint32(&packed_length)) {
1956 return -1;
1957 }
1958 auto packed_limit = stream->PushLimit(packed_length);
1959 while (!stream->ExpectAtEnd()) {
1960 protobuf_uint64 n; // There is no API for int64
1961 if (!stream->ReadVarint64(&n)) {
1962 return -1;
1963 }
1964 if (out != nullptr) {
1965 *out++ = n;
1966 }
1967 num_elements++;
1968 }
1969 stream->PopLimit(packed_limit);
1970 } else if (peek_tag == kVarintTag(1)) {
1971 while (!stream->ExpectAtEnd()) {
1972 protobuf_uint64 n; // There is no API for int64
1973 if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1974 return -1;
1975 }
1976 if (out != nullptr) {
1977 *out++ = n;
1978 }
1979 num_elements++;
1980 }
1981 } else {
1982 // Unknown tag.
1983 return -1;
1984 }
1985 stream->PopLimit(limit);
1986 }
1987 return num_elements;
1988 }
1989
1990 // Parses the next feature on `stream` into `out` starting at `out_offset`.
1991 // Updates `out_offset`, and returns the number of values added.
1992 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
ParseFeature(DataType dtype,protobuf::io::CodedInputStream * stream,Tensor * out,size_t * out_offset)1993 inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream,
1994 Tensor* out, size_t* out_offset) {
1995 int delta;
1996 switch (dtype) {
1997 case DT_STRING:
1998 delta =
1999 ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset);
2000 break;
2001 case DT_FLOAT:
2002 delta =
2003 ParseFloatFeature(stream, out->flat<float>().data() + *out_offset);
2004 break;
2005 case DT_INT64:
2006 delta =
2007 ParseInt64Feature(stream, out->flat<int64>().data() + *out_offset);
2008 break;
2009 default:
2010 ReportUnexpectedDataType(dtype);
2011 delta = 0;
2012 }
2013 if (delta > 0) {
2014 *out_offset += delta;
2015 }
2016 return delta;
2017 }
2018
2019 // Returns the length of the next feature on `stream`.
2020 // Returns -1 if the next feature on `stream` doesn't match `dtype`.
GetFeatureLength(DataType dtype,protobuf::io::CodedInputStream * stream)2021 inline int GetFeatureLength(DataType dtype,
2022 protobuf::io::CodedInputStream* stream) {
2023 switch (dtype) {
2024 case DT_STRING:
2025 return ParseBytesFeature(stream, nullptr);
2026 case DT_FLOAT:
2027 return ParseFloatFeature(stream, nullptr);
2028 case DT_INT64:
2029 return ParseInt64Feature(stream, nullptr);
2030 default:
2031 ReportUnexpectedDataType(dtype);
2032 return -1;
2033 }
2034 }
2035
ParseDataType(protobuf::io::CodedInputStream * stream)2036 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
2037 uint8 peek_tag = PeekTag(stream);
2038 switch (peek_tag) {
2039 case kDelimitedTag(1):
2040 return DT_STRING;
2041 case kDelimitedTag(2):
2042 return DT_FLOAT;
2043 case kDelimitedTag(3):
2044 return DT_INT64;
2045 default:
2046 return DT_INVALID;
2047 }
2048 }
2049
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)2050 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
2051 DataType dtype) {
2052 switch (dtype) {
2053 case DT_STRING:
2054 if (!stream->ExpectTag(kDelimitedTag(1))) {
2055 return false;
2056 }
2057 break;
2058 case DT_FLOAT:
2059 if (!stream->ExpectTag(kDelimitedTag(2))) {
2060 return false;
2061 }
2062 break;
2063 case DT_INT64:
2064 if (!stream->ExpectTag(kDelimitedTag(3))) {
2065 return false;
2066 }
2067 break;
2068 default:
2069 return false;
2070 }
2071 uint32 length;
2072 return stream->ReadVarint32(&length) && length == 0;
2073 }
2074
2075 // Reads an example proto, and extracts a StringPiece pointer to each feature.
ExtractFeaturesFromSequenceExamples(const gtl::ArraySlice<tstring> examples,const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features,FeatureProtosMap * sequence_features)2076 Status ExtractFeaturesFromSequenceExamples(
2077 const gtl::ArraySlice<tstring> examples,
2078 const gtl::ArraySlice<tstring> example_names,
2079 FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) {
2080 for (int d = 0; d < examples.size(); d++) {
2081 const tstring& example = examples[d];
2082 protobuf::io::CodedInputStream stream(
2083 reinterpret_cast<const uint8*>(example.data()), example.size());
2084 // Not clear what this does. Why not stream.EnableAliasing()?
2085 EnableAliasing(&stream);
2086
2087 // Extract pointers to all features within this serialized example.
2088 while (!stream.ExpectAtEnd()) {
2089 FeatureProtosMap* features = nullptr;
2090 if (stream.ExpectTag(kDelimitedTag(1))) {
2091 // Context
2092 features = context_features;
2093 } else if (stream.ExpectTag(kDelimitedTag(2))) {
2094 // Sequence
2095 features = sequence_features;
2096 } else if (!SkipExtraneousTag(&stream)) {
2097 return errors::InvalidArgument(
2098 "Invalid protocol message input, example id: ",
2099 ExampleName(example_names, d));
2100 }
2101 if (features != nullptr) {
2102 uint32 length;
2103 if (!stream.ReadVarint32(&length)) {
2104 return errors::InvalidArgument(
2105 "Invalid protocol message input, example id: ",
2106 ExampleName(example_names, d));
2107 }
2108 auto limit = stream.PushLimit(length);
2109 while (!stream.ExpectAtEnd()) {
2110 StringPiece key, value;
2111 uint32 length;
2112 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2113 !stream.ReadVarint32(&length)) {
2114 return errors::InvalidArgument(
2115 "Invalid protocol message input, example id: ",
2116 ExampleName(example_names, d));
2117 }
2118 auto limit = stream.PushLimit(length);
2119 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2120 !ParseString(&stream, &key) ||
2121 !stream.ExpectTag(kDelimitedTag(2)) ||
2122 !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
2123 return errors::InvalidArgument(
2124 "Invalid protocol message input, example id: ",
2125 ExampleName(example_names, d));
2126 }
2127 stream.PopLimit(limit);
2128 // Only save if this feature was requested.
2129 auto feature_iter = features->find(key);
2130 if (feature_iter != features->end()) {
2131 auto& feature = feature_iter->second;
2132 feature.protos[d] = value;
2133 feature.protos_present[d] = true;
2134 }
2135 }
2136 stream.PopLimit(limit);
2137 }
2138 }
2139 }
2140 return Status::OK();
2141 }
2142
2143 // Populates context_features[k].length based on context_features[k].protos
2144 // (for all k).
GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * context_features)2145 Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2146 FeatureProtosMap* context_features) {
2147 for (auto& c : *context_features) {
2148 FeatureProtos& feature = c.second;
2149 for (int d = 0; d < feature.protos.size(); ++d) {
2150 const auto& proto = feature.protos[d];
2151 if (proto.empty()) continue;
2152 protobuf::io::CodedInputStream stream(
2153 reinterpret_cast<const uint8*>(proto.data()), proto.size());
2154 EnableAliasing(&stream);
2155 int num_elements = GetFeatureLength(feature.dtype, &stream);
2156 if (num_elements < 0) {
2157 return errors::InvalidArgument(
2158 "Name: ", ExampleName(example_names, d),
2159 ", Context feature: ", c.first,
2160 ". Data types don't match. Expected type: ",
2161 DataTypeString(feature.dtype));
2162 }
2163 switch (feature.type) {
2164 case Type::Sparse: // intentional fall-through
2165 case Type::Ragged:
2166 feature.length += num_elements;
2167 break;
2168 case Type::Dense:
2169 feature.length =
2170 std::max(feature.length, static_cast<size_t>(num_elements));
2171 break;
2172 }
2173 }
2174 }
2175 return Status::OK();
2176 }
2177
2178 // Populates sequence_features[k].length and sequence_features[k].num_rows based
2179 // on sequence_features[k].protos (for all k).
GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,FeatureProtosMap * sequence_features)2180 Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names,
2181 FeatureProtosMap* sequence_features) {
2182 for (auto& c : *sequence_features) {
2183 FeatureProtos& feature = c.second;
2184 for (int d = 0; d < feature.protos.size(); ++d) {
2185 const auto& proto = feature.protos[d];
2186 if (proto.empty()) continue;
2187
2188 size_t num_rows = 0;
2189 size_t num_elements = 0;
2190 protobuf::io::CodedInputStream stream(
2191 reinterpret_cast<const uint8*>(proto.data()), proto.size());
2192 EnableAliasing(&stream);
2193 while (!stream.ExpectAtEnd()) {
2194 uint32 feature_bytes;
2195 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2196 !stream.ReadVarint32(&feature_bytes)) {
2197 return errors::InvalidArgument("Error in sequence feature ", c.first,
2198 " in example ",
2199 ExampleName(example_names, d));
2200 }
2201 if (feature_bytes > 2) {
2202 auto limit = stream.PushLimit(feature_bytes);
2203 int delta = GetFeatureLength(feature.dtype, &stream);
2204 if (delta < 0) {
2205 return errors::InvalidArgument(
2206 "Name: ", ExampleName(example_names, d),
2207 ", Feature list: ", c.first, ", Index: ", num_rows,
2208 ". Data types don't match. Expected type: ",
2209 DataTypeString(feature.dtype));
2210 }
2211 num_elements += delta;
2212 stream.PopLimit(limit);
2213 } else if (feature_bytes == 2) {
2214 if (!SkipEmptyFeature(&stream, feature.dtype)) {
2215 return errors::InvalidArgument(
2216 "Name: ", ExampleName(example_names, d),
2217 ", Feature list: ", c.first, ", Index: ", num_rows,
2218 ". Data types don't match. Expected type: ",
2219 DataTypeString(feature.dtype));
2220 }
2221 } else if (feature_bytes != 0) {
2222 return errors::InvalidArgument("Error in sequence feature ", c.first,
2223 " in example ",
2224 ExampleName(example_names, d));
2225 }
2226 ++num_rows;
2227 }
2228 switch (feature.type) {
2229 case Type::Sparse:
2230 feature.length += num_elements;
2231 break;
2232 case Type::Ragged:
2233 feature.length += num_elements;
2234 feature.num_rows += num_rows;
2235 break;
2236 case Type::Dense:
2237 feature.length = std::max(feature.length, num_elements);
2238 break;
2239 }
2240 }
2241 }
2242 return Status::OK();
2243 }
2244
2245 // Copies src into dst[dst_offset:dst_offset+src.size], and then increments
2246 // dst_offset by src.size.
CopyTensorIntoTensor(DataType dtype,const Tensor & src,Tensor * dst,size_t * dst_offset)2247 void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst,
2248 size_t* dst_offset) {
2249 size_t src_size = src.NumElements();
2250 switch (dtype) {
2251 case DT_INT64: {
2252 auto src_t = src.flat<int64>().data();
2253 std::copy(src_t, src_t + src_size,
2254 dst->flat<int64>().data() + *dst_offset);
2255 break;
2256 }
2257 case DT_FLOAT: {
2258 auto src_t = src.flat<float>().data();
2259 std::copy(src_t, src_t + src_size,
2260 dst->flat<float>().data() + *dst_offset);
2261 break;
2262 }
2263 case DT_STRING: {
2264 auto src_t = src.flat<tstring>().data();
2265 std::copy(src_t, src_t + src_size,
2266 dst->flat<tstring>().data() + *dst_offset);
2267 break;
2268 }
2269 default:
2270 ReportUnexpectedDataType(dtype);
2271 }
2272 *dst_offset += src_size;
2273 }
2274
2275 // Parses dense features in `context_features`, and writes their parsed
2276 // values to `context_results`.
ParseContextDenseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2277 Status ParseContextDenseFeatures(const FeatureProtosMap& context_features,
2278 const FastParseExampleConfig& context_config,
2279 gtl::ArraySlice<tstring> example_names,
2280 bool is_batch, int num_examples,
2281 Allocator* allocator, Result* context_result) {
2282 for (int t = 0; t < context_config.dense.size(); ++t) {
2283 const auto& c = context_config.dense[t];
2284 const FeatureProtos& feature =
2285 context_features.find(c.feature_name)->second;
2286 TensorShape dense_shape, example_shape;
2287 DataType dtype = c.dtype;
2288 const size_t data_max_elements = feature.length;
2289 if (!c.shape.AsTensorShape(&example_shape) ||
2290 data_max_elements != example_shape.num_elements()) {
2291 return errors::InvalidArgument(
2292 "Inconsistent max number of elements for feature ", c.feature_name,
2293 ": expected ", example_shape.num_elements(), ", but found ",
2294 data_max_elements);
2295 }
2296 if (is_batch) {
2297 dense_shape.AddDim(num_examples);
2298 }
2299 for (const int dim : c.shape.dim_sizes()) {
2300 dense_shape.AddDim(dim);
2301 }
2302 context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2303
2304 Tensor& out = context_result->dense_values[t];
2305 size_t out_offset = 0;
2306
2307 // Fill in the values.
2308 for (int e = 0; e < num_examples; e++) {
2309 size_t num_elements = 0;
2310 const auto& feature_proto = feature.protos[e];
2311 if (!feature.protos_present[e]) {
2312 // Copy the default value, if present. If not, return an error.
2313 if (c.default_value.NumElements() == 0) {
2314 return errors::InvalidArgument(
2315 "Feature: ", c.feature_name,
2316 " (data type: ", DataTypeString(c.dtype), ")",
2317 " is required but could not be found.");
2318 }
2319 CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset);
2320 num_elements += c.default_value.NumElements();
2321 } else if (!feature_proto.empty()) {
2322 protobuf::io::CodedInputStream stream(
2323 reinterpret_cast<const uint8*>(feature_proto.data()),
2324 feature_proto.size());
2325 EnableAliasing(&stream);
2326 num_elements += ParseFeature(dtype, &stream, &out, &out_offset);
2327 }
2328 if (num_elements != data_max_elements) {
2329 return errors::InvalidArgument(
2330 "Unexpected number of elements in example ",
2331 ExampleName(example_names, e));
2332 }
2333 }
2334 }
2335 return Status::OK();
2336 }
2337
2338 // Parses sparse features in `context_features`, and writes their parsed
2339 // values to `context_results`.
ParseContextSparseFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2340 Status ParseContextSparseFeatures(const FeatureProtosMap& context_features,
2341 const FastParseExampleConfig& context_config,
2342 gtl::ArraySlice<tstring> example_names,
2343 bool is_batch, int num_examples,
2344 Allocator* allocator,
2345 Result* context_result) {
2346 for (int t = 0; t < context_config.sparse.size(); ++t) {
2347 const auto& c = context_config.sparse[t];
2348 const FeatureProtos& feature =
2349 context_features.find(c.feature_name)->second;
2350 TensorShape indices_shape, values_shape;
2351 DataType dtype = c.dtype;
2352 size_t expected_num_elements = feature.length;
2353 indices_shape.AddDim(expected_num_elements);
2354 indices_shape.AddDim(is_batch ? 2 : 1);
2355 values_shape.AddDim(expected_num_elements);
2356 context_result->sparse_indices[t] =
2357 Tensor(allocator, DT_INT64, indices_shape);
2358 context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2359 context_result->sparse_shapes[t] =
2360 Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1}));
2361 Tensor& out_values = context_result->sparse_values[t];
2362 size_t out_values_offset = 0;
2363 int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
2364 auto out_shape = context_result->sparse_shapes[t].vec<int64>();
2365
2366 // Fill in the values.
2367 size_t num_elements = 0;
2368 size_t max_num_cols = 0;
2369 for (int e = 0; e < num_examples; e++) {
2370 const auto& feature_proto = feature.protos[e];
2371 if (feature_proto.empty()) continue;
2372 protobuf::io::CodedInputStream stream(
2373 reinterpret_cast<const uint8*>(feature_proto.data()),
2374 feature_proto.size());
2375 EnableAliasing(&stream);
2376 size_t num_added =
2377 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2378 num_elements += num_added;
2379 max_num_cols = std::max(max_num_cols, num_added);
2380 for (int i = 0; i < num_added; i++) {
2381 if (is_batch) *out_indices++ = e;
2382 *out_indices++ = i;
2383 }
2384 }
2385 if (num_elements != expected_num_elements) {
2386 return errors::InvalidArgument(
2387 "Unexpected total number of elements in feature ", c.feature_name);
2388 }
2389 if (is_batch) {
2390 out_shape(0) = num_examples;
2391 out_shape(1) = max_num_cols;
2392 } else {
2393 out_shape(0) = max_num_cols;
2394 }
2395 }
2396 return Status::OK();
2397 }
2398
2399 // Parses ragged features in `context_features`, and writes their parsed
2400 // values to `context_results`.
ParseContextRaggedFeatures(const FeatureProtosMap & context_features,const FastParseExampleConfig & context_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * context_result)2401 Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features,
2402 const FastParseExampleConfig& context_config,
2403 gtl::ArraySlice<tstring> example_names,
2404 bool is_batch, int num_examples,
2405 Allocator* allocator,
2406 Result* context_result) {
2407 for (int t = 0; t < context_config.ragged.size(); ++t) {
2408 const auto& c = context_config.ragged[t];
2409 const FeatureProtos& feature =
2410 context_features.find(c.feature_name)->second;
2411 TensorShape values_shape, splits_shape;
2412 DataType dtype = c.dtype;
2413 DataType splits_dtype = c.splits_dtype;
2414 size_t expected_num_elements = feature.length;
2415 values_shape.AddDim(expected_num_elements);
2416 if (is_batch) {
2417 splits_shape.AddDim(num_examples + 1);
2418 }
2419 context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2420 context_result->ragged_splits[t] =
2421 Tensor(allocator, splits_dtype, splits_shape);
2422 Tensor& out_values = context_result->ragged_values[t];
2423 size_t out_values_offset = 0;
2424 int32* int32_splits =
2425 is_batch && splits_dtype == DT_INT32
2426 ? context_result->ragged_splits[t].vec<int32>().data()
2427 : nullptr;
2428 int64* int64_splits =
2429 is_batch && splits_dtype == DT_INT64
2430 ? context_result->ragged_splits[t].vec<int64>().data()
2431 : nullptr;
2432 if (int32_splits) {
2433 *int32_splits++ = 0;
2434 } else if (int64_splits) {
2435 *int64_splits++ = 0;
2436 }
2437
2438 // Fill in the values.
2439 size_t split = 0; // = total number of elements we've seen so far
2440 for (int e = 0; e < num_examples; e++) {
2441 const auto& feature_proto = feature.protos[e];
2442 if (!feature_proto.empty()) {
2443 protobuf::io::CodedInputStream stream(
2444 reinterpret_cast<const uint8*>(feature_proto.data()),
2445 feature_proto.size());
2446 EnableAliasing(&stream);
2447 size_t num_added =
2448 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2449 split += num_added;
2450 }
2451 if (int32_splits) {
2452 *int32_splits++ = split;
2453 } else if (int64_splits) {
2454 *int64_splits++ = split;
2455 }
2456 }
2457 if (split != expected_num_elements) {
2458 return errors::InvalidArgument(
2459 "Unexpected total number of elements in feature ", c.feature_name);
2460 }
2461 if (int32_splits || int64_splits) {
2462 int actual_splits =
2463 int32_splits
2464 ? int32_splits -
2465 context_result->ragged_splits[t].vec<int32>().data()
2466 : int64_splits -
2467 context_result->ragged_splits[t].vec<int64>().data();
2468 if (actual_splits != num_examples + 1) {
2469 return errors::InvalidArgument(
2470 "Unexpected number of examples for feature ", c.feature_name);
2471 }
2472 }
2473 }
2474 return Status::OK();
2475 }
2476
2477 // Parses dense features in `sequence_features`, and writes their parsed
2478 // values to `sequence_result`.
ParseSequenceDenseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths)2479 Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features,
2480 const FastParseExampleConfig& sequence_config,
2481 gtl::ArraySlice<tstring> example_names,
2482 bool is_batch, int num_examples,
2483 Allocator* allocator, Result* sequence_result,
2484 std::vector<Tensor>* dense_feature_lengths) {
2485 TensorShape dense_length_shape;
2486 if (is_batch) {
2487 dense_length_shape.AddDim(num_examples);
2488 }
2489 for (int t = 0; t < sequence_config.dense.size(); ++t) {
2490 const auto& c = sequence_config.dense[t];
2491 const FeatureProtos& feature =
2492 sequence_features.find(c.feature_name)->second;
2493 TensorShape dense_shape, row_shape;
2494 DataType dtype = c.dtype;
2495 const size_t expected_max_elements = feature.length;
2496 if (!c.shape.AsTensorShape(&row_shape) ||
2497 expected_max_elements !=
2498 (expected_max_elements / row_shape.num_elements()) *
2499 row_shape.num_elements()) {
2500 PartialTensorShape total_shape = row_shape;
2501 total_shape.InsertDim(0, -1);
2502 return errors::InvalidArgument(
2503 "Feature list '", c.feature_name,
2504 "' has an unexpected number of values. Total values size: ",
2505 expected_max_elements,
2506 " is not consistent with output shape: ", total_shape.DebugString());
2507 }
2508 int64_t expected_max_rows =
2509 expected_max_elements / row_shape.num_elements();
2510 if (is_batch) {
2511 dense_shape.AddDim(num_examples);
2512 }
2513 dense_shape.AddDim(expected_max_rows);
2514 for (const int dim : sequence_config.dense[t].shape.dim_sizes()) {
2515 dense_shape.AddDim(dim);
2516 }
2517 sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape);
2518 (*dense_feature_lengths)[t] =
2519 Tensor(allocator, DT_INT64, dense_length_shape);
2520 int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
2521
2522 tstring* out_bytes = nullptr;
2523 float* out_float = nullptr;
2524 int64* out_int64 = nullptr;
2525 switch (dtype) {
2526 case DT_STRING:
2527 out_bytes = sequence_result->dense_values[t].flat<tstring>().data();
2528 break;
2529 case DT_FLOAT:
2530 out_float = sequence_result->dense_values[t].flat<float>().data();
2531 break;
2532 case DT_INT64:
2533 out_int64 = sequence_result->dense_values[t].flat<int64>().data();
2534 break;
2535 default:
2536 ReportUnexpectedDataType(dtype);
2537 }
2538
2539 // Fill in the values.
2540 for (int e = 0; e < num_examples; e++) {
2541 size_t num_elements = 0, num_rows = 0;
2542 const auto& feature_proto = feature.protos[e];
2543 if (!feature.protos_present[e]) {
2544 // Return an error if this feature was not allowed to be missing.
2545 // Otherwise, we'll pad as needed below.
2546 if (!c.variable_length) {
2547 return errors::InvalidArgument(
2548 "Name: ", ExampleName(example_names, e), ", Feature list '",
2549 c.feature_name,
2550 "' is required but could not be found. "
2551 "Did you mean to include it in "
2552 "feature_list_dense_missing_assumed_empty or "
2553 "feature_list_dense_defaults?");
2554 }
2555 } else if (!feature_proto.empty()) {
2556 protobuf::io::CodedInputStream stream(
2557 reinterpret_cast<const uint8*>(feature_proto.data()),
2558 feature_proto.size());
2559 EnableAliasing(&stream);
2560 while (!stream.ExpectAtEnd()) {
2561 uint32 feature_length;
2562 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2563 !stream.ReadVarint32(&feature_length)) {
2564 return errors::InvalidArgument("Error in sequence feature ",
2565 c.feature_name, " in example ",
2566 ExampleName(example_names, e));
2567 }
2568 auto limit = stream.PushLimit(feature_length);
2569 int num_added = 0;
2570 if (feature_length > 2) {
2571 switch (dtype) {
2572 case DT_STRING:
2573 num_added = ParseBytesFeature(&stream, out_bytes);
2574 out_bytes += num_added;
2575 break;
2576 case DT_FLOAT:
2577 num_added = ParseFloatFeature(&stream, out_float);
2578 out_float += num_added;
2579 break;
2580 case DT_INT64:
2581 num_added = ParseInt64Feature(&stream, out_int64);
2582 out_int64 += num_added;
2583 break;
2584 default:
2585 ReportUnexpectedDataType(dtype);
2586 num_added = 0;
2587 }
2588 if (num_added < 0) {
2589 // This should be unreachable -- we already scanned the feature in
2590 // GetSequenceFeatureLengths, and it hasn't changed since then.
2591 return errors::InvalidArgument("Error in sequence feature ",
2592 c.feature_name, " in example ",
2593 ExampleName(example_names, e));
2594 }
2595 }
2596 if (num_added != row_shape.num_elements()) {
2597 return errors::InvalidArgument(
2598 "Name: ", ExampleName(example_names, e),
2599 ", Key: ", c.feature_name, ", Index: ", num_rows,
2600 ". Number of values != expected. values size: ", num_added,
2601 " but output shape: ", row_shape.DebugString());
2602 }
2603 num_elements += num_added;
2604 num_rows++;
2605 stream.PopLimit(limit);
2606 }
2607 }
2608 *out_lengths++ = num_rows;
2609 // Pad as necessary.
2610 int num_to_pad = expected_max_elements - num_elements;
2611 switch (dtype) {
2612 case DT_STRING:
2613 out_bytes += num_to_pad;
2614 break;
2615 case DT_FLOAT:
2616 PadFloatFeature(num_to_pad, out_float);
2617 out_float += num_to_pad;
2618 break;
2619 case DT_INT64:
2620 PadInt64Feature(num_to_pad, out_int64);
2621 out_int64 += num_to_pad;
2622 break;
2623 default:
2624 ReportUnexpectedDataType(dtype);
2625 }
2626 }
2627 }
2628 return Status::OK();
2629 }
2630
2631 // Parses sparse features in `sequence_features`, and writes their parsed
2632 // values to `sequence_result`.
ParseSequenceSparseFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2633 Status ParseSequenceSparseFeatures(
2634 const FeatureProtosMap& sequence_features,
2635 const FastParseExampleConfig& sequence_config,
2636 gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2637 Allocator* allocator, Result* sequence_result) {
2638 for (int t = 0; t < sequence_config.sparse.size(); ++t) {
2639 const auto& c = sequence_config.sparse[t];
2640 const FeatureProtos& feature =
2641 sequence_features.find(c.feature_name)->second;
2642 TensorShape indices_shape, values_shape;
2643 DataType dtype = c.dtype;
2644 size_t expected_num_elements = feature.length;
2645 indices_shape.AddDim(expected_num_elements);
2646 indices_shape.AddDim(is_batch ? 3 : 2);
2647 values_shape.AddDim(expected_num_elements);
2648 sequence_result->sparse_indices[t] =
2649 Tensor(allocator, DT_INT64, indices_shape);
2650 sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape);
2651 sequence_result->sparse_shapes[t] =
2652 Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2}));
2653
2654 tstring* out_bytes = nullptr;
2655 float* out_float = nullptr;
2656 int64* out_int64 = nullptr;
2657 switch (dtype) {
2658 case DT_STRING:
2659 out_bytes = sequence_result->sparse_values[t].flat<tstring>().data();
2660 break;
2661 case DT_FLOAT:
2662 out_float = sequence_result->sparse_values[t].flat<float>().data();
2663 break;
2664 case DT_INT64:
2665 out_int64 = sequence_result->sparse_values[t].flat<int64>().data();
2666 break;
2667 default:
2668 ReportUnexpectedDataType(dtype);
2669 }
2670 int64* out_indices =
2671 sequence_result->sparse_indices[t].flat<int64>().data();
2672 auto out_shape = sequence_result->sparse_shapes[t].vec<int64>();
2673
2674 // Fill in the values.
2675 size_t num_elements = 0;
2676 size_t max_num_rows = 0;
2677 size_t max_num_cols = 0;
2678 for (int e = 0; e < num_examples; e++) {
2679 const auto& feature_proto = feature.protos[e];
2680 if (feature_proto.empty()) continue;
2681 protobuf::io::CodedInputStream stream(
2682 reinterpret_cast<const uint8*>(feature_proto.data()),
2683 feature_proto.size());
2684 EnableAliasing(&stream);
2685 size_t num_rows = 0;
2686 while (!stream.ExpectAtEnd()) {
2687 uint32 feature_length;
2688 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2689 !stream.ReadVarint32(&feature_length)) {
2690 // This should be unreachable -- we already scanned the feature in
2691 // GetSequenceFeatureLengths, and it hasn't changed since then.
2692 return errors::InvalidArgument("Error in sequence feature ",
2693 c.feature_name, " in example ",
2694 ExampleName(example_names, e));
2695 }
2696 if (feature_length > 2) {
2697 auto limit = stream.PushLimit(feature_length);
2698 size_t num_added;
2699 switch (dtype) {
2700 case DT_STRING:
2701 num_added = ParseBytesFeature(&stream, out_bytes);
2702 out_bytes += num_added;
2703 break;
2704 case DT_FLOAT:
2705 num_added = ParseFloatFeature(&stream, out_float);
2706 out_float += num_added;
2707 break;
2708 case DT_INT64:
2709 num_added = ParseInt64Feature(&stream, out_int64);
2710 out_int64 += num_added;
2711 break;
2712 default:
2713 ReportUnexpectedDataType(dtype);
2714 num_added = 0;
2715 }
2716 num_elements += num_added;
2717 max_num_cols = std::max(max_num_cols, num_added);
2718 for (int i = 0; i < num_added; i++) {
2719 if (is_batch) *out_indices++ = e;
2720 *out_indices++ = num_rows;
2721 *out_indices++ = i;
2722 }
2723 stream.PopLimit(limit);
2724 } else if (feature_length == 2) {
2725 if (!SkipEmptyFeature(&stream, dtype)) {
2726 // This should be unreachable -- we already scanned the feature in
2727 // GetSequenceFeatureLengths, and it hasn't changed since then.
2728 return errors::InvalidArgument("Error in sequence feature ",
2729 c.feature_name, " in example ",
2730 ExampleName(example_names, e));
2731 }
2732 } else if (feature_length != 0) {
2733 // This should be unreachable -- we already scanned the feature in
2734 // GetSequenceFeatureLengths, and it hasn't changed since then.
2735 return errors::InvalidArgument("Error in sequence feature ",
2736 c.feature_name, " in example ",
2737 ExampleName(example_names, e));
2738 }
2739 num_rows++;
2740 }
2741 max_num_rows = std::max(max_num_rows, num_rows);
2742 }
2743 if (num_elements != expected_num_elements) {
2744 return errors::InvalidArgument(
2745 "Unexpected number of elements in feature ", c.feature_name);
2746 }
2747 if (is_batch) {
2748 out_shape(0) = num_examples;
2749 out_shape(1) = max_num_rows;
2750 out_shape(2) = max_num_cols;
2751 } else {
2752 out_shape(0) = max_num_rows;
2753 out_shape(1) = max_num_cols;
2754 }
2755 }
2756 return Status::OK();
2757 }
2758
2759 // Parses ragged features in `sequence_features`, and writes their parsed
2760 // values to `sequence_result`.
ParseSequenceRaggedFeatures(const FeatureProtosMap & sequence_features,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> example_names,bool is_batch,int num_examples,Allocator * allocator,Result * sequence_result)2761 Status ParseSequenceRaggedFeatures(
2762 const FeatureProtosMap& sequence_features,
2763 const FastParseExampleConfig& sequence_config,
2764 gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples,
2765 Allocator* allocator, Result* sequence_result) {
2766 for (int t = 0; t < sequence_config.ragged.size(); ++t) {
2767 const auto& c = sequence_config.ragged[t];
2768 const FeatureProtos& feature =
2769 sequence_features.find(c.feature_name)->second;
2770 TensorShape values_shape, inner_splits_shape, outer_splits_shape;
2771 DataType dtype = c.dtype;
2772 DataType splits_dtype = c.splits_dtype;
2773 size_t expected_num_elements = feature.length;
2774 size_t expected_num_rows = feature.num_rows;
2775 values_shape.AddDim(expected_num_elements);
2776 inner_splits_shape.AddDim(expected_num_rows + 1);
2777 if (is_batch) {
2778 outer_splits_shape.AddDim(num_examples + 1);
2779 }
2780 sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape);
2781 sequence_result->ragged_splits[t] =
2782 Tensor(allocator, splits_dtype, inner_splits_shape);
2783 sequence_result->ragged_outer_splits[t] =
2784 Tensor(allocator, splits_dtype, outer_splits_shape);
2785 Tensor& out_values = sequence_result->ragged_values[t];
2786 size_t out_values_offset = 0;
2787 int32* int32_inner_splits =
2788 splits_dtype == DT_INT32
2789 ? sequence_result->ragged_splits[t].vec<int32>().data()
2790 : nullptr;
2791 int64* int64_inner_splits =
2792 splits_dtype == DT_INT64
2793 ? sequence_result->ragged_splits[t].vec<int64>().data()
2794 : nullptr;
2795 int32* int32_outer_splits =
2796 is_batch && splits_dtype == DT_INT32
2797 ? sequence_result->ragged_outer_splits[t].vec<int32>().data()
2798 : nullptr;
2799 int64* int64_outer_splits =
2800 is_batch && splits_dtype == DT_INT64
2801 ? sequence_result->ragged_outer_splits[t].vec<int64>().data()
2802 : nullptr;
2803 if (int32_inner_splits) {
2804 *int32_inner_splits++ = 0;
2805 } else if (int64_inner_splits) {
2806 *int64_inner_splits++ = 0;
2807 }
2808 if (int32_outer_splits) {
2809 *int32_outer_splits++ = 0;
2810 } else if (int64_outer_splits) {
2811 *int64_outer_splits++ = 0;
2812 }
2813
2814 // Fill in the values.
2815 size_t inner_split = 0; // total number of elements we've seen so far
2816 size_t outer_split = 0; // total number of rows we've seen so far
2817 for (int e = 0; e < num_examples; e++) {
2818 const auto& feature_proto = feature.protos[e];
2819 if (!feature_proto.empty()) {
2820 protobuf::io::CodedInputStream stream(
2821 reinterpret_cast<const uint8*>(feature_proto.data()),
2822 feature_proto.size());
2823 EnableAliasing(&stream);
2824 while (!stream.ExpectAtEnd()) {
2825 uint32 feature_length;
2826 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2827 !stream.ReadVarint32(&feature_length)) {
2828 // This should be unreachable -- we already scanned the feature in
2829 // GetSequenceFeatureLengths, and it hasn't changed since then.
2830 return errors::InvalidArgument("Error in sequence feature ",
2831 c.feature_name, " in example ",
2832 ExampleName(example_names, e));
2833 }
2834 if (feature_length > 2) {
2835 auto limit = stream.PushLimit(feature_length);
2836 size_t num_added =
2837 ParseFeature(dtype, &stream, &out_values, &out_values_offset);
2838 inner_split += num_added;
2839 stream.PopLimit(limit);
2840 } else if (feature_length == 2) {
2841 if (!SkipEmptyFeature(&stream, dtype)) {
2842 // This should be unreachable -- we already scanned the feature in
2843 // GetSequenceFeatureLengths, and it hasn't changed since then.
2844 return errors::InvalidArgument("Error in sequence feature ",
2845 c.feature_name, " in example ",
2846 ExampleName(example_names, e));
2847 }
2848 } else if (feature_length != 0) {
2849 // This should be unreachable -- we already scanned the feature in
2850 // GetSequenceFeatureLengths, and it hasn't changed since then.
2851 return errors::InvalidArgument("Error in sequence feature ",
2852 c.feature_name, " in example ",
2853 ExampleName(example_names, e));
2854 }
2855 if (int32_inner_splits) {
2856 *int32_inner_splits++ = inner_split;
2857 } else if (int64_inner_splits) {
2858 *int64_inner_splits++ = inner_split;
2859 }
2860 outer_split++;
2861 }
2862 }
2863 if (int32_outer_splits) {
2864 *int32_outer_splits++ = outer_split;
2865 } else if (int64_outer_splits) {
2866 *int64_outer_splits++ = outer_split;
2867 }
2868 }
2869 if (outer_split != expected_num_rows) {
2870 return errors::InvalidArgument("Unexpected number of rows for feature ",
2871 c.feature_name);
2872 }
2873 if (inner_split != expected_num_elements) {
2874 return errors::InvalidArgument(
2875 "Unexpected number of elements for feature ", c.feature_name);
2876 }
2877
2878 if (int32_inner_splits || int64_inner_splits) {
2879 const auto& inner_splits = sequence_result->ragged_splits[t];
2880 int num_inner_splits =
2881 int32_inner_splits
2882 ? int32_inner_splits - inner_splits.vec<int32>().data()
2883 : int64_inner_splits - inner_splits.vec<int64>().data();
2884 if (num_inner_splits != expected_num_rows + 1) {
2885 return errors::InvalidArgument("Unexpected number of rows for feature ",
2886 c.feature_name);
2887 }
2888 }
2889 if (int32_outer_splits || int64_outer_splits) {
2890 const auto& outer_splits = sequence_result->ragged_outer_splits[t];
2891 int num_outer_splits =
2892 int32_outer_splits
2893 ? int32_outer_splits - outer_splits.vec<int32>().data()
2894 : int64_outer_splits - outer_splits.vec<int64>().data();
2895 if (num_outer_splits != num_examples + 1) {
2896 return errors::InvalidArgument(
2897 "Unexpected number of examples for feature ", c.feature_name);
2898 }
2899 }
2900 }
2901 return Status::OK();
2902 }
2903
2904 } // namespace
2905
2906 // TODO(sundberg): Use the threadpool to parallelize example parsing.
2907 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & sequence_config,gtl::ArraySlice<tstring> serialized,gtl::ArraySlice<tstring> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * sequence_result,std::vector<Tensor> * dense_feature_lengths,bool is_batch)2908 Status FastParseSequenceExample(const FastParseExampleConfig& context_config,
2909 const FastParseExampleConfig& sequence_config,
2910 gtl::ArraySlice<tstring> serialized,
2911 gtl::ArraySlice<tstring> example_names,
2912 thread::ThreadPool* thread_pool,
2913 Result* context_result, Result* sequence_result,
2914 std::vector<Tensor>* dense_feature_lengths,
2915 bool is_batch) {
2916 int num_examples = serialized.size();
2917 DCHECK(context_result != nullptr);
2918 DCHECK(sequence_result != nullptr);
2919 DCHECK(dense_feature_lengths != nullptr);
2920 size_t num_context_features = context_config.sparse.size() +
2921 context_config.dense.size() +
2922 context_config.ragged.size();
2923 FeatureProtosMap context_features;
2924 context_features.reserve(num_context_features);
2925
2926 if (!example_names.empty() && example_names.size() != num_examples) {
2927 return errors::InvalidArgument(
2928 "example_names must be empty or have the correct number of elements");
2929 }
2930 for (auto& c : context_config.sparse) {
2931 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2932 FeatureProtos& feature = context_features[c.feature_name];
2933 feature.dtype = c.dtype;
2934 feature.length = 0;
2935 feature.type = Type::Sparse;
2936 feature.protos.resize(num_examples);
2937 feature.protos_present.resize(num_examples);
2938 }
2939 for (auto& c : context_config.ragged) {
2940 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2941 FeatureProtos& feature = context_features[c.feature_name];
2942 if (feature.type == Type::Sparse) {
2943 return errors::InvalidArgument("Context feature " + c.feature_name +
2944 " cannot be both ragged and sparse");
2945 }
2946 feature.dtype = c.dtype;
2947 feature.length = 0;
2948 feature.type = Type::Ragged;
2949 feature.protos.resize(num_examples);
2950 feature.protos_present.resize(num_examples);
2951 }
2952 for (auto& c : context_config.dense) {
2953 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2954 FeatureProtos& feature = context_features[c.feature_name];
2955 if (feature.type != Type::Dense) {
2956 return errors::InvalidArgument("Context feature " + c.feature_name +
2957 " cannot be both dense and sparse");
2958 }
2959 if (c.default_value.NumElements() > 0) {
2960 if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
2961 return errors::InvalidArgument("Default value for context feature ",
2962 c.feature_name,
2963 " has an incorrect shape: saw ",
2964 c.default_value.shape().DebugString(),
2965 " but expected ", c.shape.DebugString());
2966 }
2967 }
2968 feature.dtype = c.dtype;
2969 feature.length = c.default_value.NumElements();
2970 feature.protos.resize(num_examples);
2971 feature.protos_present.resize(num_examples);
2972 }
2973 size_t num_sequence_features = sequence_config.sparse.size() +
2974 sequence_config.dense.size() +
2975 sequence_config.ragged.size();
2976 FeatureProtosMap sequence_features;
2977 sequence_features.reserve(num_sequence_features);
2978 for (auto& c : sequence_config.sparse) {
2979 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2980 FeatureProtos& feature = sequence_features[c.feature_name];
2981 feature.dtype = c.dtype;
2982 feature.length = 0;
2983 feature.type = Type::Sparse;
2984 feature.protos.resize(num_examples);
2985 feature.protos_present.resize(num_examples);
2986 }
2987 for (auto& c : sequence_config.ragged) {
2988 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
2989 FeatureProtos& feature = sequence_features[c.feature_name];
2990 if (feature.type == Type::Sparse) {
2991 return errors::InvalidArgument("Sequence feature " + c.feature_name +
2992 " cannot be both ragged and sparse");
2993 }
2994 feature.dtype = c.dtype;
2995 feature.length = 0;
2996 feature.type = Type::Ragged;
2997 feature.protos.resize(num_examples);
2998 feature.protos_present.resize(num_examples);
2999 }
3000 for (auto& c : sequence_config.dense) {
3001 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
3002 FeatureProtos& feature = sequence_features[c.feature_name];
3003 if (feature.type != Type::Dense) {
3004 return errors::InvalidArgument("Sequence feature " + c.feature_name +
3005 " cannot be both dense and sparse");
3006 }
3007 feature.dtype = c.dtype;
3008 feature.length = 0;
3009 feature.protos.resize(num_examples);
3010 feature.protos_present.resize(num_examples);
3011 }
3012
3013 // Find the serialized proto substrings for each feature.
3014 TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples(
3015 serialized, example_names, &context_features, &sequence_features));
3016
3017 // Scan through the protos to determine how much memory we need to allocate.
3018 TF_RETURN_IF_ERROR(
3019 GetContextFeatureLengths(example_names, &context_features));
3020 TF_RETURN_IF_ERROR(
3021 GetSequenceFeatureLengths(example_names, &sequence_features));
3022
3023 // Allocate memory.
3024 context_result->sparse_values.resize(context_config.sparse.size());
3025 context_result->sparse_indices.resize(context_config.sparse.size());
3026 context_result->sparse_shapes.resize(context_config.sparse.size());
3027 context_result->dense_values.resize(context_config.dense.size());
3028 context_result->ragged_values.resize(context_config.ragged.size());
3029 context_result->ragged_splits.resize(context_config.ragged.size());
3030 context_result->ragged_outer_splits.resize(context_config.ragged.size());
3031 sequence_result->sparse_values.resize(sequence_config.sparse.size());
3032 sequence_result->sparse_indices.resize(sequence_config.sparse.size());
3033 sequence_result->sparse_shapes.resize(sequence_config.sparse.size());
3034 sequence_result->dense_values.resize(sequence_config.dense.size());
3035 sequence_result->ragged_values.resize(sequence_config.ragged.size());
3036 sequence_result->ragged_splits.resize(sequence_config.ragged.size());
3037 sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size());
3038 dense_feature_lengths->resize(sequence_config.dense.size());
3039
3040 // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction,
3041 // to avoid lock contention in `tensorflow::cpu_allocator()`.
3042 Allocator* allocator = tensorflow::cpu_allocator();
3043
3044 TF_RETURN_IF_ERROR(ParseContextDenseFeatures(
3045 context_features, context_config, example_names, is_batch, num_examples,
3046 allocator, context_result));
3047 TF_RETURN_IF_ERROR(ParseContextSparseFeatures(
3048 context_features, context_config, example_names, is_batch, num_examples,
3049 allocator, context_result));
3050 TF_RETURN_IF_ERROR(ParseContextRaggedFeatures(
3051 context_features, context_config, example_names, is_batch, num_examples,
3052 allocator, context_result));
3053 TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures(
3054 sequence_features, sequence_config, example_names, is_batch, num_examples,
3055 allocator, sequence_result, dense_feature_lengths));
3056 TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures(
3057 sequence_features, sequence_config, example_names, is_batch, num_examples,
3058 allocator, sequence_result));
3059 TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures(
3060 sequence_features, sequence_config, example_names, is_batch, num_examples,
3061 allocator, sequence_result));
3062
3063 return Status::OK();
3064 }
3065
3066 } // namespace example
3067 } // namespace tensorflow
3068