• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <fstream>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "utils/file_utils.h"
24 #include "minddata/dataset/engine/datasetops/source/text_file_op.h"
25 #include "minddata/dataset/core/config_manager.h"
26 #include "minddata/dataset/util/wait_post.h"
27 #include "minddata/dataset/util/random.h"
28 #include "minddata/dataset/engine/datasetops/source/io_block.h"
29 #include "minddata/dataset/engine/execution_tree.h"
30 
31 namespace mindspore {
32 namespace dataset {
TextFileOp(int32_t num_workers,int64_t total_rows,int32_t worker_connector_size,std::unique_ptr<DataSchema> schema,std::vector<std::string> text_files_list,int32_t op_connector_size,bool shuffle_files,int32_t num_devices,int32_t device_id)33 TextFileOp::TextFileOp(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size,
34                        std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
35                        int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
36     : NonMappableLeafOp(num_workers, worker_connector_size, total_rows, op_connector_size, shuffle_files, num_devices,
37                         device_id),
38       text_files_list_(std::move(text_files_list)),
39       data_schema_(std::move(schema)) {}
40 
41 // A print method typically used for debugging
Print(std::ostream & out,bool show_all) const42 void TextFileOp::Print(std::ostream &out, bool show_all) const {
43   if (!show_all) {
44     // Call the super class for displaying any common 1-liner info
45     ParallelOp::Print(out, show_all);
46     // Then show any custom derived-internal 1-liner info for this op
47     out << "\n";
48   } else {
49     // Call the super class for displaying any common detailed info
50     ParallelOp::Print(out, show_all);
51     // Then show any custom derived-internal stuff
52     out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
53         << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\n"
54         << DatasetName(true) << " list:\n";
55     for (size_t i = 0; i < text_files_list_.size(); ++i) {
56       out << " " << text_files_list_[i];
57     }
58     out << "\nData Schema:\n";
59     out << *data_schema_ << "\n\n";
60   }
61 }
62 
Init()63 Status TextFileOp::Init() {
64   RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_));
65 
66   int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1);
67   io_block_queues_.Init(num_workers_, safe_queue_size);
68 
69   RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
70 
71   jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
72   return Status::OK();
73 }
74 
LoadTensor(const std::string & line,TensorRow * out_row)75 Status TextFileOp::LoadTensor(const std::string &line, TensorRow *out_row) {
76   std::shared_ptr<Tensor> tensor;
77   RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
78   (*out_row)[0] = std::move(tensor);
79   return Status::OK();
80 }
81 
LoadFile(const std::string & file,int64_t start_offset,int64_t end_offset,int32_t worker_id)82 Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
83   auto realpath = FileUtils::GetRealPath(file.data());
84   if (!realpath.has_value()) {
85     MS_LOG(ERROR) << "Invalid file, " + DatasetName() + " get real path failed, path=" << file;
86     RETURN_STATUS_UNEXPECTED("Invalid file, " + DatasetName() + " get real path failed, path=" + file);
87   }
88 
89   std::ifstream handle(realpath.value());
90   if (!handle.is_open()) {
91     RETURN_STATUS_UNEXPECTED("Invalid file, failed to open " + DatasetName() + ": " + file);
92   }
93 
94   int64_t rows_total = 0;
95   std::string line;
96 
97   while (getline(handle, line)) {
98     if (line.empty()) {
99       continue;
100     }
101     // If read to the end offset of this file, break.
102     if (rows_total >= end_offset) {
103       break;
104     }
105     // Skip line before start offset.
106     if (rows_total < start_offset) {
107       rows_total++;
108       continue;
109     }
110 
111     TensorRow tRow(1, nullptr);
112     tRow.setPath({file});
113     RETURN_IF_NOT_OK(LoadTensor(line, &tRow));
114     RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow)));
115 
116     rows_total++;
117   }
118 
119   return Status::OK();
120 }
121 
FillIOBlockQueue(const std::vector<int64_t> & i_keys)122 Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
123   int32_t queue_index = 0;
124   int64_t pre_count = 0;
125   int64_t start_offset = 0;
126   int64_t end_offset = 0;
127   bool finish = false;
128   while (!finish) {
129     std::vector<std::pair<std::string, int64_t>> file_index;
130     if (!i_keys.empty()) {
131       for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
132         {
133           if (!load_io_block_queue_) {
134             break;
135           }
136         }
137         file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
138       }
139     } else {
140       for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
141         {
142           if (!load_io_block_queue_) {
143             break;
144           }
145         }
146         file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
147       }
148     }
149     for (auto file_info : file_index) {
150       if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
151         auto ioBlock =
152           std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
153         RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
154         queue_index = (queue_index + 1) % num_workers_;
155       }
156 
157       pre_count += filename_numrows_[file_info.first];
158     }
159 
160     if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
161       finish = false;
162     } else {
163       finish = true;
164     }
165   }
166 
167   RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
168   return Status::OK();
169 }
170 
171 // Internal helper function to calculate rows
CountTotalRows(const std::string & file)172 int64_t CountTotalRows(const std::string &file) {
173   auto realpath = FileUtils::GetRealPath(file.data());
174   if (!realpath.has_value()) {
175     MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file;
176     return 0;
177   }
178 
179   std::ifstream handle(realpath.value());
180   if (!handle.is_open()) {
181     MS_LOG(ERROR) << "Invalid file, failed to open file: " << file;
182     return 0;
183   }
184 
185   std::string line;
186   int64_t count = 0;
187   while (getline(handle, line)) {
188     if (!line.empty()) {
189       count++;
190     }
191   }
192 
193   return count;
194 }
195 
CalculateNumRowsPerShard()196 Status TextFileOp::CalculateNumRowsPerShard() {
197   for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
198     int64_t count = CountTotalRows(it.value());
199     filename_numrows_[it.value()] = count;
200     num_rows_ += count;
201   }
202   if (num_rows_ == 0) {
203     std::stringstream ss;
204     for (int i = 0; i < text_files_list_.size(); ++i) {
205       ss << " " << text_files_list_[i];
206     }
207     std::string file_list = ss.str();
208     RETURN_STATUS_UNEXPECTED("Invalid data, " + DatasetName(true) +
209                              "Dataset API can't read the data file (interface mismatch or no data found). Check " +
210                              DatasetName() + ": " + file_list);
211   }
212 
213   num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
214   MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
215   return Status::OK();
216 }
217 
CountAllFileRows(const std::vector<std::string> & files,int64_t * count)218 Status TextFileOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
219   *count = 0;
220   for (auto file : files) {
221     *count += CountTotalRows(file);
222   }
223   return Status::OK();
224 }
225 
ComputeColMap()226 Status TextFileOp::ComputeColMap() {
227   // Set the column name mapping (base class field)
228   if (column_name_id_map_.empty()) {
229     for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
230       column_name_id_map_[data_schema_->Column(i).Name()] = i;
231     }
232   } else {
233     MS_LOG(WARNING) << "Column name map is already set!";
234   }
235   return Status::OK();
236 }
237 
238 }  // namespace dataset
239 }  // namespace mindspore
240