1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
18
19 #include <fstream>
20
21 #include "minddata/dataset/core/config_manager.h"
22 #include "minddata/dataset/engine/datasetops/source/io_block.h"
23 #include "minddata/dataset/engine/execution_tree.h"
24 #include "minddata/dataset/util/random.h"
25 #include "minddata/dataset/util/wait_post.h"
26 #include "utils/file_utils.h"
27
28 namespace mindspore {
29 namespace dataset {
TextFileOp(int32_t num_workers,int64_t total_rows,int32_t worker_connector_size,std::unique_ptr<DataSchema> schema,std::vector<std::string> text_files_list,int32_t op_connector_size,bool shuffle_files,int32_t num_devices,int32_t device_id)30 TextFileOp::TextFileOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
31 std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
32 int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
33 : NonMappableLeafOp(num_workers, worker_connector_size, total_rows, op_connector_size, shuffle_files, num_devices,
34 device_id),
35 text_files_list_(std::move(text_files_list)),
36 data_schema_(std::move(schema)) {}
37
38 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const39 void TextFileOp::Print(std::ostream &out, bool show_all) const {
40 if (!show_all) {
41 // Call the super class for displaying any common 1-liner info
42 ParallelOp::Print(out, show_all);
43 // Then show any custom derived-internal 1-liner info for this op
44 out << "\n";
45 } else {
46 // Call the super class for displaying any common detailed info
47 ParallelOp::Print(out, show_all);
48 // Then show any custom derived-internal stuff
49 out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
50 << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\n"
51 << DatasetName(true) << " list:\n";
52 for (size_t i = 0; i < text_files_list_.size(); ++i) {
53 out << " " << text_files_list_[i];
54 }
55 out << "\nData Schema:\n";
56 out << *data_schema_ << "\n\n";
57 }
58 }
59
Init()60 Status TextFileOp::Init() {
61 RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_));
62
63 int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1);
64 io_block_queues_.Init(num_workers_, safe_queue_size);
65
66 jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
67 return Status::OK();
68 }
69
LoadTensor(const std::string & line,TensorRow * out_row) const70 Status TextFileOp::LoadTensor(const std::string &line, TensorRow *out_row) const {
71 std::shared_ptr<Tensor> tensor;
72 RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
73 (*out_row)[0] = std::move(tensor);
74 return Status::OK();
75 }
76
LoadFile(const std::string & file,int64_t start_offset,int64_t end_offset,int32_t worker_id)77 Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
78 auto realpath = FileUtils::GetRealPath(file.c_str());
79 if (!realpath.has_value()) {
80 MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist.";
81 RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist.");
82 }
83
84 std::ifstream handle(realpath.value(), std::ios::in);
85 if (!handle.is_open()) {
86 RETURN_STATUS_UNEXPECTED("Invalid file, failed to open text:" + file +
87 ", the file is damaged or permission denied.");
88 }
89
90 int64_t rows_total = 0;
91 std::string line;
92
93 while (getline(handle, line)) {
94 if (line.empty()) {
95 continue;
96 }
97 // If read to the end offset of this file, break.
98 if (rows_total >= end_offset) {
99 break;
100 }
101 // Skip line before start offset.
102 if (rows_total < start_offset) {
103 rows_total++;
104 continue;
105 }
106
107 TensorRow tRow(1, nullptr);
108 tRow.setPath({file});
109 auto s = LoadTensor(line, &tRow);
110 if (s != Status::OK()) {
111 handle.close();
112 return s;
113 }
114 s = jagged_rows_connector_->Add(worker_id, std::move(tRow));
115 if (s != Status::OK()) {
116 handle.close();
117 return s;
118 }
119
120 rows_total++;
121 }
122 handle.close();
123
124 return Status::OK();
125 }
126
FillIOBlockQueue(const std::vector<int64_t> & i_keys)127 Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
128 int32_t queue_index = 0;
129 int64_t pre_count = 0;
130 int64_t start_offset = 0;
131 int64_t end_offset = 0;
132 bool finish = false;
133 while (!finish) {
134 std::vector<std::pair<std::string, int64_t>> file_index;
135 if (!i_keys.empty()) {
136 for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
137 {
138 if (!GetLoadIoBlockQueue()) {
139 break;
140 }
141 }
142 (void)file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
143 }
144 } else {
145 for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
146 {
147 if (!GetLoadIoBlockQueue()) {
148 break;
149 }
150 }
151 (void)file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
152 }
153 }
154 for (auto file_info : file_index) {
155 if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
156 auto ioBlock = std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kFlagNone);
157 RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
158 queue_index = (queue_index + 1) % num_workers_;
159 }
160
161 pre_count += filename_numrows_[file_info.first];
162 }
163
164 if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
165 finish = false;
166 } else {
167 finish = true;
168 }
169 }
170
171 RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
172 return Status::OK();
173 }
174
CountTotalRows(const std::string & file)175 int64_t TextFileOp::CountTotalRows(const std::string &file) {
176 auto realpath = FileUtils::GetRealPath(file.c_str());
177 if (!realpath.has_value()) {
178 MS_LOG(ERROR) << "Invalid file, " << file << " does not exist.";
179 return 0;
180 }
181
182 std::ifstream handle(realpath.value(), std::ios::in);
183 if (!handle.is_open()) {
184 MS_LOG(ERROR) << "Invalid file, failed to open text file:" << file << ", the file is damaged or permission denied.";
185 return 0;
186 }
187
188 std::string line;
189 int64_t count = 0;
190 while (getline(handle, line)) {
191 if (!line.empty()) {
192 count++;
193 }
194 }
195 handle.close();
196
197 return count;
198 }
199
CalculateNumRowsPerShard()200 Status TextFileOp::CalculateNumRowsPerShard() {
201 for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
202 int64_t count = CountTotalRows(it.value());
203 filename_numrows_[it.value()] = count;
204 num_rows_ += count;
205 }
206 if (num_rows_ == 0) {
207 std::stringstream ss;
208 for (int i = 0; i < text_files_list_.size(); ++i) {
209 ss << " " << text_files_list_[i];
210 }
211 std::string file_list = ss.str();
212 RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
213 "Dataset API can't read the data file (interface mismatch or no data found). Check " +
214 DatasetName() + ": " + file_list);
215 }
216
217 num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
218 MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
219 return Status::OK();
220 }
221
CountAllFileRows(const std::vector<std::string> & files,int64_t * count)222 Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
223 RETURN_UNEXPECTED_IF_NULL(count);
224 int32_t num_workers = GlobalContext::config_manager()->num_parallel_workers();
225 int32_t connector_que_size = GlobalContext::config_manager()->op_connector_size();
226 int32_t worker_connector_size = GlobalContext::config_manager()->worker_connector_size();
227 const int32_t shard_id = 0;
228 const int32_t num_shards = 1;
229 const int64_t num_samples = 0;
230 bool shuffle_files = false;
231 // Do internal Schema generation.
232 auto schema = std::make_unique<DataSchema>();
233
234 // Create and initialize
235 std::shared_ptr<TextFileOp> op =
236 std::make_shared<TextFileOp>(num_workers, num_samples, worker_connector_size, std::move(schema), files,
237 connector_que_size, shuffle_files, num_shards, shard_id);
238 RETURN_IF_NOT_OK(op->Init());
239 *count = 0;
240 for (auto file : files) {
241 *count += op->CountTotalRows(file);
242 }
243 return Status::OK();
244 }
245
ComputeColMap()246 Status TextFileOp::ComputeColMap() {
247 // Set the column name mapping (base class field)
248 if (column_name_id_map_.empty()) {
249 for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
250 column_name_id_map_[data_schema_->Column(i).Name()] = i;
251 }
252 } else {
253 MS_LOG(WARNING) << "Column name map is already set!";
254 }
255 return Status::OK();
256 }
257 } // namespace dataset
258 } // namespace mindspore
259