1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
17
18 #include <algorithm>
19 #include <fstream>
20 #include <future>
21 #include <memory>
22 #include <mutex>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "utils/file_utils.h"
28 #include "proto/example.pb.h"
29 #include "minddata/dataset/core/config_manager.h"
30 #include "minddata/dataset/core/global_context.h"
31 #include "minddata/dataset/engine/data_schema.h"
32 #include "minddata/dataset/engine/datasetops/source/io_block.h"
33 #include "minddata/dataset/engine/db_connector.h"
34 #include "minddata/dataset/engine/execution_tree.h"
35 #include "minddata/dataset/engine/jagged_connector.h"
36 #include "minddata/dataset/util/status.h"
37 #include "minddata/dataset/util/task_manager.h"
38 #include "minddata/dataset/util/wait_post.h"
39 #include "utils/system/crc32c.h"
40
41 namespace mindspore {
42 namespace dataset {
43 const int64_t kTFRecordFileLimit = 0x140000000;
44
ValidateFirstRowCrc(const std::string & filename)45 bool TFReaderOp::ValidateFirstRowCrc(const std::string &filename) {
46 auto realpath = FileUtils::GetRealPath(filename.data());
47 if (!realpath.has_value()) {
48 MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << filename;
49 return false;
50 }
51
52 std::ifstream reader;
53 reader.open(realpath.value());
54 if (!reader) {
55 return false;
56 }
57 int64_t file_len = reader.seekg(0, std::ios::end).tellg();
58 if (file_len > kTFRecordFileLimit) {
59 MS_LOG(WARNING) << "The file size of " << filename
60 << " is larger than 5G, there may be performance problems in "
61 "distributed scenarios, and it can be split into sub-files "
62 "smaller than 5G to get better performance.";
63 }
64 (void)reader.seekg(0, std::ios::beg);
65
66 // read data
67 int64_t record_length = 0;
68 (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
69
70 // read crc from file
71 uint32_t masked_crc = 0;
72 (void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));
73
74 // generate crc from data
75 uint32_t generated_crc =
76 system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));
77
78 return masked_crc == generated_crc;
79 }
80
TFReaderOp(int32_t num_workers,int32_t worker_connector_size,int64_t total_num_rows,std::vector<std::string> dataset_files_list,std::unique_ptr<DataSchema> data_schema,int32_t op_connector_size,std::vector<std::string> columns_to_load,bool shuffle_files,int32_t num_devices,int32_t device_id,bool equal_rows_per_shard)81 TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
82 std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
83 int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
84 int32_t num_devices, int32_t device_id, bool equal_rows_per_shard)
85 : NonMappableLeafOp(num_workers, worker_connector_size, total_num_rows, op_connector_size, shuffle_files,
86 num_devices, device_id),
87 dataset_files_list_(std::move(dataset_files_list)),
88 columns_to_load_(std::move(columns_to_load)),
89 data_schema_(std::move(data_schema)),
90 equal_rows_per_shard_(equal_rows_per_shard) {}
91
92 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const93 void TFReaderOp::Print(std::ostream &out, bool show_all) const {
94 if (!show_all) {
95 // Call the super class for displaying any common 1-liner info
96 ParallelOp::Print(out, show_all);
97 // Then show any custom derived-internal 1-liner info for this op
98 out << "\n";
99 } else {
100 // Call the super class for displaying any common detailed info
101 ParallelOp::Print(out, show_all);
102 // Then show any custom derived-internal stuff
103 out << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
104 << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no")
105 << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n";
106 for (size_t i = 0; i < dataset_files_list_.size(); ++i) {
107 out << " " << dataset_files_list_[i];
108 }
109 if (!columns_to_load_.empty()) {
110 out << "\nColumns to load:\n";
111 for (size_t j = 0; j < columns_to_load_.size(); ++j) {
112 out << " " << columns_to_load_[j];
113 }
114 }
115 out << "\nData Schema:\n";
116 out << *data_schema_ << "\n\n";
117 }
118 }
119
Init()120 Status TFReaderOp::Init() {
121 if (data_schema_->Empty()) {
122 RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_));
123 }
124
125 if (total_rows_ == 0) {
126 total_rows_ = data_schema_->NumRows();
127 }
128 if (total_rows_ < 0) {
129 RETURN_STATUS_UNEXPECTED(
130 "Invalid parameter, num_samples or num_rows for TFRecordDataset must be greater than 0, but got: " +
131 std::to_string(total_rows_));
132 }
133
134 // Build the index with our files such that each file corresponds to a key id.
135 RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
136
137 // The creation of the internal connector has been delayed until now, since we may have adjusted the
138 // number of workers. Now that the worker count is established, create the connector now in the
139 // parallel op base.
140 RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
141
142 jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
143
144 // temporary: make size large enough to hold all files + EOE to avoid hangs
145 int32_t safe_queue_size = static_cast<int32_t>(std::ceil(dataset_files_list_.size() / num_workers_)) + 1;
146 io_block_queues_.Init(num_workers_, safe_queue_size);
147
148 return Status::OK();
149 }
150
CalculateNumRowsPerShard()151 Status TFReaderOp::CalculateNumRowsPerShard() {
152 if (!equal_rows_per_shard_) {
153 return Status::OK();
154 }
155
156 for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
157 std::vector<std::string> file(1, it.value());
158 int64_t num = CountTotalRowsSectioned(file, 0, 1);
159 filename_numrows_[it.value()] = num;
160 num_rows_ += num;
161 }
162 num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
163 if (num_rows_per_shard_ == 0) {
164 std::stringstream ss;
165 for (int i = 0; i < dataset_files_list_.size(); ++i) {
166 ss << " " << dataset_files_list_[i];
167 }
168 std::string file_list = ss.str();
169 RETURN_STATUS_UNEXPECTED(
170 "Invalid data, TFRecordDataset API can't read the data file (interface mismatch or no data under the file). "
171 "Check file path." +
172 file_list);
173 }
174 return Status::OK();
175 }
176
FillIOBlockShuffle(const std::vector<int64_t> & i_keys)177 Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
178 int32_t queue_index = 0;
179 int32_t key_index = 0;
180 int64_t pre_count = 0;
181 int64_t start_offset = 0;
182 int64_t end_offset = 0;
183 bool finish = false;
184 bool end_of_epoch = false;
185 while (!finish) {
186 for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
187 {
188 std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
189 if (load_io_block_queue_ == false) {
190 end_of_epoch = true;
191 break;
192 }
193 }
194 if (!equal_rows_per_shard_) {
195 if (key_index++ % num_devices_ == device_id_) {
196 auto ioBlock = std::make_unique<FilenameBlock>(*it, kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone);
197 RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
198 queue_index = (queue_index + 1) % num_workers_;
199 }
200 } else {
201 // Do an index lookup using that key to get the filename.
202 std::string file_name = (*filename_index_)[*it];
203 if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
204 auto ioBlock = std::make_unique<FilenameBlock>(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone);
205 RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
206 MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset;
207 queue_index = (queue_index + 1) % num_workers_;
208 }
209
210 pre_count += filename_numrows_[file_name];
211 }
212 }
213 if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
214 !end_of_epoch) {
215 finish = false;
216 } else {
217 finish = true;
218 }
219 }
220 RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
221 return Status::OK();
222 }
223
FillIOBlockNoShuffle()224 Status TFReaderOp::FillIOBlockNoShuffle() {
225 int32_t queue_index = 0;
226 int32_t key_index = 0;
227 int64_t pre_count = 0;
228 int64_t start_offset = 0;
229 int64_t end_offset = 0;
230 bool finish = false;
231 bool end_of_epoch = false;
232 while (!finish) {
233 // Iterate over all the keys and add one key to each block.
234 for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
235 {
236 std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
237 if (load_io_block_queue_ == false) {
238 end_of_epoch = true;
239 break;
240 }
241 }
242 if (!equal_rows_per_shard_) {
243 if (key_index++ % num_devices_ == device_id_) {
244 auto ioBlock =
245 std::make_unique<FilenameBlock>(it.key(), kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone);
246 RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
247 queue_index = (queue_index + 1) % num_workers_;
248 }
249 } else {
250 std::string file_name = it.value();
251 if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
252 auto ioBlock = std::make_unique<FilenameBlock>(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone);
253 RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
254 queue_index = (queue_index + 1) % num_workers_;
255 }
256
257 pre_count += filename_numrows_[file_name];
258 }
259 }
260 if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
261 !end_of_epoch) {
262 finish = false;
263 } else {
264 finish = true;
265 }
266 }
267
268 RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
269 return Status::OK();
270 }
271
272 // Reads a tf_file file and loads the data into multiple TensorRows.
LoadFile(const std::string & filename,int64_t start_offset,int64_t end_offset,int32_t worker_id)273 Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
274 auto realpath = FileUtils::GetRealPath(filename.data());
275 if (!realpath.has_value()) {
276 MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << filename;
277 RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + filename);
278 }
279
280 std::ifstream reader;
281 reader.open(realpath.value());
282 if (!reader) {
283 RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + filename);
284 }
285
286 int64_t rows_read = 0;
287 int64_t rows_total = 0;
288
289 while (reader.peek() != EOF) {
290 if (!load_jagged_connector_) {
291 break;
292 }
293 RETURN_IF_INTERRUPTED();
294
295 // read length
296 int64_t record_length = 0;
297 (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
298
299 // ignore crc header
300 (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
301
302 // read serialized Example
303 std::string serialized_example;
304 serialized_example.resize(record_length);
305 (void)reader.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
306
307 int32_t num_columns = data_schema_->NumColumns();
308 TensorRow newRow(num_columns, nullptr);
309
310 if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) {
311 dataengine::Example tf_file;
312 if (!tf_file.ParseFromString(serialized_example)) {
313 std::string errMsg = "Invalid file, failed to parse tfrecord file : " + filename;
314 MS_LOG(DEBUG) << errMsg + ", details of string: " << serialized_example;
315 RETURN_STATUS_UNEXPECTED(errMsg);
316 }
317
318 std::vector<std::string> file_path(num_columns, filename);
319 newRow.setPath(file_path);
320 RETURN_IF_NOT_OK(LoadExample(&tf_file, &newRow));
321 rows_read++;
322 RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(newRow)));
323 }
324
325 // ignore crc footer
326 (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
327 rows_total++;
328 }
329
330 return Status::OK();
331 }
332
333 // Parses a single row and puts the data into a tensor table.
LoadExample(const dataengine::Example * tf_file,TensorRow * out_row)334 Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, TensorRow *out_row) {
335 int32_t num_columns = data_schema_->NumColumns();
336 for (int32_t col = 0; col < num_columns; ++col) {
337 const ColDescriptor current_col = data_schema_->Column(col);
338 const dataengine::Features &example_features = tf_file->features();
339 const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
340 auto iter_column = feature_map.find(current_col.Name());
341 if (iter_column == feature_map.end()) {
342 RETURN_STATUS_UNEXPECTED("Invalid parameter, column name: " + current_col.Name() + " does not exist.");
343 }
344 const dataengine::Feature &column_values_list = iter_column->second;
345 RETURN_IF_NOT_OK(LoadFeature(out_row, column_values_list, current_col, col));
346 }
347
348 return Status::OK();
349 }
350
351 // Parses a single cell and puts the data into a tensor table.
LoadFeature(TensorRow * tensor_row,const dataengine::Feature & column_values_list,const ColDescriptor & current_col,int32_t col)352 Status TFReaderOp::LoadFeature(TensorRow *tensor_row, const dataengine::Feature &column_values_list,
353 const ColDescriptor ¤t_col, int32_t col) {
354 const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case();
355 std::unique_ptr<float[]> float_array; // For staging data from protobuf deserialization
356 const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor
357
358 // This variable will point into the above staging variables.
359 // Also used for creating shape attributes.
360 int32_t num_elements = 0;
361
362 // we build a tensor first a read directly into it if we need to cast
363 std::shared_ptr<Tensor> ts;
364
365 // Depending on the type of data from the tf_file, we want to extract 2 things:
366 // 1) A pointer to the data as a const unsigned char *
367 // 2) The number of elements of the data
368 // After those are determined, we can then build the tensor to represent this data.
369 switch (column_list_type) {
370 case dataengine::Feature::KindCase::kBytesList: {
371 RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts));
372
373 break;
374 }
375 case dataengine::Feature::KindCase::kFloatList: {
376 RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array));
377
378 data_ptr = reinterpret_cast<const unsigned char *>(float_array.get());
379
380 // only floatList needs to create the tensor here, other two lists read directly
381 // into the tensor
382 TensorShape current_shape = TensorShape::CreateUnknownRankShape();
383 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape));
384 RETURN_IF_NOT_OK(Tensor::CreateFromMemory(current_shape, current_col.Type(), data_ptr, &ts));
385 break;
386 }
387 case dataengine::Feature::KindCase::kInt64List: {
388 RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts));
389 break;
390 }
391 case dataengine::Feature::KindCase::KIND_NOT_SET: {
392 std::string err_msg = "Invalid data, column type in tf record file must be uint8, int64 or float32.";
393 RETURN_STATUS_UNEXPECTED(err_msg);
394 }
395 default: {
396 std::string err_msg = "Invalid data, column type in tf record file must be uint8, int64 or float32.";
397 RETURN_STATUS_UNEXPECTED(err_msg);
398 }
399 }
400
401 (*tensor_row)[col] = std::move(ts);
402
403 return Status::OK();
404 }
405
LoadBytesList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)406 Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
407 int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
408 // kBytesList can map to the following DE types ONLY!
409 // DE_UINT8, DE_INT8
410 // Must be single byte type for each element!
411 if (current_col.Type() != DataType::DE_UINT8 && current_col.Type() != DataType::DE_INT8 &&
412 current_col.Type() != DataType::DE_STRING) {
413 std::string err_msg = "Invalid data, invalid data type for Tensor at column: " + current_col.Name() +
414 ", data type should be int8, uint8 or string, but got " + current_col.Type().ToString();
415 RETURN_STATUS_UNEXPECTED(err_msg);
416 }
417
418 const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
419
420 *num_elements = bytes_list.value_size();
421
422 if (current_col.Type() == DataType::DE_STRING) {
423 TensorShape shape = TensorShape::CreateScalar();
424 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape));
425 RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, shape, tensor));
426 return Status::OK();
427 }
428
429 uint64_t max_size = 0;
430 for (uint32_t i = 0; i < bytes_list.value_size(); ++i) {
431 #if defined(__APPLE__)
432 max_size = fmax(max_size, bytes_list.value(i).size());
433 #else
434 max_size = std::max(max_size, bytes_list.value(i).size());
435 #endif
436 }
437
438 int64_t pad_size = max_size;
439
440 // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn
441 if (current_col.HasShape()) {
442 TensorShape cur_shape = current_col.Shape();
443 if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) {
444 int64_t new_pad_size = 1;
445 for (int i = 1; i < cur_shape.Size(); ++i) {
446 if (cur_shape[i] == TensorShape::kDimUnknown) {
447 std::string err_msg =
448 "Invalid data, more than one unknown dimension in the shape of column: " + current_col.Name();
449 RETURN_STATUS_UNEXPECTED(err_msg);
450 }
451 new_pad_size *= cur_shape[i];
452 }
453 pad_size = new_pad_size;
454 } else {
455 if (cur_shape.known() && cur_shape.NumOfElements() != max_size) {
456 std::string err_msg = "Invalid data, shape in schema's column '" + current_col.Name() + "' is incorrect." +
457 "\nshape received: " + cur_shape.ToString() +
458 "\ntotal elements in shape received: " + std::to_string(cur_shape.NumOfElements()) +
459 "\nexpected total elements in shape: " + std::to_string(max_size);
460 RETURN_STATUS_UNEXPECTED(err_msg);
461 }
462 }
463 }
464
465 // know how many elements there are and the total bytes, create tensor here:
466 TensorShape current_shape = TensorShape::CreateScalar();
467 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape));
468 RETURN_IF_NOT_OK(Tensor::CreateFromByteList(bytes_list, current_shape, current_col.Type(), pad_size, tensor));
469
470 return Status::OK();
471 }
472
LoadFloatList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::unique_ptr<float[]> * float_array)473 Status TFReaderOp::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
474 int32_t *num_elements, std::unique_ptr<float[]> *float_array) {
475 // KFloatList can only map to DE types:
476 // DE_FLOAT32
477 if (current_col.Type() != DataType::DE_FLOAT32) {
478 std::string err_msg = "Invalid data, invalid data type for Tensor at column: " + current_col.Name() +
479 ", data type should be string, but got " + current_col.Type().ToString();
480 RETURN_STATUS_UNEXPECTED(err_msg);
481 }
482
483 const dataengine::FloatList &float_list = column_values_list.float_list();
484
485 // Identify how many values we have and then create a local array of these
486 // to deserialize into
487 *num_elements = float_list.value_size();
488 *float_array = std::make_unique<float[]>(*num_elements);
489 for (int i = 0; i < float_list.value_size(); ++i) {
490 (*float_array)[i] = float_list.value(i);
491 }
492
493 return Status::OK();
494 }
495
496 // Determines which template type to use and calls LoadIntList
LoadIntListSwitch(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)497 Status TFReaderOp::LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
498 int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
499 if (current_col.Type() == DataType::DE_UINT64) {
500 RETURN_IF_NOT_OK(LoadIntList<uint64_t>(current_col, column_values_list, num_elements, tensor));
501 } else if (current_col.Type() == DataType::DE_INT64) {
502 RETURN_IF_NOT_OK(LoadIntList<int64_t>(current_col, column_values_list, num_elements, tensor));
503 } else if (current_col.Type() == DataType::DE_UINT32) {
504 RETURN_IF_NOT_OK(LoadIntList<uint32_t>(current_col, column_values_list, num_elements, tensor));
505 } else if (current_col.Type() == DataType::DE_INT32) {
506 RETURN_IF_NOT_OK(LoadIntList<int32_t>(current_col, column_values_list, num_elements, tensor));
507 } else if (current_col.Type() == DataType::DE_UINT16) {
508 RETURN_IF_NOT_OK(LoadIntList<uint16_t>(current_col, column_values_list, num_elements, tensor));
509 } else if (current_col.Type() == DataType::DE_INT16) {
510 RETURN_IF_NOT_OK(LoadIntList<int16_t>(current_col, column_values_list, num_elements, tensor));
511 } else if (current_col.Type() == DataType::DE_UINT8) {
512 RETURN_IF_NOT_OK(LoadIntList<uint8_t>(current_col, column_values_list, num_elements, tensor));
513 } else if (current_col.Type() == DataType::DE_INT8) {
514 RETURN_IF_NOT_OK(LoadIntList<int8_t>(current_col, column_values_list, num_elements, tensor));
515 } else {
516 std::string err_msg = "Invalid data, invalid datatype for Tensor at column: " + current_col.Name() +
517 ", data type should be uint64, int64, uint32, int32, uint16, int16, uint8 or int8" +
518 ", but got " + current_col.Type().ToString();
519 RETURN_STATUS_UNEXPECTED(err_msg);
520 }
521
522 return Status::OK();
523 }
524
525 // Reads values from a bytes list and casts the value to type T, must be an integral type
526 // compatible with int64_t
527 template <typename T>
LoadIntList(const ColDescriptor & current_col,const dataengine::Feature & column_values_list,int32_t * num_elements,std::shared_ptr<Tensor> * tensor)528 Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
529 int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
530 if (!(current_col.Type().IsInt())) {
531 std::string err_msg = "Invalid data, invalid data type for Tensor at column: " + current_col.Name() +
532 ", data type should be int, but got " + current_col.Type().ToString();
533 RETURN_STATUS_UNEXPECTED(err_msg);
534 }
535
536 const dataengine::Int64List &int64_list = column_values_list.int64_list();
537
538 // Identify how many values we have and then create a local array of these
539 // to deserialize into
540 *num_elements = int64_list.value_size();
541
542 // know how many elements there are, create tensor here:
543 TensorShape current_shape = TensorShape::CreateUnknownRankShape();
544 RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape));
545 RETURN_IF_NOT_OK(Tensor::CreateEmpty(current_shape, current_col.Type(), tensor));
546
547 int64_t i = 0;
548 auto it = (*tensor)->begin<T>();
549 for (; it != (*tensor)->end<T>(); i++, ++it) {
550 T element = static_cast<T>(int64_list.value(i));
551 *it = element;
552 }
553
554 return Status::OK();
555 }
556
CreateSchema(const std::string tf_file,std::vector<std::string> columns_to_load)557 Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector<std::string> columns_to_load) {
558 auto realpath = FileUtils::GetRealPath(tf_file.data());
559 if (!realpath.has_value()) {
560 MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << tf_file;
561 RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + tf_file);
562 }
563
564 std::ifstream reader;
565 reader.open(realpath.value());
566
567 // read length
568 int64_t record_length = 0;
569 (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
570
571 // ignore crc header
572 (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
573
574 // read serialized Example
575 std::string serialized_example;
576 serialized_example.resize(record_length);
577 (void)reader.read(&serialized_example[0], static_cast<std::streamsize>(record_length));
578
579 dataengine::Example example;
580 if (!example.ParseFromString(serialized_example)) {
581 RETURN_STATUS_UNEXPECTED("Invalid file, failed to parse tfrecord file: " + serialized_example);
582 }
583
584 const dataengine::Features &example_features = example.features();
585 const google::protobuf::Map<std::string, dataengine::Feature> &feature_map = example_features.feature();
586
587 if (columns_to_load.empty()) {
588 (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load),
589 [](const auto &it) -> std::string { return it.first; });
590 std::sort(columns_to_load.begin(), columns_to_load.end());
591 }
592
593 for (const auto &curr_col_name : columns_to_load) {
594 auto it = feature_map.find(curr_col_name);
595 if (it == feature_map.end()) {
596 RETURN_STATUS_UNEXPECTED("Invalid data, failed to find column name: " + curr_col_name);
597 }
598 std::string column_name = it->first;
599
600 std::string column_type;
601
602 const dataengine::Feature &feature = it->second;
603 const dataengine::Feature::KindCase kind_case = feature.kind_case();
604 switch (kind_case) {
605 case dataengine::Feature::KindCase::kBytesList:
606 column_type = "uint8";
607 break;
608
609 case dataengine::Feature::KindCase::kFloatList:
610 column_type = "float32";
611 break;
612
613 case dataengine::Feature::KindCase::kInt64List:
614 column_type = "int64";
615 break;
616
617 case dataengine::Feature::KindCase::KIND_NOT_SET:
618 RETURN_STATUS_UNEXPECTED("Invalid data, column type of tf record file must be uint8, int64 or float32.");
619
620 default:
621 RETURN_STATUS_UNEXPECTED("Invalid data, column type of tf record file must be uint8, int64 or float32.");
622 }
623
624 RETURN_IF_NOT_OK(
625 data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1)));
626 }
627
628 return Status::OK();
629 }
630
CountTotalRows(int64_t * out_total_rows,const std::vector<std::string> & filenames,int64_t threads,bool estimate)631 Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads,
632 bool estimate) {
633 try {
634 if (threads > filenames.size()) {
635 threads = filenames.size();
636 }
637
638 std::vector<std::future<int64_t>> async_results;
639
640 int64_t chunk_size = filenames.size() / threads;
641 int64_t remainder = filenames.size() % threads;
642
643 int64_t begin = 0;
644 int64_t end = begin;
645 for (int i = 0; i < threads; i++) {
646 end += chunk_size;
647 if (remainder > 0) {
648 end++;
649 remainder--;
650 }
651
652 if (estimate) {
653 // Parse a single file for each chunk with estimate mode on
654 async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1));
655 } else {
656 // Parse the whole chunk with estimate mode off
657 async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end));
658 }
659
660 begin = end;
661 }
662
663 int64_t total_rows = 0;
664 for (int i = 0; i < async_results.size(); i++) {
665 total_rows += async_results[i].get();
666 }
667
668 if (estimate) {
669 // Each thread only scans 1 file
670 // Estimated total rows = Average rows * total number of files
671 total_rows = total_rows / threads * filenames.size();
672 }
673
674 *out_total_rows = total_rows;
675 } catch (const std::exception &e) {
676 std::string err_msg = "Unexpected error occurred: ";
677 err_msg += e.what();
678 RETURN_STATUS_UNEXPECTED(err_msg);
679 }
680
681 return Status::OK();
682 }
683
CountTotalRowsSectioned(const std::vector<std::string> & filenames,int64_t begin,int64_t end)684 int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &filenames, int64_t begin, int64_t end) {
685 int64_t rows_read = 0;
686 for (int i = begin; i < end; i++) {
687 auto realpath = FileUtils::GetRealPath(filenames[i].data());
688 if (!realpath.has_value()) {
689 MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << filenames[i];
690 continue;
691 }
692
693 std::ifstream reader;
694 reader.open(realpath.value());
695 if (!reader) {
696 MS_LOG(DEBUG) << "TFReader operator failed to open file " << filenames[i] << ".";
697 }
698
699 while (reader.peek() != EOF) {
700 // read length
701 int64_t record_length = 0;
702 (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
703
704 // ignore crc header
705 (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
706
707 // ignore tf_file contents
708 (void)reader.ignore(static_cast<std::streamsize>(record_length));
709
710 // ignore crc footer
711 (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
712
713 rows_read++;
714 }
715 }
716
717 return rows_read;
718 }
719
ComputeColMap()720 Status TFReaderOp::ComputeColMap() {
721 // Construct the column name map for this operator (base class field)
722 if (column_name_id_map_.empty()) {
723 for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
724 column_name_id_map_[data_schema_->Column(i).Name()] = i;
725 }
726 } else {
727 MS_LOG(WARNING) << "Column name map is already set!";
728 }
729 return Status::OK();
730 }
FillIOBlockQueue(const std::vector<int64_t> & i_keys)731 Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
732 if (shuffle_files_) {
733 return FillIOBlockShuffle(i_keys);
734 }
735 return FillIOBlockNoShuffle();
736 }
737
738 } // namespace dataset
739 } // namespace mindspore
740