1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/util/example_proto_fast_parsing.h"
16
17 #include <vector>
18
19 #include "absl/base/casts.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/example/example.pb.h"
22 #include "tensorflow/core/example/feature.pb_text.h"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/lib/core/blocking_counter.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/core/lib/monitoring/counter.h"
31 #include "tensorflow/core/platform/byte_order.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 #include "tensorflow/core/util/presized_cuckoo_map.h"
35 #include "tensorflow/core/util/sparse/sparse_tensor.h"
36
37 namespace tensorflow {
38 namespace example {
39
40 namespace {
41
42 template <typename T>
43 using SmallVector = gtl::InlinedVector<T, 4>;
44
45 template <typename A>
EnableAliasing(A * a)46 auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) {
47 a->EnableAliasing(true);
48 }
49
50 template <typename A>
EnableAliasing(A && a)51 void EnableAliasing(A&& a) {}
52
PeekTag(protobuf::io::CodedInputStream * stream)53 uint8 PeekTag(protobuf::io::CodedInputStream* stream) {
54 DCHECK(stream != nullptr);
55 const void* ptr;
56 int size;
57 if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0;
58 return *static_cast<const uint8*>(ptr);
59 }
60
kVarintTag(uint32 tag)61 constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; }
kDelimitedTag(uint32 tag)62 constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; }
kFixed32Tag(uint32 tag)63 constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; }
64
65 namespace parsed {
66
67 // ParseDataType has to be called first, then appropriate ParseZzzzList.
68 class Feature {
69 public:
Feature()70 Feature() {}
Feature(StringPiece serialized)71 explicit Feature(StringPiece serialized) : serialized_(serialized) {}
72
ParseDataType(DataType * dtype)73 Status ParseDataType(DataType* dtype) {
74 DCHECK(dtype != nullptr);
75 if (serialized_.empty()) {
76 *dtype = DT_INVALID;
77 return Status::OK();
78 }
79 uint8 oneof_tag = static_cast<uint8>(*serialized_.data());
80 serialized_.remove_prefix(1);
81 switch (oneof_tag) {
82 case kDelimitedTag(1):
83 *dtype = DT_STRING;
84 break;
85 case kDelimitedTag(2):
86 *dtype = DT_FLOAT;
87 break;
88 case kDelimitedTag(3):
89 *dtype = DT_INT64;
90 break;
91 default:
92 // Initialize variable to avoid compiler warning
93 *dtype = DT_INVALID;
94 return errors::InvalidArgument("Unsupported datatype.");
95 }
96 return Status::OK();
97 }
98
GetNumElementsInBytesList(int * num_elements)99 bool GetNumElementsInBytesList(int* num_elements) {
100 protobuf::io::CodedInputStream stream(
101 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
102 EnableAliasing(&stream);
103 uint32 length = 0;
104 if (!stream.ReadVarint32(&length)) return false;
105 auto limit = stream.PushLimit(length);
106 *num_elements = 0;
107 while (!stream.ExpectAtEnd()) {
108 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
109 uint32 bytes_length = 0;
110 if (!stream.ReadVarint32(&bytes_length)) return false;
111 if (!stream.Skip(bytes_length)) return false;
112 ++*num_elements;
113 }
114 stream.PopLimit(limit);
115 return true;
116 }
117
118 template <typename Result>
ParseBytesList(Result * bytes_list)119 bool ParseBytesList(Result* bytes_list) {
120 DCHECK(bytes_list != nullptr);
121
122 protobuf::io::CodedInputStream stream(
123 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
124
125 EnableAliasing(&stream);
126
127 uint32 length;
128 if (!stream.ReadVarint32(&length)) return false;
129 auto limit = stream.PushLimit(length);
130
131 while (!stream.ExpectAtEnd()) {
132 if (!stream.ExpectTag(kDelimitedTag(1))) return false;
133 // parse string
134 uint32 bytes_length;
135 if (!stream.ReadVarint32(&bytes_length)) return false;
136 string bytes;
137 if (!stream.ReadString(&bytes, bytes_length)) return false;
138 bytes_list->push_back(std::move(bytes));
139 }
140 stream.PopLimit(limit);
141 return true;
142 }
143
144 template <typename Result>
ParseFloatList(Result * float_list)145 bool ParseFloatList(Result* float_list) {
146 DCHECK(float_list != nullptr);
147 protobuf::io::CodedInputStream stream(
148 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
149 EnableAliasing(&stream);
150 uint32 length;
151 if (!stream.ReadVarint32(&length)) return false;
152 auto limit = stream.PushLimit(length);
153
154 if (!stream.ExpectAtEnd()) {
155 uint8 peek_tag = PeekTag(&stream);
156 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
157 return false;
158 }
159
160 if (peek_tag == kDelimitedTag(1)) { // packed
161 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
162 uint32 packed_length;
163 if (!stream.ReadVarint32(&packed_length)) return false;
164 auto packed_limit = stream.PushLimit(packed_length);
165
166 // If the result data type is float and we are on a little endian
167 // machine then we can simply memcpy the data from the proto into the
168 // result vector.
169 constexpr int32 kNumFloatBytes = 4;
170 if (port::kLittleEndian &&
171 sizeof(typename Result::value_type) == kNumFloatBytes) {
172 // Store the initial size to know the offset we have to start writing
173 // data from before resizing the output "vector".
174 const size_t initial_size = float_list->size();
175 float_list->resize(initial_size + packed_length / kNumFloatBytes);
176 // Calculate the length of the buffer available what can be less than
177 // what we requested in resize in case of a LimitedArraySlice.
178 const uint32 bytes_to_copy =
179 std::min(static_cast<uint32>((float_list->size() - initial_size) *
180 kNumFloatBytes),
181 packed_length);
182 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy))
183 return false;
184 } else {
185 while (!stream.ExpectAtEnd()) {
186 uint32 buffer32;
187 if (!stream.ReadLittleEndian32(&buffer32)) return false;
188 float_list->push_back(absl::bit_cast<float>(buffer32));
189 }
190 }
191
192 stream.PopLimit(packed_limit);
193 } else { // non-packed
194 while (!stream.ExpectAtEnd()) {
195 if (!stream.ExpectTag(kFixed32Tag(1))) return false;
196 uint32 buffer32;
197 if (!stream.ReadLittleEndian32(&buffer32)) return false;
198 float_list->push_back(absl::bit_cast<float>(buffer32));
199 }
200 }
201 }
202
203 stream.PopLimit(limit);
204 return true;
205 }
206
207 template <typename Result>
ParseInt64List(Result * int64_list)208 bool ParseInt64List(Result* int64_list) {
209 DCHECK(int64_list != nullptr);
210 protobuf::io::CodedInputStream stream(
211 reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size());
212 EnableAliasing(&stream);
213 uint32 length;
214 if (!stream.ReadVarint32(&length)) return false;
215 auto limit = stream.PushLimit(length);
216
217 if (!stream.ExpectAtEnd()) {
218 uint8 peek_tag = PeekTag(&stream);
219 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
220 return false;
221 }
222 if (peek_tag == kDelimitedTag(1)) { // packed
223 if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag
224 uint32 packed_length;
225 if (!stream.ReadVarint32(&packed_length)) return false;
226 auto packed_limit = stream.PushLimit(packed_length);
227
228 while (!stream.ExpectAtEnd()) {
229 protobuf_uint64 n; // There is no API for int64
230 if (!stream.ReadVarint64(&n)) return false;
231 int64_list->push_back(static_cast<int64>(n));
232 }
233
234 stream.PopLimit(packed_limit);
235 } else { // non-packed
236 while (!stream.ExpectAtEnd()) {
237 if (!stream.ExpectTag(kVarintTag(1))) return false;
238 protobuf_uint64 n; // There is no API for int64
239 if (!stream.ReadVarint64(&n)) return false;
240 int64_list->push_back(static_cast<int64>(n));
241 }
242 }
243 }
244 stream.PopLimit(limit);
245 return true;
246 }
247
GetSerialized() const248 StringPiece GetSerialized() const { return serialized_; }
249
250 private:
251 // TODO(lew): Pair of uint8* would be more natural.
252 StringPiece serialized_;
253 };
254
255 using FeatureMapEntry = std::pair<StringPiece, Feature>;
256 using Example = std::vector<FeatureMapEntry>;
257
258 } // namespace parsed
259
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)260 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream* stream) {
261 uint32 data;
262 protobuf_uint64 dummy;
263 switch (stream->ReadTag() & 0x7) {
264 case 0: // varint
265 if (!stream->ReadVarint32(&data)) return false;
266 return true;
267 case 1: // fixed64
268 if (!stream->ReadLittleEndian64(&dummy)) return false;
269 return true;
270 case 2: // length delimited
271 if (!stream->ReadVarint32(&data)) return false;
272 stream->Skip(data);
273 return true;
274 case 3: // group begin
275 return false; // groups not supported.
276 case 4: // group end
277 return false; // groups not supported.
278 case 5: // fixed32
279 if (!stream->ReadLittleEndian32(&data)) return false;
280 return true;
281 }
282 return false; // unrecognized tag type
283 }
284
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)285 bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) {
286 DCHECK(stream != nullptr);
287 DCHECK(result != nullptr);
288 uint32 length;
289 if (!stream->ReadVarint32(&length)) return false;
290 if (length == 0) {
291 *result = StringPiece(nullptr, 0);
292 return true;
293 }
294 const void* stream_alias;
295 int stream_size;
296 if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
297 return false;
298 }
299 if (static_cast<uint32>(stream_size) < length) return false;
300 *result = StringPiece(static_cast<const char*>(stream_alias), length);
301 stream->Skip(length);
302 return true;
303 }
304
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)305 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream,
306 parsed::FeatureMapEntry* feature_map_entry) {
307 DCHECK(stream != nullptr);
308 DCHECK(feature_map_entry != nullptr);
309 uint32 length;
310 if (!stream->ReadVarint32(&length)) return false;
311 auto limit = stream->PushLimit(length);
312 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
313 if (!ParseString(stream, &feature_map_entry->first)) return false;
314 if (!stream->ExpectTag(kDelimitedTag(2))) return false;
315 StringPiece feature_string_piece;
316 if (!ParseString(stream, &feature_string_piece)) return false;
317 feature_map_entry->second = parsed::Feature(feature_string_piece);
318 if (!stream->ExpectAtEnd()) return false;
319 stream->PopLimit(limit);
320 return true;
321 }
322
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)323 bool ParseFeatures(protobuf::io::CodedInputStream* stream,
324 parsed::Example* example) {
325 DCHECK(stream != nullptr);
326 DCHECK(example != nullptr);
327 uint32 length;
328 if (!stream->ReadVarint32(&length)) return false;
329 auto limit = stream->PushLimit(length);
330 while (!stream->ExpectAtEnd()) {
331 parsed::FeatureMapEntry feature_map_entry;
332 if (!stream->ExpectTag(kDelimitedTag(1))) return false;
333 if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false;
334 example->push_back(std::move(feature_map_entry));
335 }
336 stream->PopLimit(limit);
337 return true;
338 }
339
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)340 bool ParseExample(protobuf::io::CodedInputStream* stream,
341 parsed::Example* example) {
342 DCHECK(stream != nullptr);
343 DCHECK(example != nullptr);
344 // Loop over the input stream which may contain multiple serialized Example
345 // protos merged together as strings. This behavior is consistent with Proto's
346 // ParseFromString when string representations are concatenated.
347 while (!stream->ExpectAtEnd()) {
348 if (!stream->ExpectTag(kDelimitedTag(1))) {
349 if (!SkipExtraneousTag(stream)) return false;
350 } else {
351 if (!ParseFeatures(stream, example)) return false;
352 }
353 }
354 return true;
355 }
356
ParseExample(StringPiece serialized,parsed::Example * example)357 bool ParseExample(StringPiece serialized, parsed::Example* example) {
358 DCHECK(example != nullptr);
359 protobuf::io::CodedInputStream stream(
360 reinterpret_cast<const uint8*>(serialized.data()), serialized.size());
361 EnableAliasing(&stream);
362 return ParseExample(&stream, example);
363 }
364
365 } // namespace
366
TestFastParse(const string & serialized,Example * example)367 bool TestFastParse(const string& serialized, Example* example) {
368 DCHECK(example != nullptr);
369 parsed::Example parsed_example;
370 if (!ParseExample(serialized, &parsed_example)) return false;
371 auto& features = *example->mutable_features();
372 size_t parsed_example_size = parsed_example.size();
373 for (size_t i = 0; i < parsed_example_size; ++i) {
374 // This is a logic that standard protobuf parsing is implementing.
375 // I.e. last entry in the map overwrites all the previous ones.
376 parsed::FeatureMapEntry& name_and_feature =
377 parsed_example[parsed_example_size - i - 1];
378 string name(name_and_feature.first);
379 if ((*features.mutable_feature()).count(name) > 0) continue;
380
381 auto& value = (*features.mutable_feature())[name];
382 DataType dtype;
383 if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false;
384 switch (dtype) {
385 case DT_INVALID:
386 break;
387 case DT_STRING: {
388 SmallVector<string> list;
389 if (!name_and_feature.second.ParseBytesList(&list)) return false;
390 auto* result_list = value.mutable_bytes_list();
391 for (auto& bytes : list) {
392 auto* new_value = result_list->add_value();
393 new_value->swap(bytes);
394 }
395 break;
396 }
397 case DT_FLOAT: {
398 SmallVector<float> list;
399 if (!name_and_feature.second.ParseFloatList(&list)) return false;
400 auto* result_list = value.mutable_float_list();
401 for (float f : list) {
402 result_list->add_value(f);
403 }
404 break;
405 }
406 case DT_INT64: {
407 SmallVector<int64> list;
408 if (!name_and_feature.second.ParseInt64List(&list)) return false;
409 auto* result_list = value.mutable_int64_list();
410 for (int64 i : list) {
411 result_list->add_value(i);
412 }
413 break;
414 }
415 default:
416 LOG(FATAL) << "Should not happen.";
417 }
418 }
419 return true;
420 }
421
422 // -----------------------------------------------------------------------------
423
424 namespace {
425
426 using Config = FastParseExampleConfig;
427
ParallelFor(const std::function<void (size_t)> & f,size_t n,thread::ThreadPool * thread_pool)428 void ParallelFor(const std::function<void(size_t)>& f, size_t n,
429 thread::ThreadPool* thread_pool) {
430 if (n == 0) return;
431 if (thread_pool == nullptr) {
432 for (size_t i = 0; i < n; ++i) {
433 f(i);
434 }
435 } else {
436 BlockingCounter counter(n - 1);
437 for (size_t i = 1; i < n; ++i) {
438 thread_pool->Schedule([i, &f, &counter] {
439 f(i);
440 counter.DecrementCount();
441 });
442 }
443 f(0);
444 counter.Wait();
445 }
446 }
447
448 enum class Type { Sparse, Dense };
449
450 struct SparseBuffer {
451 // Features are in one of the 3 vectors below depending on config's dtype.
452 // Other 2 vectors remain empty.
453 SmallVector<string> bytes_list;
454 SmallVector<float> float_list;
455 SmallVector<int64> int64_list;
456
457 // Features of example i are elements with indices
458 // from example_end_indices[i-1] to example_end_indices[i]-1 on the
459 // appropriate xxxxx_list
460 std::vector<size_t> example_end_indices;
461 };
462
463 struct SeededHasher {
operator ()tensorflow::example::__anoncf92a8560211::SeededHasher464 uint64 operator()(StringPiece s) const {
465 return Hash64(s.data(), s.size(), seed);
466 }
467 uint64 seed{0xDECAFCAFFE};
468 };
469
470 template <typename T>
471 class LimitedArraySlice {
472 public:
473 using value_type = T;
474
LimitedArraySlice(T * begin,size_t num_elements)475 LimitedArraySlice(T* begin, size_t num_elements)
476 : current_(begin), begin_(begin), end_(begin + num_elements) {}
477
478 // May return negative if there were push_back calls after slice was filled.
EndDistance() const479 int64 EndDistance() const { return end_ - current_; }
480
481 // Attempts to push value to the back of this. If the slice has
482 // already been filled, this method has no effect on the underlying data, but
483 // it changes the number returned by EndDistance into negative values.
push_back(T && value)484 void push_back(T&& value) {
485 if (EndDistance() > 0) *current_ = std::move(value);
486 ++current_;
487 }
488
489 // Returns the number of elements in the slice.
size() const490 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
491
492 // Attempts to resize the vector to the given size. It does so by advancing
493 // the pointer to the current element, possibly beyond the end of the slice.
494 // As a consequence, calling `size()` after `resize(x)` was called might
495 // return a value less than `x`.
resize(size_t size)496 void resize(size_t size) { current_ = begin_ + size; }
497
498 // Returns the pointer to the underlying data buffer.
data()499 T* data() { return begin_; }
500
501 private:
502 T* current_;
503 T* begin_;
504 T* end_;
505 };
506
LogDenseFeatureDataLoss(StringPiece feature_name)507 void LogDenseFeatureDataLoss(StringPiece feature_name) {
508 LOG(WARNING) << "Data loss! Feature '" << feature_name
509 << "' is present in multiple concatenated "
510 "tf.Examples. Ignoring all but last one.";
511 static auto* duplicated_dense_feature = monitoring::Counter<0>::New(
512 "/tensorflow/core/util/example_proto_fast_parsing/"
513 "duplicated_dense_feature",
514 "Dense feature appears twice in a tf.Example");
515 duplicated_dense_feature->GetCell()->IncrementBy(1);
516 }
517
LogSparseFeatureDataLoss(StringPiece feature_name)518 void LogSparseFeatureDataLoss(StringPiece feature_name) {
519 LOG(WARNING) << "Data loss! Feature '" << feature_name
520 << "' is present in multiple concatenated "
521 "tf.Examples. Ignoring all but last one.";
522 static auto* duplicated_sparse_feature = monitoring::Counter<0>::New(
523 "/tensorflow/core/util/example_proto_fast_parsing/"
524 "duplicated_sparse_feature",
525 "Sparse feature appears twice in a tf.Example");
526 duplicated_sparse_feature->GetCell()->IncrementBy(1);
527 }
528
FastParseSerializedExample(const string & serialized_example,const string & example_name,const size_t example_index,const Config & config,const PresizedCuckooMap<std::pair<size_t,Type>> & config_index,SeededHasher hasher,std::vector<Tensor> * output_dense,std::vector<SparseBuffer> * output_varlen_dense,std::vector<SparseBuffer> * output_sparse,PerExampleFeatureStats * output_stats)529 Status FastParseSerializedExample(
530 const string& serialized_example, const string& example_name,
531 const size_t example_index, const Config& config,
532 const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
533 SeededHasher hasher, std::vector<Tensor>* output_dense,
534 std::vector<SparseBuffer>* output_varlen_dense,
535 std::vector<SparseBuffer>* output_sparse,
536 PerExampleFeatureStats* output_stats) {
537 DCHECK(output_dense != nullptr);
538 DCHECK(output_sparse != nullptr);
539 parsed::Example parsed_example;
540 if (!ParseExample(serialized_example, &parsed_example)) {
541 return errors::InvalidArgument("Could not parse example input, value: '",
542 serialized_example, "'");
543 }
544 std::vector<int64> sparse_feature_last_example(config.sparse.size(), -1);
545 std::vector<int64> dense_feature_last_example(config.dense.size(), -1);
546
547 // Handle features present in the example.
548 const size_t parsed_example_size = parsed_example.size();
549
550 if (output_stats) {
551 // TODO(b/111553342): This may over-count the number of features if there
552 // are duplicate keys in the feature map. Consider deduplicating the keys
553 // before computing the count.
554 output_stats->features_count = parsed_example_size;
555 }
556
557 for (size_t i = 0; i < parsed_example_size; ++i) {
558 // This is a logic that standard protobuf parsing is implementing.
559 // I.e. last entry in the map overwrites all the previous ones.
560 parsed::FeatureMapEntry& name_and_feature =
561 parsed_example[parsed_example_size - i - 1];
562
563 const StringPiece feature_name = name_and_feature.first;
564 parsed::Feature& feature = name_and_feature.second;
565
566 std::pair<size_t, Type> d_and_type;
567 uint64 h = hasher(feature_name);
568 if (!config_index.Find(h, &d_and_type)) continue;
569
570 size_t d = d_and_type.first;
571 bool is_dense = d_and_type.second == Type::Dense;
572
573 {
574 // Testing for PresizedCuckooMap collision.
575 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
576 const string& config_feature_name = is_dense
577 ? config.dense[d].feature_name
578 : config.sparse[d].feature_name;
579 if (feature_name != config_feature_name) continue;
580 }
581
582 auto example_error = [&](StringPiece suffix) {
583 return errors::InvalidArgument("Name: ", example_name,
584 ", Key: ", feature_name,
585 ", Index: ", example_index, ". ", suffix);
586 };
587
588 auto parse_error = [&] {
589 return example_error("Can't parse serialized Example.");
590 };
591
592 DataType example_dtype;
593 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
594
595 if (is_dense) {
596 if (example_dtype == DT_INVALID) continue;
597
598 // If feature was already visited, skip.
599 // Compare comment at the beginning of the loop.
600 if (dense_feature_last_example[d] == example_index) {
601 LogDenseFeatureDataLoss(feature_name);
602 continue;
603 }
604 dense_feature_last_example[d] = example_index;
605
606 if (example_dtype != config.dense[d].dtype) {
607 return example_error(strings::StrCat(
608 "Data types don't match. Data type: ",
609 DataTypeString(example_dtype),
610 " but expected type: ", DataTypeString(config.dense[d].dtype)));
611 }
612 if (!config.dense[d].variable_length) {
613 Tensor& out = (*output_dense)[d];
614
615 const std::size_t num_elements = config.dense[d].elements_per_stride;
616 if (output_stats) {
617 // TODO(b/111553342): If desirable, we could add support for counting
618 // elements in the features that aren't parsed, but this could add
619 // considerable runtime cost.
620 output_stats->feature_values_count += num_elements;
621 }
622
623 const std::size_t offset = example_index * num_elements;
624
625 auto shape_error = [&](size_t size, StringPiece type_str) {
626 return example_error(strings::StrCat(
627 "Number of ", type_str,
628 " values != expected. "
629 "Values size: ",
630 size,
631 " but output shape: ", config.dense[d].shape.DebugString()));
632 };
633
634 switch (config.dense[d].dtype) {
635 case DT_INT64: {
636 auto out_p = out.flat<int64>().data() + offset;
637 LimitedArraySlice<int64> slice(out_p, num_elements);
638 if (!feature.ParseInt64List(&slice)) return parse_error();
639 if (slice.EndDistance() != 0) {
640 return shape_error(num_elements - slice.EndDistance(), "int64");
641 }
642 break;
643 }
644 case DT_FLOAT: {
645 auto out_p = out.flat<float>().data() + offset;
646 LimitedArraySlice<float> slice(out_p, num_elements);
647 if (!feature.ParseFloatList(&slice)) return parse_error();
648 if (slice.EndDistance() != 0) {
649 return shape_error(num_elements - slice.EndDistance(), "float");
650 }
651 break;
652 }
653 case DT_STRING: {
654 auto out_p = out.flat<string>().data() + offset;
655 LimitedArraySlice<string> slice(out_p, num_elements);
656 if (!feature.ParseBytesList(&slice)) return parse_error();
657 if (slice.EndDistance() != 0) {
658 return shape_error(num_elements - slice.EndDistance(), "bytes");
659 }
660 break;
661 }
662 default:
663 LOG(FATAL) << "Should not happen.";
664 }
665 } else { // if variable length
666 SparseBuffer& out = (*output_varlen_dense)[d];
667
668 const std::size_t num_elements = config.dense[d].elements_per_stride;
669
670 if (example_dtype != DT_INVALID &&
671 example_dtype != config.dense[d].dtype) {
672 return example_error(strings::StrCat(
673 "Data types don't match. ",
674 "Expected type: ", DataTypeString(config.dense[d].dtype)));
675 }
676
677 auto shape_error = [&](size_t size, StringPiece type_str) {
678 return example_error(strings::StrCat(
679 "Number of ", type_str,
680 " values is not a multiple of stride length. Saw ", size,
681 " values but output shape is: ",
682 config.dense[d].shape.DebugString()));
683 };
684
685 switch (config.dense[d].dtype) {
686 case DT_INT64: {
687 if (example_dtype != DT_INVALID) {
688 if (!feature.ParseInt64List(&out.int64_list)) {
689 return parse_error();
690 }
691 if (out.int64_list.size() % num_elements != 0) {
692 return shape_error(out.int64_list.size(), "int64");
693 }
694 }
695 out.example_end_indices.push_back(out.int64_list.size());
696 break;
697 }
698 case DT_FLOAT: {
699 if (example_dtype != DT_INVALID) {
700 if (!feature.ParseFloatList(&out.float_list)) {
701 return parse_error();
702 }
703 if (out.float_list.size() % num_elements != 0) {
704 return shape_error(out.float_list.size(), "float");
705 }
706 }
707 out.example_end_indices.push_back(out.float_list.size());
708 break;
709 }
710 case DT_STRING: {
711 if (example_dtype != DT_INVALID) {
712 if (!feature.ParseBytesList(&out.bytes_list)) {
713 return parse_error();
714 }
715 if (out.bytes_list.size() % num_elements != 0) {
716 return shape_error(out.bytes_list.size(), "bytes");
717 }
718 }
719 out.example_end_indices.push_back(out.bytes_list.size());
720 break;
721 }
722 default:
723 LOG(FATAL) << "Should not happen.";
724 }
725
726 if (output_stats) {
727 // Use `out.example_end_indices` to determine the feature-value count
728 // for this feature, because the preceding switch statement pushes
729 // the length of the appropriate feature list to that vector.
730 // TODO(b/111553342): If desirable, we could add support for counting
731 // elements in the features that aren't parsed, but this could add
732 // considerable runtime cost.
733 const size_t out_examples_count = out.example_end_indices.size();
734 if (out_examples_count == 1) {
735 output_stats->feature_values_count += out.example_end_indices[0];
736 } else {
737 output_stats->feature_values_count +=
738 out.example_end_indices[out_examples_count - 1] -
739 out.example_end_indices[out_examples_count - 2];
740 }
741 }
742 }
743 } else {
744 // If feature was already visited, skip.
745 // Compare comment at the beginning of the loop.
746 if (sparse_feature_last_example[d] == example_index) {
747 LogSparseFeatureDataLoss(feature_name);
748 continue;
749 }
750 sparse_feature_last_example[d] = example_index;
751
752 // Handle sparse features.
753 SparseBuffer& out = (*output_sparse)[d];
754 if (example_dtype != DT_INVALID &&
755 example_dtype != config.sparse[d].dtype) {
756 return example_error(strings::StrCat(
757 "Data types don't match. ",
758 "Expected type: ", DataTypeString(config.sparse[d].dtype),
759 ", Actual type: ", DataTypeString(example_dtype)));
760 }
761
762 switch (config.sparse[d].dtype) {
763 case DT_INT64: {
764 if (example_dtype != DT_INVALID) {
765 if (!feature.ParseInt64List(&out.int64_list)) {
766 return parse_error();
767 }
768 }
769 out.example_end_indices.push_back(out.int64_list.size());
770 break;
771 }
772 case DT_FLOAT: {
773 if (example_dtype != DT_INVALID) {
774 if (!feature.ParseFloatList(&out.float_list)) {
775 return parse_error();
776 }
777 }
778 out.example_end_indices.push_back(out.float_list.size());
779 break;
780 }
781 case DT_STRING: {
782 if (example_dtype != DT_INVALID) {
783 if (!feature.ParseBytesList(&out.bytes_list)) {
784 return parse_error();
785 }
786 }
787 out.example_end_indices.push_back(out.bytes_list.size());
788 break;
789 }
790 default:
791 LOG(FATAL) << "Should not happen.";
792 }
793
794 if (output_stats) {
795 // Use `out.example_end_indices` to determine the feature-value count
796 // for this feature, because the preceding switch statement pushes
797 // the length of the appropriate feature list to that vector.
798 // TODO(b/111553342): If desirable, we could add support for counting
799 // elements in the features that aren't parsed, but this could add
800 // considerable runtime cost.
801 const size_t out_examples_count = out.example_end_indices.size();
802 if (out_examples_count == 1) {
803 output_stats->feature_values_count += out.example_end_indices[0];
804 } else {
805 output_stats->feature_values_count +=
806 out.example_end_indices[out_examples_count - 1] -
807 out.example_end_indices[out_examples_count - 2];
808 }
809 }
810 }
811 }
812
813 // Handle missing dense features for fixed strides.
814 for (size_t d = 0; d < config.dense.size(); ++d) {
815 if (config.dense[d].variable_length) continue;
816 if (dense_feature_last_example[d] == example_index) continue;
817 if (config.dense[d].default_value.NumElements() == 0) {
818 return errors::InvalidArgument(
819 "Name: ", example_name, ", Feature: ", config.dense[d].feature_name,
820 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
821 " is required but could not be found.");
822 }
823 const Tensor& in = config.dense[d].default_value;
824 Tensor& out = (*output_dense)[d];
825 const std::size_t num_elements = in.shape().num_elements();
826 const std::size_t offset = example_index * num_elements;
827
828 switch (config.dense[d].dtype) {
829 case DT_INT64: {
830 std::copy_n(in.flat<int64>().data(), num_elements,
831 out.flat<int64>().data() + offset);
832 break;
833 }
834 case DT_FLOAT: {
835 std::copy_n(in.flat<float>().data(), num_elements,
836 out.flat<float>().data() + offset);
837 break;
838 }
839 case DT_STRING: {
840 std::copy_n(in.flat<string>().data(), num_elements,
841 out.flat<string>().data() + offset);
842 break;
843 }
844 default:
845 LOG(FATAL) << "Should not happen.";
846 }
847 }
848
849 // Handle missing varlen dense features.
850 for (size_t d = 0; d < config.dense.size(); ++d) {
851 if (!config.dense[d].variable_length) continue;
852 if (dense_feature_last_example[d] == example_index) continue;
853 SparseBuffer& out = (*output_varlen_dense)[d];
854 size_t prev_example_end_index =
855 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
856 out.example_end_indices.push_back(prev_example_end_index);
857 }
858
859 // Handle missing sparse features.
860 for (size_t d = 0; d < config.sparse.size(); ++d) {
861 if (sparse_feature_last_example[d] == example_index) continue;
862 SparseBuffer& out = (*output_sparse)[d];
863 size_t prev_example_end_index =
864 out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
865 out.example_end_indices.push_back(prev_example_end_index);
866 }
867
868 return Status::OK();
869 }
870
CheckConfigDataType(DataType dtype)871 Status CheckConfigDataType(DataType dtype) {
872 switch (dtype) {
873 case DT_INT64:
874 case DT_FLOAT:
875 case DT_STRING:
876 return Status::OK();
877 default:
878 return errors::InvalidArgument("Invalid config dtype: ",
879 DataTypeString(dtype));
880 }
881 }
882
883 template <typename T>
884 const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
885
886 template <>
GetListFromBuffer(const SparseBuffer & buffer)887 const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
888 return buffer.int64_list;
889 }
890 template <>
GetListFromBuffer(const SparseBuffer & buffer)891 const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
892 return buffer.float_list;
893 }
894 template <>
GetListFromBuffer(const SparseBuffer & buffer)895 const SmallVector<string>& GetListFromBuffer<string>(
896 const SparseBuffer& buffer) {
897 return buffer.bytes_list;
898 }
899
900 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)901 void CopyOrMoveBlock(const T* b, const T* e, T* t) {
902 std::copy(b, e, t);
903 }
904 template <>
CopyOrMoveBlock(const string * b,const string * e,string * t)905 void CopyOrMoveBlock(const string* b, const string* e, string* t) {
906 std::move(b, e, t);
907 }
908
909 template <typename T>
FillAndCopyVarLen(const int d,const size_t num_elements,const size_t num_elements_per_minibatch,const Config & config,const std::vector<std::vector<SparseBuffer>> & varlen_dense_buffers,Tensor * values)910 void FillAndCopyVarLen(
911 const int d, const size_t num_elements,
912 const size_t num_elements_per_minibatch, const Config& config,
913 const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
914 Tensor* values) {
915 const Tensor& default_value = config.dense[d].default_value;
916
917 // Copy-fill the tensors (creating the zero/fill-padding)
918 std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
919 default_value.flat<T>()(0));
920
921 // Data is [batch_size, max_num_elements, data_stride_size]
922 // and num_elements_per_minibatch = max_num_elements * data_stride_size
923 auto data = values->flat<T>().data();
924
925 // Iterate over minibatch elements
926 for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
927 const SparseBuffer& buffer = varlen_dense_buffers[i][d];
928 // Number of examples being stored in this buffer
929 const auto& end_indices = buffer.example_end_indices;
930 const size_t examples_in_buffer = end_indices.size();
931 // const size_t stride_size = config.dense[d].elements_per_stride;
932
933 const auto& list = GetListFromBuffer<T>(buffer);
934 auto list_ptr = list.begin();
935
936 size_t elements_tally = 0;
937 // Iterate through all the examples stored in this buffer.
938 for (size_t j = 0; j < examples_in_buffer; ++j) {
939 // Number of elements stored for this example.
940 const size_t num_elems = end_indices[j] - elements_tally;
941 CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data);
942 // Move forward this many elements in the varlen buffer.
943 list_ptr += num_elems;
944 // Move forward to the next minibatch entry in the values output.
945 data += num_elements_per_minibatch;
946 elements_tally = end_indices[j];
947 }
948 DCHECK(elements_tally == list.size());
949 }
950 }
951
952 } // namespace
953
FastParseExample(const Config & config,gtl::ArraySlice<string> serialized,gtl::ArraySlice<string> example_names,thread::ThreadPool * thread_pool,Result * result)954 Status FastParseExample(const Config& config,
955 gtl::ArraySlice<string> serialized,
956 gtl::ArraySlice<string> example_names,
957 thread::ThreadPool* thread_pool, Result* result) {
958 DCHECK(result != nullptr);
959 // Check config so we can safely CHECK(false) in switches on config.*.dtype
960 for (auto& c : config.sparse) {
961 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
962 }
963 for (auto& c : config.dense) {
964 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
965 }
966
967 if (config.collect_feature_stats) {
968 result->feature_stats.resize(serialized.size());
969 }
970
971 size_t config_size = config.dense.size() + config.sparse.size();
972 SeededHasher hasher;
973 // Build config index.
974 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
975 bool ok = true;
976 for (size_t i = 0; i < 1000; ++i) {
977 for (size_t d = 0; d < config.dense.size(); ++d) {
978 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
979 {d, Type::Dense});
980 }
981 for (size_t d = 0; d < config.sparse.size(); ++d) {
982 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
983 {d, Type::Sparse});
984 }
985 if (ok) break;
986 LOG(WARNING) << "Collision found. This should happen only if you have "
987 "around 2^32 entries in your config.";
988 hasher.seed++;
989 config_index.Clear(config_size);
990 }
991 if (!ok) {
992 return errors::Internal(
993 "Could not avoid collision. This should not happen.");
994 }
995
996 // Allocate dense output for fixed length dense values
997 // (variable-length dense and sparse have to be buffered).
998 std::vector<Tensor> fixed_dense_values(config.dense.size());
999 for (size_t d = 0; d < config.dense.size(); ++d) {
1000 if (config.dense[d].variable_length) continue;
1001 TensorShape out_shape;
1002 out_shape.AddDim(serialized.size());
1003 for (const int64 dim : config.dense[d].shape.dim_sizes()) {
1004 out_shape.AddDim(dim);
1005 }
1006 fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
1007 }
1008
1009 // This parameter affects performance in a big and data-dependent way.
1010 const size_t kMiniBatchSizeBytes = 50000;
1011
1012 // Calculate number of minibatches.
1013 // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1014 // Apply 'special logic' below for small and big regimes.
1015 const size_t num_minibatches = [&] {
1016 size_t result = 0;
1017 size_t minibatch_bytes = 0;
1018 for (size_t i = 0; i < serialized.size(); i++) {
1019 if (minibatch_bytes == 0) { // start minibatch
1020 result++;
1021 }
1022 minibatch_bytes += serialized[i].size() + 1;
1023 if (minibatch_bytes > kMiniBatchSizeBytes) {
1024 minibatch_bytes = 0;
1025 }
1026 }
1027 // 'special logic'
1028 const size_t min_minibatches = std::min<size_t>(8, serialized.size());
1029 const size_t max_minibatches = 64;
1030 return std::max<size_t>(min_minibatches,
1031 std::min<size_t>(max_minibatches, result));
1032 }();
1033
1034 auto first_example_of_minibatch = [&](size_t minibatch) -> size_t {
1035 return (serialized.size() * minibatch) / num_minibatches;
1036 };
1037
1038 // TODO(lew): A big performance low-hanging fruit here is to improve
1039 // num_minibatches calculation to take into account actual amount of work
1040 // needed, as the size in bytes is not perfect. Linear combination of
1041 // size in bytes and average number of features per example is promising.
1042 // Even better: measure time instead of estimating, but this is too costly
1043 // in small batches.
1044 // Maybe accept outside parameter #num_minibatches?
1045
1046 // Do minibatches in parallel.
1047 std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
1048 std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
1049 std::vector<Status> status_of_minibatch(num_minibatches);
1050 auto ProcessMiniBatch = [&](size_t minibatch) {
1051 sparse_buffers[minibatch].resize(config.sparse.size());
1052 varlen_dense_buffers[minibatch].resize(config.dense.size());
1053 size_t start = first_example_of_minibatch(minibatch);
1054 size_t end = first_example_of_minibatch(minibatch + 1);
1055 for (size_t e = start; e < end; ++e) {
1056 PerExampleFeatureStats* stats = nullptr;
1057 if (config.collect_feature_stats) {
1058 stats = &result->feature_stats[e];
1059 }
1060 status_of_minibatch[minibatch] = FastParseSerializedExample(
1061 serialized[e],
1062 (!example_names.empty() ? example_names[e] : "<unknown>"), e, config,
1063 config_index, hasher, &fixed_dense_values,
1064 &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch], stats);
1065 if (!status_of_minibatch[minibatch].ok()) break;
1066 }
1067 };
1068
1069 ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool);
1070
1071 for (Status& status : status_of_minibatch) {
1072 TF_RETURN_IF_ERROR(status);
1073 }
1074
1075 for (size_t d = 0; d < config.dense.size(); ++d) {
1076 result->dense_values.push_back(std::move(fixed_dense_values[d]));
1077 }
1078
1079 // Merge SparseBuffers from all minibatches for every config.sparse.
1080 auto MergeSparseMinibatches = [&](size_t d) {
1081 // Loop over minibatches
1082 size_t total_num_features = 0;
1083 size_t max_num_features = 0;
1084 for (auto& sparse_values_tmp : sparse_buffers) {
1085 const std::vector<size_t>& end_indices =
1086 sparse_values_tmp[d].example_end_indices;
1087 total_num_features += end_indices.back();
1088 max_num_features = std::max(max_num_features, end_indices[0]);
1089 for (size_t i = 1; i < end_indices.size(); ++i) {
1090 size_t example_size = end_indices[i] - end_indices[i - 1];
1091 max_num_features = std::max(max_num_features, example_size);
1092 }
1093 }
1094
1095 TensorShape indices_shape;
1096 indices_shape.AddDim(total_num_features);
1097 indices_shape.AddDim(2);
1098 result->sparse_indices.emplace_back(DT_INT64, indices_shape);
1099 Tensor* indices = &result->sparse_indices.back();
1100
1101 TensorShape values_shape;
1102 values_shape.AddDim(total_num_features);
1103 result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape);
1104 Tensor* values = &result->sparse_values.back();
1105
1106 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2}));
1107 auto shapes_shape_t = result->sparse_shapes.back().vec<int64>();
1108 shapes_shape_t(0) = serialized.size();
1109 shapes_shape_t(1) = max_num_features;
1110
1111 size_t offset = 0;
1112 for (size_t i = 0; i < sparse_buffers.size(); ++i) {
1113 const SparseBuffer& buffer = sparse_buffers[i][d];
1114
1115 // Update indices.
1116 int64* ix_p = &indices->matrix<int64>()(offset, 0);
1117 size_t delta = 0;
1118 size_t example_index = first_example_of_minibatch(i);
1119 for (size_t example_end_index : buffer.example_end_indices) {
1120 size_t feature_index = 0;
1121 for (; delta < example_end_index; ++delta) {
1122 // Column 0: example index
1123 *ix_p = example_index;
1124 // Column 1: the feature index buffer example
1125 *(ix_p + 1) = feature_index;
1126 ix_p += 2;
1127 ++feature_index;
1128 }
1129 ++example_index;
1130 }
1131
1132 // Copy values over.
1133 switch (config.sparse[d].dtype) {
1134 case DT_INT64: {
1135 std::copy(buffer.int64_list.begin(), buffer.int64_list.end(),
1136 values->flat<int64>().data() + offset);
1137 break;
1138 }
1139 case DT_FLOAT: {
1140 std::copy(buffer.float_list.begin(), buffer.float_list.end(),
1141 values->flat<float>().data() + offset);
1142 break;
1143 }
1144 case DT_STRING: {
1145 std::move(buffer.bytes_list.begin(), buffer.bytes_list.end(),
1146 values->flat<string>().data() + offset);
1147 break;
1148 }
1149 default:
1150 LOG(FATAL) << "Should not happen.";
1151 }
1152
1153 offset += delta;
1154 }
1155 };
1156
1157 // Merge SparseBuffers from all minibatches for every config.dense having
1158 // variable_length.
1159 auto MergeDenseVarLenMinibatches = [&](size_t d) {
1160 if (!config.dense[d].variable_length) return;
1161
1162 // Loop over minibatches
1163 size_t max_num_features = 0;
1164 for (auto& dense_values_tmp : varlen_dense_buffers) {
1165 std::vector<size_t>& end_indices =
1166 dense_values_tmp[d].example_end_indices;
1167 max_num_features = std::max(max_num_features, end_indices[0]);
1168 for (size_t i = 1; i < end_indices.size(); ++i) {
1169 size_t example_size = end_indices[i] - end_indices[i - 1];
1170 max_num_features = std::max(max_num_features, example_size);
1171 }
1172 }
1173
1174 const size_t stride_size = config.dense[d].elements_per_stride;
1175 const size_t max_num_elements = max_num_features / stride_size;
1176 TensorShape values_shape;
1177 DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0);
1178 const size_t batch_size = serialized.size();
1179 values_shape.AddDim(batch_size);
1180 values_shape.AddDim(max_num_elements);
1181 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1182 values_shape.AddDim(config.dense[d].shape.dim_size(i));
1183 }
1184 Tensor values(config.dense[d].dtype, values_shape);
1185 result->dense_values[d] = values;
1186 const size_t num_elements = values.NumElements();
1187
1188 // Nothing to write, exit early.
1189 if (num_elements == 0) return;
1190
1191 const size_t num_elements_per_minibatch = num_elements / batch_size;
1192
1193 switch (config.dense[d].dtype) {
1194 case DT_INT64: {
1195 FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
1196 config, varlen_dense_buffers, &values);
1197 break;
1198 }
1199 case DT_FLOAT: {
1200 FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
1201 config, varlen_dense_buffers, &values);
1202 break;
1203 }
1204 case DT_STRING: {
1205 FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch,
1206 config, varlen_dense_buffers, &values);
1207 break;
1208 }
1209 default:
1210 LOG(FATAL) << "Should not happen.";
1211 }
1212 };
1213
1214 for (size_t d = 0; d < config.dense.size(); ++d) {
1215 MergeDenseVarLenMinibatches(d);
1216 }
1217
1218 for (size_t d = 0; d < config.sparse.size(); ++d) {
1219 MergeSparseMinibatches(d);
1220 }
1221
1222 return Status::OK();
1223 }
1224
FastParseSingleExample(const Config & config,const string & serialized,Result * result)1225 Status FastParseSingleExample(const Config& config, const string& serialized,
1226 Result* result) {
1227 DCHECK(result != nullptr);
1228 // Check config so we can safely CHECK(false) in switches on config.*.dtype
1229 for (auto& c : config.sparse) {
1230 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1231 }
1232 for (auto& c : config.dense) {
1233 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1234 }
1235
1236 PerExampleFeatureStats* stats = nullptr;
1237 if (config.collect_feature_stats) {
1238 result->feature_stats.emplace_back();
1239 stats = &result->feature_stats.back();
1240 }
1241
1242 // TODO(mrry): Cache the construction of this map at Op construction time.
1243 size_t config_size = config.dense.size() + config.sparse.size();
1244 SeededHasher hasher;
1245 // Build config index.
1246 PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size);
1247 bool ok = true;
1248 for (size_t i = 0; i < 1000; ++i) {
1249 for (size_t d = 0; d < config.dense.size(); ++d) {
1250 ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name),
1251 {d, Type::Dense});
1252 }
1253 for (size_t d = 0; d < config.sparse.size(); ++d) {
1254 ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name),
1255 {d, Type::Sparse});
1256 }
1257 if (ok) break;
1258 LOG(WARNING) << "Collision found. This should happen only if you have "
1259 "around 2^32 entries in your config.";
1260 hasher.seed++;
1261 config_index.Clear(config_size);
1262 }
1263 if (!ok) {
1264 return errors::Internal(
1265 "Could not avoid collision. This should not happen.");
1266 }
1267
1268 // Allocate dense output tensors.
1269 for (size_t d = 0; d < config.dense.size(); ++d) {
1270 if (!config.dense[d].variable_length) {
1271 TensorShape values_shape;
1272 if (!config.dense[d].shape.AsTensorShape(&values_shape)) {
1273 return errors::Internal(
1274 "Fixed-length shape was not a statically defined shape.");
1275 }
1276 result->dense_values.emplace_back(config.dense[d].dtype, values_shape);
1277 } else {
1278 // Variable-length tensor will be allocated later.
1279 result->dense_values.emplace_back();
1280 }
1281 }
1282
1283 // Allocate sparse output tensors.
1284 for (size_t d = 0; d < config.sparse.size(); ++d) {
1285 // The dense_shape is always a vector of length 1.
1286 result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1}));
1287 // Variable-length tensors will be allocated later.
1288 result->sparse_indices.emplace_back();
1289 result->sparse_values.emplace_back();
1290 }
1291
1292 parsed::Example parsed_example;
1293 if (!ParseExample(serialized, &parsed_example)) {
1294 return errors::InvalidArgument("Could not parse example input, value: '",
1295 serialized, "'");
1296 }
1297 std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false);
1298 std::vector<bool> dense_feature_already_seen(config.dense.size(), false);
1299
1300 if (stats) {
1301 // TODO(b/111553342): This may over-count the number of features if there
1302 // are duplicate keys in the feature map. Consider deduplicating the keys
1303 // before computing the count.
1304 stats->features_count = parsed_example.size();
1305 }
1306
1307 // Handle features present in the example.
1308 const size_t parsed_example_size = parsed_example.size();
1309 for (size_t i = 0; i < parsed_example_size; ++i) {
1310 // This is a logic that standard protobuf parsing is implementing.
1311 // I.e. last entry in the map overwrites all the previous ones.
1312 parsed::FeatureMapEntry& name_and_feature =
1313 parsed_example[parsed_example_size - i - 1];
1314
1315 const StringPiece feature_name = name_and_feature.first;
1316 parsed::Feature& feature = name_and_feature.second;
1317
1318 std::pair<size_t, Type> d_and_type;
1319 uint64 h = hasher(feature_name);
1320 if (!config_index.Find(h, &d_and_type)) continue;
1321
1322 size_t d = d_and_type.first;
1323 bool is_dense = d_and_type.second == Type::Dense;
1324
1325 {
1326 // Testing for PresizedCuckooMap collision.
1327 // TODO(lew): Use dense_hash_map and avoid this and hasher creation.
1328 const string& config_feature_name = is_dense
1329 ? config.dense[d].feature_name
1330 : config.sparse[d].feature_name;
1331 if (feature_name != config_feature_name) continue;
1332 }
1333
1334 auto example_error = [feature_name](StringPiece suffix) {
1335 return errors::InvalidArgument("Key: ", feature_name, ". ", suffix);
1336 };
1337
1338 auto parse_error = [feature_name] {
1339 return errors::InvalidArgument("Key: ", feature_name,
1340 ". Can't parse serialized Example.");
1341 };
1342
1343 DataType example_dtype;
1344 TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype));
1345 if (example_dtype == DT_INVALID) continue;
1346
1347 if (is_dense && !config.dense[d].variable_length) {
1348 // If feature was already visited, skip.
1349 // Compare comment at the beginning of the loop.
1350 if (dense_feature_already_seen[d]) {
1351 LogDenseFeatureDataLoss(feature_name);
1352 continue;
1353 }
1354 dense_feature_already_seen[d] = true;
1355
1356 if (example_dtype != config.dense[d].dtype) {
1357 return example_error(strings::StrCat(
1358 "Data types don't match. Data type: ",
1359 DataTypeString(example_dtype),
1360 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1361 }
1362
1363 Tensor* out = &result->dense_values[d];
1364 const std::size_t num_elements = config.dense[d].elements_per_stride;
1365 if (stats) {
1366 // TODO(b/111553342): If desirable, we could add support for counting
1367 // elements in the features that aren't parsed, but this could add
1368 // considerable runtime cost.
1369 stats->feature_values_count += num_elements;
1370 }
1371 switch (example_dtype) {
1372 case DT_INT64: {
1373 auto out_p = out->flat<int64>().data();
1374 LimitedArraySlice<int64> slice(out_p, num_elements);
1375 if (!feature.ParseInt64List(&slice)) return parse_error();
1376 if (slice.EndDistance() != 0) {
1377 return parse_error();
1378 }
1379 break;
1380 }
1381 case DT_FLOAT: {
1382 auto out_p = out->flat<float>().data();
1383 LimitedArraySlice<float> slice(out_p, num_elements);
1384 if (!feature.ParseFloatList(&slice)) return parse_error();
1385 if (slice.EndDistance() != 0) {
1386 return parse_error();
1387 }
1388 break;
1389 }
1390 case DT_STRING: {
1391 auto out_p = out->flat<string>().data();
1392 LimitedArraySlice<string> slice(out_p, num_elements);
1393 if (!feature.ParseBytesList(&slice)) return parse_error();
1394 if (slice.EndDistance() != 0) {
1395 return parse_error();
1396 }
1397 break;
1398 }
1399 default:
1400 LOG(FATAL) << "Should not happen.";
1401 }
1402
1403 } else { // if variable length
1404 SparseBuffer out_temp;
1405 const size_t num_elements_divisor =
1406 is_dense ? config.dense[d].elements_per_stride : 1;
1407 size_t num_elements;
1408
1409 if (is_dense) {
1410 // If feature was already visited, skip.
1411 // Compare comment at the beginning of the loop.
1412 if (dense_feature_already_seen[d]) {
1413 LogDenseFeatureDataLoss(feature_name);
1414 continue;
1415 }
1416 dense_feature_already_seen[d] = true;
1417 if (example_dtype != config.dense[d].dtype) {
1418 return example_error(strings::StrCat(
1419 "Data types don't match. Data type: ",
1420 DataTypeString(example_dtype),
1421 " but expected type: ", DataTypeString(config.dense[d].dtype)));
1422 }
1423 } else {
1424 // If feature was already visited, skip.
1425 // Compare comment at the beginning of the loop.
1426 if (sparse_feature_already_seen[d]) {
1427 LogSparseFeatureDataLoss(feature_name);
1428 continue;
1429 }
1430 sparse_feature_already_seen[d] = true;
1431
1432 // Handle sparse features.
1433 if (example_dtype != DT_INVALID &&
1434 example_dtype != config.sparse[d].dtype) {
1435 return example_error(strings::StrCat(
1436 "Data types don't match. ",
1437 "Expected type: ", DataTypeString(config.sparse[d].dtype),
1438 ", Actual type: ", DataTypeString(example_dtype)));
1439 }
1440 }
1441
1442 switch (example_dtype) {
1443 case DT_INT64: {
1444 // TODO(mrry): Use the fact that the `int64_list` is packed to read
1445 // out the length and pre-allocate the output tensor.
1446 if (!feature.ParseInt64List(&out_temp.int64_list))
1447 return parse_error();
1448 num_elements = out_temp.int64_list.size();
1449 break;
1450 }
1451 case DT_FLOAT: {
1452 // TODO(mrry): Use the fact that the `float_list` is packed to read
1453 // out the length and pre-allocate the output tensor.
1454 if (!feature.ParseFloatList(&out_temp.float_list))
1455 return parse_error();
1456 num_elements = out_temp.float_list.size();
1457 break;
1458 }
1459 case DT_STRING: {
1460 int actual_num_elements = 0;
1461 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1462 return parse_error();
1463 }
1464 out_temp.bytes_list.reserve(actual_num_elements);
1465 if (!feature.ParseBytesList(&out_temp.bytes_list))
1466 return parse_error();
1467 num_elements = out_temp.bytes_list.size();
1468 break;
1469 }
1470 default:
1471 LOG(FATAL) << "Should not happen. " << DataTypeString(example_dtype);
1472 }
1473
1474 if (num_elements % num_elements_divisor != 0) {
1475 return parse_error();
1476 }
1477
1478 if (stats) {
1479 stats->feature_values_count += num_elements;
1480 }
1481
1482 Tensor* out;
1483 if (is_dense) {
1484 TensorShape values_shape;
1485 values_shape.AddDim(num_elements / num_elements_divisor);
1486 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1487 values_shape.AddDim(config.dense[d].shape.dim_size(i));
1488 }
1489
1490 out = &result->dense_values[d];
1491 *out = Tensor(config.dense[d].dtype, values_shape);
1492
1493 } else {
1494 Tensor* out_indices = &result->sparse_indices[d];
1495 Tensor* out_dense_shape = &result->sparse_shapes[d];
1496 out = &result->sparse_values[d];
1497
1498 // TODO(mrry): Investigate the possibility of not materializing
1499 // the indices (and perhaps dense_shape) until they are needed.
1500 *out_indices = Tensor(
1501 DT_INT64, TensorShape({static_cast<int64>(num_elements), 1}));
1502 auto indices_flat = out_indices->flat<int64>();
1503 for (size_t i = 0; i < num_elements; ++i) {
1504 indices_flat(i) = static_cast<int64>(i);
1505 }
1506
1507 *out_dense_shape = Tensor(DT_INT64, TensorShape({1}));
1508 auto shapes_shape_t = out_dense_shape->vec<int64>();
1509 shapes_shape_t(0) = num_elements;
1510
1511 *out = Tensor(config.sparse[d].dtype,
1512 TensorShape({static_cast<int64>(num_elements)}));
1513 }
1514
1515 switch (example_dtype) {
1516 case DT_INT64: {
1517 CopyOrMoveBlock(out_temp.int64_list.begin(),
1518 out_temp.int64_list.end(), out->flat<int64>().data());
1519 break;
1520 }
1521 case DT_FLOAT: {
1522 CopyOrMoveBlock(out_temp.float_list.begin(),
1523 out_temp.float_list.end(), out->flat<float>().data());
1524 break;
1525 }
1526 case DT_STRING: {
1527 CopyOrMoveBlock(out_temp.bytes_list.begin(),
1528 out_temp.bytes_list.end(),
1529 out->flat<string>().data());
1530 break;
1531 }
1532 default:
1533 LOG(FATAL) << "Should not happen.";
1534 }
1535 }
1536 }
1537
1538 // Handle missing dense features.
1539 for (size_t d = 0; d < config.dense.size(); ++d) {
1540 if (!dense_feature_already_seen[d]) {
1541 if (!config.dense[d].variable_length) {
1542 // Handle missing fixed-length dense feature.
1543 if (config.dense[d].default_value.NumElements() == 0) {
1544 return errors::InvalidArgument(
1545 "Feature: ", config.dense[d].feature_name,
1546 " (data type: ", DataTypeString(config.dense[d].dtype), ")",
1547 " is required but could not be found.");
1548 }
1549 result->dense_values[d] = config.dense[d].default_value;
1550 } else {
1551 // Handle missing varlen dense feature.
1552 TensorShape empty_shape;
1553 empty_shape.AddDim(0);
1554 for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
1555 empty_shape.AddDim(config.dense[d].shape.dim_size(i));
1556 }
1557 result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape);
1558 }
1559 }
1560 }
1561
1562 // Handle missing sparse features.
1563 for (size_t d = 0; d < config.sparse.size(); ++d) {
1564 if (!sparse_feature_already_seen[d]) {
1565 result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1}));
1566 result->sparse_values[d] =
1567 Tensor(config.sparse[d].dtype, TensorShape({0}));
1568 result->sparse_shapes[d].vec<int64>()(0) = 0;
1569 }
1570 }
1571
1572 return Status::OK();
1573 }
1574
1575 // Return the number of bytes elements parsed, or -1 on error. If out is null,
1576 // this method simply counts the number of elements without any copying.
ParseBytesFeature(protobuf::io::CodedInputStream * stream,string * out)1577 inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
1578 string* out) {
1579 int num_elements = 0;
1580 uint32 length;
1581 if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
1582 return -1;
1583 }
1584 if (length > 0) {
1585 auto limit = stream->PushLimit(length);
1586 while (!stream->ExpectAtEnd()) {
1587 uint32 bytes_length;
1588 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1589 !stream->ReadVarint32(&bytes_length) ||
1590 (out != nullptr && !stream->ReadString(out++, bytes_length))) {
1591 return -1;
1592 }
1593 if (out == nullptr) {
1594 stream->Skip(bytes_length);
1595 }
1596 num_elements++;
1597 }
1598 stream->PopLimit(limit);
1599 }
1600 return num_elements;
1601 }
1602
PadFloatFeature(int num_to_pad,float * out)1603 inline void PadFloatFeature(int num_to_pad, float* out) {
1604 for (int i = 0; i < num_to_pad; i++) {
1605 *out++ = 0.0;
1606 }
1607 }
1608
PadInt64Feature(int num_to_pad,int64 * out)1609 inline void PadInt64Feature(int num_to_pad, int64* out) {
1610 for (int i = 0; i < num_to_pad; i++) {
1611 *out++ = 0;
1612 }
1613 }
1614
1615 // Return the number of float elements parsed, or -1 on error. If out is null,
1616 // this method simply counts the number of elements without any copying.
ParseFloatFeature(protobuf::io::CodedInputStream * stream,float * out)1617 inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
1618 float* out) {
1619 int num_elements = 0;
1620 uint32 length;
1621 if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
1622 return -1;
1623 }
1624 if (length > 0) {
1625 auto limit = stream->PushLimit(length);
1626 uint8 peek_tag = PeekTag(stream);
1627 if (peek_tag == kDelimitedTag(1)) { // packed
1628 uint32 packed_length;
1629 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1630 !stream->ReadVarint32(&packed_length)) {
1631 return -1;
1632 }
1633 auto packed_limit = stream->PushLimit(packed_length);
1634 while (!stream->ExpectAtEnd()) {
1635 uint32 buffer32;
1636 if (!stream->ReadLittleEndian32(&buffer32)) {
1637 return -1;
1638 }
1639 if (out != nullptr) {
1640 *out++ = absl::bit_cast<float>(buffer32);
1641 }
1642 num_elements++;
1643 }
1644 stream->PopLimit(packed_limit);
1645 } else if (peek_tag == kFixed32Tag(1)) {
1646 while (!stream->ExpectAtEnd()) {
1647 uint32 buffer32;
1648 if (!stream->ExpectTag(kFixed32Tag(1)) ||
1649 !stream->ReadLittleEndian32(&buffer32)) {
1650 return -1;
1651 }
1652 if (out != nullptr) {
1653 *out++ = absl::bit_cast<float>(buffer32);
1654 }
1655 num_elements++;
1656 }
1657 } else {
1658 // Unknown tag.
1659 return -1;
1660 }
1661 stream->PopLimit(limit);
1662 }
1663 return num_elements;
1664 }
1665
1666 // Return the number of int64 elements parsed, or -1 on error. If out is null,
1667 // this method simply counts the number of elements without any copying.
ParseInt64Feature(protobuf::io::CodedInputStream * stream,int64 * out)1668 inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
1669 int64* out) {
1670 int num_elements = 0;
1671 uint32 length;
1672 if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
1673 return -1;
1674 }
1675 if (length > 0) {
1676 auto limit = stream->PushLimit(length);
1677 uint8 peek_tag = PeekTag(stream);
1678 if (peek_tag == kDelimitedTag(1)) { // packed
1679 uint32 packed_length;
1680 if (!stream->ExpectTag(kDelimitedTag(1)) ||
1681 !stream->ReadVarint32(&packed_length)) {
1682 return -1;
1683 }
1684 auto packed_limit = stream->PushLimit(packed_length);
1685 while (!stream->ExpectAtEnd()) {
1686 protobuf_uint64 n; // There is no API for int64
1687 if (!stream->ReadVarint64(&n)) {
1688 return -1;
1689 }
1690 if (out != nullptr) {
1691 *out++ = n;
1692 }
1693 num_elements++;
1694 }
1695 stream->PopLimit(packed_limit);
1696 } else if (peek_tag == kVarintTag(1)) {
1697 while (!stream->ExpectAtEnd()) {
1698 protobuf_uint64 n; // There is no API for int64
1699 if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
1700 return -1;
1701 }
1702 if (out != nullptr) {
1703 *out++ = n;
1704 }
1705 num_elements++;
1706 }
1707 } else {
1708 // Unknown tag.
1709 return -1;
1710 }
1711 stream->PopLimit(limit);
1712 }
1713 return num_elements;
1714 }
1715
ParseDataType(protobuf::io::CodedInputStream * stream)1716 inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
1717 uint8 peek_tag = PeekTag(stream);
1718 switch (peek_tag) {
1719 case kDelimitedTag(1):
1720 return DT_STRING;
1721 case kDelimitedTag(2):
1722 return DT_FLOAT;
1723 case kDelimitedTag(3):
1724 return DT_INT64;
1725 default:
1726 return DT_INVALID;
1727 }
1728 }
1729
SkipEmptyFeature(protobuf::io::CodedInputStream * stream,DataType dtype)1730 inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
1731 DataType dtype) {
1732 switch (dtype) {
1733 case DT_STRING:
1734 if (!stream->ExpectTag(kDelimitedTag(1))) {
1735 return false;
1736 }
1737 break;
1738 case DT_FLOAT:
1739 if (!stream->ExpectTag(kDelimitedTag(2))) {
1740 return false;
1741 }
1742 break;
1743 case DT_INT64:
1744 if (!stream->ExpectTag(kDelimitedTag(3))) {
1745 return false;
1746 }
1747 break;
1748 default:
1749 return false;
1750 }
1751 uint32 length;
1752 return stream->ReadVarint32(&length) && length == 0;
1753 }
1754
1755 // TODO(sundberg): Use the threadpool to parallelize example parsing.
1756 // TODO(b/111553342): Support extracting feature statistics from the examples.
FastParseSequenceExample(const FastParseExampleConfig & context_config,const FastParseExampleConfig & feature_list_config,gtl::ArraySlice<string> serialized,gtl::ArraySlice<string> example_names,thread::ThreadPool * thread_pool,Result * context_result,Result * feature_list_result,std::vector<Tensor> * dense_feature_lengths)1757 Status FastParseSequenceExample(
1758 const FastParseExampleConfig& context_config,
1759 const FastParseExampleConfig& feature_list_config,
1760 gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
1761 thread::ThreadPool* thread_pool, Result* context_result,
1762 Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
1763 int num_examples = serialized.size();
1764 DCHECK(context_result != nullptr);
1765 DCHECK(feature_list_result != nullptr);
1766 DCHECK(dense_feature_lengths != nullptr);
1767 size_t num_context_features =
1768 context_config.sparse.size() + context_config.dense.size();
1769 absl::flat_hash_map<StringPiece, bool> context_is_sparse;
1770 context_is_sparse.reserve(num_context_features);
1771 absl::flat_hash_map<StringPiece, std::pair<DataType, size_t>>
1772 context_feature_type_and_lengths;
1773 context_feature_type_and_lengths.reserve(num_context_features);
1774 if (!example_names.empty() && example_names.size() != num_examples) {
1775 return errors::InvalidArgument(
1776 "example_names must be empty or have the correct number of elements");
1777 }
1778 for (auto& c : context_config.sparse) {
1779 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1780 context_feature_type_and_lengths[c.feature_name] =
1781 std::make_pair(c.dtype, 0);
1782 context_is_sparse[c.feature_name] = true;
1783 }
1784 for (auto& c : context_config.dense) {
1785 if (context_is_sparse[c.feature_name]) {
1786 return errors::InvalidArgument("Context feature " + c.feature_name +
1787 " cannot be both dense and sparse");
1788 }
1789 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1790 context_feature_type_and_lengths[c.feature_name] =
1791 std::make_pair(c.dtype, c.default_value.NumElements());
1792 if (c.default_value.NumElements() > 0) {
1793 if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
1794 return errors::InvalidArgument("Default value for context feature ",
1795 c.feature_name,
1796 " has an incorrect shape: saw ",
1797 c.default_value.shape().DebugString(),
1798 " but expected ", c.shape.DebugString());
1799 }
1800 }
1801 }
1802 size_t num_sequence_features =
1803 feature_list_config.sparse.size() + feature_list_config.dense.size();
1804 absl::flat_hash_map<StringPiece, bool> sequence_is_sparse;
1805 sequence_is_sparse.reserve(num_sequence_features);
1806 absl::flat_hash_map<StringPiece, std::pair<DataType, size_t>>
1807 sequence_feature_type_and_lengths;
1808 sequence_feature_type_and_lengths.reserve(num_sequence_features);
1809 for (auto& c : feature_list_config.sparse) {
1810 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1811 sequence_feature_type_and_lengths[c.feature_name] =
1812 std::make_pair(c.dtype, 0);
1813 sequence_is_sparse[c.feature_name] = true;
1814 }
1815 for (auto& c : feature_list_config.dense) {
1816 if (sequence_is_sparse[c.feature_name]) {
1817 return errors::InvalidArgument("Sequence feature " + c.feature_name +
1818 " cannot be both dense and sparse");
1819 }
1820 TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
1821 sequence_feature_type_and_lengths[c.feature_name] =
1822 std::make_pair(c.dtype, 0);
1823 }
1824
1825 std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
1826 all_context_features(num_examples);
1827 std::vector<absl::flat_hash_map<StringPiece, StringPiece>>
1828 all_sequence_features(num_examples);
1829 const string kUnknown = "<unknown>";
1830 for (int d = 0; d < num_examples; d++) {
1831 const string& example = serialized[d];
1832 const string& example_name =
1833 example_names.empty() ? kUnknown : example_names[d];
1834 auto* context_features = &all_context_features[d];
1835 auto* sequence_features = &all_sequence_features[d];
1836
1837 protobuf::io::CodedInputStream stream(
1838 reinterpret_cast<const uint8*>(example.data()), example.size());
1839 // Not clear what this does. Why not stream.EnableAliasing()?
1840 EnableAliasing(&stream);
1841
1842 // Extract pointers to all features within this serialized example.
1843 while (!stream.ExpectAtEnd()) {
1844 absl::flat_hash_map<StringPiece, StringPiece>* features = nullptr;
1845 const absl::flat_hash_map<StringPiece, std::pair<DataType, size_t>>*
1846 config = nullptr;
1847 if (stream.ExpectTag(kDelimitedTag(1))) {
1848 // Context
1849 features = context_features;
1850 config = &context_feature_type_and_lengths;
1851 } else if (stream.ExpectTag(kDelimitedTag(2))) {
1852 // Sequence
1853 features = sequence_features;
1854 config = &sequence_feature_type_and_lengths;
1855 } else if (!SkipExtraneousTag(&stream)) {
1856 return errors::InvalidArgument(
1857 "Invalid protocol message input, example id: ", example_name);
1858 }
1859 if (features != nullptr) {
1860 uint32 length;
1861 if (!stream.ReadVarint32(&length)) {
1862 return errors::InvalidArgument(
1863 "Invalid protocol message input, example id: ", example_name);
1864 }
1865 auto limit = stream.PushLimit(length);
1866 while (!stream.ExpectAtEnd()) {
1867 StringPiece key, value;
1868 uint32 length;
1869 if (!stream.ExpectTag(kDelimitedTag(1)) ||
1870 !stream.ReadVarint32(&length)) {
1871 return errors::InvalidArgument(
1872 "Invalid protocol message input, example id: ", example_name);
1873 }
1874 auto limit = stream.PushLimit(length);
1875 if (!stream.ExpectTag(kDelimitedTag(1)) ||
1876 !ParseString(&stream, &key) ||
1877 !stream.ExpectTag(kDelimitedTag(2)) ||
1878 !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
1879 return errors::InvalidArgument(
1880 "Invalid protocol message input, example id: ", example_name);
1881 }
1882 stream.PopLimit(limit);
1883 // Only save if this feature was requested.
1884 if (config->count(key) > 0) {
1885 (*features)[key] = value;
1886 }
1887 }
1888 stream.PopLimit(limit);
1889 }
1890 }
1891
1892 for (const auto& c : *context_features) {
1893 size_t num_elements = 0;
1894 if (!c.second.empty()) {
1895 protobuf::io::CodedInputStream stream(
1896 reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
1897 EnableAliasing(&stream);
1898 DataType dtype = context_feature_type_and_lengths[c.first].first;
1899 int64 num;
1900 switch (dtype) {
1901 case DT_STRING:
1902 num = ParseBytesFeature(&stream, nullptr);
1903 break;
1904 case DT_FLOAT:
1905 num = ParseFloatFeature(&stream, nullptr);
1906 break;
1907 case DT_INT64:
1908 num = ParseInt64Feature(&stream, nullptr);
1909 break;
1910 default:
1911 num = -1;
1912 break;
1913 }
1914 if (num == -1) {
1915 return errors::InvalidArgument("Error in context feature ", c.first,
1916 " in example ", example_name);
1917 }
1918 num_elements += num;
1919 }
1920 if (context_is_sparse[c.first]) {
1921 context_feature_type_and_lengths[c.first].second += num_elements;
1922 } else {
1923 size_t current_max = context_feature_type_and_lengths[c.first].second;
1924 context_feature_type_and_lengths[c.first].second =
1925 std::max(current_max, num_elements);
1926 }
1927 }
1928 for (const auto& c : *sequence_features) {
1929 size_t num_elements = 0;
1930 if (!c.second.empty()) {
1931 protobuf::io::CodedInputStream stream(
1932 reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
1933 EnableAliasing(&stream);
1934 DataType dtype = sequence_feature_type_and_lengths[c.first].first;
1935 while (!stream.ExpectAtEnd()) {
1936 uint32 feature_length;
1937 if (!stream.ExpectTag(kDelimitedTag(1)) ||
1938 !stream.ReadVarint32(&feature_length)) {
1939 return errors::InvalidArgument("Error in sequence feature ",
1940 c.first, " in example ",
1941 example_name);
1942 }
1943 if (feature_length > 2) {
1944 auto limit = stream.PushLimit(feature_length);
1945 int64 num;
1946 switch (dtype) {
1947 case DT_STRING:
1948 num = ParseBytesFeature(&stream, nullptr);
1949 break;
1950 case DT_FLOAT:
1951 num = ParseFloatFeature(&stream, nullptr);
1952 break;
1953 case DT_INT64:
1954 num = ParseInt64Feature(&stream, nullptr);
1955 break;
1956 default:
1957 num = -1;
1958 break;
1959 }
1960 if (num == -1) {
1961 return errors::InvalidArgument("Error in sequence feature ",
1962 c.first, " in example ",
1963 example_name);
1964 }
1965 num_elements += num;
1966 stream.PopLimit(limit);
1967 } else if (feature_length == 2) {
1968 if (!SkipEmptyFeature(&stream, dtype)) {
1969 return errors::InvalidArgument("Error in sequence feature ",
1970 c.first, " in example ",
1971 example_name);
1972 }
1973 } else if (feature_length != 0) {
1974 return errors::InvalidArgument("Error in sequence feature ",
1975 c.first, " in example ",
1976 example_name);
1977 }
1978 }
1979 }
1980 if (sequence_is_sparse[c.first]) {
1981 sequence_feature_type_and_lengths[c.first].second += num_elements;
1982 } else {
1983 size_t current_max = sequence_feature_type_and_lengths[c.first].second;
1984 sequence_feature_type_and_lengths[c.first].second =
1985 std::max(current_max, num_elements);
1986 }
1987 }
1988 }
1989
1990 // Allocate memory.
1991 context_result->sparse_values.resize(context_config.sparse.size());
1992 context_result->sparse_indices.resize(context_config.sparse.size());
1993 context_result->sparse_shapes.resize(context_config.sparse.size());
1994 context_result->dense_values.resize(context_config.dense.size());
1995 feature_list_result->sparse_values.resize(feature_list_config.sparse.size());
1996 feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
1997 feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
1998 feature_list_result->dense_values.resize(feature_list_config.dense.size());
1999 dense_feature_lengths->resize(feature_list_config.dense.size());
2000
2001 int t = 0;
2002 for (const auto& c : context_config.dense) {
2003 TensorShape dense_shape, example_shape;
2004 DataType dtype = c.dtype;
2005 const size_t expected_max_elements =
2006 context_feature_type_and_lengths[c.feature_name].second;
2007 if (!c.shape.AsTensorShape(&example_shape) ||
2008 expected_max_elements != example_shape.num_elements()) {
2009 return errors::InvalidArgument(
2010 "Inconsistent number of elements for feature ", c.feature_name, ": ",
2011 expected_max_elements, " vs ", dense_shape.num_elements());
2012 }
2013 dense_shape.AddDim(num_examples);
2014 for (const int dim : c.shape.dim_sizes()) {
2015 dense_shape.AddDim(dim);
2016 }
2017 context_result->dense_values[t] = Tensor(dtype, dense_shape);
2018
2019 // TODO(sundberg): Refactor to reduce code duplication, and add bounds
2020 // checking for the outputs.
2021 string* out_bytes = nullptr;
2022 float* out_float = nullptr;
2023 int64* out_int64 = nullptr;
2024 switch (dtype) {
2025 case DT_STRING:
2026 out_bytes = context_result->dense_values[t].flat<string>().data();
2027 break;
2028 case DT_FLOAT:
2029 out_float = context_result->dense_values[t].flat<float>().data();
2030 break;
2031 case DT_INT64:
2032 out_int64 = context_result->dense_values[t].flat<int64>().data();
2033 break;
2034 default:
2035 return errors::InvalidArgument("Unexpected dtype ", dtype,
2036 " in feature ", c.feature_name);
2037 }
2038 t++;
2039
2040 // Fill in the values.
2041 for (int e = 0; e < num_examples; e++) {
2042 size_t num_elements = 0;
2043 const auto feature_iter = all_context_features[e].find(c.feature_name);
2044 const string& example_name =
2045 example_names.empty() ? kUnknown : example_names[e];
2046 if (feature_iter == all_context_features[e].end()) {
2047 // Copy the default value, if present. If not, return an error.
2048 if (c.default_value.NumElements() == 0) {
2049 return errors::InvalidArgument(
2050 "Feature: ", c.feature_name,
2051 " (data type: ", DataTypeString(c.dtype), ")",
2052 " is required but could not be found.");
2053 }
2054 const string* in_bytes = nullptr;
2055 const float* in_float = nullptr;
2056 const int64* in_int64 = nullptr;
2057 size_t num = 0;
2058 switch (dtype) {
2059 case DT_STRING:
2060 in_bytes = c.default_value.flat<string>().data();
2061 num = c.default_value.NumElements();
2062 for (int p = 0; p < num; p++) {
2063 *out_bytes++ = *in_bytes++;
2064 }
2065 break;
2066 case DT_FLOAT:
2067 in_float = c.default_value.flat<float>().data();
2068 num = c.default_value.NumElements();
2069 for (int p = 0; p < num; p++) {
2070 *out_float++ = *in_float++;
2071 }
2072 break;
2073 case DT_INT64:
2074 in_int64 = c.default_value.flat<int64>().data();
2075 num = c.default_value.NumElements();
2076 for (int p = 0; p < num; p++) {
2077 *out_int64++ = *in_int64++;
2078 }
2079 break;
2080 default:
2081 return errors::InvalidArgument("Unexpected dtype ", dtype,
2082 " in example ", example_name);
2083 }
2084 num_elements += num;
2085 } else if (!feature_iter->second.empty()) {
2086 const auto& feature = feature_iter->second;
2087 protobuf::io::CodedInputStream stream(
2088 reinterpret_cast<const uint8*>(feature.data()), feature.size());
2089 EnableAliasing(&stream);
2090 size_t num_added;
2091 switch (dtype) {
2092 case DT_STRING:
2093 num_added = ParseBytesFeature(&stream, out_bytes);
2094 out_bytes += num_added;
2095 break;
2096 case DT_FLOAT:
2097 num_added = ParseFloatFeature(&stream, out_float);
2098 out_float += num_added;
2099 break;
2100 case DT_INT64:
2101 num_added = ParseInt64Feature(&stream, out_int64);
2102 out_int64 += num_added;
2103 break;
2104 default:
2105 return errors::InvalidArgument("Unexpected dtype ", dtype,
2106 " in example ", example_name);
2107 }
2108 num_elements += num_added;
2109 }
2110 if (num_elements != expected_max_elements) {
2111 return errors::InvalidArgument(
2112 "Unexpected number of elements in example ", example_name);
2113 }
2114 }
2115 }
2116 t = 0;
2117 for (const auto& c : context_config.sparse) {
2118 TensorShape indices_shape, values_shape;
2119 DataType dtype = c.dtype;
2120 size_t expected_num_elements =
2121 context_feature_type_and_lengths[c.feature_name].second;
2122 indices_shape.AddDim(expected_num_elements);
2123 indices_shape.AddDim(2);
2124 values_shape.AddDim(expected_num_elements);
2125 context_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
2126 context_result->sparse_values[t] = Tensor(dtype, values_shape);
2127 context_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({2}));
2128 // TODO(sundberg): Refactor to reduce code duplication, and add bounds
2129 // checking for the outputs.
2130 string* out_bytes = nullptr;
2131 float* out_float = nullptr;
2132 int64* out_int64 = nullptr;
2133 switch (dtype) {
2134 case DT_STRING:
2135 out_bytes = context_result->sparse_values[t].flat<string>().data();
2136 break;
2137 case DT_FLOAT:
2138 out_float = context_result->sparse_values[t].flat<float>().data();
2139 break;
2140 case DT_INT64:
2141 out_int64 = context_result->sparse_values[t].flat<int64>().data();
2142 break;
2143 default:
2144 return errors::InvalidArgument("Unexpected dtype ", dtype,
2145 " in feature ", c.feature_name);
2146 }
2147 int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
2148 auto out_shape = context_result->sparse_shapes[t].vec<int64>();
2149 t++;
2150
2151 // Fill in the values.
2152 size_t num_elements = 0;
2153 size_t max_num_cols = 0;
2154 for (int e = 0; e < num_examples; e++) {
2155 const auto& feature = all_context_features[e][c.feature_name];
2156 const string& example_name =
2157 example_names.empty() ? kUnknown : example_names[e];
2158 if (!feature.empty()) {
2159 protobuf::io::CodedInputStream stream(
2160 reinterpret_cast<const uint8*>(feature.data()), feature.size());
2161 EnableAliasing(&stream);
2162 size_t num_added;
2163 switch (dtype) {
2164 case DT_STRING:
2165 num_added = ParseBytesFeature(&stream, out_bytes);
2166 out_bytes += num_added;
2167 break;
2168 case DT_FLOAT:
2169 num_added = ParseFloatFeature(&stream, out_float);
2170 out_float += num_added;
2171 break;
2172 case DT_INT64:
2173 num_added = ParseInt64Feature(&stream, out_int64);
2174 out_int64 += num_added;
2175 break;
2176 default:
2177 return errors::InvalidArgument("Unexpected dtype ", dtype,
2178 " in example ", example_name);
2179 }
2180 num_elements += num_added;
2181 max_num_cols = std::max(max_num_cols, num_added);
2182 for (int i = 0; i < num_added; i++) {
2183 *out_indices++ = e;
2184 *out_indices++ = i;
2185 }
2186 }
2187 }
2188 if (num_elements != expected_num_elements) {
2189 return errors::InvalidArgument(
2190 "Unexpected total number of elements in feature ", c.feature_name);
2191 }
2192 out_shape(0) = num_examples;
2193 out_shape(1) = max_num_cols;
2194 }
2195 t = 0;
2196 TensorShape dense_length_shape({num_examples});
2197 for (const auto& c : feature_list_config.dense) {
2198 TensorShape dense_shape, row_shape;
2199 DataType dtype = c.dtype;
2200 const size_t expected_max_elements =
2201 sequence_feature_type_and_lengths[c.feature_name].second;
2202 if (!c.shape.AsTensorShape(&row_shape) ||
2203 expected_max_elements !=
2204 (expected_max_elements / row_shape.num_elements()) *
2205 row_shape.num_elements()) {
2206 return errors::InvalidArgument("Unexpected shape error in feature ",
2207 c.feature_name);
2208 }
2209 int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
2210 dense_shape.AddDim(num_examples);
2211 dense_shape.AddDim(expected_max_rows);
2212 for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
2213 dense_shape.AddDim(dim);
2214 }
2215 feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
2216 (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape);
2217 int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
2218
2219 string* out_bytes = nullptr;
2220 float* out_float = nullptr;
2221 int64* out_int64 = nullptr;
2222 switch (dtype) {
2223 case DT_STRING:
2224 out_bytes = feature_list_result->dense_values[t].flat<string>().data();
2225 break;
2226 case DT_FLOAT:
2227 out_float = feature_list_result->dense_values[t].flat<float>().data();
2228 break;
2229 case DT_INT64:
2230 out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
2231 break;
2232 default:
2233 return errors::InvalidArgument("Unexpected dtype ", dtype,
2234 " in feature ", c.feature_name);
2235 }
2236 t++;
2237
2238 // Fill in the values.
2239 for (int e = 0; e < num_examples; e++) {
2240 size_t num_elements = 0, num_rows = 0;
2241 const auto feature_iter = all_sequence_features[e].find(c.feature_name);
2242 const string& example_name =
2243 example_names.empty() ? kUnknown : example_names[e];
2244 if (feature_iter == all_sequence_features[e].end()) {
2245 // Return an error if this feature was not allowed to be missing.
2246 // Otherwise, we'll pad as needed below.
2247 if (!c.variable_length) {
2248 return errors::InvalidArgument("Missing feature ", c.feature_name,
2249 " in example ", example_name);
2250 }
2251 } else if (!feature_iter->second.empty()) {
2252 const auto& feature = feature_iter->second;
2253 protobuf::io::CodedInputStream stream(
2254 reinterpret_cast<const uint8*>(feature.data()), feature.size());
2255 EnableAliasing(&stream);
2256 while (!stream.ExpectAtEnd()) {
2257 uint32 feature_length;
2258 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2259 !stream.ReadVarint32(&feature_length)) {
2260 return errors::InvalidArgument("Error in sequence feature ",
2261 c.feature_name, " in example ",
2262 example_name);
2263 }
2264 auto limit = stream.PushLimit(feature_length);
2265 size_t num_added;
2266 switch (dtype) {
2267 case DT_STRING:
2268 num_added = ParseBytesFeature(&stream, out_bytes);
2269 out_bytes += num_added;
2270 break;
2271 case DT_FLOAT:
2272 num_added = ParseFloatFeature(&stream, out_float);
2273 out_float += num_added;
2274 break;
2275 case DT_INT64:
2276 num_added = ParseInt64Feature(&stream, out_int64);
2277 out_int64 += num_added;
2278 break;
2279 default:
2280 return errors::InvalidArgument("Unexpected dtype ", dtype,
2281 " in example ", example_name);
2282 }
2283 num_elements += num_added;
2284 num_rows++;
2285 if (num_added != row_shape.num_elements()) {
2286 return errors::InvalidArgument(
2287 "Unexpected number of elements in feature ", c.feature_name,
2288 ", example ", example_name);
2289 }
2290 stream.PopLimit(limit);
2291 }
2292 }
2293 *out_lengths++ = num_rows;
2294 // Pad as necessary.
2295 int num_to_pad = expected_max_elements - num_elements;
2296 switch (dtype) {
2297 case DT_STRING:
2298 out_bytes += num_to_pad;
2299 break;
2300 case DT_FLOAT:
2301 PadFloatFeature(num_to_pad, out_float);
2302 out_float += num_to_pad;
2303 break;
2304 case DT_INT64:
2305 PadInt64Feature(num_to_pad, out_int64);
2306 out_int64 += num_to_pad;
2307 break;
2308 default:
2309 return errors::InvalidArgument("Unexpected dtype ", dtype,
2310 " in example ", example_name);
2311 }
2312 }
2313 }
2314 t = 0;
2315 for (const auto& c : feature_list_config.sparse) {
2316 TensorShape indices_shape, values_shape;
2317 DataType dtype = c.dtype;
2318 size_t expected_num_elements =
2319 sequence_feature_type_and_lengths[c.feature_name].second;
2320 indices_shape.AddDim(expected_num_elements);
2321 indices_shape.AddDim(3);
2322 values_shape.AddDim(expected_num_elements);
2323 feature_list_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
2324 feature_list_result->sparse_values[t] = Tensor(dtype, values_shape);
2325 feature_list_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({3}));
2326
2327 string* out_bytes = nullptr;
2328 float* out_float = nullptr;
2329 int64* out_int64 = nullptr;
2330 switch (dtype) {
2331 case DT_STRING:
2332 out_bytes = feature_list_result->sparse_values[t].flat<string>().data();
2333 break;
2334 case DT_FLOAT:
2335 out_float = feature_list_result->sparse_values[t].flat<float>().data();
2336 break;
2337 case DT_INT64:
2338 out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
2339 break;
2340 default:
2341 return errors::InvalidArgument("Unexpected dtype ", dtype,
2342 " in feature ", c.feature_name);
2343 }
2344 int64* out_indices =
2345 feature_list_result->sparse_indices[t].flat<int64>().data();
2346 auto out_shape = feature_list_result->sparse_shapes[t].vec<int64>();
2347 t++;
2348
2349 // Fill in the values.
2350 size_t num_elements = 0;
2351 size_t max_num_rows = 0;
2352 size_t max_num_cols = 0;
2353 for (int e = 0; e < num_examples; e++) {
2354 const auto& feature = all_sequence_features[e][c.feature_name];
2355 const string& example_name =
2356 example_names.empty() ? kUnknown : example_names[e];
2357 if (!feature.empty()) {
2358 protobuf::io::CodedInputStream stream(
2359 reinterpret_cast<const uint8*>(feature.data()), feature.size());
2360 EnableAliasing(&stream);
2361 size_t num_rows = 0;
2362 while (!stream.ExpectAtEnd()) {
2363 uint32 feature_length;
2364 if (!stream.ExpectTag(kDelimitedTag(1)) ||
2365 !stream.ReadVarint32(&feature_length)) {
2366 return errors::InvalidArgument("Error in sequence feature ",
2367 c.feature_name, " in example ",
2368 example_name);
2369 }
2370 if (feature_length > 2) {
2371 auto limit = stream.PushLimit(feature_length);
2372 size_t num_added;
2373 switch (dtype) {
2374 case DT_STRING:
2375 num_added = ParseBytesFeature(&stream, out_bytes);
2376 out_bytes += num_added;
2377 break;
2378 case DT_FLOAT:
2379 num_added = ParseFloatFeature(&stream, out_float);
2380 out_float += num_added;
2381 break;
2382 case DT_INT64:
2383 num_added = ParseInt64Feature(&stream, out_int64);
2384 out_int64 += num_added;
2385 break;
2386 default:
2387 return errors::InvalidArgument("Unexpected dtype ", dtype,
2388 " in example ", example_name);
2389 }
2390 num_elements += num_added;
2391 max_num_cols = std::max(max_num_cols, num_added);
2392 for (int i = 0; i < num_added; i++) {
2393 *out_indices++ = e;
2394 *out_indices++ = num_rows;
2395 *out_indices++ = i;
2396 }
2397 stream.PopLimit(limit);
2398 } else if (feature_length == 2) {
2399 if (!SkipEmptyFeature(&stream, dtype)) {
2400 return errors::InvalidArgument("Error in sequence feature ",
2401 c.feature_name, " in example ",
2402 example_name);
2403 }
2404 } else if (feature_length != 0) {
2405 return errors::InvalidArgument("Error in sequence feature ",
2406 c.feature_name, " in example ",
2407 example_name);
2408 }
2409 num_rows++;
2410 }
2411 max_num_rows = std::max(max_num_rows, num_rows);
2412 }
2413 }
2414 if (num_elements != expected_num_elements) {
2415 return errors::InvalidArgument(
2416 "Unexpected number of elements in feature ", c.feature_name);
2417 }
2418 out_shape(0) = num_examples;
2419 out_shape(1) = max_num_rows;
2420 out_shape(2) = max_num_cols;
2421 }
2422
2423 return Status::OK();
2424 }
2425
2426 } // namespace example
2427 } // namespace tensorflow
2428