• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/data/parse_example_op.h"
17 
18 #include <google/protobuf/io/coded_stream.h>
19 
20 #include <algorithm>
21 #include <memory>
22 
23 #include "absl/base/casts.h"
24 #include "absl/container/inlined_vector.h"
25 #include "proto/example.pb.h"
26 
27 #include "minddata/dataset/core/tensor.h"
28 #include "minddata/dataset/kernels/data/data_utils.h"
29 #include "minddata/dataset/kernels/tensor_op.h"
30 
31 namespace mindspore::dataset {
32 namespace protobuf = ::google::protobuf;
33 
34 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
35 constexpr size_t kInlinedVectorSize = 4;
36 
37 template <typename T>
38 using SmallVector = absl::InlinedVector<T, kInlinedVectorSize>;
39 using StringPiece = std::string_view;
40 
41 template <typename T>
42 class LimitedArraySlice {
43  public:
44   using value_type = T;
45 
LimitedArraySlice(T * begin,size_t num_elements)46   LimitedArraySlice(T *begin, size_t num_elements) : current_(begin), begin_(begin), end_(begin + num_elements) {}
47 
48   /// \brief Get the left space in the slice.
EndDistance() const49   int64_t EndDistance() const { return end_ - current_; }
50 
51   /// \brief Push value to back of slice. If the slice is full, only change the
52   /// total number without modify the data.
push_back(T && value)53   void push_back(T &&value) {
54     if (EndDistance() > 0) {
55       *current_ = std::move(value);
56     }
57     ++current_;
58   }
59 
60   /// \brief Construct an element at the back of slice and return a mutable
61   /// reference to the new element.
construct_at_end()62   T &construct_at_end() {
63     if (EndDistance() <= 0) {
64       MS_EXCEPTION(RuntimeError) << "LimitedArraySlice has no space left.";
65     }
66     return *(current_++);
67   }
68 
69   /// \brief Get the mutable reference to the last element in slice.
back()70   T &back() { return *(current_ - 1); }
71 
72   /// \brief Get the number of elements in slice.
size() const73   size_t size() const { return std::min(current_ - begin_, end_ - begin_); }
74 
75   /// \brief Resize the slice to the given size by advancing the pointer to
76   /// the current element.
resize(size_t size)77   void resize(size_t size) { current_ = begin_ + size; }
78 
79   /// \brief Get the data buffer.
data()80   T *data() { return begin_; }
81 
82  private:
83   T *current_;
84   T *begin_;
85   T *end_;
86 };
87 
PeekTag(protobuf::io::CodedInputStream * stream)88 uint8_t PeekTag(protobuf::io::CodedInputStream *stream) {
89   if (stream == nullptr) {
90     MS_EXCEPTION(RuntimeError) << "CodedInputStream is nullptr.";
91   }
92   const void *ptr;
93   int size;
94   if (!stream->GetDirectBufferPointer(&ptr, &size)) {
95     return 0;
96   }
97   return *static_cast<const uint8_t *>(ptr);
98 }
99 
kVarintTag(const uint32_t tag)100 constexpr uint8_t kVarintTag(const uint32_t tag) { return (tag << 3) | 0; }
kDelimitedTag(const uint32_t tag)101 constexpr uint8_t kDelimitedTag(const uint32_t tag) { return (tag << 3) | 2; }
kFixed32Tag(const uint32_t tag)102 constexpr uint8_t kFixed32Tag(const uint32_t tag) { return (tag << 3) | 5; }
103 
104 namespace parsed {
105 class Feature {
106  public:
107   Feature() = default;
Feature(const StringPiece & serialized)108   explicit Feature(const StringPiece &serialized) : serialized_(serialized) {}
109 
ParseDataType(DataType * dtype)110   Status ParseDataType(DataType *dtype) {
111     RETURN_UNEXPECTED_IF_NULL(dtype);
112     if (serialized_.empty()) {
113       *dtype = DataType(DataType::DE_UNKNOWN);
114       return Status::OK();
115     }
116     const auto oneof_tag = static_cast<uint8_t>(*serialized_.data());
117     serialized_.remove_prefix(1);
118     constexpr uint8_t kStringTag = 1;
119     constexpr uint8_t kFloat32Tag = 2;
120     constexpr uint8_t kInt64Tag = 3;
121     switch (oneof_tag) {
122       case kDelimitedTag(kStringTag):
123         *dtype = DataType(DataType::DE_STRING);
124         break;
125       case kDelimitedTag(kFloat32Tag):
126         *dtype = DataType(DataType::DE_FLOAT32);
127         break;
128       case kDelimitedTag(kInt64Tag):
129         *dtype = DataType(DataType::DE_INT64);
130         break;
131       default:
132         // Initialize variable to avoid compiler warning
133         *dtype = DataType(DataType::DE_UNKNOWN);
134         RETURN_STATUS_UNEXPECTED("Unsupported datatype.");
135     }
136     return Status::OK();
137   }
138 
GetNumElementsInBytesList(int * num_elements) const139   bool GetNumElementsInBytesList(int *num_elements) const {
140     if (num_elements == nullptr) {
141       return false;
142     }
143     protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
144                                           static_cast<int>(serialized_.size()));
145     uint32_t length = 0;
146     if (!stream.ReadVarint32(&length)) {
147       return false;
148     }
149     const auto limit = stream.PushLimit(static_cast<int>(length));
150     *num_elements = 0;
151     while (!stream.ExpectAtEnd()) {
152       if (!stream.ExpectTag(kDelimitedTag(1))) {
153         return false;
154       }
155       uint32_t bytes_length = 0;
156       if (!stream.ReadVarint32(&bytes_length)) {
157         return false;
158       }
159       if (!stream.Skip(static_cast<int>(bytes_length))) {
160         return false;
161       }
162       ++*num_elements;
163     }
164     stream.PopLimit(limit);
165     return true;
166   }
167 
construct_at_end(LimitedArraySlice<std::string> * bytes_list)168   static std::string *construct_at_end(LimitedArraySlice<std::string> *bytes_list) {
169     if (bytes_list->EndDistance() <= 0) {
170       return nullptr;
171     }
172     return &bytes_list->construct_at_end();
173   }
174 
construct_at_end(std::vector<std::string> * bytes_list)175   static std::string *construct_at_end(std::vector<std::string> *bytes_list) { return &bytes_list->emplace_back(); }
176 
177   template <typename Result>
ParseBytesList(Result * bytes_list) const178   bool ParseBytesList(Result *bytes_list) const {
179     if (bytes_list == nullptr) {
180       return false;
181     }
182 
183     protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
184                                           static_cast<int>(serialized_.size()));
185 
186     uint32_t length;
187     if (!stream.ReadVarint32(&length)) {
188       return false;
189     }
190     const auto limit = stream.PushLimit(static_cast<int>(length));
191 
192     while (!stream.ExpectAtEnd()) {
193       if (!stream.ExpectTag(kDelimitedTag(1))) {
194         return false;
195       }
196       // parse string
197       uint32_t bytes_length;
198       if (!stream.ReadVarint32(&bytes_length)) {
199         return false;
200       }
201       std::string *bytes = construct_at_end(bytes_list);
202       if (bytes == nullptr) {
203         return false;
204       }
205       bytes->resize(bytes_length);
206       if (!stream.ReadRaw(bytes->data(), static_cast<int>(bytes_length))) {
207         return false;
208       }
209     }
210     stream.PopLimit(limit);
211     return true;
212   }
213 
214   template <typename Result>
ParseFloatList(Result * float_list) const215   bool ParseFloatList(Result *float_list) const {
216     if (float_list == nullptr) {
217       return false;
218     }
219     protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
220                                           static_cast<int>(serialized_.size()));
221     uint32_t length;
222     if (!stream.ReadVarint32(&length)) {
223       return false;
224     }
225     const auto limit = stream.PushLimit(static_cast<int>(length));
226 
227     if (!stream.ExpectAtEnd()) {
228       const uint8_t peek_tag = PeekTag(&stream);
229       if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) {
230         return false;
231       }
232 
233       constexpr int32_t kNumFloatBytes = 4;
234       if (peek_tag == kDelimitedTag(1)) {           // packed
235         if (!stream.ExpectTag(kDelimitedTag(1))) {  // packed tag
236           return false;
237         }
238         uint32_t packed_length;
239         if (!stream.ReadVarint32(&packed_length)) {
240           return false;
241         }
242         const auto packed_limit = stream.PushLimit(static_cast<int>(packed_length));
243 
244         // Store the initial size to know the offset we have to start writing
245         // data from before resizing the output "vector".
246         const size_t initial_size = float_list->size();
247         float_list->resize(initial_size + packed_length / kNumFloatBytes);
248 
249         // If the result data type is float and we are on a little endian
250         // machine then we can simply memcpy the data from the proto into the
251         // result vector.
252         if (kLittleEndian && sizeof(typename Result::value_type) == kNumFloatBytes) {
253           // Calculate the length of the buffer available what can be less than
254           // what we requested in resize in case of a LimitedArraySlice.
255           const uint32_t bytes_to_copy =
256             std::min(static_cast<uint32_t>((float_list->size() - initial_size) * kNumFloatBytes), packed_length);
257           if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy)) {
258             return false;
259           }
260         } else {
261           int64_t index = initial_size;
262           while (!stream.ExpectAtEnd()) {
263             uint32_t buffer32;
264             if (!stream.ReadLittleEndian32(&buffer32)) {
265               return false;
266             }
267             if (index < float_list->size()) {
268               float_list->data()[index] = absl::bit_cast<float>(buffer32);
269               ++index;
270             }
271           }
272         }
273 
274         stream.PopLimit(packed_limit);
275       } else {  // non-packed
276         const size_t initial_size = float_list->size();
277         // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for
278         // the value.
279         const int64_t num_elements = stream.BytesUntilLimit() / (1 + kNumFloatBytes);
280         float_list->resize(initial_size + num_elements);
281         int64_t index = initial_size;
282         while (!stream.ExpectAtEnd()) {
283           if (!stream.ExpectTag(kFixed32Tag(1))) {
284             return false;
285           }
286           uint32_t buffer32;
287           if (!stream.ReadLittleEndian32(&buffer32)) {
288             return false;
289           }
290           float_list->data()[index] = absl::bit_cast<float>(buffer32);
291           ++index;
292         }
293       }
294     }
295 
296     stream.PopLimit(limit);
297     return true;
298   }
299 
300   template <typename Result>
ParseInt64List(Result * int64_list) const301   bool ParseInt64List(Result *int64_list) const {
302     if (int64_list == nullptr) {
303       return false;
304     }
305     protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized_.data()),
306                                           static_cast<int>(serialized_.size()));
307     uint32_t length;
308     if (!stream.ReadVarint32(&length)) {
309       return false;
310     }
311     const auto limit = stream.PushLimit(static_cast<int>(length));
312 
313     if (!stream.ExpectAtEnd()) {
314       const uint8_t peek_tag = PeekTag(&stream);
315       if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) {
316         return false;
317       }
318       if (peek_tag == kDelimitedTag(1)) {           // packed
319         if (!stream.ExpectTag(kDelimitedTag(1))) {  // packed tag
320           return false;
321         }
322         uint32_t packed_length;
323         if (!stream.ReadVarint32(&packed_length)) {
324           return false;
325         }
326         const auto packed_limit = stream.PushLimit(static_cast<int>(packed_length));
327 
328         while (!stream.ExpectAtEnd()) {
329           uint64_t n;  // There is no API for int64
330           if (!stream.ReadVarint64(&n)) {
331             return false;
332           }
333           int64_list->push_back(static_cast<int64_t>(n));
334         }
335 
336         stream.PopLimit(packed_limit);
337       } else {  // non-packed
338         while (!stream.ExpectAtEnd()) {
339           if (!stream.ExpectTag(kVarintTag(1))) {
340             return false;
341           }
342           uint64_t n;  // There is no API for int64
343           if (!stream.ReadVarint64(&n)) {
344             return false;
345           }
346           int64_list->push_back(static_cast<int64_t>(n));
347         }
348       }
349     }
350     stream.PopLimit(limit);
351     return true;
352   }
353 
354  private:
355   StringPiece serialized_;
356 };
357 
358 using FeatureMapEntry = std::pair<StringPiece, Feature>;
359 using Example = std::vector<FeatureMapEntry>;
360 }  // namespace parsed
361 
SkipExtraneousTag(protobuf::io::CodedInputStream * stream)362 inline bool SkipExtraneousTag(protobuf::io::CodedInputStream *stream) {
363   uint32_t data;
364   uint64_t dummy;
365   constexpr uint32_t kVarint = 0;
366   constexpr uint32_t kFixed64 = 1;
367   constexpr uint32_t kLengthDelimited = 2;
368   constexpr uint32_t kGroupBegin = 3;
369   constexpr uint32_t kGroupEnd = 4;
370   constexpr uint32_t kFixed32 = 5;
371   switch (stream->ReadTag() & 0x7) {
372     case kVarint:  // varint
373       return stream->ReadVarint32(&data);
374     case kFixed64:  // fixed64
375       return stream->ReadLittleEndian64(&dummy);
376     case kLengthDelimited:  // length delimited
377       if (!stream->ReadVarint32(&data)) {
378         return false;
379       }
380       stream->Skip(static_cast<int>(data));
381       return true;
382     case kGroupBegin:  // group begin
383     case kGroupEnd:    // group end
384       return false;    // groups not supported.
385     case kFixed32:     // fixed32
386       return stream->ReadLittleEndian32(&data);
387     default:
388       return false;
389   }
390   return false;  // unrecognized tag type
391 }
392 
ParseString(protobuf::io::CodedInputStream * stream,StringPiece * result)393 bool ParseString(protobuf::io::CodedInputStream *stream, StringPiece *result) {
394   if (stream == nullptr) {
395     return false;
396   }
397   if (result == nullptr) {
398     return false;
399   }
400   uint32_t length;
401   if (!stream->ReadVarint32(&length)) {
402     return false;
403   }
404   if (length == 0) {
405     *result = StringPiece(nullptr, 0);
406     return true;
407   }
408   const void *stream_alias;
409   int stream_size;
410   if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) {
411     return false;
412   }
413   if (static_cast<uint32_t>(stream_size) < length) {
414     return false;
415   }
416   *result = StringPiece(static_cast<const char *>(stream_alias), length);
417   stream->Skip(static_cast<int>(length));
418   return true;
419 }
420 
ParseFeatureMapEntry(protobuf::io::CodedInputStream * stream,parsed::FeatureMapEntry * feature_map_entry)421 bool ParseFeatureMapEntry(protobuf::io::CodedInputStream *stream, parsed::FeatureMapEntry *feature_map_entry) {
422   if (stream == nullptr) {
423     return false;
424   }
425   if (feature_map_entry == nullptr) {
426     return false;
427   }
428   uint32_t length;
429   if (!stream->ReadVarint32(&length)) {
430     return false;
431   }
432   const auto limit = stream->PushLimit(static_cast<int>(length));
433 
434   // Protobufs allow an arbitrary order for the key and value fields.
435   for (int n = 0; n <= 1; ++n) {
436     constexpr uint32_t kNameTag = 1;
437     constexpr uint32_t kFeatureTag = 2;
438     switch (stream->ReadTag()) {
439       case kDelimitedTag(kNameTag):
440         if (!ParseString(stream, &feature_map_entry->first)) {
441           return false;
442         }
443         break;
444 
445       case kDelimitedTag(kFeatureTag): {
446         StringPiece feature_string_piece;
447         if (!ParseString(stream, &feature_string_piece)) {
448           return false;
449         }
450         feature_map_entry->second = parsed::Feature(feature_string_piece);
451         break;
452       }
453 
454       default:
455         return false;
456     }
457   }
458 
459   if (!stream->ExpectAtEnd()) {
460     return false;
461   }
462   stream->PopLimit(limit);
463   return true;
464 }
465 
ParseFeatures(protobuf::io::CodedInputStream * stream,parsed::Example * example)466 bool ParseFeatures(protobuf::io::CodedInputStream *stream, parsed::Example *example) {
467   if (stream == nullptr) {
468     return false;
469   }
470   if (example == nullptr) {
471     return false;
472   }
473   uint32_t length;
474   if (!stream->ReadVarint32(&length)) {
475     return false;
476   }
477   const auto limit = stream->PushLimit(static_cast<int>(length));
478   while (!stream->ExpectAtEnd()) {
479     parsed::FeatureMapEntry feature_map_entry;
480     if (!stream->ExpectTag(kDelimitedTag(1))) {
481       return false;
482     }
483     if (!ParseFeatureMapEntry(stream, &feature_map_entry)) {
484       return false;
485     }
486     example->push_back(std::move(feature_map_entry));
487   }
488   stream->PopLimit(limit);
489   return true;
490 }
491 
ParseExample(protobuf::io::CodedInputStream * stream,parsed::Example * example)492 bool ParseExample(protobuf::io::CodedInputStream *stream, parsed::Example *example) {
493   if (stream == nullptr) {
494     return false;
495   }
496   if (example == nullptr) {
497     return false;
498   }
499   // Loop over the input stream which may contain multiple serialized Example
500   // protos merged together as strings. This behavior is consistent with Proto's
501   // ParseFromString when string representations are concatenated.
502   while (!stream->ExpectAtEnd()) {
503     if (!stream->ExpectTag(kDelimitedTag(1))) {
504       if (!SkipExtraneousTag(stream)) {
505         return false;
506       }
507     } else {
508       if (!ParseFeatures(stream, example)) {
509         return false;
510       }
511     }
512   }
513   return true;
514 }
515 
ParseExample(const StringPiece & serialized,parsed::Example * example)516 bool ParseExample(const StringPiece &serialized, parsed::Example *example) {
517   if (example == nullptr) {
518     return false;
519   }
520   protobuf::io::CodedInputStream stream(reinterpret_cast<const uint8_t *>(serialized.data()),
521                                         static_cast<int>(serialized.size()));
522   return ParseExample(&stream, example);
523 }
524 
525 template <typename T>
526 class TensorVector {
527  public:
528   using value_type = T;
529 
tensor()530   std::shared_ptr<Tensor> tensor() {
531     if (tensor_ == nullptr) {
532       resize(0);
533     }
534     return tensor_;
535   }
536 
size() const537   int64_t size() const { return tensor_ != nullptr ? tensor_->Size() : 0; }
538 
resize(int64_t new_size)539   void resize(int64_t new_size) {
540     if (tensor_ != nullptr) {
541       MS_EXCEPTION(RuntimeError) << "TensorVector has already initialized.";
542     }
543     Status s = Tensor::CreateEmpty(TensorShape({new_size}), DataType::FromCType<T>(), &tensor_);
544     if (s.IsError()) {
545       MS_EXCEPTION(RuntimeError) << s.ToString();
546     }
547     data_ = &*(tensor_->begin<T>());
548   }
549 
data()550   T *data() { return data_; }
551 
data() const552   const T *data() const { return data_; }
553 
554  private:
555   std::shared_ptr<Tensor> tensor_ = nullptr;
556   T *data_ = nullptr;  // the raw data inside the tensor
557 };
558 
559 template <typename T>
CopyOrMoveBlock(const T * b,const T * e,T * t)560 void CopyOrMoveBlock(const T *b, const T *e, T *t) {
561   std::copy(b, e, t);
562 }
563 
LogFeatureRepeated(const StringPiece & feature_name)564 void LogFeatureRepeated(const StringPiece &feature_name) {
565   MS_LOG(WARNING) << "Feature name: " << feature_name << " is repeated in Example. Ignoring all but last one.";
566 }
567 
ReportUnexpectedParseFailure(const StringPiece & feature_name)568 inline Status ReportUnexpectedParseFailure(const StringPiece &feature_name) {
569   RETURN_STATUS_UNEXPECTED("Failed to parse serialized Example of feature name: " + std::string(feature_name));
570 }
571 
ReportUnexpectedDataType(const StringPiece & feature_name,const DataType & dtype)572 inline Status ReportUnexpectedDataType(const StringPiece &feature_name, const DataType &dtype) {
573   RETURN_STATUS_UNEXPECTED("Got unexpected data type: " + dtype.ToString() +
574                            " of feature name: " + std::string(feature_name));
575 }
576 
ReportUnexpectedDataShape(const StringPiece & feature_name)577 inline Status ReportUnexpectedDataShape(const StringPiece &feature_name) {
578   RETURN_STATUS_UNEXPECTED("Column shape of " + std::string(feature_name) +
579                            " defined in schema does not match the shape actually load.");
580 }
581 
CreateUint8TensorFromString(const std::vector<std::string> & bytes_list,std::shared_ptr<Tensor> * column_tensor,const TensorShape & column_shape,const std::string & column_name)582 Status CreateUint8TensorFromString(const std::vector<std::string> &bytes_list, std::shared_ptr<Tensor> *column_tensor,
583                                    const TensorShape &column_shape, const std::string &column_name) {
584   dsize_t total_size =
585     std::accumulate(bytes_list.begin(), bytes_list.end(), 0,
586                     [](dsize_t size, const std::string &str) { return size + static_cast<dsize_t>(str.size()); });
587   TensorShape output_shape = column_shape;
588   if (!column_shape.known()) {
589     output_shape = TensorShape({total_size});
590   } else {
591     CHECK_FAIL_RETURN_UNEXPECTED(
592       output_shape.NumOfElements() == total_size,
593       "Column shape of " + column_name + " defined in schema does not match the shape actually load.");
594   }
595   RETURN_IF_NOT_OK(Tensor::CreateEmpty(output_shape, DataType(DataType::DE_UINT8), column_tensor));
596   ptrdiff_t offset = 0;
597   for (const auto &str : bytes_list) {
598     int ret_code = memcpy_s((*column_tensor)->GetMutableBuffer() + offset, (*column_tensor)->SizeInBytes() - offset,
599                             common::SafeCStr(str), str.size());
600     CHECK_FAIL_RETURN_UNEXPECTED(ret_code == EOK, "Failed to copy string into Tensor.");
601     offset += static_cast<ptrdiff_t>(str.size());
602   }
603   return Status::OK();
604 }
605 
Compute(const TensorRow & input,TensorRow * output)606 Status ParseExampleOp::Compute(const TensorRow &input, TensorRow *output) {
607   IO_CHECK_VECTOR(input, output);
608   if (parallel_parse_) {
609     return ParallelParseExample(input, output);
610   } else {
611     return ParseSingleExample(input, output);
612   }
613 }
614 
ParseSingleKnownShapeColumn(const parsed::Feature & feature,std::shared_ptr<Tensor> * column_tensor,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)615 Status ParseSingleKnownShapeColumn(const parsed::Feature &feature, std::shared_ptr<Tensor> *column_tensor,
616                                    const StringPiece &feature_name, const ColDescriptor &column_descriptor,
617                                    const DataType &example_dtype) {
618   const size_t num_elements = column_descriptor.Shape().NumOfElements();
619   switch (example_dtype.value()) {
620     case DataType::DE_INT64: {
621       const auto data_buffer = reinterpret_cast<int64_t *>((*column_tensor)->GetMutableBuffer());
622       LimitedArraySlice<int64_t> slice(data_buffer, num_elements);
623       if (!feature.ParseInt64List(&slice)) {
624         return ReportUnexpectedParseFailure(feature_name);
625       }
626       if (slice.EndDistance() != 0) {
627         return ReportUnexpectedDataShape(feature_name);
628       }
629       break;
630     }
631     case DataType::DE_FLOAT32: {
632       const auto data_buffer = reinterpret_cast<float *>((*column_tensor)->GetMutableBuffer());
633       LimitedArraySlice<float> slice(data_buffer, num_elements);
634       if (!feature.ParseFloatList(&slice)) {
635         return ReportUnexpectedParseFailure(feature_name);
636       }
637       if (slice.EndDistance() != 0) {
638         return ReportUnexpectedDataShape(feature_name);
639       }
640       break;
641     }
642     case DataType::DE_STRING: {
643       std::vector<std::string> bytes_list;
644       bytes_list.reserve(num_elements);
645       if (!feature.ParseBytesList(&bytes_list)) {
646         return ReportUnexpectedParseFailure(feature_name);
647       }
648       if (column_descriptor.Type().value() == DataType::DE_STRING) {
649         if (bytes_list.size() != num_elements) {
650           return ReportUnexpectedDataShape(feature_name);
651         }
652         TensorShape string_tensor_shape = TensorShape::CreateUnknownRankShape();
653         RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &string_tensor_shape));
654         RETURN_IF_NOT_OK(
655           Tensor::CreateFromVector(bytes_list, string_tensor_shape, DataType(DataType::DE_STRING), column_tensor));
656       } else {
657         // load string or bytes as uint8 tensor
658         RETURN_IF_NOT_OK(
659           CreateUint8TensorFromString(bytes_list, column_tensor, column_descriptor.Shape(), std::string(feature_name)));
660       }
661       break;
662     }
663     default:
664       return ReportUnexpectedDataType(feature_name, example_dtype);
665   }
666   return Status::OK();
667 }
668 
ParseSingleVarLenColumn(const parsed::Feature & feature,std::shared_ptr<Tensor> * column_tensor,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)669 Status ParseSingleVarLenColumn(const parsed::Feature &feature, std::shared_ptr<Tensor> *column_tensor,
670                                const StringPiece &feature_name, const ColDescriptor &column_descriptor,
671                                const DataType &example_dtype) {
672   std::vector<std::string> bytes_list;
673   TensorVector<float> float_list;
674   SmallVector<int64_t> int64_list;
675 
676   size_t num_elements;
677   switch (example_dtype.value()) {
678     case DataType::DE_INT64: {
679       if (!feature.ParseInt64List(&int64_list)) {
680         return ReportUnexpectedParseFailure(feature_name);
681       }
682       num_elements = int64_list.size();
683       break;
684     }
685     case DataType::DE_FLOAT32: {
686       if (!feature.ParseFloatList(&float_list)) {
687         return ReportUnexpectedParseFailure(feature_name);
688       }
689       num_elements = float_list.size();
690       break;
691     }
692     case DataType::DE_STRING: {
693       int actual_num_elements = 0;
694       if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
695         return ReportUnexpectedParseFailure(feature_name);
696       }
697       bytes_list.reserve(actual_num_elements);
698       if (!feature.ParseBytesList(&bytes_list)) {
699         return ReportUnexpectedParseFailure(feature_name);
700       }
701       num_elements = bytes_list.size();
702       break;
703     }
704     default:
705       return ReportUnexpectedDataType(feature_name, example_dtype);
706   }
707 
708   TensorShape column_shape = TensorShape::CreateUnknownRankShape();
709   RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &column_shape));
710 
711   switch (example_dtype.value()) {
712     case DataType::DE_INT64: {
713       RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_shape, example_dtype, column_tensor));
714       CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
715                       reinterpret_cast<int64_t *>((*column_tensor)->GetMutableBuffer()));
716       break;
717     }
718     case DataType::DE_FLOAT32: {
719       RETURN_IF_NOT_OK(Tensor::CreateFromTensor(std::shared_ptr<Tensor>(float_list.tensor()), column_tensor));
720       RETURN_IF_NOT_OK((*column_tensor)->Reshape(column_shape));
721       break;
722     }
723     case DataType::DE_STRING: {
724       if (column_descriptor.Type().value() == DataType::DE_STRING) {
725         RETURN_IF_NOT_OK(
726           Tensor::CreateFromVector(bytes_list, column_shape, DataType(DataType::DE_STRING), column_tensor));
727       } else {
728         // load string or bytes as uint8 tensor
729         RETURN_IF_NOT_OK(CreateUint8TensorFromString(bytes_list, column_tensor, TensorShape::CreateUnknownRankShape(),
730                                                      std::string(feature_name)));
731       }
732       break;
733     }
734     default:
735       return ReportUnexpectedDataType(feature_name, example_dtype);
736   }
737   return Status::OK();
738 }
739 
ParseSingleExample(const TensorRow & raw_bytes,TensorRow * parsed_row)740 Status ParseExampleOp::ParseSingleExample(const TensorRow &raw_bytes, TensorRow *parsed_row) {
741   const auto filename = raw_bytes.getPath().empty() ? "" : raw_bytes.getPath()[0];
742   const auto tensor_iterator = raw_bytes[0]->begin<std::string_view>();
743 
744   const auto example_bytes = std::string(*tensor_iterator);
745   RETURN_IF_NOT_OK(ConstructColumnMap(example_bytes));
746 
747   parsed::Example parsed_example;
748   CHECK_FAIL_RETURN_UNEXPECTED(ParseExample(example_bytes, &parsed_example),
749                                "Failed to parse example bytes: " + example_bytes + " in tfrecord file: " + filename);
750 
751   parsed_row->reserve(data_schema_.NumColumns());
752 
753   for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
754     const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
755     if (column_descriptor.HasKnownShape()) {
756       if (!column_descriptor.Type().IsString()) {
757         DataType type;
758         if (column_descriptor.Type().IsInt() || column_descriptor.Type().IsBool()) {
759           type = DataType(DataType::DE_INT64);
760         } else if (column_descriptor.Type().IsFloat()) {
761           type = DataType(DataType::DE_FLOAT32);
762         }
763         std::shared_ptr<Tensor> column_tensor;
764         RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_descriptor.Shape(), type, &column_tensor));
765         parsed_row->emplace_back(std::move(column_tensor));
766       } else {
767         parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
768       }
769     } else {
770       MS_LOG(INFO) << "Shape of column name: " << column_descriptor.Name() << " is not defined.";
771       parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
772     }
773   }
774 
775   std::vector<bool> feature_already_seen(data_schema_.NumColumns(), false);
776   std::vector<std::string> file_paths;
777 
778   const size_t parsed_example_size = parsed_example.size();
779   for (size_t i = 0; i < parsed_example_size; ++i) {
780     // This is a logic that standard protobuf parsing is implementing.
781     // I.e. last entry in the map overwrites all the previous ones.
782     parsed::FeatureMapEntry &name_and_feature = parsed_example[parsed_example_size - i - 1];
783 
784     const StringPiece &feature_name = name_and_feature.first;
785     parsed::Feature &feature = name_and_feature.second;
786 
787     if (column_name_id_map_.find(std::string(feature_name)) == column_name_id_map_.end()) {
788       MS_LOG(INFO) << "Feature name: " << feature_name << " is not in schema, skip it.";
789       continue;
790     }
791 
792     const auto column_index = column_name_id_map_[std::string(feature_name)];
793 
794     DataType example_dtype;
795     RETURN_IF_NOT_OK(feature.ParseDataType(&example_dtype));
796     if (example_dtype == DataType::DE_UNKNOWN) {
797       continue;
798     }
799 
800     // If feature was already visited, skip.
801     if (feature_already_seen[column_index]) {
802       LogFeatureRepeated(feature_name);
803       continue;
804     }
805     feature_already_seen[column_index] = true;
806 
807     const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
808     bool type_cast_flag = false;
809     if (example_dtype != column_descriptor.Type()) {
810       const std::string msg =
811         "The data type loaded from the example for feature name: " + column_descriptor.Name() +
812         " does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
813         ", but the predefined type: " + column_descriptor.Type().ToString();
814       if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
815         MS_LOG(INFO) << msg << ". This will cause a type cast.";
816         type_cast_flag = true;
817       } else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
818         // allow to read data of type string or bytes into an uint8 tensor
819         RETURN_STATUS_UNEXPECTED(msg);
820       }
821     }
822 
823     if (column_descriptor.HasKnownShape()) {
824       RETURN_IF_NOT_OK(ParseSingleKnownShapeColumn(feature, &(*parsed_row)[column_index], feature_name,
825                                                    column_descriptor, example_dtype));
826     } else {  // if variable length
827       RETURN_IF_NOT_OK(
828         ParseSingleVarLenColumn(feature, &(*parsed_row)[column_index], feature_name, column_descriptor, example_dtype));
829     }
830     if (type_cast_flag) {
831       std::shared_ptr<Tensor> cast_out;
832       RETURN_IF_NOT_OK(TypeCast((*parsed_row)[column_index], &cast_out, column_descriptor.Type()));
833       (*parsed_row)[column_index] = cast_out;
834     }
835     file_paths.push_back(filename);
836   }
837 
838   for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
839     CHECK_FAIL_RETURN_UNEXPECTED(feature_already_seen[column_index],
840                                  "Feature name: " + data_schema_.Column(column_index).Name() +
841                                    " is required in schema but could not be found in tfrecord file.");
842   }
843 
844   parsed_row->setPath(file_paths);
845   return Status::OK();
846 }
847 
CalculateNumMiniBatch(const std::shared_ptr<Tensor> & batch_tensor)848 size_t CalculateNumMiniBatch(const std::shared_ptr<Tensor> &batch_tensor) {
849   // This parameter affects performance in a big and data-dependent way.
850   constexpr size_t kMiniBatchSizeBytes = 50000;
851 
852   const size_t batch_size = batch_tensor->shape()[0];
853 
854   size_t result = 0;
855   size_t minibatch_bytes = 0;
856   for (size_t i = 0; i < batch_size; i++) {
857     if (minibatch_bytes == 0) {  // start minibatch
858       result++;
859     }
860     std::string_view tensor_value;
861     batch_tensor->GetItemAt(&tensor_value, {static_cast<dsize_t>(i)});
862     minibatch_bytes += tensor_value.size() + 1;
863     if (minibatch_bytes > kMiniBatchSizeBytes) {
864       minibatch_bytes = 0;
865     }
866   }
867   // 'special logic'
868   const size_t min_minibatches = std::min<size_t>(8, batch_size);
869   constexpr size_t max_minibatches = 64;
870   return std::max<size_t>(min_minibatches, std::min<size_t>(max_minibatches, result));
871 }
872 
873 class BlockingCounter {
874  public:
BlockingCounter(const uint32_t initial_count)875   explicit BlockingCounter(const uint32_t initial_count) : state_(initial_count << 1), notified_(false) {
876     if ((initial_count << 1) >> 1 != initial_count) {
877       MS_EXCEPTION(RuntimeError) << "Value of initial_count exceeds upper limit: " << initial_count;
878     }
879   }
880 
881   ~BlockingCounter() = default;
882 
DecrementCount()883   inline void DecrementCount() {
884     constexpr uint32_t kStep = 2;
885     uint32_t new_state = state_.fetch_sub(kStep, std::memory_order_acq_rel) - kStep;
886     if (new_state != 1) {
887       if (((new_state + kStep) & ~1) == 0) {
888         MS_EXCEPTION(RuntimeError) << "The number of remaining worker threads is already 0.";
889       }
890       return;  // either count has not dropped to 0, or waiter is not waiting
891     }
892     std::unique_lock<std::mutex> lock(mutex_);
893     if (notified_) {
894       MS_EXCEPTION(RuntimeError) << "Try to awake a notified worker.";
895     }
896     notified_ = true;
897     cond_var_.notify_all();
898   }
899 
Wait()900   inline void Wait() {
901     uint32_t new_state = state_.fetch_or(1, std::memory_order_acq_rel);
902     if ((new_state >> 1) == 0) {
903       return;
904     }
905     std::unique_lock<std::mutex> lock(mutex_);
906     while (!notified_) {
907       cond_var_.wait(lock);
908     }
909   }
910 
911   // Wait for the specified time, return false iff the count has not dropped to
912   // zero before the timeout expired.
WaitFor(std::chrono::milliseconds millisecond)913   inline bool WaitFor(std::chrono::milliseconds millisecond) {
914     uint32_t new_state = state_.fetch_or(1, std::memory_order_acq_rel);
915     if ((new_state >> 1) == 0) {
916       return true;
917     }
918     std::unique_lock<std::mutex> lock(mutex_);
919     while (!notified_) {
920       const std::cv_status status = cond_var_.wait_for(lock, millisecond);
921       if (status == std::cv_status::timeout) {
922         return false;
923       }
924     }
925     return true;
926   }
927 
928  private:
929   std::mutex mutex_;
930   std::condition_variable cond_var_;
931   std::atomic<uint32_t> state_;  // low bit is waiter flag
932   bool notified_;
933 };
934 
ParallelFor(const std::function<void (size_t)> & function,const size_t task_count,const std::unique_ptr<Eigen::ThreadPool> & thread_pool)935 void ParallelFor(const std::function<void(size_t)> &function, const size_t task_count,
936                  const std::unique_ptr<Eigen::ThreadPool> &thread_pool) {
937   if (task_count == 0) {
938     return;
939   }
940   if (thread_pool == nullptr) {
941     for (size_t i = 0; i < task_count; ++i) {
942       function(i);
943     }
944   } else {
945     BlockingCounter counter(task_count - 1);
946     for (size_t i = 1; i < task_count; ++i) {
947       thread_pool->Schedule([i, &function, &counter] {
948         function(i);
949         counter.DecrementCount();
950       });
951     }
952     function(0);
953     counter.Wait();
954   }
955 }
956 
FillAndCopyVarLenTensor(const std::vector<std::vector<VarLenTensorBuffer>> & minibatch_row_buffer,std::shared_ptr<Tensor> * column_tensor,const size_t column_index)957 Status FillAndCopyVarLenTensor(const std::vector<std::vector<VarLenTensorBuffer>> &minibatch_row_buffer,
958                                std::shared_ptr<Tensor> *column_tensor, const size_t column_index) {
959   ptrdiff_t buffer_offset = 0;
960   for (const auto &minibatch_row : minibatch_row_buffer) {
961     const auto &minibatch_tensor = minibatch_row[column_index].numeric_tensor;
962     for (const auto &varlen_tensor : minibatch_tensor) {
963       const auto tensor_buffer_size = varlen_tensor->SizeInBytes();
964       const errno_t copy_status =
965         memcpy_s((*column_tensor)->GetMutableBuffer() + buffer_offset, (*column_tensor)->SizeInBytes() - buffer_offset,
966                  varlen_tensor->GetBuffer(), tensor_buffer_size);
967       CHECK_FAIL_RETURN_UNEXPECTED(copy_status == EOK,
968                                    "Failed to copy tensor to batch, got error_t: " + std::to_string(copy_status));
969       buffer_offset += tensor_buffer_size;
970     }
971   }
972   return Status::OK();
973 }
974 
FillAndCopyVarLenString(const std::vector<std::vector<VarLenTensorBuffer>> & minibatch_row_buffer,std::shared_ptr<Tensor> * column_tensor,const size_t column_index,const ColDescriptor & column_descriptor,dsize_t batch_size)975 Status FillAndCopyVarLenString(const std::vector<std::vector<VarLenTensorBuffer>> &minibatch_row_buffer,
976                                std::shared_ptr<Tensor> *column_tensor, const size_t column_index,
977                                const ColDescriptor &column_descriptor, dsize_t batch_size) {
978   std::vector<std::string> string_buffer;
979   dsize_t element_size = 0;
980   for (const auto &minibatch_row : minibatch_row_buffer) {
981     const auto string_length = minibatch_row[column_index].string_length;
982     if (element_size == 0) {
983       element_size = static_cast<dsize_t>(string_length);
984     } else {
985       CHECK_FAIL_RETURN_UNEXPECTED(string_length == element_size,
986                                    "Could not batch string or bytes tensors with different shapes.");
987     }
988     const auto &minibatch_string = minibatch_row[column_index].string_tensor;
989     string_buffer.insert(string_buffer.end(), minibatch_string.begin(), minibatch_string.end());
990   }
991 
992   std::vector<dsize_t> shape;
993   if (element_size != 0) {
994     shape = {batch_size, element_size};
995   } else {
996     shape = {batch_size};
997   }
998   const auto column_shape = TensorShape(shape);
999   if (column_descriptor.Type().value() == DataType::DE_STRING) {
1000     RETURN_IF_NOT_OK(
1001       Tensor::CreateFromVector(string_buffer, column_shape, DataType(DataType::DE_STRING), column_tensor));
1002   } else {
1003     RETURN_IF_NOT_OK(CreateUint8TensorFromString(string_buffer, column_tensor, column_shape, column_descriptor.Name()));
1004   }
1005   return Status::OK();
1006 }
1007 
MergeDenseVarLenMiniBatches(const std::vector<std::vector<VarLenTensorBuffer>> & varlen_dense_buffers,TensorRow * parsed_row,int32_t column_index,const DataSchema & data_schema,dsize_t batch_size)1008 Status MergeDenseVarLenMiniBatches(const std::vector<std::vector<VarLenTensorBuffer>> &varlen_dense_buffers,
1009                                    TensorRow *parsed_row, int32_t column_index, const DataSchema &data_schema,
1010                                    dsize_t batch_size) {
1011   const ColDescriptor &column_descriptor = data_schema.Column(column_index);
1012   if (column_descriptor.HasKnownShape()) {
1013     return Status::OK();
1014   }
1015   std::shared_ptr<Tensor> column_tensor;
1016   if (!varlen_dense_buffers[0][column_index].numeric_tensor.empty()) {
1017     const TensorShape column_shape =
1018       varlen_dense_buffers[0][column_index].numeric_tensor[0]->shape().InsertDim(0, batch_size);
1019     RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_shape, column_descriptor.Type(), &column_tensor));
1020     RETURN_IF_NOT_OK(FillAndCopyVarLenTensor(varlen_dense_buffers, &column_tensor, column_index));
1021   } else {
1022     RETURN_IF_NOT_OK(
1023       FillAndCopyVarLenString(varlen_dense_buffers, &column_tensor, column_index, column_descriptor, batch_size));
1024   }
1025   (*parsed_row)[column_index] = column_tensor;
1026   return Status::OK();
1027 }
1028 
ParallelParseExample(const TensorRow & raw_bytes,TensorRow * parsed_row)1029 Status ParseExampleOp::ParallelParseExample(const TensorRow &raw_bytes, TensorRow *parsed_row) {
1030   Tensor::TensorIterator tensor_iterator = raw_bytes[0]->begin<std::string_view>();
1031   RETURN_IF_NOT_OK(ConstructColumnMap(std::string(*tensor_iterator)));
1032   parsed_row->reserve(data_schema_.NumColumns());
1033 
1034   auto batch_size = raw_bytes[0]->shape()[0];
1035   std::vector<bool> type_cast_flag(data_schema_.NumColumns(), false);
1036   std::vector<bool> varlen_column(data_schema_.NumColumns(), false);
1037   std::unordered_map<int32_t, std::vector<std::string>> string_column_map;
1038   for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
1039     const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1040     if (column_descriptor.HasKnownShape()) {
1041       if (!column_descriptor.Type().IsString()) {
1042         auto column_shape = column_descriptor.Shape().InsertDim(0, batch_size);
1043         DataType type;
1044         if (column_descriptor.Type().IsInt() || column_descriptor.Type().IsBool()) {
1045           if (column_descriptor.Type().value() != DataType::DE_INT64) {
1046             type_cast_flag[column_index] = true;
1047           }
1048           type = DataType(DataType::DE_INT64);
1049         } else if (column_descriptor.Type().IsFloat()) {
1050           if (column_descriptor.Type().value() != DataType::DE_FLOAT32) {
1051             type_cast_flag[column_index] = true;
1052           }
1053           type = DataType(DataType::DE_FLOAT32);
1054         }
1055         std::shared_ptr<Tensor> column_tensor;
1056         RETURN_IF_NOT_OK(Tensor::CreateEmpty(column_shape, type, &column_tensor));
1057         parsed_row->emplace_back(std::move(column_tensor));
1058         if (column_descriptor.Type().value() == DataType::DE_UINT8) {
1059           string_column_map[column_index] =
1060             std::vector<std::string>(batch_size * column_descriptor.Shape().NumOfElements());
1061         }
1062       } else {
1063         parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
1064         string_column_map[column_index] =
1065           std::vector<std::string>(batch_size * column_descriptor.Shape().NumOfElements());
1066       }
1067     } else {
1068       MS_LOG(INFO) << "Shape of column name: " << column_descriptor.Name() << " is not defined.";
1069       varlen_column[column_index] = true;
1070       parsed_row->emplace_back(std::make_shared<Tensor>(TensorShape({}), DataType(DataType::DE_UNKNOWN)));
1071     }
1072   }
1073 
1074   // Calculate number of minibatches.
1075   // In main regime make each minibatch around kMiniBatchSizeBytes bytes.
1076   // Apply 'special logic' below for small and big regimes.
1077   const size_t num_minibatches = CalculateNumMiniBatch(raw_bytes[0]);
1078 
1079   auto first_example_of_minibatch = [&](const size_t minibatch) -> size_t {
1080     return (batch_size * minibatch) / num_minibatches;
1081   };
1082 
1083   std::vector<std::vector<VarLenTensorBuffer>> varlen_dense_buffers(num_minibatches);
1084   std::vector<Status> status_of_minibatch(num_minibatches);
1085   auto ProcessMiniBatch = [&](const size_t minibatch) {
1086     varlen_dense_buffers[minibatch].resize(data_schema_.NumColumns());
1087     const auto start = first_example_of_minibatch(minibatch);
1088     const auto end = first_example_of_minibatch(minibatch + 1);
1089     for (auto tensor_index = start; tensor_index < end; ++tensor_index) {
1090       status_of_minibatch[minibatch] =
1091         ParseSerializedExample(static_cast<std::string>(*tensor_iterator.operator+(static_cast<dsize_t>(tensor_index))),
1092                                parsed_row, &string_column_map, &varlen_dense_buffers[minibatch], tensor_index);
1093       if (!status_of_minibatch[minibatch].IsOk()) {
1094         break;
1095       }
1096     }
1097   };
1098 
1099   ParallelFor(ProcessMiniBatch, num_minibatches, pool_);
1100 
1101   for (Status &status : status_of_minibatch) {
1102     RETURN_IF_NOT_OK(status);
1103   }
1104 
1105   for (auto string_column = string_column_map.begin(); string_column != string_column_map.end(); ++string_column) {
1106     auto column_index = string_column->first;
1107     const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1108     auto column_shape = column_descriptor.Shape().InsertDim(0, batch_size);
1109     std::shared_ptr<Tensor> string_tensor;
1110     if (column_descriptor.Type().value() == DataType::DE_STRING) {
1111       RETURN_IF_NOT_OK(
1112         Tensor::CreateFromVector(string_column->second, column_shape, DataType(DataType::DE_STRING), &string_tensor));
1113     } else {
1114       // load string or bytes as uint8 tensor
1115       RETURN_IF_NOT_OK(
1116         CreateUint8TensorFromString(string_column->second, &string_tensor, column_shape, column_descriptor.Name()));
1117       type_cast_flag[column_index] = false;
1118     }
1119     (*parsed_row)[column_index] = string_tensor;
1120   }
1121 
1122   for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
1123     if (type_cast_flag[column_index]) {
1124       const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1125       std::shared_ptr<Tensor> cast_out;
1126       RETURN_IF_NOT_OK(TypeCast((*parsed_row)[column_index], &cast_out, column_descriptor.Type()));
1127       (*parsed_row)[column_index] = cast_out;
1128     } else if (varlen_column[column_index]) {
1129       RETURN_IF_NOT_OK(
1130         MergeDenseVarLenMiniBatches(varlen_dense_buffers, parsed_row, column_index, data_schema_, batch_size));
1131     }
1132   }
1133   return Status::OK();
1134 }
1135 
ParseSerializedKnownShapeColumn(const parsed::Feature & feature,TensorRow * parsed_row,std::unordered_map<int32_t,std::vector<std::string>> * string_col_map,const int32_t column_index,const size_t tensor_index,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)1136 Status ParseSerializedKnownShapeColumn(const parsed::Feature &feature, TensorRow *parsed_row,
1137                                        std::unordered_map<int32_t, std::vector<std::string>> *string_col_map,
1138                                        const int32_t column_index, const size_t tensor_index,
1139                                        const StringPiece &feature_name, const ColDescriptor &column_descriptor,
1140                                        const DataType &example_dtype) {
1141   std::shared_ptr<Tensor> &column_tensor = (*parsed_row)[column_index];
1142   if (example_dtype != column_descriptor.Type()) {
1143     const std::string msg =
1144       "The data type loaded from the example for feature name: " + column_descriptor.Name() +
1145       " does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
1146       ", but the predefined type: " + column_descriptor.Type().ToString();
1147     if (example_dtype == column_tensor->type()) {
1148       // if the actual data type is the same as the pre-allocated tensor,
1149       // we can first read it into the tensor, then cast to the type specified by the schema
1150       MS_LOG(INFO) << msg << ". This will cause a type cast.";
1151     } else if (!example_dtype.IsString() || column_descriptor.Type().value() != DataType::DE_UINT8) {
1152       // allow to read data of type string or bytes into an uint8 tensor
1153       RETURN_STATUS_UNEXPECTED(msg);
1154     }
1155   }
1156 
1157   const std::size_t num_elements = column_descriptor.Shape().NumOfElements();
1158   switch (example_dtype.value()) {
1159     case DataType::DE_INT64: {
1160       const auto data_buffer =
1161         reinterpret_cast<int64_t *>(column_tensor->GetMutableBuffer()) + tensor_index * num_elements;
1162       LimitedArraySlice<int64_t> slice(data_buffer, num_elements);
1163       if (!feature.ParseInt64List(&slice)) {
1164         return ReportUnexpectedParseFailure(feature_name);
1165       }
1166       if (slice.EndDistance() != 0) {
1167         return ReportUnexpectedDataShape(feature_name);
1168       }
1169       break;
1170     }
1171     case DataType::DE_FLOAT32: {
1172       const auto data_buffer =
1173         reinterpret_cast<float *>(column_tensor->GetMutableBuffer()) + tensor_index * num_elements;
1174       LimitedArraySlice<float> slice(data_buffer, num_elements);
1175       if (!feature.ParseFloatList(&slice)) {
1176         return ReportUnexpectedParseFailure(feature_name);
1177       }
1178       if (slice.EndDistance() != 0) {
1179         return ReportUnexpectedDataShape(feature_name);
1180       }
1181       break;
1182     }
1183     case DataType::DE_STRING: {
1184       const auto data_buffer = &(*string_col_map)[column_index][tensor_index * num_elements];
1185       LimitedArraySlice<std::string> slice(data_buffer, num_elements);
1186       if (!feature.ParseBytesList(&slice)) {
1187         return ReportUnexpectedParseFailure(feature_name);
1188       }
1189       if (slice.EndDistance() != 0) {
1190         return ReportUnexpectedDataShape(feature_name);
1191       }
1192       break;
1193     }
1194     default:
1195       return ReportUnexpectedDataType(feature_name, example_dtype);
1196   }
1197   return Status::OK();
1198 }
1199 
PushStringToBuffer(const std::vector<std::string> & bytes_list,VarLenTensorBuffer * varlen_tensor_buffer,const ColDescriptor & column_descriptor)1200 Status PushStringToBuffer(const std::vector<std::string> &bytes_list, VarLenTensorBuffer *varlen_tensor_buffer,
1201                           const ColDescriptor &column_descriptor) {
1202   if (column_descriptor.Type().value() == DataType::DE_STRING) {
1203     // check that each sample contains the same number of strings
1204     if (varlen_tensor_buffer->string_length != 0) {
1205       CHECK_FAIL_RETURN_UNEXPECTED(varlen_tensor_buffer->string_length == bytes_list.size(),
1206                                    "Could not batch string Tensors with different shapes.");
1207     } else {
1208       if (column_descriptor.Rank() != 0) {
1209         varlen_tensor_buffer->string_length = bytes_list.size();
1210       } else {
1211         varlen_tensor_buffer->string_length = 0;
1212       }
1213     }
1214     for (auto &bytes : bytes_list) {
1215       varlen_tensor_buffer->string_tensor.emplace_back(bytes);
1216     }
1217   } else if (column_descriptor.Type().value() == DataType::DE_UINT8) {
1218     size_t total_size = 0;
1219     for (auto &bytes : bytes_list) {
1220       total_size += bytes.size();
1221       varlen_tensor_buffer->string_tensor.emplace_back(bytes);
1222     }
1223     if (varlen_tensor_buffer->string_length != 0) {
1224       CHECK_FAIL_RETURN_UNEXPECTED(varlen_tensor_buffer->string_length == total_size,
1225                                    "Could not batch bytes Tensors with different shapes.");
1226     } else {
1227       varlen_tensor_buffer->string_length = total_size;
1228     }
1229   }
1230   return Status::OK();
1231 }
1232 
ParseSerializedVarLenColumn(const parsed::Feature & feature,VarLenTensorBuffer * varlen_tensor_buffer,const StringPiece & feature_name,const ColDescriptor & column_descriptor,const DataType & example_dtype)1233 Status ParseSerializedVarLenColumn(const parsed::Feature &feature, VarLenTensorBuffer *varlen_tensor_buffer,
1234                                    const StringPiece &feature_name, const ColDescriptor &column_descriptor,
1235                                    const DataType &example_dtype) {
1236   bool type_cast_flag = false;
1237   if (example_dtype != column_descriptor.Type()) {
1238     const std::string msg =
1239       "The data type loaded from the example for feature name: " + column_descriptor.Name() +
1240       " does not match the predefined type in schema, the actual type: " + example_dtype.ToString() +
1241       ", but the predefined type: " + column_descriptor.Type().ToString();
1242     if (!example_dtype.IsString() && !column_descriptor.Type().IsString()) {
1243       MS_LOG(INFO) << msg << ". This will cause a type cast.";
1244       type_cast_flag = true;
1245     } else if (column_descriptor.Type().value() != DataType::DE_UINT8) {
1246       // allow to read data of type string or bytes into an uint8 tensor
1247       RETURN_STATUS_UNEXPECTED(msg);
1248     }
1249   }
1250 
1251   size_t num_elements;
1252   SmallVector<int64_t> int64_list;
1253   TensorVector<float> float_list;
1254   std::vector<std::string> bytes_list;
1255   switch (example_dtype.value()) {
1256     case DataType::DE_INT64: {
1257       if (!feature.ParseInt64List(&int64_list)) {
1258         return ReportUnexpectedParseFailure(feature_name);
1259       }
1260       num_elements = int64_list.size();
1261       break;
1262     }
1263     case DataType::DE_FLOAT32: {
1264       if (!feature.ParseFloatList(&float_list)) {
1265         return ReportUnexpectedParseFailure(feature_name);
1266       }
1267       num_elements = float_list.size();
1268       break;
1269     }
1270     case DataType::DE_STRING: {
1271       int actual_num_elements = 0;
1272       if (!feature.GetNumElementsInBytesList(&actual_num_elements)) {
1273         return ReportUnexpectedParseFailure(feature_name);
1274       }
1275       bytes_list.reserve(actual_num_elements);
1276       if (!feature.ParseBytesList(&bytes_list)) {
1277         return ReportUnexpectedParseFailure(feature_name);
1278       }
1279       num_elements = bytes_list.size();
1280       break;
1281     }
1282     default:
1283       return ReportUnexpectedDataType(feature_name, example_dtype);
1284   }
1285 
1286   TensorShape varlen_tensor_shape = TensorShape::CreateUnknownRankShape();
1287   RETURN_IF_NOT_OK(column_descriptor.MaterializeTensorShape(num_elements, &varlen_tensor_shape));
1288   std::shared_ptr<Tensor> varlen_tensor;
1289   switch (example_dtype.value()) {
1290     case DataType::DE_INT64: {
1291       RETURN_IF_NOT_OK(Tensor::CreateEmpty(varlen_tensor_shape, example_dtype, &varlen_tensor));
1292       CopyOrMoveBlock(int64_list.begin(), int64_list.end(),
1293                       reinterpret_cast<int64_t *>(varlen_tensor->GetMutableBuffer()));
1294       if (type_cast_flag) {
1295         std::shared_ptr<Tensor> casted_varlen_tensor;
1296         RETURN_IF_NOT_OK(TypeCast(varlen_tensor, &casted_varlen_tensor, column_descriptor.Type()));
1297         varlen_tensor_buffer->numeric_tensor.emplace_back(casted_varlen_tensor);
1298       } else {
1299         varlen_tensor_buffer->numeric_tensor.emplace_back(varlen_tensor);
1300       }
1301       break;
1302     }
1303     case DataType::DE_FLOAT32: {
1304       RETURN_IF_NOT_OK(Tensor::CreateFromTensor(std::shared_ptr<Tensor>(float_list.tensor()), &varlen_tensor));
1305       RETURN_IF_NOT_OK(varlen_tensor->Reshape(varlen_tensor_shape));
1306       if (type_cast_flag) {
1307         std::shared_ptr<Tensor> casted_varlen_tensor;
1308         RETURN_IF_NOT_OK(TypeCast(varlen_tensor, &casted_varlen_tensor, column_descriptor.Type()));
1309         varlen_tensor_buffer->numeric_tensor.emplace_back(casted_varlen_tensor);
1310       } else {
1311         varlen_tensor_buffer->numeric_tensor.emplace_back(varlen_tensor);
1312       }
1313       break;
1314     }
1315     case DataType::DE_STRING: {
1316       RETURN_IF_NOT_OK(PushStringToBuffer(bytes_list, varlen_tensor_buffer, column_descriptor));
1317       break;
1318     }
1319     default:
1320       return ReportUnexpectedDataType(feature_name, example_dtype);
1321   }
1322   return Status::OK();
1323 }
1324 
ParseSerializedExample(const std::string & example_bytes,TensorRow * parsed_row,std::unordered_map<int32_t,std::vector<std::string>> * string_column_map,std::vector<VarLenTensorBuffer> * varlen_tensor_vector,const size_t tensor_index)1325 Status ParseExampleOp::ParseSerializedExample(const std::string &example_bytes, TensorRow *parsed_row,
1326                                               std::unordered_map<int32_t, std::vector<std::string>> *string_column_map,
1327                                               std::vector<VarLenTensorBuffer> *varlen_tensor_vector,
1328                                               const size_t tensor_index) {
1329   parsed::Example parsed_example;
1330   CHECK_FAIL_RETURN_UNEXPECTED(ParseExample(example_bytes, &parsed_example),
1331                                "Failed to parse example bytes: " + example_bytes);
1332 
1333   const size_t parsed_example_size = parsed_example.size();
1334   std::vector<bool> feature_already_seen(data_schema_.NumColumns(), false);
1335   for (size_t i = 0; i < parsed_example_size; ++i) {
1336     // This is a logic that standard protobuf parsing is implementing.
1337     // I.e. last entry in the map overwrites all the previous ones.
1338     parsed::FeatureMapEntry &name_and_feature = parsed_example[parsed_example_size - i - 1];
1339     const StringPiece &feature_name = name_and_feature.first;
1340     parsed::Feature &feature = name_and_feature.second;
1341 
1342     if (column_name_id_map_.find(std::string(feature_name)) == column_name_id_map_.end()) {
1343       MS_LOG(INFO) << "Feature name: " << feature_name << " is not in schema, skip it.";
1344       continue;
1345     }
1346 
1347     DataType example_dtype;
1348     RETURN_IF_NOT_OK(feature.ParseDataType(&example_dtype));
1349     if (example_dtype == DataType::DE_UNKNOWN) {
1350       continue;
1351     }
1352 
1353     const auto column_index = column_name_id_map_[std::string(feature_name)];
1354     // If feature was already visited, skip.
1355     if (feature_already_seen[column_index]) {
1356       LogFeatureRepeated(feature_name);
1357       continue;
1358     }
1359     feature_already_seen[column_index] = true;
1360 
1361     const ColDescriptor &column_descriptor = data_schema_.Column(column_index);
1362     if (column_descriptor.HasKnownShape()) {
1363       RETURN_IF_NOT_OK(ParseSerializedKnownShapeColumn(feature, parsed_row, string_column_map, column_index,
1364                                                        tensor_index, feature_name, column_descriptor, example_dtype));
1365     } else {  // if variable length
1366       RETURN_IF_NOT_OK(ParseSerializedVarLenColumn(feature, &(*varlen_tensor_vector)[column_index], feature_name,
1367                                                    column_descriptor, example_dtype));
1368     }
1369   }
1370 
1371   for (int32_t column_index = 0; column_index < data_schema_.NumColumns(); ++column_index) {
1372     if (!feature_already_seen[column_index]) {
1373       RETURN_STATUS_UNEXPECTED("Feature name: " + data_schema_.Column(column_index).Name() +
1374                                " is required in schema but could not be found in tfrecord file.");
1375     }
1376   }
1377   return Status::OK();
1378 }
1379 
ConstructColumnMap(const std::string & example_bytes)1380 Status ParseExampleOp::ConstructColumnMap(const std::string &example_bytes) {
1381   if (column_name_id_map_.empty()) {
1382     if (data_schema_.Empty()) {
1383       dataengine::Example example;
1384       if (!example.ParseFromString(example_bytes)) {
1385         RETURN_STATUS_UNEXPECTED("Failed to parse example bytes: " + std::string(example_bytes));
1386       }
1387 
1388       const dataengine::Features &example_features = example.features();
1389       const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
1390       if (column_list_.empty()) {
1391         (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(column_list_),
1392                              [](const auto &it) -> std::string { return it.first; });
1393         std::sort(column_list_.begin(), column_list_.end());
1394       }
1395 
1396       for (const auto &column_name : column_list_) {
1397         auto it = feature_map.find(column_name);
1398         if (it == feature_map.end()) {
1399           RETURN_STATUS_UNEXPECTED("Invalid column list, failed to find column name: " + column_name + " in example.");
1400         }
1401 
1402         std::string column_type;
1403         const dataengine::Feature &feature = it->second;
1404         switch (feature.kind_case()) {
1405           case dataengine::Feature::KindCase::kBytesList:
1406             column_type = "uint8";
1407             break;
1408           case dataengine::Feature::KindCase::kFloatList:
1409             column_type = "float32";
1410             break;
1411           case dataengine::Feature::KindCase::kInt64List:
1412             column_type = "int64";
1413             break;
1414           default:
1415             RETURN_STATUS_UNEXPECTED("Unsupported column type, the column type of " + column_name +
1416                                      " should be int64, float32 or string.");
1417         }
1418         RETURN_IF_NOT_OK(
1419           data_schema_.AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1)));
1420       }
1421     }
1422     RETURN_IF_NOT_OK(data_schema_.GetColumnNameMap(&column_name_id_map_));
1423     CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map_.empty(), "Can not get column name map, it is empty.");
1424   }
1425   return Status::OK();
1426 }
1427 }  // namespace mindspore::dataset
1428