• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
17 
18 #include <algorithm>
19 #include <fstream>
20 #include <future>
21 #include <memory>
22 #include <mutex>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "utils/file_utils.h"
28 #include "proto/example.pb.h"
29 #include "minddata/dataset/core/config_manager.h"
30 #include "minddata/dataset/core/global_context.h"
31 #include "minddata/dataset/engine/data_schema.h"
32 #include "minddata/dataset/engine/datasetops/source/io_block.h"
33 #include "minddata/dataset/engine/db_connector.h"
34 #include "minddata/dataset/engine/execution_tree.h"
35 #include "minddata/dataset/engine/jagged_connector.h"
36 #include "minddata/dataset/util/status.h"
37 #include "minddata/dataset/util/task_manager.h"
38 #include "minddata/dataset/util/wait_post.h"
39 #include "utils/system/crc32c.h"
40 
41 namespace mindspore {
42 namespace dataset {
43 const int64_t kTFRecordFileLimit = 0x140000000;
44 
ValidateFirstRowCrc(const std::string & filename)45 bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) {
46   auto realpath = FileUtils::GetRealPath(filename.data());
47   if (!realpath.has_value()) {
48     MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << filename;
49     return false;
50   }
51 
52   std::ifstream reader;
53   reader.open(realpath.value());
54   if (!reader) {
55     return false;
56   }
57   int64_t file_len = reader.seekg(0, std::ios::end).tellg();
58   if (file_len > kTFRecordFileLimit) {
59     MS_LOG(WARNING) << "The file size of " << filename
60                     << " is larger than 5G, there may be performance problems in "
61                        "distributed scenarios, and it can be split into sub-files "
62                        "smaller than 5G to get better performance.";
63   }
64   (void)reader.seekg(0, std::ios::beg);
65 
66   // read data
67   int64_t record_length = 0;
68   (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
69 
70   // read crc from file
71   uint32_t masked_crc = 0;
72   (void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));
73 
74   // generate crc from data
75   uint32_t generated_crc =
76     system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));
77 
78   return masked_crc == generated_crc;
79 }
80 
TFReaderOp(int32_t num_workers,int32_t worker_connector_size,int64_t total_num_rows,std::vector<std::string> dataset_files_list,std::unique_ptr<DataSchema> data_schema,int32_t op_connector_size,std::vector<std::string> columns_to_load,bool shuffle_files,int32_t num_devices,int32_t device_id,bool equal_rows_per_shard)81 TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
82                        std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
83                        int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
84                        int32_t num_devices, int32_t device_id, bool equal_rows_per_shard)
85     : NonMappableLeafOp(num_workers, worker_connector_size, total_num_rows, op_connector_size, shuffle_files,
86                         num_devices, device_id),
87       dataset_files_list_(std::move(dataset_files_list)),
88       columns_to_load_(std::move(columns_to_load)),
89       data_schema_(std::move(data_schema)),
90       equal_rows_per_shard_(equal_rows_per_shard) {}
91 
92 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const93 void TFReaderOp::Print(std::ostream &out, bool show_all) const {
94   if (!show_all) {
95     // Call the super class for displaying any common 1-liner info
96     ParallelOp::Print(out, show_all);
97     // Then show any custom derived-internal 1-liner info for this op
98     out << "\n";
99   } else {
100     // Call the super class for displaying any common detailed info
101     ParallelOp::Print(out, show_all);
102     // Then show any custom derived-internal stuff
103     out << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
104         << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
105         << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n";
106     for (size_t i = 0; i < dataset_files_list_.size(); ++i) {
107       out << " " << dataset_files_list_[i];
108     }
109     if (!columns_to_load_.empty()) {
110       out << "\nColumns to load:\n";
111       for (size_t j = 0; j < columns_to_load_.size(); ++j) {
112         out << " " << columns_to_load_[j];
113       }
114     }
115     out << "\nData Schema:\n";
116     out << *data_schema_ << "\n\n";
117   }
118 }
119 
Init()120 Status TFReaderOp::Init() {
121   if (data_schema_->Empty()) {
122     RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_));
123   }
124 
125   if (total_rows_ == 0) {
126     total_rows_ = data_schema_->NumRows();
127   }
128   if (total_rows_ < 0) {
129     RETURN_STATUS_UNEXPECTED(
130       "Invalid parameter, num_samples or num_rows for TFRecordDataset must be greater than 0, but got: " +
131       std::to_string(total_rows_));
132   }
133 
134   // Build the index with our files such that each file corresponds to a key id.
135   RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
136 
137   // The creation of the internal connector has been delayed until now, since we may have adjusted the
138   // number of workers.  Now that the worker count is established, create the connector now in the
139   // parallel op base.
140   RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
141 
142   jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
143 
144   // temporary: make size large enough to hold all files + EOE to avoid hangs
145   int32_t safe_queue_size = static_cast<int32_t>(std::ceil(dataset_files_list_.size() / num_workers_)) + 1;
146   io_block_queues_.Init(num_workers_, safe_queue_size);
147 
148   return Status::OK();
149 }
150 
CalculateNumRowsPerShard()151 Status TFReaderOp::CalculateNumRowsPerShard() {
152   if (!equal_rows_per_shard_) {
153     return Status::OK();
154   }
155 
156   for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
157     std::vector<std::string> file(1, it.value());
158     int64_t num = CountTotalRowsSectioned(file, 0, 1);
159     filename_numrows_[it.value()] = num;
160     num_rows_ += num;
161   }
162   num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
163   if (num_rows_per_shard_ == 0) {
164     std::stringstream ss;
165     for (int i = 0; i < dataset_files_list_.size(); ++i) {
166       ss << " " << dataset_files_list_[i];
167     }
168     std::string file_list = ss.str();
169     RETURN_STATUS_UNEXPECTED(
170       "Invalid data, TFRecordDataset API can't read the data file (interface mismatch or no data under the file). "
171       "Check file path." +
172       file_list);
173   }
174   return Status::OK();
175 }
176 
FillIOBlockShuffle(const std::vector<int64_t> & i_keys)177 Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
178   int32_t queue_index = 0;
179   int32_t key_index = 0;
180   int64_t pre_count = 0;
181   int64_t start_offset = 0;
182   int64_t end_offset = 0;
183   bool finish = false;
184   bool end_of_epoch = false;
185   while (!finish) {
186     for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
187       {
188         std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
189         if (load_io_block_queue_ == false) {
190           end_of_epoch = true;
191           break;
192         }
193       }
194       if (!equal_rows_per_shard_) {
195         if (key_index++ % num_devices_ == device_id_) {
196           auto ioBlock = std::make_unique<FilenameBlock>(*it, kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone);
197           RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
198           queue_index = (queue_index + 1) % num_workers_;
199         }
200       } else {
201         // Do an index lookup using that key to get the filename.
202         std::string file_name = (*filename_index_)[*it];
203         if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
204           auto ioBlock = std::make_unique<FilenameBlock>(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone);
205           RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
206           MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset;
207           queue_index = (queue_index + 1) % num_workers_;
208         }
209 
210         pre_count += filename_numrows_[file_name];
211       }
212     }
213     if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
214         !end_of_epoch) {
215       finish = false;
216     } else {
217       finish = true;
218     }
219   }
220   RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
221   return Status::OK();
222 }
223 
FillIOBlockNoShuffle()224 Status TFReaderOp::FillIOBlockNoShuffle() {
225   int32_t queue_index = 0;
226   int32_t key_index = 0;
227   int64_t pre_count = 0;
228   int64_t start_offset = 0;
229   int64_t end_offset = 0;
230   bool finish = false;
231   bool end_of_epoch = false;
232   while (!finish) {
233     // Iterate over all the keys and add one key to each block.
234     for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
235       {
236         std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
237         if (load_io_block_queue_ == false) {
238           end_of_epoch = true;
239           break;
240         }
241       }
242       if (!equal_rows_per_shard_) {
243         if (key_index++ % num_devices_ == device_id_) {
244           auto ioBlock =
245             std::make_unique<FilenameBlock>(it.key(), kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone);
246           RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
247           queue_index = (queue_index + 1) % num_workers_;
248         }
249       } else {
250         std::string file_name = it.value();
251         if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
252           auto ioBlock = std::make_unique<FilenameBlock>(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone);
253           RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
254           queue_index = (queue_index + 1) % num_workers_;
255         }
256 
257         pre_count += filename_numrows_[file_name];
258       }
259     }
260     if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
261         !end_of_epoch) {
262       finish = false;
263     } else {
264       finish = true;
265     }
266   }
267 
268   RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
269   return Status::OK();
270 }
271 
272 // Reads a tf_file file and loads the data into multiple TensorRows.
LoadFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id)273 Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
274   auto realpath = FileUtils::GetRealPath(filename.data());
275   if (!realpath.has_value()) {
276     MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << filename;
277     RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + filename);
278   }
279 
280   std::ifstream reader;
281   reader.open(realpath.value());
282   if (!reader) {
283     RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + filename);
284   }
285 
286   int64_t rows_read = 0;
287   int64_t rows_total = 0;
288 
289   while (reader.peek() != EOF) {
290     if (!load_jagged_connector_) {
291       break;
292     }
293     RETURN_IF_INTERRUPTED();
294 
295     // read length
296     int64_t record_length = 0;
297     (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
298 
299     // ignore crc header
300     (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
301 
302     // read serialized Example
303     std::string serialized_example;
304     serialized_example.resize(record_length);
305     (void)reader.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
306 
307     int32_t num_columns = data_schema_->NumColumns();
308     TensorRow newRow(num_columns, nullptr);
309 
310     if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) {
311       dataengine::Example tf_file;
312       if (!tf_file.ParseFromString(serialized_example)) {
313         std::string errMsg = "Invalid file, failed to parse tfrecord file : " + filename;
314         MS_LOG(DEBUG) << errMsg + ", details of string: " << serialized_example;
315         RETURN_STATUS_UNEXPECTED(errMsg);
316       }
317 
318       std::vector<std::string> file_path(num_columns, filename);
319       newRow.setPath(file_path);
320       RETURN_IF_NOT_OK(LoadExample(&tf_file, &newRow));
321       rows_read++;
322       RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(newRow)));
323     }
324 
325     // ignore crc footer
326     (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
327     rows_total++;
328   }
329 
330   return Status::OK();
331 }
332 
333 // Parses a single row and puts the data into a tensor table.
LoadExample(const dataengine::Example * tf_file,TensorRow * out_row)334 Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, TensorRow *out_row) {
335   int32_t num_columns = data_schema_->NumColumns();
336   for (int32_t col = 0; col < num_columns; ++col) {
337     const ColDescriptor current_col = data_schema_->Column(col);
338     const dataengine::Features &example_features = tf_file->features();
339     const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
340     auto iter_column = feature_map.find(current_col.Name());
341     if (iter_column == feature_map.end()) {
342       RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.Name() + " does not exist.");
343     }
344     const dataengine::Feature &column_values_list = iter_column->second;
345     RETURN_IF_NOT_OK(LoadFeature(out_row, column_values_list, current_col, col));
346   }
347 
348   return Status::OK();
349 }
350 
351 // Parses a single cell and puts the data into a tensor table.
LoadFeature(TensorRow * tensor_row,const dataengine::Feature & column_values_list,const ColDescriptor & current_col,int32_t col)352 Status TFReaderOp::LoadFeature(TensorRow *tensor_row, const dataengine::Feature &column_values_list,
353                                const ColDescriptor &current_col, int32_t col) {
354   const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case();
355   std::unique_ptr<float[]> float_array;     // For staging data from protobuf deserialization
356   const unsigned char *data_ptr = nullptr;  // Generic pointer used for populating the Tensor
357 
358   // This variable will point into the above staging variables.
359   // Also used for creating shape attributes.
360   int32_t num_elements = 0;
361 
362   // we build a tensor first a read directly into it if we need to cast
363   std::shared_ptr<Tensor> ts;
364 
365   // Depending on the type of data from the tf_file, we want to extract 2 things:
366   // 1) A pointer to the data as a const unsigned char *
367   // 2) The number of elements of the data
368   // After those are determined, we can then build the tensor to represent this data.
369   switch (column_list_type) {
370     case dataengine::Feature::KindCase::kBytesList: {
371       RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts));
372 
373       break;
374     }
375     case dataengine::Feature::KindCase::kFloatList: {
376       RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array));
377 
378       data_ptr = reinterpret_cast<const unsigned char *>(float_array.get());
379 
380       // only floatList needs to create the tensor here, other two lists read directly
381       // into the tensor
382       TensorShape current_shape = TensorShape::CreateUnknownRankShape();
383       RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, &current_shape));
384       RETURN_IF_NOT_OK(Tensor::CreateFromMemory(current_shape, current_col.Type(), data_ptr, &ts));
385       break;
386     }
387     case dataengine::Feature::KindCase::kInt64List: {
388       RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts));
389       break;
390     }
391     case dataengine::Feature::KindCase::KIND_NOT_SET: {
392       std::string err_msg = "Invalid data, column type in tf record file must be uint8, int64 or float32.";
393       RETURN_STATUS_UNEXPECTED(err_msg);
394     }
395     default: {
396       std::string err_msg = "Invalid data, column type in tf record file must be uint8, int64 or float32.";
397       RETURN_STATUS_UNEXPECTED(err_msg);
398     }
399   }
400 
401   (*tensor_row)[col] = std::move(ts);
402 
403   return Status::OK();
404 }
405 
LoadBytesList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)406 Status TFReaderOp::LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
407                                  int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
408   // kBytesList can map to the following DE types ONLY!
409   // DE_UINT8, DE_INT8
410   // Must be single byte type for each element!
411   if (current_col.Type() != DataType::DE_UINT8 && current_col.Type() != DataType::DE_INT8 &&
412       current_col.Type() != DataType::DE_STRING) {
413     std::string err_msg = "Invalid data, invalid data type for Tensor at column: " + current_col.Name() +
414                           ", data type should be int8, uint8 or string, but got " + current_col.Type().ToString();
415     RETURN_STATUS_UNEXPECTED(err_msg);
416   }
417 
418   const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
419 
420   *num_elements = bytes_list.value_size();
421 
422   if (current_col.Type() == DataType::DE_STRING) {
423     TensorShape shape = TensorShape::CreateScalar();
424     RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape));
425     RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, shape, tensor));
426     return Status::OK();
427   }
428 
429   uint64_t max_size = 0;
430   for (uint32_t i = 0; i < bytes_list.value_size(); ++i) {
431 #if defined(__APPLE__)
432     max_size = fmax(max_size, bytes_list.value(i).size());
433 #else
434     max_size = std::max(max_size, bytes_list.value(i).size());
435 #endif
436   }
437 
438   int64_t pad_size = max_size;
439 
440   // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn
441   if (current_col.HasShape()) {
442     TensorShape cur_shape = current_col.Shape();
443     if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) {
444       int64_t new_pad_size = 1;
445       for (int i = 1; i < cur_shape.Size(); ++i) {
446         if (cur_shape[i] == TensorShape::kDimUnknown) {
447           std::string err_msg =
448             "Invalid data, more than one unknown dimension in the shape of column: " + current_col.Name();
449           RETURN_STATUS_UNEXPECTED(err_msg);
450         }
451         new_pad_size *= cur_shape[i];
452       }
453       pad_size = new_pad_size;
454     } else {
455       if (cur_shape.known() && cur_shape.NumOfElements() != max_size) {
456         std::string err_msg = "Invalid data, shape in schema's column '" + current_col.Name() + "' is incorrect." +
457                               "\nshape received: " + cur_shape.ToString() +
458                               "\ntotal elements in shape received: " + std::to_string(cur_shape.NumOfElements()) +
459                               "\nexpected total elements in shape: " + std::to_string(max_size);
460         RETURN_STATUS_UNEXPECTED(err_msg);
461       }
462     }
463   }
464 
465   // know how many elements there are and the total bytes, create tensor here:
466   TensorShape current_shape = TensorShape::CreateScalar();
467   RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, &current_shape));
468   RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, current_shape, current_col.Type(), pad_size, tensor));
469 
470   return Status::OK();
471 }
472 
LoadFloatList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::unique_ptr<float[]> * float_array)473 Status TFReaderOp::LoadFloatList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
474                                  int32_t *num_elements, std::unique_ptr<float[]> *float_array) {
475   // KFloatList can only map to DE types:
476   // DE_FLOAT32
477   if (current_col.Type() != DataType::DE_FLOAT32) {
478     std::string err_msg = "Invalid data, invalid data type for Tensor at column: " + current_col.Name() +
479                           ", data type should be string, but got " + current_col.Type().ToString();
480     RETURN_STATUS_UNEXPECTED(err_msg);
481   }
482 
483   const dataengine::FloatList &float_list = column_values_list.float_list();
484 
485   // Identify how many values we have and then create a local array of these
486   // to deserialize into
487   *num_elements = float_list.value_size();
488   *float_array = std::make_unique<float[]>(*num_elements);
489   for (int i = 0; i < float_list.value_size(); ++i) {
490     (*float_array)[i] = float_list.value(i);
491   }
492 
493   return Status::OK();
494 }
495 
496 // Determines which template type to use and calls LoadIntList
LoadIntListSwitch(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)497 Status TFReaderOp::LoadIntListSwitch(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
498                                      int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
499   if (current_col.Type() == DataType::DE_UINT64) {
500     RETURN_IF_NOT_OK(LoadIntList<uint64_t>(current_col, column_values_list, num_elements, tensor));
501   } else if (current_col.Type() == DataType::DE_INT64) {
502     RETURN_IF_NOT_OK(LoadIntList<int64_t>(current_col, column_values_list, num_elements, tensor));
503   } else if (current_col.Type() == DataType::DE_UINT32) {
504     RETURN_IF_NOT_OK(LoadIntList<uint32_t>(current_col, column_values_list, num_elements, tensor));
505   } else if (current_col.Type() == DataType::DE_INT32) {
506     RETURN_IF_NOT_OK(LoadIntList<int32_t>(current_col, column_values_list, num_elements, tensor));
507   } else if (current_col.Type() == DataType::DE_UINT16) {
508     RETURN_IF_NOT_OK(LoadIntList<uint16_t>(current_col, column_values_list, num_elements, tensor));
509   } else if (current_col.Type() == DataType::DE_INT16) {
510     RETURN_IF_NOT_OK(LoadIntList<int16_t>(current_col, column_values_list, num_elements, tensor));
511   } else if (current_col.Type() == DataType::DE_UINT8) {
512     RETURN_IF_NOT_OK(LoadIntList<uint8_t>(current_col, column_values_list, num_elements, tensor));
513   } else if (current_col.Type() == DataType::DE_INT8) {
514     RETURN_IF_NOT_OK(LoadIntList<int8_t>(current_col, column_values_list, num_elements, tensor));
515   } else {
516     std::string err_msg = "Invalid data, invalid datatype for Tensor at column: " + current_col.Name() +
517                           ", data type should be uint64, int64, uint32, int32, uint16, int16, uint8 or int8" +
518                           ", but got " + current_col.Type().ToString();
519     RETURN_STATUS_UNEXPECTED(err_msg);
520   }
521 
522   return Status::OK();
523 }
524 
525 // Reads values from a bytes list and casts the value to type T, must be an integral type
526 // compatible with int64_t
527 template <typename T>
LoadIntList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)528 Status TFReaderOp::LoadIntList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
529                                int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
530   if (!(current_col.Type().IsInt())) {
531     std::string err_msg = "Invalid data, invalid data type for Tensor at column: " + current_col.Name() +
532                           ", data type should be int, but got " + current_col.Type().ToString();
533     RETURN_STATUS_UNEXPECTED(err_msg);
534   }
535 
536   const dataengine::Int64List &int64_list = column_values_list.int64_list();
537 
538   // Identify how many values we have and then create a local array of these
539   // to deserialize into
540   *num_elements = int64_list.value_size();
541 
542   // know how many elements there are, create tensor here:
543   TensorShape current_shape = TensorShape::CreateUnknownRankShape();
544   RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &current_shape));
545   RETURN_IF_NOT_OK(Tensor::CreateEmpty(current_shape, current_col.Type(), tensor));
546 
547   int64_t i = 0;
548   auto it = (*tensor)->begin<T>();
549   for (; it != (*tensor)->end<T>(); i++, ++it) {
550     T element = static_cast<T>(int64_list.value(i));
551     *it = element;
552   }
553 
554   return Status::OK();
555 }
556 
CreateSchema(const std::string tf_file,std::vector<std::string> columns_to_load)557 Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector<std::string> columns_to_load) {
558   auto realpath = FileUtils::GetRealPath(tf_file.data());
559   if (!realpath.has_value()) {
560     MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << tf_file;
561     RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + tf_file);
562   }
563 
564   std::ifstream reader;
565   reader.open(realpath.value());
566 
567   // read length
568   int64_t record_length = 0;
569   (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
570 
571   // ignore crc header
572   (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
573 
574   // read serialized Example
575   std::string serialized_example;
576   serialized_example.resize(record_length);
577   (void)reader.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
578 
579   dataengine::Example example;
580   if (!example.ParseFromString(serialized_example)) {
581     RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse tfrecord file: " + serialized_example);
582   }
583 
584   const dataengine::Features &example_features = example.features();
585   const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
586 
587   if (columns_to_load.empty()) {
588     (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load),
589                          [](const auto &it) -> std::string { return it.first; });
590     std::sort(columns_to_load.begin(), columns_to_load.end());
591   }
592 
593   for (const auto &curr_col_name : columns_to_load) {
594     auto it = feature_map.find(curr_col_name);
595     if (it == feature_map.end()) {
596       RETURN_STATUS_UNEXPECTED("Invalid data, failed to find column name: " + curr_col_name);
597     }
598     std::string column_name = it->first;
599 
600     std::string column_type;
601 
602     const dataengine::Feature &feature = it->second;
603     const dataengine::Feature::KindCase kind_case = feature.kind_case();
604     switch (kind_case) {
605       case dataengine::Feature::KindCase::kBytesList:
606         column_type = "uint8";
607         break;
608 
609       case dataengine::Feature::KindCase::kFloatList:
610         column_type = "float32";
611         break;
612 
613       case dataengine::Feature::KindCase::kInt64List:
614         column_type = "int64";
615         break;
616 
617       case dataengine::Feature::KindCase::KIND_NOT_SET:
618         RETURN_STATUS_UNEXPECTED("Invalid data, column type of tf record file must be uint8, int64 or float32.");
619 
620       default:
621         RETURN_STATUS_UNEXPECTED("Invalid data, column type of tf record file must be uint8, int64 or float32.");
622     }
623 
624     RETURN_IF_NOT_OK(
625       data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1)));
626   }
627 
628   return Status::OK();
629 }
630 
CountTotalRows(int64_t * out_total_rows,const std::vector<std::string> & filenames,int64_t threads,bool estimate)631 Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads,
632                                   bool estimate) {
633   try {
634     if (threads > filenames.size()) {
635       threads = filenames.size();
636     }
637 
638     std::vector<std::future<int64_t>> async_results;
639 
640     int64_t chunk_size = filenames.size() / threads;
641     int64_t remainder = filenames.size() % threads;
642 
643     int64_t begin = 0;
644     int64_t end = begin;
645     for (int i = 0; i < threads; i++) {
646       end += chunk_size;
647       if (remainder > 0) {
648         end++;
649         remainder--;
650       }
651 
652       if (estimate) {
653         // Parse a single file for each chunk with estimate mode on
654         async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1));
655       } else {
656         // Parse the whole chunk with estimate mode off
657         async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end));
658       }
659 
660       begin = end;
661     }
662 
663     int64_t total_rows = 0;
664     for (int i = 0; i < async_results.size(); i++) {
665       total_rows += async_results[i].get();
666     }
667 
668     if (estimate) {
669       // Each thread only scans 1 file
670       // Estimated total rows = Average rows * total number of files
671       total_rows = total_rows / threads * filenames.size();
672     }
673 
674     *out_total_rows = total_rows;
675   } catch (const std::exception &e) {
676     std::string err_msg = "Unexpected error occurred: ";
677     err_msg += e.what();
678     RETURN_STATUS_UNEXPECTED(err_msg);
679   }
680 
681   return Status::OK();
682 }
683 
CountTotalRowsSectioned(const std::vector<std::string> & filenames,int64_t begin,int64_t end)684 int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &filenames, int64_t begin, int64_t end) {
685   int64_t rows_read = 0;
686   for (int i = begin; i < end; i++) {
687     auto realpath = FileUtils::GetRealPath(filenames[i].data());
688     if (!realpath.has_value()) {
689       MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << filenames[i];
690       continue;
691     }
692 
693     std::ifstream reader;
694     reader.open(realpath.value());
695     if (!reader) {
696       MS_LOG(DEBUG) << "TFReader operator failed to open file " << filenames[i] << ".";
697     }
698 
699     while (reader.peek() != EOF) {
700       // read length
701       int64_t record_length = 0;
702       (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
703 
704       // ignore crc header
705       (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
706 
707       // ignore tf_file contents
708       (void)reader.ignore(static_cast<std::streamsize>(record_length));
709 
710       // ignore crc footer
711       (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
712 
713       rows_read++;
714     }
715   }
716 
717   return rows_read;
718 }
719 
ComputeColMap()720 Status TFReaderOp::ComputeColMap() {
721   // Construct the column name map for this operator (base class field)
722   if (column_name_id_map_.empty()) {
723     for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
724       column_name_id_map_[data_schema_->Column(i).Name()] = i;
725     }
726   } else {
727     MS_LOG(WARNING) << "Column name map is already set!";
728   }
729   return Status::OK();
730 }
FillIOBlockQueue(const std::vector<int64_t> & i_keys)731 Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
732   if (shuffle_files_) {
733     return FillIOBlockShuffle(i_keys);
734   }
735   return FillIOBlockNoShuffle();
736 }
737 
738 }  // namespace dataset
739 }  // namespace mindspore
740