• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
17 
18 #include <algorithm>
19 #include <fstream>
20 #include <future>
21 #include <memory>
22 #include <mutex>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "proto/example.pb.h"
28 
29 #include "minddata/dataset/engine/data_schema.h"
30 #include "minddata/dataset/engine/datasetops/source/io_block.h"
31 #include "minddata/dataset/engine/execution_tree.h"
32 #include "minddata/dataset/engine/jagged_connector.h"
33 #include "minddata/dataset/util/status.h"
34 #include "minddata/dataset/util/task_manager.h"
35 #include "minddata/dataset/util/wait_post.h"
36 #include "utils/file_utils.h"
37 #include "utils/system/crc32c.h"
38 
39 namespace mindspore {
40 namespace dataset {
TFReaderOp(int32_t num_workers,int32_t worker_connector_size,int64_t total_num_rows,std::vector<std::string> dataset_files_list,std::unique_ptr<DataSchema> data_schema,int32_t op_connector_size,std::vector<std::string> columns_to_load,bool shuffle_files,int32_t num_devices,int32_t device_id,bool equal_rows_per_shard,const CompressionType & compression_type,bool decode)41 TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
42                        std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
43                        int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
44                        int32_t num_devices, int32_t device_id, bool equal_rows_per_shard,
45                        const CompressionType &compression_type, bool decode)
46     : NonMappableLeafOp(num_workers, worker_connector_size, total_num_rows, op_connector_size, shuffle_files,
47                         num_devices, device_id, compression_type),
48       dataset_files_list_(std::move(dataset_files_list)),
49       columns_to_load_(std::move(columns_to_load)),
50       data_schema_(std::move(data_schema)),
51       equal_rows_per_shard_(equal_rows_per_shard),
52       decode_(decode) {}
53 
54 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const55 void TFReaderOp::Print(std::ostream &out, bool show_all) const {
56   if (!show_all) {
57     // Call the super class for displaying any common 1-liner info
58     ParallelOp::Print(out, show_all);
59     // Then show any custom derived-internal 1-liner info for this op
60     out << "\n";
61   } else {
62     // Call the super class for displaying any common detailed info
63     ParallelOp::Print(out, show_all);
64     // Then show any custom derived-internal stuff
65     out << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
66         << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
67         << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n";
68     for (const auto &i : dataset_files_list_) {
69       out << " " << i;
70     }
71     if (!columns_to_load_.empty()) {
72       out << "\nColumns to load:\n";
73       for (const auto &j : columns_to_load_) {
74         out << " " << j;
75       }
76     }
77     out << "\nData Schema:\n";
78     out << *data_schema_ << "\n\n";
79   }
80 }
81 
Init()82 Status TFReaderOp::Init() {
83   if (data_schema_->Empty()) {
84     RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_));
85   }
86 
87   if (total_rows_ == 0) {
88     total_rows_ = data_schema_->NumRows();
89   }
90   if (total_rows_ < 0) {
91     RETURN_STATUS_UNEXPECTED(
92       "[Internal ERROR] num_samples or num_rows for TFRecordDataset must be greater than 0, but got: " +
93       std::to_string(total_rows_));
94   } else if (compression_type_ != CompressionType::NONE && total_rows_ == 0) {
95     MS_LOG(WARNING) << "Since compression_type is set, but neither num_samples nor numRows (from schema file) "
96                     << "is provided, performance might be degraded.";
97   }
98 
99   // Build the index with our files such that each file corresponds to a key id.
100   RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
101 
102   jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
103 
104   // temporary: make size large enough to hold all files + EOE to avoid hangs
105   int32_t safe_queue_size = static_cast<int32_t>(std::ceil(dataset_files_list_.size() / num_workers_)) + 1;
106   io_block_queues_.Init(num_workers_, safe_queue_size);
107 
108   return Status::OK();
109 }
110 
RegisterAndLaunchThreads()111 Status TFReaderOp::RegisterAndLaunchThreads() {
112   RETURN_UNEXPECTED_IF_NULL(tree_);
113   worker_in_queues_.Init(num_workers_, worker_connector_size_);
114   worker_out_queues_.Init(num_workers_, worker_connector_size_);
115 
116   // Registers QueueList and individual Queues for interrupt services
117   RETURN_IF_NOT_OK(worker_in_queues_.Register(tree_->AllTasks()));
118   RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks()));
119   RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
120 
121   RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1),
122                                         &worker_tasks_, Name() + "::WorkerEntry", id()));
123   // if decode is true, launch some workers to parse the protobuf
124   if (decode_) {
125     RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
126                                           std::bind(&TFReaderOp::ParsingWorkerEntry, this, std::placeholders::_1),
127                                           Name() + "::ParsingWorkerEntry", id()));
128   }
129   RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::Collector, this), Name() + "::Collector", id()));
130 
131   return Status::OK();
132 }
133 
operator ()()134 Status TFReaderOp::operator()() {
135   RETURN_IF_NOT_OK(PrepareData());
136   while (!finished_reading_dataset_) {
137     int32_t workers_done = 0;
138     int64_t rows_read = 0;
139     {
140       std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
141       load_io_block_queue_ = true;
142     }
143     TensorRow fetched_row;
144     while (workers_done < num_workers_) {
145       RETURN_IF_NOT_OK(jagged_rows_connector_->Pop(0, &fetched_row));
146       if (fetched_row.eoe()) {
147         workers_done++;
148       } else if ((compression_type_ == CompressionType::NONE || compression_type_ == CompressionType::GZIP_WITH_COUNT ||
149                   compression_type_ == CompressionType::ZLIB_WITH_COUNT) &&
150                  (total_rows_ == 0 || rows_read < total_rows_)) {
151         if (decode_) {
152           // get record bytes from jagged_rows_connector and send them to workers for parsing
153           const auto parse_worker_id = NextWorkerID();
154           RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(std::move(fetched_row)));
155         } else {
156           // get record bytes from jagged_rows_connector and send them to out_connector
157           RETURN_IF_NOT_OK(out_connector_->Add(std::move(fetched_row)));
158         }
159         rows_read++;
160       } else if ((compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::ZLIB) &&
161                  (rows_read < total_rows_ * num_devices_)) {
162         // for compressed version, total_rows_ is total rows that will be read per shard
163         if (decode_) {
164           // get record bytes from jagged_rows_connector and send them to workers for parsing
165           const auto parse_worker_id = NextWorkerID();
166           RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(std::move(fetched_row)));
167         } else {
168           // get record bytes from jagged_rows_connector and send them to out_connector
169           RETURN_IF_NOT_OK(out_connector_->Add(std::move(fetched_row)));
170         }
171         rows_read++;
172       } else {
173         // IOBlockQueue thread needs to:
174         // -stop pushing stuff to IOBlockQueue
175         // -call PostEndOfEpoch (will send EOE)
176         // -wait for reset
177         //
178         // Worker threads need to:
179         // -stop reading the file they are currently reading and throw it away
180         // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
181         //
182         // Master thread needs to:
183         // -tell IOBlockQueue thread to stop pushing
184         // -tell worker threads to stop reading the file they are currently reading
185         // -keep pulling until EOE
186 
187         // don't think we need a lock for now
188         {
189           std::unique_lock<std::mutex> lock(load_jagged_connector_mutex_);
190           load_jagged_connector_ = false;
191         }
192         {
193           std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
194           load_io_block_queue_ = false;
195         }
196       }
197     }
198 
199     if (decode_) {
200       // finish reading this epoch, send an EOE flag to next parsing worker
201       const auto parse_worker_id = NextWorkerID();
202       RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(TensorRow(TensorRow::kFlagEOE)));
203     } else {
204       // finish reading this epoch, send an EOE flag to out_connector
205       RETURN_IF_NOT_OK(out_connector_->SendEOE());
206     }
207 
208     RETURN_IF_NOT_OK(ResetAndUpdateRepeat());
209   }
210 
211   if (decode_) {
212     // finish reading all the data, send an EOF flag to next parsing worker
213     auto parse_worker_id = NextWorkerID();
214     RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(TensorRow::kFlagEOF));
215     // tell all the parsing workers to quit
216     for (auto i = 0; i < num_workers_; ++i) {
217       RETURN_IF_NOT_OK(worker_in_queues_[i]->EmplaceBack(TensorRow::kFlagQuit));
218     }
219   } else {
220     // finish reading all the data, send an EOF flag to out_connector
221     RETURN_IF_NOT_OK(out_connector_->SendEOF());
222   }
223 
224   RETURN_IF_NOT_OK(PostEndOfData());
225 
226   return Status::OK();
227 }
228 
CalculateNumRowsPerShard()229 Status TFReaderOp::CalculateNumRowsPerShard() {
230   if (!equal_rows_per_shard_) {
231     return Status::OK();
232   }
233 
234   if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::ZLIB) {
235     num_rows_per_shard_ = total_rows_;
236   } else {
237     for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
238       std::vector<std::string> file(1, it.value());
239       int64_t num = CountTotalRowsSectioned(file, 0, 1, compression_type_);
240       filename_numrows_[it.value()] = num;
241       num_rows_ += num;
242     }
243     num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
244   }
245   if (num_rows_per_shard_ == 0) {
246     std::stringstream ss;
247     for (auto &i : dataset_files_list_) {
248       ss << " " << i;
249     }
250     std::string file_list = ss.str();
251     RETURN_STATUS_UNEXPECTED(
252       "Invalid data, TFRecordDataset API can't read the data file (interface mismatch or no data under the file). "
253       "Check file path." +
254       file_list);
255   }
256   return Status::OK();
257 }
258 
ParsingWorkerEntry(int32_t worker_id)259 Status TFReaderOp::ParsingWorkerEntry(int32_t worker_id) {
260   // must be called first if called by worker spawned by taskgroup
261   TaskManager::FindMe()->Post();
262 
263   TensorRow next_row;
264   RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&next_row));
265   while (!next_row.quit()) {
266     if (!next_row.empty()) {
267       TensorRow parsed_row;
268       RETURN_IF_NOT_OK(ParseExample(next_row, &parsed_row));
269       RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(parsed_row));
270     } else if (next_row.eoe() || next_row.eof()) {
271       RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(next_row));
272     } else {
273       RETURN_STATUS_UNEXPECTED("TFReaderOp: parsing worker got an unexpected empty tensor row.");
274     }
275     RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&next_row));
276   }
277   return Status::OK();
278 }
279 
ParseExample(const TensorRow & raw_bytes,TensorRow * parsed_row)280 Status TFReaderOp::ParseExample(const TensorRow &raw_bytes, TensorRow *parsed_row) {
281   auto filename = raw_bytes.getPath()[0];
282   auto itr = raw_bytes[0]->begin<std::string_view>();
283   dataengine::Example tf_record_example;
284   CHECK_FAIL_RETURN_UNEXPECTED(tf_record_example.ParseFromString(static_cast<std::string>(*itr)),
285                                "TFReaderOp: failed to parse example in tfrecord file: " + filename +
286                                  ". Perhaps the version of protobuf is not compatible. The example bytes is " +
287                                  static_cast<std::string>(*itr));
288 
289   auto num_columns = data_schema_->NumColumns();
290   TensorRow parsed_example(num_columns, nullptr);
291   std::vector<std::string> file_path(num_columns, filename);
292   parsed_example.setPath(file_path);
293   RETURN_IF_NOT_OK(LoadExample(&tf_record_example, &parsed_example));
294 
295   *parsed_row = std::move(parsed_example);
296   return Status::OK();
297 }
298 
299 // Reads a tf_record_file file and loads the data into multiple TensorRows.
LoadFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id)300 Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
301   auto realpath = FileUtils::GetRealPath(filename.c_str());
302   if (!realpath.has_value()) {
303     MS_LOG(ERROR) << "Invalid file path, " << filename << " does not exist.";
304     RETURN_STATUS_UNEXPECTED("Invalid file path, " + filename + " does not exist.");
305   }
306   std::string realpath_value = realpath.value();
307 
308   if (compression_type_ == CompressionType::NONE) {
309     RETURN_IF_NOT_OK(HelperLoadNonCompFile(filename, start_offset, end_offset, worker_id, realpath_value));
310   }
311 #if !defined(_WIN32) && !defined(_WIN64)
312   if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::GZIP_WITH_COUNT) {
313     RETURN_IF_NOT_OK(HelperLoadCompGZIPFile(filename, start_offset, end_offset, worker_id, realpath_value));
314   } else if (compression_type_ == CompressionType::ZLIB || compression_type_ == CompressionType::ZLIB_WITH_COUNT) {
315     RETURN_IF_NOT_OK(HelperLoadCompZLIBFile(filename, start_offset, end_offset, worker_id, realpath_value));
316   }
317 #endif
318 
319   return Status::OK();
320 }
321 
SendRecordBytesRow(const std::string & filename,const std::string & serialized_example,int32_t worker_id)322 Status TFReaderOp::SendRecordBytesRow(const std::string &filename, const std::string &serialized_example,
323                                       int32_t worker_id) {
324   std::vector<std::string> filenames(1, filename);
325   TensorRow record_bytes_row(1, nullptr);
326   record_bytes_row.setPath(filenames);
327   std::shared_ptr<Tensor> record_bytes_tensor;
328   RETURN_IF_NOT_OK(Tensor::CreateScalar(serialized_example, &record_bytes_tensor));
329   record_bytes_row[0] = std::move(record_bytes_tensor);
330   RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(record_bytes_row)));
331   return Status::OK();
332 }
333 
HelperLoadNonCompFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id,const std::string & realpath_value)334 Status TFReaderOp::HelperLoadNonCompFile(const std::string &filename, int64_t start_offset, int64_t end_offset,
335                                          int32_t worker_id, const std::string &realpath_value) {
336   std::ifstream reader;
337   reader.open(realpath_value, std::ios::in);
338   if (!reader) {
339     RETURN_STATUS_UNEXPECTED("Invalid file, " + filename + " open failed: permission denied!");
340   }
341 
342   int64_t rows_total = 0;
343 
344   while (reader.peek() != EOF) {
345     if (!GetLoadJaggedConnector()) {
346       break;
347     }
348     RETURN_IF_INTERRUPTED();
349 
350     // read length
351     std::streamsize record_length = 0;
352     (void)reader.read(reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
353 
354     // ignore crc header
355     (void)reader.ignore(kTFRecordHeadFootSize);
356 
357     // read serialized Example
358     std::string serialized_example;
359     serialized_example.resize(static_cast<size_t>(record_length));
360     (void)reader.read(&serialized_example[0], record_length);
361 
362     if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) {
363       auto s = SendRecordBytesRow(filename, serialized_example, worker_id);
364       if (s != Status::OK()) {
365         reader.close();
366         return s;
367       }
368     }
369 
370     // ignore crc footer
371     (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
372     rows_total++;
373   }
374   reader.close();
375   return Status::OK();
376 }
377 
378 #if !defined(_WIN32) && !defined(_WIN64)
HelperLoadCompGZIPFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id,const std::string & realpath_value)379 Status TFReaderOp::HelperLoadCompGZIPFile(const std::string &filename, int64_t start_offset, int64_t end_offset,
380                                           int32_t worker_id, const std::string &realpath_value) {
381   gzFile file = gzopen(realpath_value.c_str(), "rb");
382   if (file == nullptr) {
383     RETURN_STATUS_UNEXPECTED("Invalid file, " + filename + " open failed: permission denied!");
384   }
385 
386   int64_t rows_read = 0;
387   int64_t rows_total = 0;
388 
389   while (gzeof(file) != 1) {
390     if (compression_type_ == CompressionType::GZIP && rows_read >= end_offset) {
391       break;
392     }
393 
394     if (!GetLoadJaggedConnector()) {
395       break;
396     }
397     RETURN_IF_INTERRUPTED();
398 
399     // read length
400     int64_t record_length = 0;
401     (void)gzread(file, reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
402     if (record_length == 0) {
403       continue;
404     }
405 
406     if (rows_total == 0) {
407       // do the delayed checking; read crc from file
408       uint32_t masked_crc = 0;
409       (void)gzread(file, reinterpret_cast<char *>(&masked_crc), sizeof(uint32_t));
410 
411       // generate crc from data
412       uint32_t generated_crc =
413         system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
414 
415       // invalid tfrecord file
416       if (masked_crc != generated_crc) {
417         (void)gzclose(file);
418         RETURN_STATUS_UNEXPECTED("Invalid TFRecord file: " + filename);
419       }
420     } else {
421       // ignore crc header
422       (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
423     }
424 
425     // read serialized Example
426     std::string serialized_example;
427     serialized_example.resize(static_cast<size_t>(record_length));
428     (void)gzread(file, &serialized_example[0], static_cast<unsigned int>(record_length));
429 
430     if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) {
431       auto s = SendRecordBytesRow(filename, serialized_example, worker_id);
432       if (s != Status::OK()) {
433         (void)gzclose(file);
434         return s;
435       }
436       rows_read++;
437     }
438     // ignore crc footer
439     (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
440     rows_total++;
441   }
442 
443   (void)gzclose(file);
444   if (compression_type_ == CompressionType::GZIP && rows_read < end_offset) {
445     std::string errMsg = "This tfrecord file: " + filename +
446                          ", does not meet minimum rows per shard requirement: " + std::to_string(total_rows_) +
447                          " and " + std::to_string(static_cast<int>(total_rows_ / num_devices_)) +
448                          " number of rows per file, but got " + std::to_string(rows_read) +
449                          " number of rows in this file.";
450     RETURN_STATUS_UNEXPECTED(errMsg);
451   }
452 
453   return Status::OK();
454 }
455 
HelperLoadCompZLIBFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id,const std::string & realpath_value)456 Status TFReaderOp::HelperLoadCompZLIBFile(const std::string &filename, int64_t start_offset, int64_t end_offset,
457                                           int32_t worker_id, const std::string &realpath_value) {
458   // ZLIB stream setup (based on zlib.h tutorial)
459   ZLIBStreamInf zlib_stream;
460   std::ifstream reader(realpath_value, std::ios::in | std::ios::binary);
461   if (!reader) {
462     RETURN_STATUS_UNEXPECTED("Invalid file, " + filename + " open failed: permission denied!");
463   }
464 
465   zlib_stream.inflate_status = inflateInit(&zlib_stream.strm);
466   if (zlib_stream.inflate_status != Z_OK) {
467     reader.close();
468     RETURN_STATUS_UNEXPECTED("Failed to initialize inflate stream for ZLIB for file " + filename + "!");
469   }
470 
471   int64_t rows_read = 0;
472   int64_t rows_total = 0;
473 
474   // decompress until inflate stream ends or end of file
475   do {
476     if (compression_type_ == CompressionType::ZLIB && rows_read >= end_offset) {
477       break;
478     }
479 
480     if (!GetLoadJaggedConnector()) {
481       break;
482     }
483     RETURN_IF_INTERRUPTED();
484 
485     (void)reader.read(zlib_stream.input_stream, kZLIBChunkSize);
486     zlib_stream.strm.avail_in = static_cast<unsigned int>(reader.gcount());
487     if (zlib_stream.strm.avail_in == 0) {
488       break;
489     }
490     zlib_stream.strm.next_in = reinterpret_cast<unsigned char *>(zlib_stream.input_stream);
491 
492     // run inflate() on input buffer until current output buffer is not full yet but still need more from input buffer,
493     // or rows_read have exceeded the required number of rows to be read (end_offset)
494     do {
495       if (compression_type_ == CompressionType::ZLIB && rows_read >= end_offset) {
496         break;
497       }
498 
499       // inflate the stream
500       auto s = HelperInflateZLIB(&zlib_stream, filename);
501       if (s != Status::OK()) {
502         reader.close();
503         return s;
504       }
505       if (zlib_stream.left_to_read != 0) {
506         break;
507       }
508 
509       // Process inflated data depending on read flag
510       s = HelperProcessZLIBData(&zlib_stream, &rows_read, &rows_total, filename, start_offset, end_offset, worker_id);
511       if (s != Status::OK()) {
512         reader.close();
513         return s;
514       }
515       zlib_stream.read_flag = (zlib_stream.read_flag + 1) %
516                               (static_cast<int>(ZLIBReadFlag::Footer) + 1);  // resets flag to reading record length
517     } while (zlib_stream.strm.avail_out == 0);
518   } while (zlib_stream.inflate_status != Z_STREAM_END);
519 
520   (void)inflateEnd(&zlib_stream.strm);
521   if (zlib_stream.inflate_status != Z_STREAM_END && rows_read < end_offset) {
522     reader.close();
523     RETURN_STATUS_UNEXPECTED("Decompression of ZLIB file failed for file " + filename + "!");
524   }
525 
526   if (compression_type_ == CompressionType::ZLIB && rows_read < end_offset) {
527     reader.close();
528     std::string errMsg = "This tfrecord file: " + filename +
529                          ", does not meet minimum rows per shard requirement: " + std::to_string(total_rows_) +
530                          " and " + std::to_string(static_cast<int>(total_rows_ / num_devices_)) +
531                          " number of rows per file, but got " + std::to_string(rows_read) +
532                          " number of rows in this file.";
533     RETURN_STATUS_UNEXPECTED(errMsg);
534   }
535   reader.close();
536   return Status::OK();
537 }
538 
HelperBinDataToInt(const unsigned char * str_record_size,size_t str_size)539 int64_t TFReaderOp::HelperBinDataToInt(const unsigned char *str_record_size, size_t str_size) {
540   int n = 1;
541   int new_value_width = 2;
542   if (*reinterpret_cast<char *>(&n) == 1) {  // Little-endian system
543     std::stringstream ss;
544     ss << std::hex << std::setfill('0');
545     std::string hex_str = "0x";
546     for (int pos = static_cast<int>(str_size) - 1; pos >= 0; pos--) {
547       ss << std::setw(new_value_width) << static_cast<unsigned>(str_record_size[static_cast<size_t>(pos)]);
548     }
549     (void)hex_str.append(ss.str());
550     auto result = static_cast<int64_t>(std::stoul(hex_str, nullptr, 16));
551     return result;
552   } else {  // Big-endian system
553     std::stringstream ss;
554     ss << std::hex << std::setfill('0');
555     std::string hex_str = "0x";
556     for (size_t pos = 0; pos < str_size; pos++) {
557       ss << std::setw(new_value_width) << static_cast<unsigned>(str_record_size[pos]);
558     }
559     (void)hex_str.append(ss.str());
560     auto result = static_cast<int64_t>(std::stoul(hex_str, nullptr, 16));
561     return result;
562   }
563 }
564 
HelperInflateZLIB(ZLIBStreamInf * zlib_stream,const std::string & filename) const565 Status TFReaderOp::HelperInflateZLIB(ZLIBStreamInf *zlib_stream, const std::string &filename) const {
566   if (zlib_stream->left_to_read != 0) {
567     zlib_stream->strm.avail_out =
568       static_cast<unsigned int>(zlib_stream->left_to_read);  // need to read the rest before process
569   } else {
570     switch (zlib_stream->read_flag) {
571       case ZLIBReadFlag::RecordLength:  // record length
572         zlib_stream->strm.avail_out = kTFRecordRecLenSize;
573         zlib_stream->strm.next_out = zlib_stream->record_size;
574         break;
575       case ZLIBReadFlag::Header:  // record header/footer
576       case ZLIBReadFlag::Footer:
577         zlib_stream->strm.avail_out = kTFRecordHeadFootSize;
578         zlib_stream->strm.next_out = zlib_stream->garbage;
579         break;
580       default:  // record example
581         zlib_stream->strm.avail_out = static_cast<unsigned int>(zlib_stream->record_length);
582         zlib_stream->content = std::make_unique<unsigned char[]>(static_cast<size_t>(zlib_stream->record_length));
583         zlib_stream->strm.next_out = zlib_stream->content.get();
584     }
585   }
586 
587   // Inflate stream
588   zlib_stream->inflate_status = inflate(&zlib_stream->strm, Z_NO_FLUSH);
589   auto inflate_status = zlib_stream->inflate_status;
590   // inflate returns Z_BUF_ERROR if no progress is possible.
591   // It is not fatal and inflate can be called again to continue compressing.
592   if (inflate_status == Z_OK || inflate_status == Z_STREAM_END || inflate_status == Z_BUF_ERROR) {
593     zlib_stream->left_to_read = static_cast<unsigned int>(zlib_stream->strm.avail_out);  // after reading
594     return Status::OK();
595   } else if (inflate_status == Z_STREAM_ERROR) {
596     (void)inflateEnd(&zlib_stream->strm);
597     RETURN_STATUS_UNEXPECTED("State not clobbered when inflating file " + filename + "!");
598   } else if (inflate_status == Z_NEED_DICT || inflate_status == Z_DATA_ERROR) {
599     (void)inflateEnd(&zlib_stream->strm);
600     RETURN_STATUS_UNEXPECTED("Invalid or incomplete inflate data when inflating file " + filename + "!");
601   } else if (inflate_status == Z_MEM_ERROR) {
602     (void)inflateEnd(&zlib_stream->strm);
603     RETURN_STATUS_UNEXPECTED("Out of memory when inflating file " + filename + "!");
604   } else {
605     (void)inflateEnd(&zlib_stream->strm);
606     RETURN_STATUS_UNEXPECTED("Got error code " + std::to_string(inflate_status) + " when inflating file " + filename +
607                              "! Please refer to the zilb documentation for more details.");
608   }
609 }
610 
HelperProcessZLIBData(ZLIBStreamInf * zlib_stream,int64_t * rows_read,int64_t * rows_total,const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id)611 Status TFReaderOp::HelperProcessZLIBData(ZLIBStreamInf *zlib_stream, int64_t *rows_read, int64_t *rows_total,
612                                          const std::string &filename, int64_t start_offset, int64_t end_offset,
613                                          int32_t worker_id) {
614   if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::RecordLength)) {  // read record length
615     zlib_stream->record_length = HelperBinDataToInt(zlib_stream->record_size, kTFRecordRecLenSize);
616   } else if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::Header) &&
617              *rows_total == 0) {  // read header when needed (for tfrecord validation)
618     auto masked_crc = static_cast<uint32_t>(HelperBinDataToInt(zlib_stream->garbage, kTFRecordHeadFootSize));
619     uint32_t generated_crc =
620       system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&zlib_stream->record_length), kTFRecordRecLenSize);
621 
622     // invalid tfrecord file
623     if (masked_crc != generated_crc) {
624       RETURN_STATUS_UNEXPECTED("Invalid TFRecord file: " + filename);
625     }
626   } else if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::Content)) {  // read serialized example
627     std::string serialized_example(reinterpret_cast<char *>(zlib_stream->content.get()), zlib_stream->record_length);
628 
629     if (start_offset == kInvalidOffset || (*rows_total >= start_offset && *rows_total < end_offset)) {
630       RETURN_IF_NOT_OK(SendRecordBytesRow(filename, serialized_example, worker_id));
631       (*rows_read)++;
632     }
633   } else if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::Footer)) {
634     (*rows_total)++;
635   }
636 
637   return Status::OK();
638 }
639 #endif
640 
641 // Parses a single row and puts the data into a tensor table.
LoadExample(const dataengine::Example * tf_record_file,TensorRow * out_row)642 Status TFReaderOp::LoadExample(const dataengine::Example *tf_record_file, TensorRow *out_row) {
643   auto num_columns = static_cast<int32_t>(data_schema_->NumColumns());
644   for (int32_t col = 0; col < num_columns; ++col) {
645     const ColDescriptor current_col = data_schema_->Column(col);
646     const dataengine::Features &example_features = tf_record_file->features();
647     const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
648     auto iter_column = feature_map.find(current_col.Name());
649     if (iter_column == feature_map.end()) {
650       RETURN_STATUS_UNEXPECTED("Invalid columns_list, column name: " + current_col.Name() +
651                                " does not exist in tfrecord file, check tfrecord files.");
652     }
653     const dataengine::Feature &column_values_list = iter_column->second;
654     RETURN_IF_NOT_OK(LoadFeature(out_row, column_values_list, current_col, col));
655   }
656 
657   return Status::OK();
658 }
659 
660 // Parses a single cell and puts the data into a tensor table.
LoadFeature(TensorRow * tensor_row,const dataengine::Feature & column_values_list,const ColDescriptor & current_col,int32_t col)661 Status TFReaderOp::LoadFeature(TensorRow *tensor_row, const dataengine::Feature &column_values_list,
662                                const ColDescriptor &current_col, int32_t col) {
663   const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case();
664   std::unique_ptr<float[]> float_array;     // For staging data from protobuf deserialization
665   const unsigned char *data_ptr = nullptr;  // Generic pointer used for populating the Tensor
666 
667   // This variable will point into the above staging variables.
668   // Also used for creating shape attributes.
669   int32_t num_elements = 0;
670 
671   // we build a tensor first a read directly into it if we need to cast
672   std::shared_ptr<Tensor> ts;
673 
674   // Depending on the type of data from the tf_record_file, we want to extract 2 things:
675   // 1) A pointer to the data as a const unsigned char *
676   // 2) The number of elements of the data
677   // After those are determined, we can then build the tensor to represent this data.
678   switch (column_list_type) {
679     case dataengine::Feature::KindCase::kBytesList: {
680       RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts));
681 
682       break;
683     }
684     case dataengine::Feature::KindCase::kFloatList: {
685       RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array));
686 
687       data_ptr = reinterpret_cast<const unsigned char *>(float_array.get());
688 
689       // only floatList needs to create the tensor here, other two lists read directly
690       // into the tensor
691       TensorShape current_shape = TensorShape::CreateUnknownRankShape();
692       RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, &current_shape));
693       RETURN_IF_NOT_OK(Tensor::CreateFromMemory(current_shape, current_col.Type(), data_ptr, &ts));
694       break;
695     }
696     case dataengine::Feature::KindCase::kInt64List: {
697       RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts));
698       break;
699     }
700     case dataengine::Feature::KindCase::KIND_NOT_SET: {
701       std::string err_msg =
702         "Unrecognized datatype, column type in tfrecord file must be uint8, int64 or float32, check tfrecord file.";
703       RETURN_STATUS_UNEXPECTED(err_msg);
704     }
705     default: {
706       std::string err_msg =
707         "Unrecognized datatype, column type in tfrecord file must be uint8, int64 or float32, check tfrecord file.";
708       RETURN_STATUS_UNEXPECTED(err_msg);
709     }
710   }
711 
712   (*tensor_row)[col] = std::move(ts);
713 
714   return Status::OK();
715 }
716 
LoadBytesList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)717 Status TFReaderOp::LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
718                                  int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
719   // kBytesList can map to the following DE types ONLY!
720   // DE_UINT8, DE_INT8
721   // Must be single byte type for each element!
722   if (current_col.Type() != DataType::DE_UINT8 && current_col.Type() != DataType::DE_INT8 &&
723       current_col.Type() != DataType::DE_STRING) {
724     std::string err_msg = "Invalid column type, the column type of " + current_col.Name() +
725                           " should be int8, uint8 or string, but got " + current_col.Type().ToString();
726     RETURN_STATUS_UNEXPECTED(err_msg);
727   }
728 
729   const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
730 
731   *num_elements = bytes_list.value_size();
732 
733   if (current_col.Type() == DataType::DE_STRING) {
734     TensorShape shape = TensorShape::CreateScalar();
735     RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape));
736     RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, shape, tensor));
737     return Status::OK();
738   }
739 
740   uint64_t max_size = 0;
741   for (uint32_t i = 0; i < bytes_list.value_size(); ++i) {
742 #if defined(__APPLE__)
743     max_size = fmax(max_size, bytes_list.value(i).size());
744 #else
745     max_size = std::max(max_size, bytes_list.value(i).size());
746 #endif
747   }
748 
749   int64_t pad_size = max_size;
750 
751   // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn
752   if (current_col.HasShape()) {
753     TensorShape cur_shape = current_col.Shape();
754     if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) {
755       int64_t new_pad_size = 1;
756       for (int i = 1; i < cur_shape.Size(); ++i) {
757         if (cur_shape[i] == TensorShape::kDimUnknown) {
758           std::string err_msg =
759             "Invalid data dimension, only one dimension shape supported is -1, but the 0th and the" +
760             std::to_string(i) + "th dimension shape of " + current_col.Name() + " are both -1.";
761           RETURN_STATUS_UNEXPECTED(err_msg);
762         }
763         new_pad_size *= cur_shape[i];
764       }
765       pad_size = new_pad_size;
766     } else {
767       if (cur_shape.known() && cur_shape.NumOfElements() != max_size) {
768         std::string err_msg = "Data dimensions of '" + current_col.Name() +
769                               "' do not match, the expected total elements of shape " + cur_shape.ToString() +
770                               " should be " + std::to_string(max_size) + ", but got " +
771                               std::to_string(cur_shape.NumOfElements());
772         RETURN_STATUS_UNEXPECTED(err_msg);
773       }
774     }
775   }
776 
777   // know how many elements there are and the total bytes, create tensor here:
778   TensorShape current_shape = TensorShape::CreateScalar();
779   RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, &current_shape));
780   RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, current_shape, current_col.Type(), pad_size, tensor));
781 
782   return Status::OK();
783 }
784 
LoadFloatList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::unique_ptr<float[]> * float_array)785 Status TFReaderOp::LoadFloatList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
786                                  int32_t *num_elements, std::unique_ptr<float[]> *float_array) {
787   // KFloatList can only map to DE types:
788   // DE_FLOAT32
789   if (current_col.Type() != DataType::DE_FLOAT32) {
790     std::string err_msg = "Invalid column type, the column type of " + current_col.Name() +
791                           " should be string, but got " + current_col.Type().ToString();
792     RETURN_STATUS_UNEXPECTED(err_msg);
793   }
794 
795   const dataengine::FloatList &float_list = column_values_list.float_list();
796 
797   // Identify how many values we have and then create a local array of these
798   // to deserialize into
799   *num_elements = float_list.value_size();
800   *float_array = std::make_unique<float[]>(*num_elements);
801   for (int i = 0; i < float_list.value_size(); ++i) {
802     (*float_array)[i] = float_list.value(i);
803   }
804 
805   return Status::OK();
806 }
807 
808 // Determines which template type to use and calls LoadIntList
LoadIntListSwitch(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)809 Status TFReaderOp::LoadIntListSwitch(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
810                                      int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
811   if (current_col.Type() == DataType::DE_UINT64) {
812     RETURN_IF_NOT_OK(LoadIntList<uint64_t>(current_col, column_values_list, num_elements, tensor));
813   } else if (current_col.Type() == DataType::DE_INT64) {
814     RETURN_IF_NOT_OK(LoadIntList<int64_t>(current_col, column_values_list, num_elements, tensor));
815   } else if (current_col.Type() == DataType::DE_UINT32) {
816     RETURN_IF_NOT_OK(LoadIntList<uint32_t>(current_col, column_values_list, num_elements, tensor));
817   } else if (current_col.Type() == DataType::DE_INT32) {
818     RETURN_IF_NOT_OK(LoadIntList<int32_t>(current_col, column_values_list, num_elements, tensor));
819   } else if (current_col.Type() == DataType::DE_UINT16) {
820     RETURN_IF_NOT_OK(LoadIntList<uint16_t>(current_col, column_values_list, num_elements, tensor));
821   } else if (current_col.Type() == DataType::DE_INT16) {
822     RETURN_IF_NOT_OK(LoadIntList<int16_t>(current_col, column_values_list, num_elements, tensor));
823   } else if (current_col.Type() == DataType::DE_UINT8) {
824     RETURN_IF_NOT_OK(LoadIntList<uint8_t>(current_col, column_values_list, num_elements, tensor));
825   } else if (current_col.Type() == DataType::DE_INT8) {
826     RETURN_IF_NOT_OK(LoadIntList<int8_t>(current_col, column_values_list, num_elements, tensor));
827   } else {
828     std::string err_msg = "Invalid column type, the column type of " + current_col.Name() +
829                           " should be uint64, int64, uint32, int32, uint16, int16, uint8 or int8, but got " +
830                           current_col.Type().ToString();
831     RETURN_STATUS_UNEXPECTED(err_msg);
832   }
833 
834   return Status::OK();
835 }
836 
837 // Reads values from a bytes list and casts the value to type T, must be an integral type
838 // compatible with int64_t
839 template <typename T>
LoadIntList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)840 Status TFReaderOp::LoadIntList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
841                                int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
842   if (!(current_col.Type().IsInt())) {
843     std::string err_msg = "Invalid column type, the column type of " + current_col.Name() + " should be int, but got " +
844                           current_col.Type().ToString();
845     RETURN_STATUS_UNEXPECTED(err_msg);
846   }
847 
848   const dataengine::Int64List &int64_list = column_values_list.int64_list();
849 
850   // Identify how many values we have and then create a local array of these
851   // to deserialize into
852   *num_elements = int64_list.value_size();
853 
854   // know how many elements there are, create tensor here:
855   TensorShape current_shape = TensorShape::CreateUnknownRankShape();
856   RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &current_shape));
857   RETURN_IF_NOT_OK(Tensor::CreateEmpty(current_shape, current_col.Type(), tensor));
858 
859   int64_t i = 0;
860   auto it = (*tensor)->begin<T>();
861   for (; it != (*tensor)->end<T>(); i++, ++it) {
862     T element = static_cast<T>(int64_list.value(i));
863     *it = element;
864   }
865 
866   return Status::OK();
867 }
868 
CreateSchema(const std::string & tf_record_file,std::vector<std::string> columns_to_load)869 Status TFReaderOp::CreateSchema(const std::string &tf_record_file, std::vector<std::string> columns_to_load) {
870   auto realpath = FileUtils::GetRealPath(tf_record_file.c_str());
871   if (!realpath.has_value()) {
872     MS_LOG(ERROR) << "Invalid file path, " << tf_record_file << " does not exist.";
873     RETURN_STATUS_UNEXPECTED("Invalid file path, " + tf_record_file + " does not exist.");
874   }
875 
876   std::string serialized_example;
877   RETURN_IF_NOT_OK(HelperGetExampleSchema(&serialized_example, realpath.value(), tf_record_file));
878 
879   dataengine::Example example;
880   if (!example.ParseFromString(serialized_example)) {
881     RETURN_STATUS_UNEXPECTED("Failed to parse tfrecord file: " + realpath.value() +
882                              ", fields that failed to parse: " + serialized_example);
883   }
884 
885   const dataengine::Features &example_features = example.features();
886   const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
887 
888   if (columns_to_load.empty()) {
889     (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load),
890                          [](const auto &it) -> std::string { return it.first; });
891     std::sort(columns_to_load.begin(), columns_to_load.end());
892   }
893 
894   for (const auto &curr_col_name : columns_to_load) {
895     auto it = feature_map.find(curr_col_name);
896     if (it == feature_map.end()) {
897       RETURN_STATUS_UNEXPECTED("Invalid columns_list, tfrecord file failed to find column name: " + curr_col_name);
898     }
899     std::string column_name = it->first;
900 
901     std::string column_type;
902 
903     const dataengine::Feature &feature = it->second;
904     const dataengine::Feature::KindCase kind_case = feature.kind_case();
905     switch (kind_case) {
906       case dataengine::Feature::KindCase::kBytesList:
907         column_type = "uint8";
908         break;
909 
910       case dataengine::Feature::KindCase::kFloatList:
911         column_type = "float32";
912         break;
913 
914       case dataengine::Feature::KindCase::kInt64List:
915         column_type = "int64";
916         break;
917 
918       case dataengine::Feature::KindCase::KIND_NOT_SET:
919         RETURN_STATUS_UNEXPECTED("Unrecognized column type, the column type of " + column_name +
920                                  " should be uint8, int64 or float32, but got unrecognized column type.");
921 
922       default:
923         RETURN_STATUS_UNEXPECTED("Unsupported column type, the column type of " + column_name +
924                                  " should be uint8, int64 or float32, but got unsupported column type.");
925     }
926 
927     RETURN_IF_NOT_OK(
928       data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1)));
929   }
930 
931   return Status::OK();
932 }
933 
HelperGetExampleSchema(std::string * const serialized_example,const std::string & realpath_value,const std::string & filename) const934 Status TFReaderOp::HelperGetExampleSchema(std::string *const serialized_example, const std::string &realpath_value,
935                                           const std::string &filename) const {
936   if (compression_type_ == CompressionType::NONE) {
937     std::ifstream reader;
938     reader.open(realpath_value, std::ios::in);
939 
940     // read length
941     int64_t record_length = 0;
942     (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(kTFRecordRecLenSize));
943 
944     // ignore crc header
945     (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
946 
947     // read serialized Example
948     (*serialized_example).resize(static_cast<size_t>(record_length));
949     (void)reader.read(&(*serialized_example)[0], static_cast<std::streamsize>(record_length));
950     reader.close();
951   }
952 #if !defined(_WIN32) && !defined(_WIN64)
953   if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::GZIP_WITH_COUNT) {
954     gzFile file = gzopen(realpath_value.c_str(), "rb");
955 
956     // read length
957     int64_t record_length = 0;
958     (void)gzread(file, reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
959 
960     // ignore crc header
961     (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
962 
963     // read serialized Example
964     (*serialized_example).resize(static_cast<size_t>(record_length));
965     (void)gzread(file, &(*serialized_example)[0], static_cast<unsigned int>(record_length));
966     (void)gzclose(file);
967   } else if (compression_type_ == CompressionType::ZLIB || compression_type_ == CompressionType::ZLIB_WITH_COUNT) {
968     // ZLIB stream setup (based on zlib.h tutorial)
969     ZLIBStreamInf zlib_stream;
970 
971     std::ifstream reader(realpath_value.c_str(), std::ios::in | std::ios::binary);
972     zlib_stream.inflate_status = inflateInit(&zlib_stream.strm);
973     if (zlib_stream.inflate_status != Z_OK) {
974       reader.close();
975       RETURN_STATUS_UNEXPECTED("Failed to initialize inflate stream for ZLIB for file " + filename + "!");
976     }
977 
978     // decompress until first row is read
979     do {
980       (void)reader.read(zlib_stream.input_stream, kZLIBChunkSize);
981       zlib_stream.strm.avail_in = static_cast<unsigned int>(reader.gcount());
982       zlib_stream.strm.next_in = reinterpret_cast<unsigned char *>(zlib_stream.input_stream);
983 
984       // run inflate() on input until output buffer not full
985       do {
986         auto s = HelperInflateZLIB(&zlib_stream, filename);
987         if (s != Status::OK()) {
988           reader.close();
989           return s;
990         }
991         if (zlib_stream.left_to_read != 0) {
992           break;
993         }
994 
995         // Process inflated data depending on read flag
996         if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::RecordLength)) {  // read record length
997           zlib_stream.record_length = HelperBinDataToInt(zlib_stream.record_size, kTFRecordRecLenSize);
998         } else if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::Content)) {  // read serialized example
999           (*serialized_example).resize(static_cast<size_t>(zlib_stream.record_length));
1000           (void)(*serialized_example)
1001             .assign(reinterpret_cast<char *>(zlib_stream.content.get()),
1002                     static_cast<size_t>(zlib_stream.record_length));
1003         }
1004         zlib_stream.read_flag++;
1005       } while (zlib_stream.strm.avail_out == 0 && zlib_stream.read_flag != static_cast<int>(ZLIBReadFlag::Footer));
1006     } while (zlib_stream.inflate_status != Z_STREAM_END &&
1007              zlib_stream.read_flag != static_cast<int>(ZLIBReadFlag::Footer));
1008 
1009     (void)inflateEnd(&zlib_stream.strm);
1010     if (zlib_stream.inflate_status != Z_STREAM_END && zlib_stream.read_flag < static_cast<int>(ZLIBReadFlag::Footer)) {
1011       reader.close();
1012       RETURN_STATUS_UNEXPECTED("Decompression of ZLIB file failed for file " + filename + "!");
1013     }
1014 
1015     reader.close();
1016   }
1017 #endif
1018 
1019   return Status::OK();
1020 }
1021 
CountTotalRows(int64_t * out_total_rows,const std::vector<std::string> & filenames,int64_t threads,bool estimate,CompressionType compression_type)1022 Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads,
1023                                   bool estimate, CompressionType compression_type) {
1024   RETURN_UNEXPECTED_IF_NULL(out_total_rows);
1025   try {
1026     if (threads > filenames.size()) {
1027       threads = filenames.size();
1028     }
1029 
1030     std::vector<std::future<int64_t>> async_results;
1031 
1032     if (threads <= 0) {
1033       RETURN_STATUS_UNEXPECTED(
1034         "Invalid threads number, the threads number of TFReader should be greater than zero, but got " +
1035         std::to_string(threads) + ".");
1036     }
1037     int64_t chunk_size = filenames.size() / threads;
1038     int64_t remainder = filenames.size() % threads;
1039 
1040     int64_t begin = 0;
1041     int64_t end = begin;
1042     for (int i = 0; i < threads; i++) {
1043       end += chunk_size;
1044       if (remainder > 0) {
1045         end++;
1046         remainder--;
1047       }
1048 
1049       if (estimate) {
1050         // Parse a single file for each chunk with estimate mode on
1051         async_results.push_back(
1052           std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1, compression_type));
1053       } else {
1054         // Parse the whole chunk with estimate mode off
1055         async_results.push_back(
1056           std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end, compression_type));
1057       }
1058 
1059       begin = end;
1060     }
1061 
1062     int64_t total_rows = 0;
1063     for (auto &async_result : async_results) {
1064       total_rows += async_result.get();
1065     }
1066 
1067     if (estimate) {
1068       // Each thread only scans 1 file
1069       // Estimated total rows = Average rows * total number of files
1070       total_rows = total_rows / threads * filenames.size();
1071     }
1072 
1073     *out_total_rows = total_rows;
1074   } catch (const std::exception &e) {
1075     std::string err_msg = "Unexpected error occurred: ";
1076     err_msg += std::string(e.what());
1077     RETURN_STATUS_UNEXPECTED(err_msg);
1078   }
1079 
1080   return Status::OK();
1081 }
1082 
CountTotalRowsSectioned(const std::vector<std::string> & filenames,int64_t begin,int64_t end,CompressionType compression_type)1083 int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &filenames, int64_t begin, int64_t end,
1084                                             CompressionType compression_type) {
1085   int64_t rows_read = 0;
1086   for (size_t i = begin; i < end; i++) {
1087     auto realpath = FileUtils::GetRealPath(filenames[i].c_str());
1088     if (!realpath.has_value()) {
1089       MS_LOG(ERROR) << "Invalid file path, " << filenames[i] << " does not exist.";
1090       continue;
1091     }
1092 
1093     if (compression_type == CompressionType::NONE) {
1094       HelperCountNonCompRows(realpath.value(), filenames[i], &rows_read);
1095     }
1096 #if !defined(_WIN32) && !defined(_WIN64)
1097     if (compression_type == CompressionType::GZIP_WITH_COUNT) {
1098       HelperCountGZIPRows(realpath.value(), filenames[i], &rows_read);
1099     } else if (compression_type == CompressionType::ZLIB_WITH_COUNT) {
1100       HelperCountZLIBRows(realpath.value(), filenames[i], &rows_read);
1101     }
1102 #endif
1103   }
1104 
1105   return rows_read;
1106 }
1107 
HelperCountNonCompRows(const std::string & realpath_value,const std::string & filename,int64_t * rows_read)1108 void TFReaderOp::HelperCountNonCompRows(const std::string &realpath_value, const std::string &filename,
1109                                         int64_t *rows_read) {
1110   std::ifstream reader;
1111   reader.open(realpath_value, std::ios::in);
1112   if (!reader) {
1113     MS_LOG(DEBUG) << "TFReader operator failed to open file " << filename << ".";
1114   }
1115 
1116   while (reader.peek() != EOF) {
1117     // read length
1118     int64_t record_length = 0;
1119     (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(kTFRecordRecLenSize));
1120 
1121     // ignore crc header
1122     (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
1123 
1124     // ignore TFRecord file contents
1125     (void)reader.ignore(static_cast<std::streamsize>(record_length));
1126 
1127     // ignore crc footer
1128     (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
1129     (*rows_read)++;
1130   }
1131   reader.close();
1132 }
1133 
1134 #if !defined(_WIN32) && !defined(_WIN64)
HelperCountGZIPRows(const std::string & realpath_value,const std::string & filename,int64_t * rows_read)1135 void TFReaderOp::HelperCountGZIPRows(const std::string &realpath_value, const std::string &filename,
1136                                      int64_t *rows_read) {
1137   gzFile file = gzopen(realpath_value.c_str(), "rb");
1138 
1139   if (file == nullptr) {
1140     MS_LOG(DEBUG) << "TFReader operator failed to open file " << filename << " with GZIP.";
1141   }
1142 
1143   while (gzeof(file) != 1) {
1144     // read length
1145     int64_t record_length = 0;
1146     (void)gzread(file, reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
1147     if (record_length == 0) {
1148       continue;
1149     }
1150 
1151     // ignore crc header
1152     (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
1153 
1154     // ignore TFRecord file contents
1155     (void)gzseek(file, record_length, SEEK_CUR);
1156 
1157     // ignore crc footer
1158     (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
1159     (*rows_read)++;
1160   }
1161   (void)gzclose(file);
1162 }
1163 
HelperCountZLIBRows(const std::string & realpath_value,const std::string & filename,int64_t * rows_read)1164 void TFReaderOp::HelperCountZLIBRows(const std::string &realpath_value, const std::string &filename,
1165                                      int64_t *rows_read) {
1166   // ZLIB stream setup (based on zlib.h tutorial)
1167   ZLIBStreamInf zlib_stream;
1168 
1169   std::ifstream reader(realpath_value.c_str(), std::ios::in | std::ios::binary);
1170 
1171   if (!reader) {
1172     MS_LOG(DEBUG) << "TFReader operator failed to open file " << filename << " with ZLIB.";
1173   }
1174 
1175   zlib_stream.inflate_status = inflateInit(&zlib_stream.strm);
1176   if (zlib_stream.inflate_status != Z_OK) {
1177     reader.close();
1178     MS_LOG(DEBUG) << "Failed to initialize inflate stream for ZLIB when counting rows for file " << filename << "!";
1179   }
1180 
1181   // decompress until first row is read
1182   do {
1183     (void)reader.read(zlib_stream.input_stream, kZLIBChunkSize);
1184     zlib_stream.strm.avail_in = static_cast<unsigned int>(reader.gcount());
1185     zlib_stream.strm.next_in = reinterpret_cast<unsigned char *>(zlib_stream.input_stream);
1186 
1187     // run inflate() on input until output buffer not full
1188     do {
1189       if (zlib_stream.left_to_read != 0) {
1190         zlib_stream.strm.avail_out = zlib_stream.left_to_read;  // need to read the rest before process
1191       } else {
1192         switch (zlib_stream.read_flag) {
1193           case ZLIBReadFlag::RecordLength:  // record length
1194             zlib_stream.strm.avail_out = kTFRecordRecLenSize;
1195             zlib_stream.strm.next_out = zlib_stream.record_size;
1196             break;
1197           default:  // record header, example, and footer since we just want to count rows
1198             zlib_stream.strm.avail_out = zlib_stream.record_length + kTFRecordHeadFootSize + kTFRecordHeadFootSize;
1199             zlib_stream.content = std::make_unique<unsigned char[]>(zlib_stream.record_length + kTFRecordHeadFootSize +
1200                                                                     kTFRecordHeadFootSize);
1201             zlib_stream.strm.next_out = zlib_stream.content.get();
1202         }
1203       }
1204 
1205       // Inflate stream
1206       zlib_stream.inflate_status = inflate(&zlib_stream.strm, Z_NO_FLUSH);
1207       if (zlib_stream.inflate_status == Z_OK || zlib_stream.inflate_status == Z_STREAM_END) {
1208         zlib_stream.left_to_read = zlib_stream.strm.avail_out;  // after reading
1209       } else {
1210         MS_LOG(DEBUG) << "An error is found during inflation when counting rows for file: " << filename << "!";
1211       }
1212 
1213       if (zlib_stream.left_to_read != 0) {
1214         break;
1215       }
1216 
1217       // Process inflated data depending on read flag
1218       if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::RecordLength)) {  // read record length
1219         zlib_stream.record_length = HelperBinDataToInt(zlib_stream.record_size, kTFRecordRecLenSize);
1220       } else if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::Footer)) {
1221         (*rows_read)++;
1222       }
1223       zlib_stream.read_flag = zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::Footer)
1224                                 ? static_cast<int>(ZLIBReadFlag::RecordLength)
1225                                 : static_cast<int>(ZLIBReadFlag::Footer);  // resets flag to reading record length
1226     } while (zlib_stream.strm.avail_out == 0 && zlib_stream.inflate_status == Z_OK);
1227   } while (zlib_stream.inflate_status != Z_STREAM_END && zlib_stream.inflate_status == Z_OK);
1228 
1229   (void)inflateEnd(&zlib_stream.strm);
1230   if (zlib_stream.inflate_status != Z_STREAM_END) {
1231     MS_LOG(DEBUG) << "Decompression of ZLIB file failed when counting rows for file " << filename << "!";
1232   }
1233 
1234   reader.close();
1235 }
1236 
1237 #endif
1238 
ComputeColMap()1239 Status TFReaderOp::ComputeColMap() {
1240   // Construct the column name map for this operator (base class field)
1241   if (column_name_id_map_.empty()) {
1242     if (decode_) {
1243       for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
1244         column_name_id_map_[data_schema_->Column(i).Name()] = i;
1245       }
1246     } else {
1247       // if decode is false, the output will only have one column containing the record bytes
1248       column_name_id_map_["proto"] = 0;
1249     }
1250   } else {
1251     MS_LOG(WARNING) << "Column name map is already set!";
1252   }
1253   return Status::OK();
1254 }
1255 
FillIOBlockQueue(const std::vector<int64_t> & i_keys)1256 Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
1257   int32_t queue_index = 0;
1258   int32_t key_index = 0;
1259   int64_t pre_count = 0;
1260   int64_t start_offset = 0;
1261   int64_t end_offset = 0;
1262   bool end_of_epoch = false;
1263   if (shuffle_files_) {
1264     do {
1265       // Iterate over all the keys and add one key to each block.
1266       for (auto i_key : i_keys) {
1267         {
1268           if (!GetLoadIoBlockQueue()) {
1269             end_of_epoch = true;
1270             break;
1271           }
1272         }
1273         RETURN_IF_NOT_OK(HelperIOBlockFiller(&queue_index, &key_index, &pre_count, &start_offset, &end_offset, i_key,
1274                                              (*filename_index_)[i_key]));
1275       }
1276     } while ((compression_type_ == CompressionType::NONE || compression_type_ == CompressionType::GZIP_WITH_COUNT ||
1277               compression_type_ == CompressionType::ZLIB_WITH_COUNT) &&
1278              equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
1279              !end_of_epoch);
1280   } else {
1281     do {
1282       // Iterate over all the keys and add one key to each block.
1283       for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
1284         {
1285           if (!GetLoadIoBlockQueue()) {
1286             end_of_epoch = true;
1287             break;
1288           }
1289         }
1290         RETURN_IF_NOT_OK(
1291           HelperIOBlockFiller(&queue_index, &key_index, &pre_count, &start_offset, &end_offset, it.key(), it.value()));
1292       }
1293     } while ((compression_type_ == CompressionType::NONE || compression_type_ == CompressionType::GZIP_WITH_COUNT ||
1294               compression_type_ == CompressionType::ZLIB_WITH_COUNT) &&
1295              equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
1296              !end_of_epoch);
1297   }
1298   RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
1299   return Status::OK();
1300 }
1301 
HelperIOBlockFiller(int32_t * queue_index,int32_t * key_index,int64_t * pre_count,int64_t * start_offset,int64_t * end_offset,int64_t key,const std::string & file_name)1302 Status TFReaderOp::HelperIOBlockFiller(int32_t *queue_index, int32_t *key_index, int64_t *pre_count,
1303                                        int64_t *start_offset, int64_t *end_offset, int64_t key,
1304                                        const std::string &file_name) {
1305   if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::ZLIB) {
1306     int num_files_to_read =
1307       static_cast<int>(dataset_files_list_.size() - dataset_files_list_.size() % static_cast<size_t>(num_devices_));
1308     if (*key_index % num_devices_ == device_id_ && *key_index < num_files_to_read) {
1309       *end_offset = static_cast<int>(total_rows_ /
1310                                      static_cast<int>(dataset_files_list_.size() / static_cast<size_t>(num_devices_)));
1311       auto ioBlock = std::make_unique<FilenameBlock>(key, 0, *end_offset, IOBlock::kFlagNone);
1312       RETURN_IF_NOT_OK(PushIoBlockQueue(*queue_index, std::move(ioBlock)));
1313       *queue_index = (*queue_index + 1) % num_workers_;
1314     }
1315     (*key_index)++;
1316   } else if (!equal_rows_per_shard_) {
1317     if ((*key_index)++ % num_devices_ == device_id_) {
1318       auto ioBlock = std::make_unique<FilenameBlock>(key, kInvalidOffset, kInvalidOffset, IOBlock::kFlagNone);
1319       RETURN_IF_NOT_OK(PushIoBlockQueue(*queue_index, std::move(ioBlock)));
1320       *queue_index = (*queue_index + 1) % num_workers_;
1321     }
1322   } else {
1323     if (NeedPushFileToBlockQueue(file_name, start_offset, end_offset, *pre_count)) {
1324       auto ioBlock = std::make_unique<FilenameBlock>(key, *start_offset, *end_offset, IOBlock::kFlagNone);
1325       RETURN_IF_NOT_OK(PushIoBlockQueue(*queue_index, std::move(ioBlock)));
1326       *queue_index = (*queue_index + 1) % num_workers_;
1327     }
1328 
1329     *pre_count += filename_numrows_[file_name];
1330   }
1331   return Status::OK();
1332 }
1333 
GetNextRowPullMode(TensorRow * const row)1334 Status TFReaderOp::GetNextRowPullMode(TensorRow *const row) {
1335   RETURN_UNEXPECTED_IF_NULL(row);
1336   RETURN_IF_NOT_OK(NonMappableLeafOp::GetNextRowPullMode(row));
1337   if (decode_) {
1338     if (!row->empty()) {
1339       // data got from jagged_rows_connector is raw bytes so we need to parse it before return
1340       TensorRow res;
1341       RETURN_IF_NOT_OK(ParseExample(*row, &res));
1342       *row = std::move(res);
1343     }
1344   }
1345   return Status::OK();
1346 }
1347 }  // namespace dataset
1348 }  // namespace mindspore
1349