1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
17
18 #include <algorithm>
19 #include <fstream>
20 #include <future>
21 #include <memory>
22 #include <mutex>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "proto/example.pb.h"
28
29 #include "minddata/dataset/engine/data_schema.h"
30 #include "minddata/dataset/engine/datasetops/source/io_block.h"
31 #include "minddata/dataset/engine/execution_tree.h"
32 #include "minddata/dataset/engine/jagged_connector.h"
33 #include "minddata/dataset/util/status.h"
34 #include "minddata/dataset/util/task_manager.h"
35 #include "minddata/dataset/util/wait_post.h"
36 #include "utils/file_utils.h"
37 #include "utils/system/crc32c.h"
38
39 namespace mindspore {
40 namespace dataset {
TFReaderOp(int32_t num_workers,int32_t worker_connector_size,int64_t total_num_rows,std::vector<std::string> dataset_files_list,std::unique_ptr<DataSchema> data_schema,int32_t op_connector_size,std::vector<std::string> columns_to_load,bool shuffle_files,int32_t num_devices,int32_t device_id,bool equal_rows_per_shard,const CompressionType & compression_type,bool decode)41 TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
42 std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
43 int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
44 int32_t num_devices, int32_t device_id, bool equal_rows_per_shard,
45 const CompressionType &compression_type, bool decode)
46 : NonMappableLeafOp(num_workers, worker_connector_size, total_num_rows, op_connector_size, shuffle_files,
47 num_devices, device_id, compression_type),
48 dataset_files_list_(std::move(dataset_files_list)),
49 columns_to_load_(std::move(columns_to_load)),
50 data_schema_(std::move(data_schema)),
51 equal_rows_per_shard_(equal_rows_per_shard),
52 decode_(decode) {}
53
54 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const55 void TFReaderOp::Print(std::ostream &out, bool show_all) const {
56 if (!show_all) {
57 // Call the super class for displaying any common 1-liner info
58 ParallelOp::Print(out, show_all);
59 // Then show any custom derived-internal 1-liner info for this op
60 out << "\n";
61 } else {
62 // Call the super class for displaying any common detailed info
63 ParallelOp::Print(out, show_all);
64 // Then show any custom derived-internal stuff
65 out << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
66 << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
67 << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n";
68 for (const auto &i : dataset_files_list_) {
69 out << " " << i;
70 }
71 if (!columns_to_load_.empty()) {
72 out << "\nColumns to load:\n";
73 for (const auto &j : columns_to_load_) {
74 out << " " << j;
75 }
76 }
77 out << "\nData Schema:\n";
78 out << *data_schema_ << "\n\n";
79 }
80 }
81
Init()82 Status TFReaderOp::Init() {
83 if (data_schema_->Empty()) {
84 RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_));
85 }
86
87 if (total_rows_ == 0) {
88 total_rows_ = data_schema_->NumRows();
89 }
90 if (total_rows_ < 0) {
91 RETURN_STATUS_UNEXPECTED(
92 "[Internal ERROR] num_samples or num_rows for TFRecordDataset must be greater than 0, but got: " +
93 std::to_string(total_rows_));
94 } else if (compression_type_ != CompressionType::NONE && total_rows_ == 0) {
95 MS_LOG(WARNING) << "Since compression_type is set, but neither num_samples nor numRows (from schema file) "
96 << "is provided, performance might be degraded.";
97 }
98
99 // Build the index with our files such that each file corresponds to a key id.
100 RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
101
102 jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
103
104 // temporary: make size large enough to hold all files + EOE to avoid hangs
105 int32_t safe_queue_size = static_cast<int32_t>(std::ceil(dataset_files_list_.size() / num_workers_)) + 1;
106 io_block_queues_.Init(num_workers_, safe_queue_size);
107
108 return Status::OK();
109 }
110
RegisterAndLaunchThreads()111 Status TFReaderOp::RegisterAndLaunchThreads() {
112 RETURN_UNEXPECTED_IF_NULL(tree_);
113 worker_in_queues_.Init(num_workers_, worker_connector_size_);
114 worker_out_queues_.Init(num_workers_, worker_connector_size_);
115
116 // Registers QueueList and individual Queues for interrupt services
117 RETURN_IF_NOT_OK(worker_in_queues_.Register(tree_->AllTasks()));
118 RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks()));
119 RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
120
121 RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1),
122 &worker_tasks_, Name() + "::WorkerEntry", id()));
123 // if decode is true, launch some workers to parse the protobuf
124 if (decode_) {
125 RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
126 std::bind(&TFReaderOp::ParsingWorkerEntry, this, std::placeholders::_1),
127 Name() + "::ParsingWorkerEntry", id()));
128 }
129 RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::Collector, this), Name() + "::Collector", id()));
130
131 return Status::OK();
132 }
133
operator ()()134 Status TFReaderOp::operator()() {
135 RETURN_IF_NOT_OK(PrepareData());
136 while (!finished_reading_dataset_) {
137 int32_t workers_done = 0;
138 int64_t rows_read = 0;
139 {
140 std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
141 load_io_block_queue_ = true;
142 }
143 TensorRow fetched_row;
144 while (workers_done < num_workers_) {
145 RETURN_IF_NOT_OK(jagged_rows_connector_->Pop(0, &fetched_row));
146 if (fetched_row.eoe()) {
147 workers_done++;
148 } else if ((compression_type_ == CompressionType::NONE || compression_type_ == CompressionType::GZIP_WITH_COUNT ||
149 compression_type_ == CompressionType::ZLIB_WITH_COUNT) &&
150 (total_rows_ == 0 || rows_read < total_rows_)) {
151 if (decode_) {
152 // get record bytes from jagged_rows_connector and send them to workers for parsing
153 const auto parse_worker_id = NextWorkerID();
154 RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(std::move(fetched_row)));
155 } else {
156 // get record bytes from jagged_rows_connector and send them to out_connector
157 RETURN_IF_NOT_OK(out_connector_->Add(std::move(fetched_row)));
158 }
159 rows_read++;
160 } else if ((compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::ZLIB) &&
161 (rows_read < total_rows_ * num_devices_)) {
162 // for compressed version, total_rows_ is total rows that will be read per shard
163 if (decode_) {
164 // get record bytes from jagged_rows_connector and send them to workers for parsing
165 const auto parse_worker_id = NextWorkerID();
166 RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(std::move(fetched_row)));
167 } else {
168 // get record bytes from jagged_rows_connector and send them to out_connector
169 RETURN_IF_NOT_OK(out_connector_->Add(std::move(fetched_row)));
170 }
171 rows_read++;
172 } else {
173 // IOBlockQueue thread needs to:
174 // -stop pushing stuff to IOBlockQueue
175 // -call PostEndOfEpoch (will send EOE)
176 // -wait for reset
177 //
178 // Worker threads need to:
179 // -stop reading the file they are currently reading and throw it away
180 // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
181 //
182 // Master thread needs to:
183 // -tell IOBlockQueue thread to stop pushing
184 // -tell worker threads to stop reading the file they are currently reading
185 // -keep pulling until EOE
186
187 // don't think we need a lock for now
188 {
189 std::unique_lock<std::mutex> lock(load_jagged_connector_mutex_);
190 load_jagged_connector_ = false;
191 }
192 {
193 std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
194 load_io_block_queue_ = false;
195 }
196 }
197 }
198
199 if (decode_) {
200 // finish reading this epoch, send an EOE flag to next parsing worker
201 const auto parse_worker_id = NextWorkerID();
202 RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(TensorRow(TensorRow::kFlagEOE)));
203 } else {
204 // finish reading this epoch, send an EOE flag to out_connector
205 RETURN_IF_NOT_OK(out_connector_->SendEOE());
206 }
207
208 RETURN_IF_NOT_OK(ResetAndUpdateRepeat());
209 }
210
211 if (decode_) {
212 // finish reading all the data, send an EOF flag to next parsing worker
213 auto parse_worker_id = NextWorkerID();
214 RETURN_IF_NOT_OK(worker_in_queues_[parse_worker_id]->EmplaceBack(TensorRow::kFlagEOF));
215 // tell all the parsing workers to quit
216 for (auto i = 0; i < num_workers_; ++i) {
217 RETURN_IF_NOT_OK(worker_in_queues_[i]->EmplaceBack(TensorRow::kFlagQuit));
218 }
219 } else {
220 // finish reading all the data, send an EOF flag to out_connector
221 RETURN_IF_NOT_OK(out_connector_->SendEOF());
222 }
223
224 RETURN_IF_NOT_OK(PostEndOfData());
225
226 return Status::OK();
227 }
228
CalculateNumRowsPerShard()229 Status TFReaderOp::CalculateNumRowsPerShard() {
230 if (!equal_rows_per_shard_) {
231 return Status::OK();
232 }
233
234 if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::ZLIB) {
235 num_rows_per_shard_ = total_rows_;
236 } else {
237 for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
238 std::vector<std::string> file(1, it.value());
239 int64_t num = CountTotalRowsSectioned(file, 0, 1, compression_type_);
240 filename_numrows_[it.value()] = num;
241 num_rows_ += num;
242 }
243 num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
244 }
245 if (num_rows_per_shard_ == 0) {
246 std::stringstream ss;
247 for (auto &i : dataset_files_list_) {
248 ss << " " << i;
249 }
250 std::string file_list = ss.str();
251 RETURN_STATUS_UNEXPECTED(
252 "Invalid data, TFRecordDataset API can't read the data file (interface mismatch or no data under the file). "
253 "Check file path." +
254 file_list);
255 }
256 return Status::OK();
257 }
258
ParsingWorkerEntry(int32_t worker_id)259 Status TFReaderOp::ParsingWorkerEntry(int32_t worker_id) {
260 // must be called first if called by worker spawned by taskgroup
261 TaskManager::FindMe()->Post();
262
263 TensorRow next_row;
264 RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&next_row));
265 while (!next_row.quit()) {
266 if (!next_row.empty()) {
267 TensorRow parsed_row;
268 RETURN_IF_NOT_OK(ParseExample(next_row, &parsed_row));
269 RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(parsed_row));
270 } else if (next_row.eoe() || next_row.eof()) {
271 RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(next_row));
272 } else {
273 RETURN_STATUS_UNEXPECTED("TFReaderOp: parsing worker got an unexpected empty tensor row.");
274 }
275 RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&next_row));
276 }
277 return Status::OK();
278 }
279
ParseExample(const TensorRow & raw_bytes,TensorRow * parsed_row)280 Status TFReaderOp::ParseExample(const TensorRow &raw_bytes, TensorRow *parsed_row) {
281 auto filename = raw_bytes.getPath()[0];
282 auto itr = raw_bytes[0]->begin<std::string_view>();
283 dataengine::Example tf_record_example;
284 CHECK_FAIL_RETURN_UNEXPECTED(tf_record_example.ParseFromString(static_cast<std::string>(*itr)),
285 "TFReaderOp: failed to parse example in tfrecord file: " + filename +
286 ". Perhaps the version of protobuf is not compatible. The example bytes is " +
287 static_cast<std::string>(*itr));
288
289 auto num_columns = data_schema_->NumColumns();
290 TensorRow parsed_example(num_columns, nullptr);
291 std::vector<std::string> file_path(num_columns, filename);
292 parsed_example.setPath(file_path);
293 RETURN_IF_NOT_OK(LoadExample(&tf_record_example, &parsed_example));
294
295 *parsed_row = std::move(parsed_example);
296 return Status::OK();
297 }
298
299 // Reads a tf_record_file file and loads the data into multiple TensorRows.
LoadFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id)300 Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
301 auto realpath = FileUtils::GetRealPath(filename.c_str());
302 if (!realpath.has_value()) {
303 MS_LOG(ERROR) << "Invalid file path, " << filename << " does not exist.";
304 RETURN_STATUS_UNEXPECTED("Invalid file path, " + filename + " does not exist.");
305 }
306 std::string realpath_value = realpath.value();
307
308 if (compression_type_ == CompressionType::NONE) {
309 RETURN_IF_NOT_OK(HelperLoadNonCompFile(filename, start_offset, end_offset, worker_id, realpath_value));
310 }
311 #if !defined(_WIN32) && !defined(_WIN64)
312 if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::GZIP_WITH_COUNT) {
313 RETURN_IF_NOT_OK(HelperLoadCompGZIPFile(filename, start_offset, end_offset, worker_id, realpath_value));
314 } else if (compression_type_ == CompressionType::ZLIB || compression_type_ == CompressionType::ZLIB_WITH_COUNT) {
315 RETURN_IF_NOT_OK(HelperLoadCompZLIBFile(filename, start_offset, end_offset, worker_id, realpath_value));
316 }
317 #endif
318
319 return Status::OK();
320 }
321
SendRecordBytesRow(const std::string & filename,const std::string & serialized_example,int32_t worker_id)322 Status TFReaderOp::SendRecordBytesRow(const std::string &filename, const std::string &serialized_example,
323 int32_t worker_id) {
324 std::vector<std::string> filenames(1, filename);
325 TensorRow record_bytes_row(1, nullptr);
326 record_bytes_row.setPath(filenames);
327 std::shared_ptr<Tensor> record_bytes_tensor;
328 RETURN_IF_NOT_OK(Tensor::CreateScalar(serialized_example, &record_bytes_tensor));
329 record_bytes_row[0] = std::move(record_bytes_tensor);
330 RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(record_bytes_row)));
331 return Status::OK();
332 }
333
HelperLoadNonCompFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id,const std::string & realpath_value)334 Status TFReaderOp::HelperLoadNonCompFile(const std::string &filename, int64_t start_offset, int64_t end_offset,
335 int32_t worker_id, const std::string &realpath_value) {
336 std::ifstream reader;
337 reader.open(realpath_value, std::ios::in);
338 if (!reader) {
339 RETURN_STATUS_UNEXPECTED("Invalid file, " + filename + " open failed: permission denied!");
340 }
341
342 int64_t rows_total = 0;
343
344 while (reader.peek() != EOF) {
345 if (!GetLoadJaggedConnector()) {
346 break;
347 }
348 RETURN_IF_INTERRUPTED();
349
350 // read length
351 std::streamsize record_length = 0;
352 (void)reader.read(reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
353
354 // ignore crc header
355 (void)reader.ignore(kTFRecordHeadFootSize);
356
357 // read serialized Example
358 std::string serialized_example;
359 serialized_example.resize(static_cast<size_t>(record_length));
360 (void)reader.read(&serialized_example[0], record_length);
361
362 if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) {
363 auto s = SendRecordBytesRow(filename, serialized_example, worker_id);
364 if (s != Status::OK()) {
365 reader.close();
366 return s;
367 }
368 }
369
370 // ignore crc footer
371 (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
372 rows_total++;
373 }
374 reader.close();
375 return Status::OK();
376 }
377
378 #if !defined(_WIN32) && !defined(_WIN64)
HelperLoadCompGZIPFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id,const std::string & realpath_value)379 Status TFReaderOp::HelperLoadCompGZIPFile(const std::string &filename, int64_t start_offset, int64_t end_offset,
380 int32_t worker_id, const std::string &realpath_value) {
381 gzFile file = gzopen(realpath_value.c_str(), "rb");
382 if (file == nullptr) {
383 RETURN_STATUS_UNEXPECTED("Invalid file, " + filename + " open failed: permission denied!");
384 }
385
386 int64_t rows_read = 0;
387 int64_t rows_total = 0;
388
389 while (gzeof(file) != 1) {
390 if (compression_type_ == CompressionType::GZIP && rows_read >= end_offset) {
391 break;
392 }
393
394 if (!GetLoadJaggedConnector()) {
395 break;
396 }
397 RETURN_IF_INTERRUPTED();
398
399 // read length
400 int64_t record_length = 0;
401 (void)gzread(file, reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
402 if (record_length == 0) {
403 continue;
404 }
405
406 if (rows_total == 0) {
407 // do the delayed checking; read crc from file
408 uint32_t masked_crc = 0;
409 (void)gzread(file, reinterpret_cast<char *>(&masked_crc), sizeof(uint32_t));
410
411 // generate crc from data
412 uint32_t generated_crc =
413 system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
414
415 // invalid tfrecord file
416 if (masked_crc != generated_crc) {
417 (void)gzclose(file);
418 RETURN_STATUS_UNEXPECTED("Invalid TFRecord file: " + filename);
419 }
420 } else {
421 // ignore crc header
422 (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
423 }
424
425 // read serialized Example
426 std::string serialized_example;
427 serialized_example.resize(static_cast<size_t>(record_length));
428 (void)gzread(file, &serialized_example[0], static_cast<unsigned int>(record_length));
429
430 if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) {
431 auto s = SendRecordBytesRow(filename, serialized_example, worker_id);
432 if (s != Status::OK()) {
433 (void)gzclose(file);
434 return s;
435 }
436 rows_read++;
437 }
438 // ignore crc footer
439 (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
440 rows_total++;
441 }
442
443 (void)gzclose(file);
444 if (compression_type_ == CompressionType::GZIP && rows_read < end_offset) {
445 std::string errMsg = "This tfrecord file: " + filename +
446 ", does not meet minimum rows per shard requirement: " + std::to_string(total_rows_) +
447 " and " + std::to_string(static_cast<int>(total_rows_ / num_devices_)) +
448 " number of rows per file, but got " + std::to_string(rows_read) +
449 " number of rows in this file.";
450 RETURN_STATUS_UNEXPECTED(errMsg);
451 }
452
453 return Status::OK();
454 }
455
HelperLoadCompZLIBFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id,const std::string & realpath_value)456 Status TFReaderOp::HelperLoadCompZLIBFile(const std::string &filename, int64_t start_offset, int64_t end_offset,
457 int32_t worker_id, const std::string &realpath_value) {
458 // ZLIB stream setup (based on zlib.h tutorial)
459 ZLIBStreamInf zlib_stream;
460 std::ifstream reader(realpath_value, std::ios::in | std::ios::binary);
461 if (!reader) {
462 RETURN_STATUS_UNEXPECTED("Invalid file, " + filename + " open failed: permission denied!");
463 }
464
465 zlib_stream.inflate_status = inflateInit(&zlib_stream.strm);
466 if (zlib_stream.inflate_status != Z_OK) {
467 reader.close();
468 RETURN_STATUS_UNEXPECTED("Failed to initialize inflate stream for ZLIB for file " + filename + "!");
469 }
470
471 int64_t rows_read = 0;
472 int64_t rows_total = 0;
473
474 // decompress until inflate stream ends or end of file
475 do {
476 if (compression_type_ == CompressionType::ZLIB && rows_read >= end_offset) {
477 break;
478 }
479
480 if (!GetLoadJaggedConnector()) {
481 break;
482 }
483 RETURN_IF_INTERRUPTED();
484
485 (void)reader.read(zlib_stream.input_stream, kZLIBChunkSize);
486 zlib_stream.strm.avail_in = static_cast<unsigned int>(reader.gcount());
487 if (zlib_stream.strm.avail_in == 0) {
488 break;
489 }
490 zlib_stream.strm.next_in = reinterpret_cast<unsigned char *>(zlib_stream.input_stream);
491
492 // run inflate() on input buffer until current output buffer is not full yet but still need more from input buffer,
493 // or rows_read have exceeded the required number of rows to be read (end_offset)
494 do {
495 if (compression_type_ == CompressionType::ZLIB && rows_read >= end_offset) {
496 break;
497 }
498
499 // inflate the stream
500 auto s = HelperInflateZLIB(&zlib_stream, filename);
501 if (s != Status::OK()) {
502 reader.close();
503 return s;
504 }
505 if (zlib_stream.left_to_read != 0) {
506 break;
507 }
508
509 // Process inflated data depending on read flag
510 s = HelperProcessZLIBData(&zlib_stream, &rows_read, &rows_total, filename, start_offset, end_offset, worker_id);
511 if (s != Status::OK()) {
512 reader.close();
513 return s;
514 }
515 zlib_stream.read_flag = (zlib_stream.read_flag + 1) %
516 (static_cast<int>(ZLIBReadFlag::Footer) + 1); // resets flag to reading record length
517 } while (zlib_stream.strm.avail_out == 0);
518 } while (zlib_stream.inflate_status != Z_STREAM_END);
519
520 (void)inflateEnd(&zlib_stream.strm);
521 if (zlib_stream.inflate_status != Z_STREAM_END && rows_read < end_offset) {
522 reader.close();
523 RETURN_STATUS_UNEXPECTED("Decompression of ZLIB file failed for file " + filename + "!");
524 }
525
526 if (compression_type_ == CompressionType::ZLIB && rows_read < end_offset) {
527 reader.close();
528 std::string errMsg = "This tfrecord file: " + filename +
529 ", does not meet minimum rows per shard requirement: " + std::to_string(total_rows_) +
530 " and " + std::to_string(static_cast<int>(total_rows_ / num_devices_)) +
531 " number of rows per file, but got " + std::to_string(rows_read) +
532 " number of rows in this file.";
533 RETURN_STATUS_UNEXPECTED(errMsg);
534 }
535 reader.close();
536 return Status::OK();
537 }
538
HelperBinDataToInt(const unsigned char * str_record_size,size_t str_size)539 int64_t TFReaderOp::HelperBinDataToInt(const unsigned char *str_record_size, size_t str_size) {
540 int n = 1;
541 int new_value_width = 2;
542 if (*reinterpret_cast<char *>(&n) == 1) { // Little-endian system
543 std::stringstream ss;
544 ss << std::hex << std::setfill('0');
545 std::string hex_str = "0x";
546 for (int pos = static_cast<int>(str_size) - 1; pos >= 0; pos--) {
547 ss << std::setw(new_value_width) << static_cast<unsigned>(str_record_size[static_cast<size_t>(pos)]);
548 }
549 (void)hex_str.append(ss.str());
550 auto result = static_cast<int64_t>(std::stoul(hex_str, nullptr, 16));
551 return result;
552 } else { // Big-endian system
553 std::stringstream ss;
554 ss << std::hex << std::setfill('0');
555 std::string hex_str = "0x";
556 for (size_t pos = 0; pos < str_size; pos++) {
557 ss << std::setw(new_value_width) << static_cast<unsigned>(str_record_size[pos]);
558 }
559 (void)hex_str.append(ss.str());
560 auto result = static_cast<int64_t>(std::stoul(hex_str, nullptr, 16));
561 return result;
562 }
563 }
564
HelperInflateZLIB(ZLIBStreamInf * zlib_stream,const std::string & filename) const565 Status TFReaderOp::HelperInflateZLIB(ZLIBStreamInf *zlib_stream, const std::string &filename) const {
566 if (zlib_stream->left_to_read != 0) {
567 zlib_stream->strm.avail_out =
568 static_cast<unsigned int>(zlib_stream->left_to_read); // need to read the rest before process
569 } else {
570 switch (zlib_stream->read_flag) {
571 case ZLIBReadFlag::RecordLength: // record length
572 zlib_stream->strm.avail_out = kTFRecordRecLenSize;
573 zlib_stream->strm.next_out = zlib_stream->record_size;
574 break;
575 case ZLIBReadFlag::Header: // record header/footer
576 case ZLIBReadFlag::Footer:
577 zlib_stream->strm.avail_out = kTFRecordHeadFootSize;
578 zlib_stream->strm.next_out = zlib_stream->garbage;
579 break;
580 default: // record example
581 zlib_stream->strm.avail_out = static_cast<unsigned int>(zlib_stream->record_length);
582 zlib_stream->content = std::make_unique<unsigned char[]>(static_cast<size_t>(zlib_stream->record_length));
583 zlib_stream->strm.next_out = zlib_stream->content.get();
584 }
585 }
586
587 // Inflate stream
588 zlib_stream->inflate_status = inflate(&zlib_stream->strm, Z_NO_FLUSH);
589 auto inflate_status = zlib_stream->inflate_status;
590 // inflate returns Z_BUF_ERROR if no progress is possible.
591 // It is not fatal and inflate can be called again to continue compressing.
592 if (inflate_status == Z_OK || inflate_status == Z_STREAM_END || inflate_status == Z_BUF_ERROR) {
593 zlib_stream->left_to_read = static_cast<unsigned int>(zlib_stream->strm.avail_out); // after reading
594 return Status::OK();
595 } else if (inflate_status == Z_STREAM_ERROR) {
596 (void)inflateEnd(&zlib_stream->strm);
597 RETURN_STATUS_UNEXPECTED("State not clobbered when inflating file " + filename + "!");
598 } else if (inflate_status == Z_NEED_DICT || inflate_status == Z_DATA_ERROR) {
599 (void)inflateEnd(&zlib_stream->strm);
600 RETURN_STATUS_UNEXPECTED("Invalid or incomplete inflate data when inflating file " + filename + "!");
601 } else if (inflate_status == Z_MEM_ERROR) {
602 (void)inflateEnd(&zlib_stream->strm);
603 RETURN_STATUS_UNEXPECTED("Out of memory when inflating file " + filename + "!");
604 } else {
605 (void)inflateEnd(&zlib_stream->strm);
606 RETURN_STATUS_UNEXPECTED("Got error code " + std::to_string(inflate_status) + " when inflating file " + filename +
607 "! Please refer to the zilb documentation for more details.");
608 }
609 }
610
HelperProcessZLIBData(ZLIBStreamInf * zlib_stream,int64_t * rows_read,int64_t * rows_total,const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id)611 Status TFReaderOp::HelperProcessZLIBData(ZLIBStreamInf *zlib_stream, int64_t *rows_read, int64_t *rows_total,
612 const std::string &filename, int64_t start_offset, int64_t end_offset,
613 int32_t worker_id) {
614 if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::RecordLength)) { // read record length
615 zlib_stream->record_length = HelperBinDataToInt(zlib_stream->record_size, kTFRecordRecLenSize);
616 } else if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::Header) &&
617 *rows_total == 0) { // read header when needed (for tfrecord validation)
618 auto masked_crc = static_cast<uint32_t>(HelperBinDataToInt(zlib_stream->garbage, kTFRecordHeadFootSize));
619 uint32_t generated_crc =
620 system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&zlib_stream->record_length), kTFRecordRecLenSize);
621
622 // invalid tfrecord file
623 if (masked_crc != generated_crc) {
624 RETURN_STATUS_UNEXPECTED("Invalid TFRecord file: " + filename);
625 }
626 } else if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::Content)) { // read serialized example
627 std::string serialized_example(reinterpret_cast<char *>(zlib_stream->content.get()), zlib_stream->record_length);
628
629 if (start_offset == kInvalidOffset || (*rows_total >= start_offset && *rows_total < end_offset)) {
630 RETURN_IF_NOT_OK(SendRecordBytesRow(filename, serialized_example, worker_id));
631 (*rows_read)++;
632 }
633 } else if (zlib_stream->read_flag == static_cast<int>(ZLIBReadFlag::Footer)) {
634 (*rows_total)++;
635 }
636
637 return Status::OK();
638 }
639 #endif
640
641 // Parses a single row and puts the data into a tensor table.
LoadExample(const dataengine::Example * tf_record_file,TensorRow * out_row)642 Status TFReaderOp::LoadExample(const dataengine::Example *tf_record_file, TensorRow *out_row) {
643 auto num_columns = static_cast<int32_t>(data_schema_->NumColumns());
644 for (int32_t col = 0; col < num_columns; ++col) {
645 const ColDescriptor current_col = data_schema_->Column(col);
646 const dataengine::Features &example_features = tf_record_file->features();
647 const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
648 auto iter_column = feature_map.find(current_col.Name());
649 if (iter_column == feature_map.end()) {
650 RETURN_STATUS_UNEXPECTED("Invalid columns_list, column name: " + current_col.Name() +
651 " does not exist in tfrecord file, check tfrecord files.");
652 }
653 const dataengine::Feature &column_values_list = iter_column->second;
654 RETURN_IF_NOT_OK(LoadFeature(out_row, column_values_list, current_col, col));
655 }
656
657 return Status::OK();
658 }
659
660 // Parses a single cell and puts the data into a tensor table.
LoadFeature(TensorRow * tensor_row,const dataengine::Feature & column_values_list,const ColDescriptor & current_col,int32_t col)661 Status TFReaderOp::LoadFeature(TensorRow *tensor_row, const dataengine::Feature &column_values_list,
662 const ColDescriptor ¤t_col, int32_t col) {
663 const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case();
664 std::unique_ptr<float[]> float_array; // For staging data from protobuf deserialization
665 const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor
666
667 // This variable will point into the above staging variables.
668 // Also used for creating shape attributes.
669 int32_t num_elements = 0;
670
671 // we build a tensor first a read directly into it if we need to cast
672 std::shared_ptr<Tensor> ts;
673
674 // Depending on the type of data from the tf_record_file, we want to extract 2 things:
675 // 1) A pointer to the data as a const unsigned char *
676 // 2) The number of elements of the data
677 // After those are determined, we can then build the tensor to represent this data.
678 switch (column_list_type) {
679 case dataengine::Feature::KindCase::kBytesList: {
680 RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts));
681
682 break;
683 }
684 case dataengine::Feature::KindCase::kFloatList: {
685 RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array));
686
687 data_ptr = reinterpret_cast<const unsigned char *>(float_array.get());
688
689 // only floatList needs to create the tensor here, other two lists read directly
690 // into the tensor
691 TensorShape current_shape = TensorShape::CreateUnknownRankShape();
692 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape));
693 RETURN_IF_NOT_OK(Tensor::CreateFromMemory(current_shape, current_col.Type(), data_ptr, &ts));
694 break;
695 }
696 case dataengine::Feature::KindCase::kInt64List: {
697 RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts));
698 break;
699 }
700 case dataengine::Feature::KindCase::KIND_NOT_SET: {
701 std::string err_msg =
702 "Unrecognized datatype, column type in tfrecord file must be uint8, int64 or float32, check tfrecord file.";
703 RETURN_STATUS_UNEXPECTED(err_msg);
704 }
705 default: {
706 std::string err_msg =
707 "Unrecognized datatype, column type in tfrecord file must be uint8, int64 or float32, check tfrecord file.";
708 RETURN_STATUS_UNEXPECTED(err_msg);
709 }
710 }
711
712 (*tensor_row)[col] = std::move(ts);
713
714 return Status::OK();
715 }
716
LoadBytesList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)717 Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
718 int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
719 // kBytesList can map to the following DE types ONLY!
720 // DE_UINT8, DE_INT8
721 // Must be single byte type for each element!
722 if (current_col.Type() != DataType::DE_UINT8 && current_col.Type() != DataType::DE_INT8 &&
723 current_col.Type() != DataType::DE_STRING) {
724 std::string err_msg = "Invalid column type, the column type of " + current_col.Name() +
725 " should be int8, uint8 or string, but got " + current_col.Type().ToString();
726 RETURN_STATUS_UNEXPECTED(err_msg);
727 }
728
729 const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
730
731 *num_elements = bytes_list.value_size();
732
733 if (current_col.Type() == DataType::DE_STRING) {
734 TensorShape shape = TensorShape::CreateScalar();
735 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape));
736 RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, shape, tensor));
737 return Status::OK();
738 }
739
740 uint64_t max_size = 0;
741 for (uint32_t i = 0; i < bytes_list.value_size(); ++i) {
742 #if defined(__APPLE__)
743 max_size = fmax(max_size, bytes_list.value(i).size());
744 #else
745 max_size = std::max(max_size, bytes_list.value(i).size());
746 #endif
747 }
748
749 int64_t pad_size = max_size;
750
751 // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn
752 if (current_col.HasShape()) {
753 TensorShape cur_shape = current_col.Shape();
754 if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) {
755 int64_t new_pad_size = 1;
756 for (int i = 1; i < cur_shape.Size(); ++i) {
757 if (cur_shape[i] == TensorShape::kDimUnknown) {
758 std::string err_msg =
759 "Invalid data dimension, only one dimension shape supported is -1, but the 0th and the" +
760 std::to_string(i) + "th dimension shape of " + current_col.Name() + " are both -1.";
761 RETURN_STATUS_UNEXPECTED(err_msg);
762 }
763 new_pad_size *= cur_shape[i];
764 }
765 pad_size = new_pad_size;
766 } else {
767 if (cur_shape.known() && cur_shape.NumOfElements() != max_size) {
768 std::string err_msg = "Data dimensions of '" + current_col.Name() +
769 "' do not match, the expected total elements of shape " + cur_shape.ToString() +
770 " should be " + std::to_string(max_size) + ", but got " +
771 std::to_string(cur_shape.NumOfElements());
772 RETURN_STATUS_UNEXPECTED(err_msg);
773 }
774 }
775 }
776
777 // know how many elements there are and the total bytes, create tensor here:
778 TensorShape current_shape = TensorShape::CreateScalar();
779 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape));
780 RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, current_shape, current_col.Type(), pad_size, tensor));
781
782 return Status::OK();
783 }
784
LoadFloatList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::unique_ptr<float[]> * float_array)785 Status TFReaderOp::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
786 int32_t *num_elements, std::unique_ptr<float[]> *float_array) {
787 // KFloatList can only map to DE types:
788 // DE_FLOAT32
789 if (current_col.Type() != DataType::DE_FLOAT32) {
790 std::string err_msg = "Invalid column type, the column type of " + current_col.Name() +
791 " should be string, but got " + current_col.Type().ToString();
792 RETURN_STATUS_UNEXPECTED(err_msg);
793 }
794
795 const dataengine::FloatList &float_list = column_values_list.float_list();
796
797 // Identify how many values we have and then create a local array of these
798 // to deserialize into
799 *num_elements = float_list.value_size();
800 *float_array = std::make_unique<float[]>(*num_elements);
801 for (int i = 0; i < float_list.value_size(); ++i) {
802 (*float_array)[i] = float_list.value(i);
803 }
804
805 return Status::OK();
806 }
807
808 // Determines which template type to use and calls LoadIntList
LoadIntListSwitch(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)809 Status TFReaderOp::LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
810 int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
811 if (current_col.Type() == DataType::DE_UINT64) {
812 RETURN_IF_NOT_OK(LoadIntList<uint64_t>(current_col, column_values_list, num_elements, tensor));
813 } else if (current_col.Type() == DataType::DE_INT64) {
814 RETURN_IF_NOT_OK(LoadIntList<int64_t>(current_col, column_values_list, num_elements, tensor));
815 } else if (current_col.Type() == DataType::DE_UINT32) {
816 RETURN_IF_NOT_OK(LoadIntList<uint32_t>(current_col, column_values_list, num_elements, tensor));
817 } else if (current_col.Type() == DataType::DE_INT32) {
818 RETURN_IF_NOT_OK(LoadIntList<int32_t>(current_col, column_values_list, num_elements, tensor));
819 } else if (current_col.Type() == DataType::DE_UINT16) {
820 RETURN_IF_NOT_OK(LoadIntList<uint16_t>(current_col, column_values_list, num_elements, tensor));
821 } else if (current_col.Type() == DataType::DE_INT16) {
822 RETURN_IF_NOT_OK(LoadIntList<int16_t>(current_col, column_values_list, num_elements, tensor));
823 } else if (current_col.Type() == DataType::DE_UINT8) {
824 RETURN_IF_NOT_OK(LoadIntList<uint8_t>(current_col, column_values_list, num_elements, tensor));
825 } else if (current_col.Type() == DataType::DE_INT8) {
826 RETURN_IF_NOT_OK(LoadIntList<int8_t>(current_col, column_values_list, num_elements, tensor));
827 } else {
828 std::string err_msg = "Invalid column type, the column type of " + current_col.Name() +
829 " should be uint64, int64, uint32, int32, uint16, int16, uint8 or int8, but got " +
830 current_col.Type().ToString();
831 RETURN_STATUS_UNEXPECTED(err_msg);
832 }
833
834 return Status::OK();
835 }
836
837 // Reads values from a bytes list and casts the value to type T, must be an integral type
838 // compatible with int64_t
839 template <typename T>
LoadIntList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)840 Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
841 int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
842 if (!(current_col.Type().IsInt())) {
843 std::string err_msg = "Invalid column type, the column type of " + current_col.Name() + " should be int, but got " +
844 current_col.Type().ToString();
845 RETURN_STATUS_UNEXPECTED(err_msg);
846 }
847
848 const dataengine::Int64List &int64_list = column_values_list.int64_list();
849
850 // Identify how many values we have and then create a local array of these
851 // to deserialize into
852 *num_elements = int64_list.value_size();
853
854 // know how many elements there are, create tensor here:
855 TensorShape current_shape = TensorShape::CreateUnknownRankShape();
856 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape));
857 RETURN_IF_NOT_OK(Tensor::CreateEmpty(current_shape, current_col.Type(), tensor));
858
859 int64_t i = 0;
860 auto it = (*tensor)->begin<T>();
861 for (; it != (*tensor)->end<T>(); i++, ++it) {
862 T element = static_cast<T>(int64_list.value(i));
863 *it = element;
864 }
865
866 return Status::OK();
867 }
868
CreateSchema(const std::string & tf_record_file,std::vector<std::string> columns_to_load)869 Status TFReaderOp::CreateSchema(const std::string &tf_record_file, std::vector<std::string> columns_to_load) {
870 auto realpath = FileUtils::GetRealPath(tf_record_file.c_str());
871 if (!realpath.has_value()) {
872 MS_LOG(ERROR) << "Invalid file path, " << tf_record_file << " does not exist.";
873 RETURN_STATUS_UNEXPECTED("Invalid file path, " + tf_record_file + " does not exist.");
874 }
875
876 std::string serialized_example;
877 RETURN_IF_NOT_OK(HelperGetExampleSchema(&serialized_example, realpath.value(), tf_record_file));
878
879 dataengine::Example example;
880 if (!example.ParseFromString(serialized_example)) {
881 RETURN_STATUS_UNEXPECTED("Failed to parse tfrecord file: " + realpath.value() +
882 ", fields that failed to parse: " + serialized_example);
883 }
884
885 const dataengine::Features &example_features = example.features();
886 const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
887
888 if (columns_to_load.empty()) {
889 (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load),
890 [](const auto &it) -> std::string { return it.first; });
891 std::sort(columns_to_load.begin(), columns_to_load.end());
892 }
893
894 for (const auto &curr_col_name : columns_to_load) {
895 auto it = feature_map.find(curr_col_name);
896 if (it == feature_map.end()) {
897 RETURN_STATUS_UNEXPECTED("Invalid columns_list, tfrecord file failed to find column name: " + curr_col_name);
898 }
899 std::string column_name = it->first;
900
901 std::string column_type;
902
903 const dataengine::Feature &feature = it->second;
904 const dataengine::Feature::KindCase kind_case = feature.kind_case();
905 switch (kind_case) {
906 case dataengine::Feature::KindCase::kBytesList:
907 column_type = "uint8";
908 break;
909
910 case dataengine::Feature::KindCase::kFloatList:
911 column_type = "float32";
912 break;
913
914 case dataengine::Feature::KindCase::kInt64List:
915 column_type = "int64";
916 break;
917
918 case dataengine::Feature::KindCase::KIND_NOT_SET:
919 RETURN_STATUS_UNEXPECTED("Unrecognized column type, the column type of " + column_name +
920 " should be uint8, int64 or float32, but got unrecognized column type.");
921
922 default:
923 RETURN_STATUS_UNEXPECTED("Unsupported column type, the column type of " + column_name +
924 " should be uint8, int64 or float32, but got unsupported column type.");
925 }
926
927 RETURN_IF_NOT_OK(
928 data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1)));
929 }
930
931 return Status::OK();
932 }
933
HelperGetExampleSchema(std::string * const serialized_example,const std::string & realpath_value,const std::string & filename) const934 Status TFReaderOp::HelperGetExampleSchema(std::string *const serialized_example, const std::string &realpath_value,
935 const std::string &filename) const {
936 if (compression_type_ == CompressionType::NONE) {
937 std::ifstream reader;
938 reader.open(realpath_value, std::ios::in);
939
940 // read length
941 int64_t record_length = 0;
942 (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(kTFRecordRecLenSize));
943
944 // ignore crc header
945 (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
946
947 // read serialized Example
948 (*serialized_example).resize(static_cast<size_t>(record_length));
949 (void)reader.read(&(*serialized_example)[0], static_cast<std::streamsize>(record_length));
950 reader.close();
951 }
952 #if !defined(_WIN32) && !defined(_WIN64)
953 if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::GZIP_WITH_COUNT) {
954 gzFile file = gzopen(realpath_value.c_str(), "rb");
955
956 // read length
957 int64_t record_length = 0;
958 (void)gzread(file, reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
959
960 // ignore crc header
961 (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
962
963 // read serialized Example
964 (*serialized_example).resize(static_cast<size_t>(record_length));
965 (void)gzread(file, &(*serialized_example)[0], static_cast<unsigned int>(record_length));
966 (void)gzclose(file);
967 } else if (compression_type_ == CompressionType::ZLIB || compression_type_ == CompressionType::ZLIB_WITH_COUNT) {
968 // ZLIB stream setup (based on zlib.h tutorial)
969 ZLIBStreamInf zlib_stream;
970
971 std::ifstream reader(realpath_value.c_str(), std::ios::in | std::ios::binary);
972 zlib_stream.inflate_status = inflateInit(&zlib_stream.strm);
973 if (zlib_stream.inflate_status != Z_OK) {
974 reader.close();
975 RETURN_STATUS_UNEXPECTED("Failed to initialize inflate stream for ZLIB for file " + filename + "!");
976 }
977
978 // decompress until first row is read
979 do {
980 (void)reader.read(zlib_stream.input_stream, kZLIBChunkSize);
981 zlib_stream.strm.avail_in = static_cast<unsigned int>(reader.gcount());
982 zlib_stream.strm.next_in = reinterpret_cast<unsigned char *>(zlib_stream.input_stream);
983
984 // run inflate() on input until output buffer not full
985 do {
986 auto s = HelperInflateZLIB(&zlib_stream, filename);
987 if (s != Status::OK()) {
988 reader.close();
989 return s;
990 }
991 if (zlib_stream.left_to_read != 0) {
992 break;
993 }
994
995 // Process inflated data depending on read flag
996 if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::RecordLength)) { // read record length
997 zlib_stream.record_length = HelperBinDataToInt(zlib_stream.record_size, kTFRecordRecLenSize);
998 } else if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::Content)) { // read serialized example
999 (*serialized_example).resize(static_cast<size_t>(zlib_stream.record_length));
1000 (void)(*serialized_example)
1001 .assign(reinterpret_cast<char *>(zlib_stream.content.get()),
1002 static_cast<size_t>(zlib_stream.record_length));
1003 }
1004 zlib_stream.read_flag++;
1005 } while (zlib_stream.strm.avail_out == 0 && zlib_stream.read_flag != static_cast<int>(ZLIBReadFlag::Footer));
1006 } while (zlib_stream.inflate_status != Z_STREAM_END &&
1007 zlib_stream.read_flag != static_cast<int>(ZLIBReadFlag::Footer));
1008
1009 (void)inflateEnd(&zlib_stream.strm);
1010 if (zlib_stream.inflate_status != Z_STREAM_END && zlib_stream.read_flag < static_cast<int>(ZLIBReadFlag::Footer)) {
1011 reader.close();
1012 RETURN_STATUS_UNEXPECTED("Decompression of ZLIB file failed for file " + filename + "!");
1013 }
1014
1015 reader.close();
1016 }
1017 #endif
1018
1019 return Status::OK();
1020 }
1021
CountTotalRows(int64_t * out_total_rows,const std::vector<std::string> & filenames,int64_t threads,bool estimate,CompressionType compression_type)1022 Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads,
1023 bool estimate, CompressionType compression_type) {
1024 RETURN_UNEXPECTED_IF_NULL(out_total_rows);
1025 try {
1026 if (threads > filenames.size()) {
1027 threads = filenames.size();
1028 }
1029
1030 std::vector<std::future<int64_t>> async_results;
1031
1032 if (threads <= 0) {
1033 RETURN_STATUS_UNEXPECTED(
1034 "Invalid threads number, the threads number of TFReader should be greater than zero, but got " +
1035 std::to_string(threads) + ".");
1036 }
1037 int64_t chunk_size = filenames.size() / threads;
1038 int64_t remainder = filenames.size() % threads;
1039
1040 int64_t begin = 0;
1041 int64_t end = begin;
1042 for (int i = 0; i < threads; i++) {
1043 end += chunk_size;
1044 if (remainder > 0) {
1045 end++;
1046 remainder--;
1047 }
1048
1049 if (estimate) {
1050 // Parse a single file for each chunk with estimate mode on
1051 async_results.push_back(
1052 std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1, compression_type));
1053 } else {
1054 // Parse the whole chunk with estimate mode off
1055 async_results.push_back(
1056 std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end, compression_type));
1057 }
1058
1059 begin = end;
1060 }
1061
1062 int64_t total_rows = 0;
1063 for (auto &async_result : async_results) {
1064 total_rows += async_result.get();
1065 }
1066
1067 if (estimate) {
1068 // Each thread only scans 1 file
1069 // Estimated total rows = Average rows * total number of files
1070 total_rows = total_rows / threads * filenames.size();
1071 }
1072
1073 *out_total_rows = total_rows;
1074 } catch (const std::exception &e) {
1075 std::string err_msg = "Unexpected error occurred: ";
1076 err_msg += std::string(e.what());
1077 RETURN_STATUS_UNEXPECTED(err_msg);
1078 }
1079
1080 return Status::OK();
1081 }
1082
CountTotalRowsSectioned(const std::vector<std::string> & filenames,int64_t begin,int64_t end,CompressionType compression_type)1083 int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &filenames, int64_t begin, int64_t end,
1084 CompressionType compression_type) {
1085 int64_t rows_read = 0;
1086 for (size_t i = begin; i < end; i++) {
1087 auto realpath = FileUtils::GetRealPath(filenames[i].c_str());
1088 if (!realpath.has_value()) {
1089 MS_LOG(ERROR) << "Invalid file path, " << filenames[i] << " does not exist.";
1090 continue;
1091 }
1092
1093 if (compression_type == CompressionType::NONE) {
1094 HelperCountNonCompRows(realpath.value(), filenames[i], &rows_read);
1095 }
1096 #if !defined(_WIN32) && !defined(_WIN64)
1097 if (compression_type == CompressionType::GZIP_WITH_COUNT) {
1098 HelperCountGZIPRows(realpath.value(), filenames[i], &rows_read);
1099 } else if (compression_type == CompressionType::ZLIB_WITH_COUNT) {
1100 HelperCountZLIBRows(realpath.value(), filenames[i], &rows_read);
1101 }
1102 #endif
1103 }
1104
1105 return rows_read;
1106 }
1107
HelperCountNonCompRows(const std::string & realpath_value,const std::string & filename,int64_t * rows_read)1108 void TFReaderOp::HelperCountNonCompRows(const std::string &realpath_value, const std::string &filename,
1109 int64_t *rows_read) {
1110 std::ifstream reader;
1111 reader.open(realpath_value, std::ios::in);
1112 if (!reader) {
1113 MS_LOG(DEBUG) << "TFReader operator failed to open file " << filename << ".";
1114 }
1115
1116 while (reader.peek() != EOF) {
1117 // read length
1118 int64_t record_length = 0;
1119 (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(kTFRecordRecLenSize));
1120
1121 // ignore crc header
1122 (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
1123
1124 // ignore TFRecord file contents
1125 (void)reader.ignore(static_cast<std::streamsize>(record_length));
1126
1127 // ignore crc footer
1128 (void)reader.ignore(static_cast<std::streamsize>(kTFRecordHeadFootSize));
1129 (*rows_read)++;
1130 }
1131 reader.close();
1132 }
1133
1134 #if !defined(_WIN32) && !defined(_WIN64)
HelperCountGZIPRows(const std::string & realpath_value,const std::string & filename,int64_t * rows_read)1135 void TFReaderOp::HelperCountGZIPRows(const std::string &realpath_value, const std::string &filename,
1136 int64_t *rows_read) {
1137 gzFile file = gzopen(realpath_value.c_str(), "rb");
1138
1139 if (file == nullptr) {
1140 MS_LOG(DEBUG) << "TFReader operator failed to open file " << filename << " with GZIP.";
1141 }
1142
1143 while (gzeof(file) != 1) {
1144 // read length
1145 int64_t record_length = 0;
1146 (void)gzread(file, reinterpret_cast<char *>(&record_length), kTFRecordRecLenSize);
1147 if (record_length == 0) {
1148 continue;
1149 }
1150
1151 // ignore crc header
1152 (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
1153
1154 // ignore TFRecord file contents
1155 (void)gzseek(file, record_length, SEEK_CUR);
1156
1157 // ignore crc footer
1158 (void)gzseek(file, kTFRecordHeadFootSize, SEEK_CUR);
1159 (*rows_read)++;
1160 }
1161 (void)gzclose(file);
1162 }
1163
HelperCountZLIBRows(const std::string & realpath_value,const std::string & filename,int64_t * rows_read)1164 void TFReaderOp::HelperCountZLIBRows(const std::string &realpath_value, const std::string &filename,
1165 int64_t *rows_read) {
1166 // ZLIB stream setup (based on zlib.h tutorial)
1167 ZLIBStreamInf zlib_stream;
1168
1169 std::ifstream reader(realpath_value.c_str(), std::ios::in | std::ios::binary);
1170
1171 if (!reader) {
1172 MS_LOG(DEBUG) << "TFReader operator failed to open file " << filename << " with ZLIB.";
1173 }
1174
1175 zlib_stream.inflate_status = inflateInit(&zlib_stream.strm);
1176 if (zlib_stream.inflate_status != Z_OK) {
1177 reader.close();
1178 MS_LOG(DEBUG) << "Failed to initialize inflate stream for ZLIB when counting rows for file " << filename << "!";
1179 }
1180
1181 // decompress until first row is read
1182 do {
1183 (void)reader.read(zlib_stream.input_stream, kZLIBChunkSize);
1184 zlib_stream.strm.avail_in = static_cast<unsigned int>(reader.gcount());
1185 zlib_stream.strm.next_in = reinterpret_cast<unsigned char *>(zlib_stream.input_stream);
1186
1187 // run inflate() on input until output buffer not full
1188 do {
1189 if (zlib_stream.left_to_read != 0) {
1190 zlib_stream.strm.avail_out = zlib_stream.left_to_read; // need to read the rest before process
1191 } else {
1192 switch (zlib_stream.read_flag) {
1193 case ZLIBReadFlag::RecordLength: // record length
1194 zlib_stream.strm.avail_out = kTFRecordRecLenSize;
1195 zlib_stream.strm.next_out = zlib_stream.record_size;
1196 break;
1197 default: // record header, example, and footer since we just want to count rows
1198 zlib_stream.strm.avail_out = zlib_stream.record_length + kTFRecordHeadFootSize + kTFRecordHeadFootSize;
1199 zlib_stream.content = std::make_unique<unsigned char[]>(zlib_stream.record_length + kTFRecordHeadFootSize +
1200 kTFRecordHeadFootSize);
1201 zlib_stream.strm.next_out = zlib_stream.content.get();
1202 }
1203 }
1204
1205 // Inflate stream
1206 zlib_stream.inflate_status = inflate(&zlib_stream.strm, Z_NO_FLUSH);
1207 if (zlib_stream.inflate_status == Z_OK || zlib_stream.inflate_status == Z_STREAM_END) {
1208 zlib_stream.left_to_read = zlib_stream.strm.avail_out; // after reading
1209 } else {
1210 MS_LOG(DEBUG) << "An error is found during inflation when counting rows for file: " << filename << "!";
1211 }
1212
1213 if (zlib_stream.left_to_read != 0) {
1214 break;
1215 }
1216
1217 // Process inflated data depending on read flag
1218 if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::RecordLength)) { // read record length
1219 zlib_stream.record_length = HelperBinDataToInt(zlib_stream.record_size, kTFRecordRecLenSize);
1220 } else if (zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::Footer)) {
1221 (*rows_read)++;
1222 }
1223 zlib_stream.read_flag = zlib_stream.read_flag == static_cast<int>(ZLIBReadFlag::Footer)
1224 ? static_cast<int>(ZLIBReadFlag::RecordLength)
1225 : static_cast<int>(ZLIBReadFlag::Footer); // resets flag to reading record length
1226 } while (zlib_stream.strm.avail_out == 0 && zlib_stream.inflate_status == Z_OK);
1227 } while (zlib_stream.inflate_status != Z_STREAM_END && zlib_stream.inflate_status == Z_OK);
1228
1229 (void)inflateEnd(&zlib_stream.strm);
1230 if (zlib_stream.inflate_status != Z_STREAM_END) {
1231 MS_LOG(DEBUG) << "Decompression of ZLIB file failed when counting rows for file " << filename << "!";
1232 }
1233
1234 reader.close();
1235 }
1236
1237 #endif
1238
ComputeColMap()1239 Status TFReaderOp::ComputeColMap() {
1240 // Construct the column name map for this operator (base class field)
1241 if (column_name_id_map_.empty()) {
1242 if (decode_) {
1243 for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
1244 column_name_id_map_[data_schema_->Column(i).Name()] = i;
1245 }
1246 } else {
1247 // if decode is false, the output will only have one column containing the record bytes
1248 column_name_id_map_["proto"] = 0;
1249 }
1250 } else {
1251 MS_LOG(WARNING) << "Column name map is already set!";
1252 }
1253 return Status::OK();
1254 }
1255
FillIOBlockQueue(const std::vector<int64_t> & i_keys)1256 Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
1257 int32_t queue_index = 0;
1258 int32_t key_index = 0;
1259 int64_t pre_count = 0;
1260 int64_t start_offset = 0;
1261 int64_t end_offset = 0;
1262 bool end_of_epoch = false;
1263 if (shuffle_files_) {
1264 do {
1265 // Iterate over all the keys and add one key to each block.
1266 for (auto i_key : i_keys) {
1267 {
1268 if (!GetLoadIoBlockQueue()) {
1269 end_of_epoch = true;
1270 break;
1271 }
1272 }
1273 RETURN_IF_NOT_OK(HelperIOBlockFiller(&queue_index, &key_index, &pre_count, &start_offset, &end_offset, i_key,
1274 (*filename_index_)[i_key]));
1275 }
1276 } while ((compression_type_ == CompressionType::NONE || compression_type_ == CompressionType::GZIP_WITH_COUNT ||
1277 compression_type_ == CompressionType::ZLIB_WITH_COUNT) &&
1278 equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
1279 !end_of_epoch);
1280 } else {
1281 do {
1282 // Iterate over all the keys and add one key to each block.
1283 for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
1284 {
1285 if (!GetLoadIoBlockQueue()) {
1286 end_of_epoch = true;
1287 break;
1288 }
1289 }
1290 RETURN_IF_NOT_OK(
1291 HelperIOBlockFiller(&queue_index, &key_index, &pre_count, &start_offset, &end_offset, it.key(), it.value()));
1292 }
1293 } while ((compression_type_ == CompressionType::NONE || compression_type_ == CompressionType::GZIP_WITH_COUNT ||
1294 compression_type_ == CompressionType::ZLIB_WITH_COUNT) &&
1295 equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
1296 !end_of_epoch);
1297 }
1298 RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
1299 return Status::OK();
1300 }
1301
HelperIOBlockFiller(int32_t * queue_index,int32_t * key_index,int64_t * pre_count,int64_t * start_offset,int64_t * end_offset,int64_t key,const std::string & file_name)1302 Status TFReaderOp::HelperIOBlockFiller(int32_t *queue_index, int32_t *key_index, int64_t *pre_count,
1303 int64_t *start_offset, int64_t *end_offset, int64_t key,
1304 const std::string &file_name) {
1305 if (compression_type_ == CompressionType::GZIP || compression_type_ == CompressionType::ZLIB) {
1306 int num_files_to_read =
1307 static_cast<int>(dataset_files_list_.size() - dataset_files_list_.size() % static_cast<size_t>(num_devices_));
1308 if (*key_index % num_devices_ == device_id_ && *key_index < num_files_to_read) {
1309 *end_offset = static_cast<int>(total_rows_ /
1310 static_cast<int>(dataset_files_list_.size() / static_cast<size_t>(num_devices_)));
1311 auto ioBlock = std::make_unique<FilenameBlock>(key, 0, *end_offset, IOBlock::kFlagNone);
1312 RETURN_IF_NOT_OK(PushIoBlockQueue(*queue_index, std::move(ioBlock)));
1313 *queue_index = (*queue_index + 1) % num_workers_;
1314 }
1315 (*key_index)++;
1316 } else if (!equal_rows_per_shard_) {
1317 if ((*key_index)++ % num_devices_ == device_id_) {
1318 auto ioBlock = std::make_unique<FilenameBlock>(key, kInvalidOffset, kInvalidOffset, IOBlock::kFlagNone);
1319 RETURN_IF_NOT_OK(PushIoBlockQueue(*queue_index, std::move(ioBlock)));
1320 *queue_index = (*queue_index + 1) % num_workers_;
1321 }
1322 } else {
1323 if (NeedPushFileToBlockQueue(file_name, start_offset, end_offset, *pre_count)) {
1324 auto ioBlock = std::make_unique<FilenameBlock>(key, *start_offset, *end_offset, IOBlock::kFlagNone);
1325 RETURN_IF_NOT_OK(PushIoBlockQueue(*queue_index, std::move(ioBlock)));
1326 *queue_index = (*queue_index + 1) % num_workers_;
1327 }
1328
1329 *pre_count += filename_numrows_[file_name];
1330 }
1331 return Status::OK();
1332 }
1333
GetNextRowPullMode(TensorRow * const row)1334 Status TFReaderOp::GetNextRowPullMode(TensorRow *const row) {
1335 RETURN_UNEXPECTED_IF_NULL(row);
1336 RETURN_IF_NOT_OK(NonMappableLeafOp::GetNextRowPullMode(row));
1337 if (decode_) {
1338 if (!row->empty()) {
1339 // data got from jagged_rows_connector is raw bytes so we need to parse it before return
1340 TensorRow res;
1341 RETURN_IF_NOT_OK(ParseExample(*row, &res));
1342 *row = std::move(res);
1343 }
1344 }
1345 return Status::OK();
1346 }
1347 } // namespace dataset
1348 } // namespace mindspore
1349