1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/data/parse_example_op.h"
17
18 #include <google/protobuf/io/coded_stream.h>
19
20 #include <algorithm>
21 #include <memory>
22
23 #include "absl/base/casts.h"
24 #include "absl/container/inlined_vector.h"
25 #include "proto/example.pb.h"
26
27 #include "minddata/dataset/core/tensor.h"
28 #include "minddata/dataset/kernels/data/data_utils.h"
29 #include "minddata/dataset/kernels/tensor_op.h"
30
31 namespace mindspore::dataset {
32 namespace protobuf = ::google::protobuf;
33
34 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
35 constexpr size_t kInlinedVectorSize = 4;
36
37 template <typename T>
38 using SmallVector = absl::InlinedVector<T, kInlinedVectorSize>;
39 using StringPiece = std::string_view;
40
41 template <typename T>
42 class LimitedArraySlice {
43 public:
44 using value_type = T;
45
LimitedArraySlice(T * begin,size_t num_elements)46 LimitedArraySlice(T *begin, size_t num_elements) : current_(begin), begin_(begin), end_(begin + num_elements) {}
47
48 /// \brief Get the left space in the slice.
EndDistance() const49 int64_t EndDistance() const { return end_ - current_; }
50
51 /// \brief Push value to back of slice. If the slice is full, only change the
52 /// total number without modify the data.
push_back(T && value)53 void push_back(T &&value) {
54 if (EndDistance() > 0) {
55 *current_ = std::move(value);
56 }
57 ++current_;
58 }
59
60 /// \brief Construct an element at the back of slice and return a mutable
61 /// reference to the new element.
construct_at_end()62 T &construct_at_end() {
63 if (EndDistance() <= 0) {
64 MS_EXCEPTION(RuntimeError) << "LimitedArraySlice has no space left.";
65 }
66 return *(current_++);
67 }
68
69 /// \brief Get the mutable reference to the last element in slice.
back()70 T &back() { return *(current_ - 1); }
71
72 /// \brief Get the number of elements in slice.
size() const73 size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
74
75 /// \brief Resize the slice to the given size by advancing the pointer to
76 /// the current element.
resize(size_t size)77 void resize(size_t size) { current_ = begin_ + size; }
78
79 /// \brief Get the data buffer.
data()80 T *data() { return begin_; }
81
82 private:
83 T *current_;
84 T *begin_;
85 T *end_;
86 };
87
PeekTag(protobuf::io::CodedInputStream * stream)88 uint8_t PeekTag(protobuf::io::CodedInputStream *stream) {
89 if (stream == nullptr) {
90 MS_EXCEPTION(RuntimeError) << "CodedInputStream is nullptr.";
91 }
92 const void *ptr;
93 int size;
94 if (!stream->GetDirectBufferPointer(&ptr, &size)) {
95 return 0;
96 }
97 return *static_cast<const uint8_t *>(ptr);
98 }
99
kVarintTag(const uint32_t tag)100 constexpr uint8_t kVarintTag(const uint32_t tag) { return (tag << 3) | 0; }
kDelimitedTag(const uint32_t tag)101 constexpr uint8_t kDelimitedTag(const uint32_t tag) { return (tag << 3) | 2; }
kFixed32Tag(const uint32_t tag)102 constexpr uint8_t kFixed32Tag(const uint32_t tag) { return (tag << 3) | 5; }
103
104 namespace parsed {
105 class Feature {
106 public:
107 Feature() = default;
Feature(const StringPiece & serialized)108 explicit Feature(const StringPiece &serialized) : serialized_(serialized) {}
109
ParseDataType(DataType * dtype)110 Status ParseDataType(DataType *dtype) {
111 RETURN_UNEXPECTED_IF_NULL(dtype);
112 if (serialized_.empty()) {
113 *dtype = DataType(DataType::DE_UNKNOWN);
114 return Status::OK();
115 }
116 const auto oneof_tag = static_cast<uint8_t>(*serialized_.data());
117 serialized_.remove_prefix(1);
118 constexpr uint8_t kStringTag = 1;
119 constexpr uint8_t kFloat32Tag = 2;
120 constexpr uint8_t kInt64Tag = 3;
121 switch (oneof_tag) {
122 case kDelimitedTag(kStringTag):
123 *dtype = DataType(DataType::DE_STRING);
124 break;
125 case kDelimitedTag(kFloat32Tag):
126 *dtype = DataType(DataType::DE_FLOAT32);
127 break;
128 case kDelimitedTag(kInt64Tag):
129 *dtype = DataType(DataType::DE_INT64);
130 break;
131 default:
132 // Initialize variable to avoid compiler warning
133 *dtype = DataType(DataType::DE_UNKNOWN);
134 RETURN_STATUS_UNEXPECTED("Unsupported datatype.");
135 }
136 return Status::OK();
137 }
138
GetNumElementsInBytesList(int * num_elements) const139 bool GetNumElementsInBytesList(int *num_elements) const {
140 if (num_elements == nullptr) {
141 return false;
142 }
143 protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
144 static_cast<int>(serialized_.size()));
145 uint32_t length = 0;
146 if (!stream.ReadVarint32(&length)) {
147 return false;
148 }
149 const auto limit = stream.PushLimit(static_cast<int>(length));
150 *num_elements = 0;
151 while (!stream.ExpectAtEnd()) {
152 if (!stream.ExpectTag(kDelimitedTag(1))) {
153 return false;
154 }
155 uint32_t bytes_length = 0;
156 if (!stream.ReadVarint32(&bytes_length)) {
157 return false;
158 }
159 if (!stream.Skip(static_cast<int>(bytes_length))) {
160 return false;
161 }
162 ++*num_elements;
163 }
164 stream.PopLimit(limit);
165 return true;
166 }
167
construct_at_end(LimitedArraySlice<std::string> * bytes_list)168 static std::string *construct_at_end(LimitedArraySlice<std::string> *bytes_list) {
169 if (bytes_list->EndDistance() <= 0) {
170 return nullptr;
171 }
172 return &bytes_list->construct_at_end();
173 }
174
construct_at_end(std::vector<std::string> * bytes_list)175 static std::string *construct_at_end(std::vector<std::string> *bytes_list) { return &bytes_list->emplace_back(); }
176
177 template <typename Result>
ParseBytesList(Result * bytes_list) const178 bool ParseBytesList(Result *bytes_list) const {
179 if (bytes_list == nullptr) {
180 return false;
181 }
182
183 protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
184 static_cast<int>(serialized_.size()));
185
186 uint32_t length;
187 if (!stream.ReadVarint32(&length)) {
188 return false;
189 }
190 const auto limit = stream.PushLimit(static_cast<int>(length));
191
192 while (!stream.ExpectAtEnd()) {
193 if (!stream.ExpectTag(kDelimitedTag(1))) {
194 return false;
195 }
196 // parse string
197 uint32_t bytes_length;
198 if (!stream.ReadVarint32(&bytes_length)) {
199 return false;
200 }
201 std::string *bytes = construct_at_end(bytes_list);
202 if (bytes == nullptr) {
203 return false;
204 }
205 bytes->resize(bytes_length);
206 if (!stream.ReadRaw(bytes->data(), static_cast<int>(bytes_length))) {
207 return false;
208 }
209 }
210 stream.PopLimit(limit);
211 return true;
212 }
213
214 template <typename Result>
ParseFloatList(Result * float_list) const215 bool ParseFloatList(Result *float_list) const {
216 if (float_list == nullptr) {
217 return false;
218 }
219 protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
220 static_cast<int>(serialized_.size()));
221 uint32_t length;
222 if (!stream.ReadVarint32(&length)) {
223 return false;
224 }
225 const auto limit = stream.PushLimit(static_cast<int>(length));
226
227 if (!stream.ExpectAtEnd()) {
228 const uint8_t peek_tag = PeekTag(&stream);
229 if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
230 return false;
231 }
232
233 constexpr int32_t kNumFloatBytes = 4;
234 if (peek_tag == kDelimitedTag(1)) { // packed
235 if (!stream.ExpectTag(kDelimitedTag(1))) { // packed tag
236 return false;
237 }
238 uint32_t packed_length;
239 if (!stream.ReadVarint32(&packed_length)) {
240 return false;
241 }
242 const auto packed_limit = stream.PushLimit(static_cast<int>(packed_length));
243
244 // Store the initial size to know the offset we have to start writing
245 // data from before resizing the output "vector".
246 const size_t initial_size = float_list->size();
247 float_list->resize(initial_size + packed_length / kNumFloatBytes);
248
249 // If the result data type is float and we are on a little endian
250 // machine then we can simply memcpy the data from the proto into the
251 // result vector.
252 if (kLittleEndian && sizeof(typename Result::value_type) == kNumFloatBytes) {
253 // Calculate the length of the buffer available what can be less than
254 // what we requested in resize in case of a LimitedArraySlice.
255 const uint32_t bytes_to_copy =
256 std::min(static_cast<uint32_t>((float_list->size() - initial_size) * kNumFloatBytes), packed_length);
257 if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy)) {
258 return false;
259 }
260 } else {
261 int64_t index = initial_size;
262 while (!stream.ExpectAtEnd()) {
263 uint32_t buffer32;
264 if (!stream.ReadLittleEndian32(&buffer32)) {
265 return false;
266 }
267 if (index < float_list->size()) {
268 float_list->data()[index] = absl::bit_cast<float>(buffer32);
269 ++index;
270 }
271 }
272 }
273
274 stream.PopLimit(packed_limit);
275 } else { // non-packed
276 const size_t initial_size = float_list->size();
277 // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
278 // the value.
279 const int64_t num_elements = stream.BytesUntilLimit() / (1 + kNumFloatBytes);
280 float_list->resize(initial_size + num_elements);
281 int64_t index = initial_size;
282 while (!stream.ExpectAtEnd()) {
283 if (!stream.ExpectTag(kFixed32Tag(1))) {
284 return false;
285 }
286 uint32_t buffer32;
287 if (!stream.ReadLittleEndian32(&buffer32)) {
288 return false;
289 }
290 float_list->data()[index] = absl::bit_cast<float>(buffer32);
291 ++index;
292 }
293 }
294 }
295
296 stream.PopLimit(limit);
297 return true;
298 }
299
300 template <typename Result>
ParseInt64List(Result * int64_list) const301 bool ParseInt64List(Result *int64_list) const {
302 if (int64_list == nullptr) {
303 return false;
304 }
305 protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
306 static_cast<int>(serialized_.size()));
307 uint32_t length;
308 if (!stream.ReadVarint32(&length)) {
309 return false;
310 }
311 const auto limit = stream.PushLimit(static_cast<int>(length));
312
313 if (!stream.ExpectAtEnd()) {
314 const uint8_t peek_tag = PeekTag(&stream);
315 if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
316 return false;
317 }
318 if (peek_tag == kDelimitedTag(1)) { // packed
319 if (!stream.ExpectTag(kDelimitedTag(1))) { // packed tag
320 return false;
321 }
322 uint32_t packed_length;
323 if (!stream.ReadVarint32(&packed_length)) {
324 return false;
325 }
326 const auto packed_limit = stream.PushLimit(static_cast<int>(packed_length));
327
328 while (!stream.ExpectAtEnd()) {
329 uint64_t n; // There is no API for int64
330 if (!stream.ReadVarint64(&n)) {
331 return false;
332 }
333 int64_list->push_back(static_cast<int64_t>(n));
334 }
335
336 stream.PopLimit(packed_limit);
337 } else { // non-packed
338 while (!stream.ExpectAtEnd()) {
339 if (!stream.ExpectTag(kVarintTag(1))) {
340 return false;
341 }
342 uint64_t n; // There is no API for int64
343 if (!stream.ReadVarint64(&n)) {
344 return false;
345 }
346 int64_list->push_back(static_cast<int64_t>(n));
347 }
348 }
349 }
350 stream.PopLimit(limit);
351 return true;
352 }
353
354 private:
355 StringPiece serialized_;
356 };
357
358 using FeatureMapEntry = std::pair<StringPiece, Feature>;
359 using Example = std::vector<FeatureMapEntry>;
360 } // namespace parsed
361
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)362 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream *stream) {
363 uint32_t data;
364 uint64_t dummy;
365 constexpr uint32_t kVarint = 0;
366 constexpr uint32_t kFixed64 = 1;
367 constexpr uint32_t kLengthDelimited = 2;
368 constexpr uint32_t kGroupBegin = 3;
369 constexpr uint32_t kGroupEnd = 4;
370 constexpr uint32_t kFixed32 = 5;
371 switch (stream->ReadTag() & 0x7) {
372 case kVarint: // varint
373 return stream->ReadVarint32(&data);
374 case kFixed64: // fixed64
375 return stream->ReadLittleEndian64(&dummy);
376 case kLengthDelimited: // length delimited
377 if (!stream->ReadVarint32(&data)) {
378 return false;
379 }
380 stream->Skip(static_cast<int>(data));
381 return true;
382 case kGroupBegin: // group begin
383 case kGroupEnd: // group end
384 return false; // groups not supported.
385 case kFixed32: // fixed32
386 return stream->ReadLittleEndian32(&data);
387 default:
388 return false;
389 }
390 return false; // unrecognized tag type
391 }
392
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)393 bool ParseString(protobuf::io::CodedInputStream *stream, StringPiece *result) {
394 if (stream == nullptr) {
395 return false;
396 }
397 if (result == nullptr) {
398 return false;
399 }
400 uint32_t length;
401 if (!stream->ReadVarint32(&length)) {
402 return false;
403 }
404 if (length == 0) {
405 *result = StringPiece(nullptr, 0);
406 return true;
407 }
408 const void *stream_alias;
409 int stream_size;
410 if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
411 return false;
412 }
413 if (static_cast<uint32_t>(stream_size) < length) {
414 return false;
415 }
416 *result = StringPiece(static_cast<const char *>(stream_alias), length);
417 stream->Skip(static_cast<int>(length));
418 return true;
419 }
420
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)421 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream *stream, parsed::FeatureMapEntry *feature_map_entry) {
422 if (stream == nullptr) {
423 return false;
424 }
425 if (feature_map_entry == nullptr) {
426 return false;
427 }
428 uint32_t length;
429 if (!stream->ReadVarint32(&length)) {
430 return false;
431 }
432 const auto limit = stream->PushLimit(static_cast<int>(length));
433
434 // Protobufs allow an arbitrary order for the key and value fields.
435 for (int n = 0; n <= 1; ++n) {
436 constexpr uint32_t kNameTag = 1;
437 constexpr uint32_t kFeatureTag = 2;
438 switch (stream->ReadTag()) {
439 case kDelimitedTag(kNameTag):
440 if (!ParseString(stream, &feature_map_entry->first)) {
441 return false;
442 }
443 break;
444
445 case kDelimitedTag(kFeatureTag): {
446 StringPiece feature_string_piece;
447 if (!ParseString(stream, &feature_string_piece)) {
448 return false;
449 }
450 feature_map_entry->second = parsed::Feature(feature_string_piece);
451 break;
452 }
453
454 default:
455 return false;
456 }
457 }
458
459 if (!stream->ExpectAtEnd()) {
460 return false;
461 }
462 stream->PopLimit(limit);
463 return true;
464 }
465
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)466 bool ParseFeatures(protobuf::io::CodedInputStream *stream, parsed::Example *example) {
467 if (stream == nullptr) {
468 return false;
469 }
470 if (example == nullptr) {
471 return false;
472 }
473 uint32_t length;
474 if (!stream->ReadVarint32(&length)) {
475 return false;
476 }
477 const auto limit = stream->PushLimit(static_cast<int>(length));
478 while (!stream->ExpectAtEnd()) {
479 parsed::FeatureMapEntry feature_map_entry;
480 if (!stream->ExpectTag(kDelimitedTag(1))) {
481 return false;
482 }
483 if (!ParseFeatureMapEntry(stream, &feature_map_entry)) {
484 return false;
485 }
486 example->push_back(std::move(feature_map_entry));
487 }
488 stream->PopLimit(limit);
489 return true;
490 }
491
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)492 bool ParseExample(protobuf::io::CodedInputStream *stream, parsed::Example *example) {
493 if (stream == nullptr) {
494 return false;
495 }
496 if (example == nullptr) {
497 return false;
498 }
499 // Loop over the input stream which may contain multiple serialized Example
500 // protos merged together as strings. This behavior is consistent with Proto's
501 // ParseFromString when string representations are concatenated.
502 while (!stream->ExpectAtEnd()) {
503 if (!stream->ExpectTag(kDelimitedTag(1))) {
504 if (!SkipExtraneousTag(stream)) {
505 return false;
506 }
507 } else {
508 if (!ParseFeatures(stream, example)) {
509 return false;
510 }
511 }
512 }
513 return true;
514 }
515
ParseExample(const StringPiece & serialized,parsed::Example * example)516 bool ParseExample(const StringPiece &serialized, parsed::Example *example) {
517 if (example == nullptr) {
518 return false;
519 }
520 protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized.data()),
521 static_cast<int>(serialized.size()));
522 return ParseExample(&stream, example);
523 }
524
525 template <typename T>
526 class TensorVector {
527 public:
528 using value_type = T;
529
tensor()530 std::shared_ptr<Tensor> tensor() {
531 if (tensor_ == nullptr) {
532 resize(0);
533 }
534 return tensor_;
535 }
536
size() const537 int64_t size() const { return tensor_ != nullptr ? tensor_->Size() : 0; }
538
resize(int64_t new_size)539 void resize(int64_t new_size) {
540 if (tensor_ != nullptr) {
541 MS_EXCEPTION(RuntimeError) << "TensorVector has already initialized.";
542 }
543 Status s = Tensor::CreateEmpty(TensorShape({new_size}), DataType::FromCType<T>(), &tensor_);
544 if (s.IsError()) {
545 MS_EXCEPTION(RuntimeError) << s.ToString();
546 }
547 data_ = &*(tensor_->begin<T>());
548 }
549
data()550 T *data() { return data_; }
551
data() const552 const T *data() const { return data_; }
553
554 private:
555 std::shared_ptr<Tensor> tensor_ = nullptr;
556 T *data_ = nullptr; // the raw data inside the tensor
557 };
558
559 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)560 void CopyOrMoveBlock(const T *b, const T *e, T *t) {
561 std::copy(b, e, t);
562 }
563
LogFeatureRepeated(const StringPiece & feature_name)564 void LogFeatureRepeated(const StringPiece &feature_name) {
565 MS_LOG(WARNING) << "Feature name: " << feature_name << " is repeated in Example. Ignoring all but last one.";
566 }
567
ReportUnexpectedParseFailure(const StringPiece & feature_name)568 inline Status ReportUnexpectedParseFailure(const StringPiece &feature_name) {
569 RETURN_STATUS_UNEXPECTED("Failed to parse serialized Example of feature name: " + std::string(feature_name));
570 }
571
ReportUnexpectedDataType(const StringPiece & feature_name,const DataType & dtype)572 inline Status ReportUnexpectedDataType(const StringPiece &feature_name, const DataType &dtype) {
573 RETURN_STATUS_UNEXPECTED("Got unexpected data type: " + dtype.ToString() +
574 " of feature name: " + std::string(feature_name));
575 }
576
ReportUnexpectedDataShape(const StringPiece & feature_name)577 inline Status ReportUnexpectedDataShape(const StringPiece &feature_name) {
578 RETURN_STATUS_UNEXPECTED("Column shape of " + std::string(feature_name) +
579 " defined in schema does not match the shape actually load.");
580 }
581
CreateUint8TensorFromString(const std::vector<std::string> & bytes_list,std::shared_ptr<Tensor> * column_tensor,const TensorShape & column_shape,const std::string & column_name)582 Status CreateUint8TensorFromString(const std::vector<std::string> &bytes_list, std::shared_ptr<Tensor> *column_tensor,
583 const TensorShape &column_shape, const std::string &column_name) {
584 dsize_t total_size =
585 std::accumulate(bytes_list.begin(), bytes_list.end(), 0,
586 [](dsize_t size, const std::string &str) { return size + static_cast<dsize_t>(str.size()); });
587 TensorShape output_shape = column_shape;
588 if (!column_shape.known()) {
589 output_shape = TensorShape({total_size});
590 } else {
591 CHECK_FAIL_RETURN_UNEXPECTED(
592 output_shape.NumOfElements() == total_size,
593 "Column shape of " + column_name + " defined in schema does not match the shape actually load.");
594 }
595 RETURN_IF_NOT_OK(Tensor::CreateEmpty(output_shape, DataType(DataType::DE_UINT8), column_tensor));
596 ptrdiff_t offset = 0;
597 for (const auto &str : bytes_list) {
598 int ret_code = memcpy_s((*column_tensor)->GetMutableBuffer() + offset, (*column_tensor)->SizeInBytes() - offset,
599 common::SafeCStr(str), str.size());
600 CHECK_FAIL_RETURN_UNEXPECTED(ret_code == EOK, "Failed to copy string into Tensor.");
601 offset += static_cast<ptrdiff_t>(str.size());
602 }
603 return Status::OK();
604 }
605
Compute(const TensorRow & input,TensorRow * output)606 Status ParseExampleOp::Compute(const TensorRow &input, TensorRow *output) {
607 IO_CHECK_VECTOR(input, output);
608 if (parallel_parse_) {
609 return ParallelParseExample(input, output);
610 } else {
611 return ParseSingleExample(input, output);
612 }
613 }
614
ParseSingleKnownShapeColumn(const parsed::Feature & feature,std::shared_ptr<Tensor> * column_tensor,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)615 Status ParseSingleKnownShapeColumn(const parsed::Feature &feature, std::shared_ptr<Tensor> *column_tensor,
616 const StringPiece &feature_name, const ColDescriptor &column_descriptor,
617 const DataType &example_dtype) {
618 const size_t num_elements = column_descriptor.Shape().NumOfElements();
619 switch (example_dtype.value()) {
620 case DataType::DE_INT64: {
621 const auto data_buffer = reinterpret_cast<int64_t *>((*column_tensor)->GetMutableBuffer());
622 LimitedArraySlice<int64_t> slice(data_buffer, num_elements);
623 if (!feature.ParseInt64List(&slice)) {
624 return ReportUnexpectedParseFailure(feature_name);
625 }
626 if (slice.EndDistance() != 0) {
627 return ReportUnexpectedDataShape(feature_name);
628 }
629 break;
630 }
631 case DataType::DE_FLOAT32: {
632 const auto data_buffer = reinterpret_cast<float *>((*column_tensor)->GetMutableBuffer());
633 LimitedArraySlice<float> slice(data_buffer, num_elements);
634 if (!feature.ParseFloatList(&slice)) {
635 return ReportUnexpectedParseFailure(feature_name);
636 }
637 if (slice.EndDistance() != 0) {
638 return ReportUnexpectedDataShape(feature_name);
639 }
640 break;
641 }
642 case DataType::DE_STRING: {
643 std::vector<std::string> bytes_list;
644 bytes_list.reserve(num_elements);
645 if (!feature.ParseBytesList(&bytes_list)) {
646 return ReportUnexpectedParseFailure(feature_name);
647 }
648 if (column_descriptor.Type().value() == DataType::DE_STRING) {
649 if (bytes_list.size() != num_elements) {
650 return ReportUnexpectedDataShape(feature_name);
651 }
652 TensorShape string_tensor_shape = TensorShape::CreateUnknownRankShape();
653 RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &string_tensor_shape));
654 RETURN_IF_NOT_OK(
655 Tensor::CreateFromVector(bytes_list, string_tensor_shape, DataType(DataType::DE_STRING), column_tensor));
656 } else {
657 // load string or bytes as uint8 tensor
658 RETURN_IF_NOT_OK(
659 CreateUint8TensorFromString(bytes_list, column_tensor, column_descriptor.Shape(), std::string(feature_name)));
660 }
661 break;
662 }
663 default:
664 return ReportUnexpectedDataType(feature_name, example_dtype);
665 }
666 return Status::OK();
667 }
668
ParseSingleVarLenColumn(const parsed::Feature & feature,std::shared_ptr<Tensor> * column_tensor,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)669 Status ParseSingleVarLenColumn(const parsed::Feature &feature, std::shared_ptr<Tensor> *column_tensor,
670 const StringPiece &feature_name, const ColDescriptor &column_descriptor,
671 const DataType &example_dtype) {
672 std::vector<std::string> bytes_list;
673 TensorVector<float> float_list;
674 SmallVector<int64_t> int64_list;
675
676 size_t num_elements;
677 switch (example_dtype.value()) {
678 case DataType::DE_INT64: {
679 if (!feature.ParseInt64List(&int64_list)) {
680 return ReportUnexpectedParseFailure(feature_name);
681 }
682 num_elements = int64_list.size();
683 break;
684 }
685 case DataType::DE_FLOAT32: {
686 if (!feature.ParseFloatList(&float_list)) {
687 return ReportUnexpectedParseFailure(feature_name);
688 }
689 num_elements = float_list.size();
690 break;
691 }
692 case DataType::DE_STRING: {
693 int actual_num_elements = 0;
694 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
695 return ReportUnexpectedParseFailure(feature_name);
696 }
697 bytes_list.reserve(actual_num_elements);
698 if (!feature.ParseBytesList(&bytes_list)) {
699 return ReportUnexpectedParseFailure(feature_name);
700 }
701 num_elements = bytes_list.size();
702 break;
703 }
704 default:
705 return ReportUnexpectedDataType(feature_name, example_dtype);
706 }
707
708 TensorShape column_shape = TensorShape::CreateUnknownRankShape();
709 RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &column_shape));
710
711 switch (example_dtype.value()) {
712 case DataType::DE_INT64: {
713 RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_shape, example_dtype, column_tensor));
714 CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
715 reinterpret_cast<int64_t *>((*column_tensor)->GetMutableBuffer()));
716 break;
717 }
718 case DataType::DE_FLOAT32: {
719 RETURN_IF_NOT_OK(Tensor::CreateFromTensor(std::shared_ptr<Tensor>(float_list.tensor()), column_tensor));
720 RETURN_IF_NOT_OK((*column_tensor)->Reshape(column_shape));
721 break;
722 }
723 case DataType::DE_STRING: {
724 if (column_descriptor.Type().value() == DataType::DE_STRING) {
725 RETURN_IF_NOT_OK(
726 Tensor::CreateFromVector(bytes_list, column_shape, DataType(DataType::DE_STRING), column_tensor));
727 } else {
728 // load string or bytes as uint8 tensor
729 RETURN_IF_NOT_OK(CreateUint8TensorFromString(bytes_list, column_tensor, TensorShape::CreateUnknownRankShape(),
730 std::string(feature_name)));
731 }
732 break;
733 }
734 default:
735 return ReportUnexpectedDataType(feature_name, example_dtype);
736 }
737 return Status::OK();
738 }
739
ParseSingleExample(const TensorRow & raw_bytes,TensorRow * parsed_row)740 Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow *parsed_row) {
741 const auto filename = raw_bytes.getPath().empty() ? "" : raw_bytes.getPath()[0];
742 const auto tensor_iterator = raw_bytes[0]->begin<std::string_view>();
743
744 const auto example_bytes = std::string(*tensor_iterator);
745 RETURN_IF_NOT_OK(ConstructColumnMap(example_bytes));
746
747 parsed::Example parsed_example;
748 CHECK_FAIL_RETURN_UNEXPECTED(ParseExample(example_bytes, &parsed_example),
749 "Failed to parse example bytes: " + example_bytes + " in tfrecord file: " + filename);
750
751 parsed_row->reserve(data_schema_.NumColumns());
752
753 for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
754 const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
755 if (column_descriptor.HasKnownShape()) {
756 if (!column_descriptor.Type().IsString()) {
757 DataType type;
758 if (column_descriptor.Type().IsInt() || column_descriptor.Type().IsBool()) {
759 type = DataType(DataType::DE_INT64);
760 } else if (column_descriptor.Type().IsFloat()) {
761 type = DataType(DataType::DE_FLOAT32);
762 }
763 std::shared_ptr<Tensor> column_tensor;
764 RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_descriptor.Shape(), type, &column_tensor));
765 parsed_row->emplace_back(std::move(column_tensor));
766 } else {
767 parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
768 }
769 } else {
770 MS_LOG(INFO) << "Shape of column name: " << column_descriptor.Name() << " is not defined.";
771 parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
772 }
773 }
774
775 std::vector<bool> feature_already_seen(data_schema_.NumColumns(), false);
776 std::vector<std::string> file_paths;
777
778 const size_t parsed_example_size = parsed_example.size();
779 for (size_t i = 0; i < parsed_example_size; ++i) {
780 // This is a logic that standard protobuf parsing is implementing.
781 // I.e. last entry in the map overwrites all the previous ones.
782 parsed::FeatureMapEntry &name_and_feature = parsed_example[parsed_example_size - i - 1];
783
784 const StringPiece &feature_name = name_and_feature.first;
785 parsed::Feature &feature = name_and_feature.second;
786
787 if (column_name_id_map_.find(std::string(feature_name)) == column_name_id_map_.end()) {
788 MS_LOG(INFO) << "Feature name: " << feature_name << " is not in schema, skip it.";
789 continue;
790 }
791
792 const auto column_index = column_name_id_map_[std::string(feature_name)];
793
794 DataType example_dtype;
795 RETURN_IF_NOT_OK(feature.ParseDataType(&example_dtype));
796 if (example_dtype == DataType::DE_UNKNOWN) {
797 continue;
798 }
799
800 // If feature was already visited, skip.
801 if (feature_already_seen[column_index]) {
802 LogFeatureRepeated(feature_name);
803 continue;
804 }
805 feature_already_seen[column_index] = true;
806
807 const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
808 bool type_cast_flag = false;
809 if (example_dtype != column_descriptor.Type()) {
810 const std::string msg =
811 "The data type loaded from the example for feature name: " + column_descriptor.Name() +
812 " does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
813 ", but the predefined type: " + column_descriptor.Type().ToString();
814 if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
815 MS_LOG(INFO) << msg << ". This will cause a type cast.";
816 type_cast_flag = true;
817 } else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
818 // allow to read data of type string or bytes into an uint8 tensor
819 RETURN_STATUS_UNEXPECTED(msg);
820 }
821 }
822
823 if (column_descriptor.HasKnownShape()) {
824 RETURN_IF_NOT_OK(ParseSingleKnownShapeColumn(feature, &(*parsed_row)[column_index], feature_name,
825 column_descriptor, example_dtype));
826 } else { // if variable length
827 RETURN_IF_NOT_OK(
828 ParseSingleVarLenColumn(feature, &(*parsed_row)[column_index], feature_name, column_descriptor, example_dtype));
829 }
830 if (type_cast_flag) {
831 std::shared_ptr<Tensor> cast_out;
832 RETURN_IF_NOT_OK(TypeCast((*parsed_row)[column_index], &cast_out, column_descriptor.Type()));
833 (*parsed_row)[column_index] = cast_out;
834 }
835 file_paths.push_back(filename);
836 }
837
838 for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
839 CHECK_FAIL_RETURN_UNEXPECTED(feature_already_seen[column_index],
840 "Feature name: " + data_schema_.Column(column_index).Name() +
841 " is required in schema but could not be found in tfrecord file.");
842 }
843
844 parsed_row->setPath(file_paths);
845 return Status::OK();
846 }
847
CalculateNumMiniBatch(const std::shared_ptr<Tensor> & batch_tensor)848 size_t CalculateNumMiniBatch(const std::shared_ptr<Tensor> &batch_tensor) {
849 // This parameter affects performance in a big and data-dependent way.
850 constexpr size_t kMiniBatchSizeBytes = 50000;
851
852 const size_t batch_size = batch_tensor->shape()[0];
853
854 size_t result = 0;
855 size_t minibatch_bytes = 0;
856 for (size_t i = 0; i < batch_size; i++) {
857 if (minibatch_bytes == 0) { // start minibatch
858 result++;
859 }
860 std::string_view tensor_value;
861 batch_tensor->GetItemAt(&tensor_value, {static_cast<dsize_t>(i)});
862 minibatch_bytes += tensor_value.size() + 1;
863 if (minibatch_bytes > kMiniBatchSizeBytes) {
864 minibatch_bytes = 0;
865 }
866 }
867 // 'special logic'
868 const size_t min_minibatches = std::min<size_t>(8, batch_size);
869 constexpr size_t max_minibatches = 64;
870 return std::max<size_t>(min_minibatches, std::min<size_t>(max_minibatches, result));
871 }
872
873 class BlockingCounter {
874 public:
BlockingCounter(const uint32_t initial_count)875 explicit BlockingCounter(const uint32_t initial_count) : state_(initial_count << 1), notified_(false) {
876 if ((initial_count << 1) >> 1 != initial_count) {
877 MS_EXCEPTION(RuntimeError) << "Value of initial_count exceeds upper limit: " << initial_count;
878 }
879 }
880
881 ~BlockingCounter() = default;
882
DecrementCount()883 inline void DecrementCount() {
884 constexpr uint32_t kStep = 2;
885 uint32_t new_state = state_.fetch_sub(kStep, std::memory_order_acq_rel) - kStep;
886 if (new_state != 1) {
887 if (((new_state + kStep) & ~1) == 0) {
888 MS_EXCEPTION(RuntimeError) << "The number of remaining worker threads is already 0.";
889 }
890 return; // either count has not dropped to 0, or waiter is not waiting
891 }
892 std::unique_lock<std::mutex> lock(mutex_);
893 if (notified_) {
894 MS_EXCEPTION(RuntimeError) << "Try to awake a notified worker.";
895 }
896 notified_ = true;
897 cond_var_.notify_all();
898 }
899
Wait()900 inline void Wait() {
901 uint32_t new_state = state_.fetch_or(1, std::memory_order_acq_rel);
902 if ((new_state >> 1) == 0) {
903 return;
904 }
905 std::unique_lock<std::mutex> lock(mutex_);
906 while (!notified_) {
907 cond_var_.wait(lock);
908 }
909 }
910
911 // Wait for the specified time, return false iff the count has not dropped to
912 // zero before the timeout expired.
WaitFor(std::chrono::milliseconds millisecond)913 inline bool WaitFor(std::chrono::milliseconds millisecond) {
914 uint32_t new_state = state_.fetch_or(1, std::memory_order_acq_rel);
915 if ((new_state >> 1) == 0) {
916 return true;
917 }
918 std::unique_lock<std::mutex> lock(mutex_);
919 while (!notified_) {
920 const std::cv_status status = cond_var_.wait_for(lock, millisecond);
921 if (status == std::cv_status::timeout) {
922 return false;
923 }
924 }
925 return true;
926 }
927
928 private:
929 std::mutex mutex_;
930 std::condition_variable cond_var_;
931 std::atomic<uint32_t> state_; // low bit is waiter flag
932 bool notified_;
933 };
934
ParallelFor(const std::function<void (size_t)> & function,const size_t task_count,const std::unique_ptr<Eigen::ThreadPool> & thread_pool)935 void ParallelFor(const std::function<void(size_t)> &function, const size_t task_count,
936 const std::unique_ptr<Eigen::ThreadPool> &thread_pool) {
937 if (task_count == 0) {
938 return;
939 }
940 if (thread_pool == nullptr) {
941 for (size_t i = 0; i < task_count; ++i) {
942 function(i);
943 }
944 } else {
945 BlockingCounter counter(task_count - 1);
946 for (size_t i = 1; i < task_count; ++i) {
947 thread_pool->Schedule([i, &function, &counter] {
948 function(i);
949 counter.DecrementCount();
950 });
951 }
952 function(0);
953 counter.Wait();
954 }
955 }
956
FillAndCopyVarLenTensor(const std::vector<std::vector<VarLenTensorBuffer>> & minibatch_row_buffer,std::shared_ptr<Tensor> * column_tensor,const size_t column_index)957 Status FillAndCopyVarLenTensor(const std::vector<std::vector<VarLenTensorBuffer>> &minibatch_row_buffer,
958 std::shared_ptr<Tensor> *column_tensor, const size_t column_index) {
959 ptrdiff_t buffer_offset = 0;
960 for (const auto &minibatch_row : minibatch_row_buffer) {
961 const auto &minibatch_tensor = minibatch_row[column_index].numeric_tensor;
962 for (const auto &varlen_tensor : minibatch_tensor) {
963 const auto tensor_buffer_size = varlen_tensor->SizeInBytes();
964 const errno_t copy_status =
965 memcpy_s((*column_tensor)->GetMutableBuffer() + buffer_offset, (*column_tensor)->SizeInBytes() - buffer_offset,
966 varlen_tensor->GetBuffer(), tensor_buffer_size);
967 CHECK_FAIL_RETURN_UNEXPECTED(copy_status == EOK,
968 "Failed to copy tensor to batch, got error_t: " + std::to_string(copy_status));
969 buffer_offset += tensor_buffer_size;
970 }
971 }
972 return Status::OK();
973 }
974
FillAndCopyVarLenString(const std::vector<std::vector<VarLenTensorBuffer>> & minibatch_row_buffer,std::shared_ptr<Tensor> * column_tensor,const size_t column_index,const ColDescriptor & column_descriptor,dsize_t batch_size)975 Status FillAndCopyVarLenString(const std::vector<std::vector<VarLenTensorBuffer>> &minibatch_row_buffer,
976 std::shared_ptr<Tensor> *column_tensor, const size_t column_index,
977 const ColDescriptor &column_descriptor, dsize_t batch_size) {
978 std::vector<std::string> string_buffer;
979 dsize_t element_size = 0;
980 for (const auto &minibatch_row : minibatch_row_buffer) {
981 const auto string_length = minibatch_row[column_index].string_length;
982 if (element_size == 0) {
983 element_size = static_cast<dsize_t>(string_length);
984 } else {
985 CHECK_FAIL_RETURN_UNEXPECTED(string_length == element_size,
986 "Could not batch string or bytes tensors with different shapes.");
987 }
988 const auto &minibatch_string = minibatch_row[column_index].string_tensor;
989 string_buffer.insert(string_buffer.end(), minibatch_string.begin(), minibatch_string.end());
990 }
991
992 std::vector<dsize_t> shape;
993 if (element_size != 0) {
994 shape = {batch_size, element_size};
995 } else {
996 shape = {batch_size};
997 }
998 const auto column_shape = TensorShape(shape);
999 if (column_descriptor.Type().value() == DataType::DE_STRING) {
1000 RETURN_IF_NOT_OK(
1001 Tensor::CreateFromVector(string_buffer, column_shape, DataType(DataType::DE_STRING), column_tensor));
1002 } else {
1003 RETURN_IF_NOT_OK(CreateUint8TensorFromString(string_buffer, column_tensor, column_shape, column_descriptor.Name()));
1004 }
1005 return Status::OK();
1006 }
1007
MergeDenseVarLenMiniBatches(const std::vector<std::vector<VarLenTensorBuffer>> & varlen_dense_buffers,TensorRow * parsed_row,int32_t column_index,const DataSchema & data_schema,dsize_t batch_size)1008 Status MergeDenseVarLenMiniBatches(const std::vector<std::vector<VarLenTensorBuffer>> &varlen_dense_buffers,
1009 TensorRow *parsed_row, int32_t column_index, const DataSchema &data_schema,
1010 dsize_t batch_size) {
1011 const ColDescriptor &column_descriptor = data_schema.Column(column_index);
1012 if (column_descriptor.HasKnownShape()) {
1013 return Status::OK();
1014 }
1015 std::shared_ptr<Tensor> column_tensor;
1016 if (!varlen_dense_buffers[0][column_index].numeric_tensor.empty()) {
1017 const TensorShape column_shape =
1018 varlen_dense_buffers[0][column_index].numeric_tensor[0]->shape().InsertDim(0, batch_size);
1019 RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_shape, column_descriptor.Type(), &column_tensor));
1020 RETURN_IF_NOT_OK(FillAndCopyVarLenTensor(varlen_dense_buffers, &column_tensor, column_index));
1021 } else {
1022 RETURN_IF_NOT_OK(
1023 FillAndCopyVarLenString(varlen_dense_buffers, &column_tensor, column_index, column_descriptor, batch_size));
1024 }
1025 (*parsed_row)[column_index] = column_tensor;
1026 return Status::OK();
1027 }
1028
ParallelParseExample(const TensorRow & raw_bytes,TensorRow * parsed_row)1029 Status ParseExampleOp::ParallelParseExample(const TensorRow &raw_bytes, TensorRow *parsed_row) {
1030 Tensor::TensorIterator tensor_iterator = raw_bytes[0]->begin<std::string_view>();
1031 RETURN_IF_NOT_OK(ConstructColumnMap(std::string(*tensor_iterator)));
1032 parsed_row->reserve(data_schema_.NumColumns());
1033
1034 auto batch_size = raw_bytes[0]->shape()[0];
1035 std::vector<bool> type_cast_flag(data_schema_.NumColumns(), false);
1036 std::vector<bool> varlen_column(data_schema_.NumColumns(), false);
1037 std::unordered_map<int32_t, std::vector<std::string>> string_column_map;
1038 for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
1039 const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1040 if (column_descriptor.HasKnownShape()) {
1041 if (!column_descriptor.Type().IsString()) {
1042 auto column_shape = column_descriptor.Shape().InsertDim(0, batch_size);
1043 DataType type;
1044 if (column_descriptor.Type().IsInt() || column_descriptor.Type().IsBool()) {
1045 if (column_descriptor.Type().value() != DataType::DE_INT64) {
1046 type_cast_flag[column_index] = true;
1047 }
1048 type = DataType(DataType::DE_INT64);
1049 } else if (column_descriptor.Type().IsFloat()) {
1050 if (column_descriptor.Type().value() != DataType::DE_FLOAT32) {
1051 type_cast_flag[column_index] = true;
1052 }
1053 type = DataType(DataType::DE_FLOAT32);
1054 }
1055 std::shared_ptr<Tensor> column_tensor;
1056 RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_shape, type, &column_tensor));
1057 parsed_row->emplace_back(std::move(column_tensor));
1058 if (column_descriptor.Type().value() == DataType::DE_UINT8) {
1059 string_column_map[column_index] =
1060 std::vector<std::string>(batch_size * column_descriptor.Shape().NumOfElements());
1061 }
1062 } else {
1063 parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
1064 string_column_map[column_index] =
1065 std::vector<std::string>(batch_size * column_descriptor.Shape().NumOfElements());
1066 }
1067 } else {
1068 MS_LOG(INFO) << "Shape of column name: " << column_descriptor.Name() << " is not defined.";
1069 varlen_column[column_index] = true;
1070 parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
1071 }
1072 }
1073
1074 // Calculate number of minibatches.
1075 // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1076 // Apply 'special logic' below for small and big regimes.
1077 const size_t num_minibatches = CalculateNumMiniBatch(raw_bytes[0]);
1078
1079 auto first_example_of_minibatch = [&](const size_t minibatch) -> size_t {
1080 return (batch_size * minibatch) / num_minibatches;
1081 };
1082
1083 std::vector<std::vector<VarLenTensorBuffer>> varlen_dense_buffers(num_minibatches);
1084 std::vector<Status> status_of_minibatch(num_minibatches);
1085 auto ProcessMiniBatch = [&](const size_t minibatch) {
1086 varlen_dense_buffers[minibatch].resize(data_schema_.NumColumns());
1087 const auto start = first_example_of_minibatch(minibatch);
1088 const auto end = first_example_of_minibatch(minibatch + 1);
1089 for (auto tensor_index = start; tensor_index < end; ++tensor_index) {
1090 status_of_minibatch[minibatch] =
1091 ParseSerializedExample(static_cast<std::string>(*tensor_iterator.operator+(static_cast<dsize_t>(tensor_index))),
1092 parsed_row, &string_column_map, &varlen_dense_buffers[minibatch], tensor_index);
1093 if (!status_of_minibatch[minibatch].IsOk()) {
1094 break;
1095 }
1096 }
1097 };
1098
1099 ParallelFor(ProcessMiniBatch, num_minibatches, pool_);
1100
1101 for (Status &status : status_of_minibatch) {
1102 RETURN_IF_NOT_OK(status);
1103 }
1104
1105 for (auto string_column = string_column_map.begin(); string_column != string_column_map.end(); ++string_column) {
1106 auto column_index = string_column->first;
1107 const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1108 auto column_shape = column_descriptor.Shape().InsertDim(0, batch_size);
1109 std::shared_ptr<Tensor> string_tensor;
1110 if (column_descriptor.Type().value() == DataType::DE_STRING) {
1111 RETURN_IF_NOT_OK(
1112 Tensor::CreateFromVector(string_column->second, column_shape, DataType(DataType::DE_STRING), &string_tensor));
1113 } else {
1114 // load string or bytes as uint8 tensor
1115 RETURN_IF_NOT_OK(
1116 CreateUint8TensorFromString(string_column->second, &string_tensor, column_shape, column_descriptor.Name()));
1117 type_cast_flag[column_index] = false;
1118 }
1119 (*parsed_row)[column_index] = string_tensor;
1120 }
1121
1122 for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
1123 if (type_cast_flag[column_index]) {
1124 const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1125 std::shared_ptr<Tensor> cast_out;
1126 RETURN_IF_NOT_OK(TypeCast((*parsed_row)[column_index], &cast_out, column_descriptor.Type()));
1127 (*parsed_row)[column_index] = cast_out;
1128 } else if (varlen_column[column_index]) {
1129 RETURN_IF_NOT_OK(
1130 MergeDenseVarLenMiniBatches(varlen_dense_buffers, parsed_row, column_index, data_schema_, batch_size));
1131 }
1132 }
1133 return Status::OK();
1134 }
1135
ParseSerializedKnownShapeColumn(const parsed::Feature & feature,TensorRow * parsed_row,std::unordered_map<int32_t,std::vector<std::string>> * string_col_map,const int32_t column_index,const size_t tensor_index,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)1136 Status ParseSerializedKnownShapeColumn(const parsed::Feature &feature, TensorRow *parsed_row,
1137 std::unordered_map<int32_t, std::vector<std::string>> *string_col_map,
1138 const int32_t column_index, const size_t tensor_index,
1139 const StringPiece &feature_name, const ColDescriptor &column_descriptor,
1140 const DataType &example_dtype) {
1141 std::shared_ptr<Tensor> &column_tensor = (*parsed_row)[column_index];
1142 if (example_dtype != column_descriptor.Type()) {
1143 const std::string msg =
1144 "The data type loaded from the example for feature name: " + column_descriptor.Name() +
1145 " does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
1146 ", but the predefined type: " + column_descriptor.Type().ToString();
1147 if (example_dtype == column_tensor->type()) {
1148 // if the actual data type is the same as the pre-allocated tensor,
1149 // we can first read it into the tensor, then cast to the type specified by the schema
1150 MS_LOG(INFO) << msg << ". This will cause a type cast.";
1151 } else if (!example_dtype.IsString() || column_descriptor.Type().value() != DataType::DE_UINT8) {
1152 // allow to read data of type string or bytes into an uint8 tensor
1153 RETURN_STATUS_UNEXPECTED(msg);
1154 }
1155 }
1156
1157 const std::size_t num_elements = column_descriptor.Shape().NumOfElements();
1158 switch (example_dtype.value()) {
1159 case DataType::DE_INT64: {
1160 const auto data_buffer =
1161 reinterpret_cast<int64_t *>(column_tensor->GetMutableBuffer()) + tensor_index * num_elements;
1162 LimitedArraySlice<int64_t> slice(data_buffer, num_elements);
1163 if (!feature.ParseInt64List(&slice)) {
1164 return ReportUnexpectedParseFailure(feature_name);
1165 }
1166 if (slice.EndDistance() != 0) {
1167 return ReportUnexpectedDataShape(feature_name);
1168 }
1169 break;
1170 }
1171 case DataType::DE_FLOAT32: {
1172 const auto data_buffer =
1173 reinterpret_cast<float *>(column_tensor->GetMutableBuffer()) + tensor_index * num_elements;
1174 LimitedArraySlice<float> slice(data_buffer, num_elements);
1175 if (!feature.ParseFloatList(&slice)) {
1176 return ReportUnexpectedParseFailure(feature_name);
1177 }
1178 if (slice.EndDistance() != 0) {
1179 return ReportUnexpectedDataShape(feature_name);
1180 }
1181 break;
1182 }
1183 case DataType::DE_STRING: {
1184 const auto data_buffer = &(*string_col_map)[column_index][tensor_index * num_elements];
1185 LimitedArraySlice<std::string> slice(data_buffer, num_elements);
1186 if (!feature.ParseBytesList(&slice)) {
1187 return ReportUnexpectedParseFailure(feature_name);
1188 }
1189 if (slice.EndDistance() != 0) {
1190 return ReportUnexpectedDataShape(feature_name);
1191 }
1192 break;
1193 }
1194 default:
1195 return ReportUnexpectedDataType(feature_name, example_dtype);
1196 }
1197 return Status::OK();
1198 }
1199
PushStringToBuffer(const std::vector<std::string> & bytes_list,VarLenTensorBuffer * varlen_tensor_buffer,const ColDescriptor & column_descriptor)1200 Status PushStringToBuffer(const std::vector<std::string> &bytes_list, VarLenTensorBuffer *varlen_tensor_buffer,
1201 const ColDescriptor &column_descriptor) {
1202 if (column_descriptor.Type().value() == DataType::DE_STRING) {
1203 // check that each sample contains the same number of strings
1204 if (varlen_tensor_buffer->string_length != 0) {
1205 CHECK_FAIL_RETURN_UNEXPECTED(varlen_tensor_buffer->string_length == bytes_list.size(),
1206 "Could not batch string Tensors with different shapes.");
1207 } else {
1208 if (column_descriptor.Rank() != 0) {
1209 varlen_tensor_buffer->string_length = bytes_list.size();
1210 } else {
1211 varlen_tensor_buffer->string_length = 0;
1212 }
1213 }
1214 for (auto &bytes : bytes_list) {
1215 varlen_tensor_buffer->string_tensor.emplace_back(bytes);
1216 }
1217 } else if (column_descriptor.Type().value() == DataType::DE_UINT8) {
1218 size_t total_size = 0;
1219 for (auto &bytes : bytes_list) {
1220 total_size += bytes.size();
1221 varlen_tensor_buffer->string_tensor.emplace_back(bytes);
1222 }
1223 if (varlen_tensor_buffer->string_length != 0) {
1224 CHECK_FAIL_RETURN_UNEXPECTED(varlen_tensor_buffer->string_length == total_size,
1225 "Could not batch bytes Tensors with different shapes.");
1226 } else {
1227 varlen_tensor_buffer->string_length = total_size;
1228 }
1229 }
1230 return Status::OK();
1231 }
1232
ParseSerializedVarLenColumn(const parsed::Feature & feature,VarLenTensorBuffer * varlen_tensor_buffer,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)1233 Status ParseSerializedVarLenColumn(const parsed::Feature &feature, VarLenTensorBuffer *varlen_tensor_buffer,
1234 const StringPiece &feature_name, const ColDescriptor &column_descriptor,
1235 const DataType &example_dtype) {
1236 bool type_cast_flag = false;
1237 if (example_dtype != column_descriptor.Type()) {
1238 const std::string msg =
1239 "The data type loaded from the example for feature name: " + column_descriptor.Name() +
1240 " does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
1241 ", but the predefined type: " + column_descriptor.Type().ToString();
1242 if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
1243 MS_LOG(INFO) << msg << ". This will cause a type cast.";
1244 type_cast_flag = true;
1245 } else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
1246 // allow to read data of type string or bytes into an uint8 tensor
1247 RETURN_STATUS_UNEXPECTED(msg);
1248 }
1249 }
1250
1251 size_t num_elements;
1252 SmallVector<int64_t> int64_list;
1253 TensorVector<float> float_list;
1254 std::vector<std::string> bytes_list;
1255 switch (example_dtype.value()) {
1256 case DataType::DE_INT64: {
1257 if (!feature.ParseInt64List(&int64_list)) {
1258 return ReportUnexpectedParseFailure(feature_name);
1259 }
1260 num_elements = int64_list.size();
1261 break;
1262 }
1263 case DataType::DE_FLOAT32: {
1264 if (!feature.ParseFloatList(&float_list)) {
1265 return ReportUnexpectedParseFailure(feature_name);
1266 }
1267 num_elements = float_list.size();
1268 break;
1269 }
1270 case DataType::DE_STRING: {
1271 int actual_num_elements = 0;
1272 if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1273 return ReportUnexpectedParseFailure(feature_name);
1274 }
1275 bytes_list.reserve(actual_num_elements);
1276 if (!feature.ParseBytesList(&bytes_list)) {
1277 return ReportUnexpectedParseFailure(feature_name);
1278 }
1279 num_elements = bytes_list.size();
1280 break;
1281 }
1282 default:
1283 return ReportUnexpectedDataType(feature_name, example_dtype);
1284 }
1285
1286 TensorShape varlen_tensor_shape = TensorShape::CreateUnknownRankShape();
1287 RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &varlen_tensor_shape));
1288 std::shared_ptr<Tensor> varlen_tensor;
1289 switch (example_dtype.value()) {
1290 case DataType::DE_INT64: {
1291 RETURN_IF_NOT_OK(Tensor::CreateEmpty(varlen_tensor_shape, example_dtype, &varlen_tensor));
1292 CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1293 reinterpret_cast<int64_t *>(varlen_tensor->GetMutableBuffer()));
1294 if (type_cast_flag) {
1295 std::shared_ptr<Tensor> casted_varlen_tensor;
1296 RETURN_IF_NOT_OK(TypeCast(varlen_tensor, &casted_varlen_tensor, column_descriptor.Type()));
1297 varlen_tensor_buffer->numeric_tensor.emplace_back(casted_varlen_tensor);
1298 } else {
1299 varlen_tensor_buffer->numeric_tensor.emplace_back(varlen_tensor);
1300 }
1301 break;
1302 }
1303 case DataType::DE_FLOAT32: {
1304 RETURN_IF_NOT_OK(Tensor::CreateFromTensor(std::shared_ptr<Tensor>(float_list.tensor()), &varlen_tensor));
1305 RETURN_IF_NOT_OK(varlen_tensor->Reshape(varlen_tensor_shape));
1306 if (type_cast_flag) {
1307 std::shared_ptr<Tensor> casted_varlen_tensor;
1308 RETURN_IF_NOT_OK(TypeCast(varlen_tensor, &casted_varlen_tensor, column_descriptor.Type()));
1309 varlen_tensor_buffer->numeric_tensor.emplace_back(casted_varlen_tensor);
1310 } else {
1311 varlen_tensor_buffer->numeric_tensor.emplace_back(varlen_tensor);
1312 }
1313 break;
1314 }
1315 case DataType::DE_STRING: {
1316 RETURN_IF_NOT_OK(PushStringToBuffer(bytes_list, varlen_tensor_buffer, column_descriptor));
1317 break;
1318 }
1319 default:
1320 return ReportUnexpectedDataType(feature_name, example_dtype);
1321 }
1322 return Status::OK();
1323 }
1324
ParseSerializedExample(const std::string & example_bytes,TensorRow * parsed_row,std::unordered_map<int32_t,std::vector<std::string>> * string_column_map,std::vector<VarLenTensorBuffer> * varlen_tensor_vector,const size_t tensor_index)1325 Status ParseExampleOp::ParseSerializedExample(const std::string &example_bytes, TensorRow *parsed_row,
1326 std::unordered_map<int32_t, std::vector<std::string>> *string_column_map,
1327 std::vector<VarLenTensorBuffer> *varlen_tensor_vector,
1328 const size_t tensor_index) {
1329 parsed::Example parsed_example;
1330 CHECK_FAIL_RETURN_UNEXPECTED(ParseExample(example_bytes, &parsed_example),
1331 "Failed to parse example bytes: " + example_bytes);
1332
1333 const size_t parsed_example_size = parsed_example.size();
1334 std::vector<bool> feature_already_seen(data_schema_.NumColumns(), false);
1335 for (size_t i = 0; i < parsed_example_size; ++i) {
1336 // This is a logic that standard protobuf parsing is implementing.
1337 // I.e. last entry in the map overwrites all the previous ones.
1338 parsed::FeatureMapEntry &name_and_feature = parsed_example[parsed_example_size - i - 1];
1339 const StringPiece &feature_name = name_and_feature.first;
1340 parsed::Feature &feature = name_and_feature.second;
1341
1342 if (column_name_id_map_.find(std::string(feature_name)) == column_name_id_map_.end()) {
1343 MS_LOG(INFO) << "Feature name: " << feature_name << " is not in schema, skip it.";
1344 continue;
1345 }
1346
1347 DataType example_dtype;
1348 RETURN_IF_NOT_OK(feature.ParseDataType(&example_dtype));
1349 if (example_dtype == DataType::DE_UNKNOWN) {
1350 continue;
1351 }
1352
1353 const auto column_index = column_name_id_map_[std::string(feature_name)];
1354 // If feature was already visited, skip.
1355 if (feature_already_seen[column_index]) {
1356 LogFeatureRepeated(feature_name);
1357 continue;
1358 }
1359 feature_already_seen[column_index] = true;
1360
1361 const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1362 if (column_descriptor.HasKnownShape()) {
1363 RETURN_IF_NOT_OK(ParseSerializedKnownShapeColumn(feature, parsed_row, string_column_map, column_index,
1364 tensor_index, feature_name, column_descriptor, example_dtype));
1365 } else { // if variable length
1366 RETURN_IF_NOT_OK(ParseSerializedVarLenColumn(feature, &(*varlen_tensor_vector)[column_index], feature_name,
1367 column_descriptor, example_dtype));
1368 }
1369 }
1370
1371 for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
1372 if (!feature_already_seen[column_index]) {
1373 RETURN_STATUS_UNEXPECTED("Feature name: " + data_schema_.Column(column_index).Name() +
1374 " is required in schema but could not be found in tfrecord file.");
1375 }
1376 }
1377 return Status::OK();
1378 }
1379
ConstructColumnMap(const std::string & example_bytes)1380 Status ParseExampleOp::ConstructColumnMap(const std::string &example_bytes) {
1381 if (column_name_id_map_.empty()) {
1382 if (data_schema_.Empty()) {
1383 dataengine::Example example;
1384 if (!example.ParseFromString(example_bytes)) {
1385 RETURN_STATUS_UNEXPECTED("Failed to parse example bytes: " + std::string(example_bytes));
1386 }
1387
1388 const dataengine::Features &example_features = example.features();
1389 const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
1390 if (column_list_.empty()) {
1391 (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(column_list_),
1392 [](const auto &it) -> std::string { return it.first; });
1393 std::sort(column_list_.begin(), column_list_.end());
1394 }
1395
1396 for (const auto &column_name : column_list_) {
1397 auto it = feature_map.find(column_name);
1398 if (it == feature_map.end()) {
1399 RETURN_STATUS_UNEXPECTED("Invalid column list, failed to find column name: " + column_name + " in example.");
1400 }
1401
1402 std::string column_type;
1403 const dataengine::Feature &feature = it->second;
1404 switch (feature.kind_case()) {
1405 case dataengine::Feature::KindCase::kBytesList:
1406 column_type = "uint8";
1407 break;
1408 case dataengine::Feature::KindCase::kFloatList:
1409 column_type = "float32";
1410 break;
1411 case dataengine::Feature::KindCase::kInt64List:
1412 column_type = "int64";
1413 break;
1414 default:
1415 RETURN_STATUS_UNEXPECTED("Unsupported column type, the column type of " + column_name +
1416 " should be int64, float32 or string.");
1417 }
1418 RETURN_IF_NOT_OK(
1419 data_schema_.AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1)));
1420 }
1421 }
1422 RETURN_IF_NOT_OK(data_schema_.GetColumnNameMap(&column_name_id_map_));
1423 CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map_.empty(), "Can not get column name map, it is empty.");
1424 }
1425 return Status::OK();
1426 }
1427 } // namespace mindspore::dataset
1428