• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_reader.h"
18 
19 #include <algorithm>
20 #include <thread>
21 
22 #include "utils/file_utils.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "utils/ms_utils.h"
25 
26 namespace mindspore {
27 namespace mindrecord {
28 template <class Type>
29 // convert the string to exactly number type (int32_t/int64_t/float/double)
StringToNum(const std::string & str)30 Type StringToNum(const std::string &str) {
31   std::istringstream iss(str);
32   Type num;
33   iss >> num;
34   return num;
35 }
36 
ShardReader()37 ShardReader::ShardReader()
38     : header_size_(0),
39       page_size_(0),
40       shard_count_(0),
41       n_consumer_(0),
42       num_padded_(0),
43       num_rows_(0),
44       total_blob_size_(0),
45       sample_id_position_(0),
46       deliver_id_(0),
47       load_mode_(LoadMode::kFast),
48       shard_sample_count_() {}
49 
GetMeta(const std::string & file_path,std::shared_ptr<json> meta_data_ptr,std::shared_ptr<std::vector<std::string>> * addresses_ptr)50 Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
51                             std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
52   RETURN_UNEXPECTED_IF_NULL_MR(addresses_ptr);
53   RETURN_IF_NOT_OK_MR(CheckFile(file_path));
54   std::shared_ptr<json> header_ptr;
55   RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(file_path, &header_ptr));
56 
57   *meta_data_ptr = {{"header_size", (*header_ptr)["header_size"]}, {"page_size", (*header_ptr)["page_size"]},
58                     {"version", (*header_ptr)["version"]},         {"index_fields", (*header_ptr)["index_fields"]},
59                     {"schema", (*header_ptr)["schema"]},           {"blob_fields", (*header_ptr)["blob_fields"]}};
60   std::vector<std::string> addresses_vec = (*header_ptr)["shard_addresses"];
61   *addresses_ptr = std::make_shared<std::vector<std::string>>(addresses_vec);
62   return Status::OK();
63 }
64 
Init(const std::vector<std::string> & file_paths,bool load_dataset)65 Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
66   std::string file_path = file_paths[0];
67   auto first_meta_data_ptr = std::make_shared<json>();
68   std::shared_ptr<std::vector<std::string>> addresses_ptr;
69   RETURN_IF_NOT_OK_MR(GetMeta(file_path, first_meta_data_ptr, &addresses_ptr));
70   if (file_paths.size() == 1 && load_dataset == true) {
71     auto ds = std::make_shared<std::vector<std::string>>();
72     RETURN_IF_NOT_OK_MR(GetDatasetFiles(file_path, *addresses_ptr, &ds));
73     file_paths_ = *ds;  // load files according to shard_addresses
74   } else if (file_paths.size() >= 1 && load_dataset == false) {
75     file_paths_ = file_paths;  // load files according to the input
76   } else {
77     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] The values of 'load_dataset' and 'file_paths' are not as expected.");
78   }
79   for (const auto &file : file_paths_) {
80     auto meta_data_ptr = std::make_shared<json>();
81     RETURN_IF_NOT_OK_MR(GetMeta(file, meta_data_ptr, &addresses_ptr));
82     CHECK_FAIL_RETURN_UNEXPECTED_MR(
83       *meta_data_ptr == *first_meta_data_ptr,
84       "Invalid file, the metadata of mindrecord file: " + file +
85         " is different from others, please make sure all the mindrecord files generated by the same script.");
86     sqlite3 *db = nullptr;
87     RETURN_IF_NOT_OK_MR(VerifyDataset(&db, file));
88     database_paths_.push_back(db);
89   }
90   ShardHeader sh = ShardHeader();
91   RETURN_IF_NOT_OK_MR(sh.BuildDataset(file_paths_, load_dataset));
92   shard_header_ = std::make_shared<ShardHeader>(sh);
93   header_size_ = shard_header_->GetHeaderSize();
94   page_size_ = shard_header_->GetPageSize();
95   // version < 3.0
96   if ((*first_meta_data_ptr)["version"] < kVersion) {
97     shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
98   } else {
99     shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
100   }
101   num_rows_ = 0;
102   auto row_group_summary = ReadRowGroupSummary();
103 
104   // clear the shard_sample_count_, because it will be insert when Launch func
105   shard_sample_count_.clear();
106 
107   constexpr int64_t get_index = 3;
108   for (const auto &rg : row_group_summary) {
109     num_rows_ += std::get<get_index>(rg);
110   }
111 
112   if (num_rows_ > SLOW_LOAD_THRESHOLD) {
113     load_mode_ = LoadMode::kSlow;
114     tasks_.load_mode_ = LoadMode::kSlow;
115     MS_LOG(INFO) << "The number of samples is larger than " << SLOW_LOAD_THRESHOLD
116                  << ", enable slow load mode. If you want to speed up data loading, "
117                  << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
118                  << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
119                  << "according to the current samples.";
120   } else if (num_rows_ > LAZY_LOAD_THRESHOLD) {
121     load_mode_ = LoadMode::kLazy;
122     tasks_.load_mode_ = LoadMode::kLazy;
123     MS_LOG(INFO) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
124                  << ", enable lazy load mode. If you want to speed up data loading, "
125                  << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
126                  << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
127                  << "according to the current samples.";
128   } else {
129     load_mode_ = LoadMode::kFast;
130     tasks_.load_mode_ = LoadMode::kFast;
131   }
132 
133   auto disk_size = page_size_ * row_group_summary.size();
134   auto compression_size = shard_header_->GetCompressionSize();
135   total_blob_size_ = disk_size + compression_size;
136   MS_LOG(INFO) << "The size of blob data on disk: " << disk_size
137                << " , additional uncompression size: " << compression_size
138                << " , total blob size: " << total_blob_size_;
139 
140   MS_LOG(INFO) << "Succeed to get metadata from mindrecord files";
141 
142   return Status::OK();
143 }
144 
VerifyDataset(sqlite3 ** db,const string & file)145 Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
146   std::string path_utf8 = "";
147 #if defined(_WIN32) || defined(_WIN64)
148   path_utf8 = FileUtils::GB2312ToUTF_8((file + ".db").data());
149 #endif
150   if (path_utf8.empty()) {
151     path_utf8 = file + ".db";
152   }
153 
154   // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
155 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
156   // use "unix-none" to avoid flock and achieve better performance on shared storage platform
157   CHECK_FAIL_RETURN_UNEXPECTED_MR(
158     sqlite3_open_v2(path_utf8.data(), db, SQLITE_OPEN_READONLY, "unix-none") == SQLITE_OK,
159     "Invalid file, failed to open mindrecord meta file. Please check whether the meta file: " + file +
160       ".db exists and do not rename the mindrecord file and meta file.");
161 #else
162   CHECK_FAIL_RETURN_UNEXPECTED_MR(
163     sqlite3_open_v2(path_utf8.data(), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
164     "Invalid file, failed to open mindrecord meta file. Please check whether the meta file: " + file +
165       ".db exists and do not rename the mindrecord file and meta file.");
166 #endif
167   MS_LOG(DEBUG) << "Succeed to open meta file, path: " << file << ".db.";
168 
169   // starting a transaction during a read-only select operation can solve the problem of frequently
170   // accessing *-journal / *-wal files.
171   auto sql_code = sqlite3_exec(*db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
172   if (sql_code != SQLITE_OK) {
173     sqlite3_free(*db);
174     RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: " +
175                                 std::to_string(sql_code));
176   }
177 
178   string sql = "SELECT NAME from SHARD_NAME;";
179   std::vector<std::vector<std::string>> name;
180   char *errmsg = nullptr;
181   if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
182     std::ostringstream oss;
183     oss << "Failed to execute the sql [ " << sql << " ] while verifying meta file, " << errmsg
184         << ".\nPlease check the meta file: " + file + ".db";
185     sqlite3_free(errmsg);
186     sqlite3_close(*db);
187     RETURN_STATUS_UNEXPECTED_MR(oss.str());
188   } else {
189     std::shared_ptr<std::string> fn_ptr;
190     RETURN_IF_NOT_OK_MR(GetFileName(file, &fn_ptr));
191     if (name.empty() || name[0][0] != *fn_ptr) {
192       sqlite3_free(errmsg);
193       sqlite3_close(*db);
194       RETURN_STATUS_UNEXPECTED_MR("Invalid file, mindrecord meta file: " + file + ".db and mindrecord file: " + file +
195                                   " can not match. Please do not rename the mindrecord file or meta file.");
196     }
197   }
198   return Status::OK();
199 }
200 
CheckColumnList(const std::vector<std::string> & selected_columns)201 Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
202   auto schema_ptr = GetShardHeader()->GetSchemas()[0];
203   auto schema = schema_ptr->GetSchema()["schema"];
204   for (auto i = 0; i < selected_columns.size(); ++i) {
205     CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(selected_columns[i]) != schema.end(),
206                                     "Invalid data, column name: " + selected_columns[i] +
207                                       " can not found in schema. Please check the 'column_list'.");
208   }
209   return Status::OK();
210 }
211 
Open(int n_consumer)212 Status ShardReader::Open(int n_consumer) {
213   file_streams_random_ =
214     std::vector<std::vector<std::shared_ptr<std::fstream>>>(n_consumer, std::vector<std::shared_ptr<std::fstream>>());
215   for (const auto &file : file_paths_) {
216     for (int j = 0; j < n_consumer; ++j) {
217       std::optional<std::string> dir = "";
218       std::optional<std::string> local_file_name = "";
219       FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
220       if (!dir.has_value()) {
221         dir = ".";
222       }
223 
224       auto realpath = FileUtils::GetRealPath(dir.value().c_str());
225       CHECK_FAIL_RETURN_UNEXPECTED_MR(
226         realpath.has_value(),
227         "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
228 
229       std::optional<std::string> whole_path = "";
230       FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
231 
232       std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
233       fs->open(whole_path.value(), std::ios::in | std::ios::binary);
234       if (!fs->good()) {
235         fs->close();
236         RETURN_STATUS_UNEXPECTED_MR(
237           "Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
238           "open files limit(ulimit -a): " +
239           file);
240       }
241       file_streams_random_[j].push_back(fs);
242     }
243     MS_LOG(INFO) << "Succeed to open file, path: " << file;
244   }
245   return Status::OK();
246 }
247 
ExtendRandomFileStreams(const int n_new_consumers)248 Status ShardReader::ExtendRandomFileStreams(const int n_new_consumers) {
249   CHECK_FAIL_RETURN_UNEXPECTED_MR(n_new_consumers > 0,
250                                   "n_new_consumers must be a positive number. Got: " + std::to_string(n_new_consumers));
251   CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_streams_random_.empty(),
252                                   "ExtendRandomFileStreams() must not be called prior to calling Open()");
253   // make sure we won't exceed the number of allowed threads.
254   uint32_t thread_limit = GetMaxThreadNum();
255   CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ + n_new_consumers <= thread_limit,
256                                   "Requested increase in number of consumers will cause it to be above the number of "
257                                   "allowed threads. n_new_consumers: " +
258                                     std::to_string(n_new_consumers) +
259                                     ", new n_consumers: " + std::to_string(n_consumer_ + n_new_consumers));
260 
261   for (int i = 0; i < n_new_consumers; i++) {
262     (void)file_streams_random_.emplace_back(std::vector<std::shared_ptr<std::fstream>>());
263   }
264 
265   for (const auto &file : file_paths_) {
266     std::optional<std::string> dir = "";
267     std::optional<std::string> local_file_name = "";
268     FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
269     if (!dir.has_value()) {
270       dir = ".";
271     }
272 
273     auto realpath = FileUtils::GetRealPath(dir.value().data());
274     CHECK_FAIL_RETURN_UNEXPECTED_MR(
275       realpath.has_value(), "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
276 
277     std::optional<std::string> whole_path = "";
278     FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
279 
280     for (int j = n_consumer_; j < n_consumer_ + n_new_consumers; ++j) {
281       std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
282       fs->open(whole_path.value(), std::ios::in | std::ios::binary);
283       if (!fs->good()) {
284         fs->close();
285         RETURN_STATUS_UNEXPECTED_MR(
286           "Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
287           "open files limit(ulimit -a): " +
288           file);
289       }
290       file_streams_random_[j].push_back(fs);
291     }
292     MS_LOG(INFO) << "Succeed to open file, path: " << file;
293   }
294   n_consumer_ += n_new_consumers;
295   MS_LOG(INFO) << "n_consumer_ is increased by " + std::to_string(n_new_consumers) + " to " +
296                     std::to_string(n_consumer_);
297 
298   return Status::OK();
299 }
300 
ShrinkRandomFileStreams(const int n_remove_consumers)301 Status ShardReader::ShrinkRandomFileStreams(const int n_remove_consumers) {
302   CHECK_FAIL_RETURN_UNEXPECTED_MR(
303     n_remove_consumers > 0, "n_remove_consumers must be a positive number. Got: " + std::to_string(n_remove_consumers));
304   CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_streams_random_.empty(),
305                                   "ShrinkRandomFileStreams() must not be called prior to calling Open()");
306   // make sure we won't go below the number of allowed threads.
307   CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ - n_remove_consumers >= kMinConsumerCount,
308                                   "Requested decrease in number of consumers will cause it to be below the number of "
309                                   "allowed threads. n_remove_consumers: " +
310                                     std::to_string(n_remove_consumers) +
311                                     ", new n_consumers: " + std::to_string(n_consumer_ - n_remove_consumers));
312 
313   for (int i = n_consumer_ - 1; i >= n_consumer_ - n_remove_consumers; i--) {
314     for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
315       if (file_streams_random_[i][j] != nullptr) {
316         file_streams_random_[i][j]->close();
317       }
318     }
319     file_streams_random_.pop_back();
320   }
321   n_consumer_ -= n_remove_consumers;
322   MS_LOG(INFO) << "n_consumer_ is decreased by " + std::to_string(n_remove_consumers) + " to " +
323                     std::to_string(n_consumer_);
324 
325   return Status::OK();
326 }
327 
FileStreamsOperator()328 void ShardReader::FileStreamsOperator() {
329   for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; --i) {
330     if (file_streams_[i] != nullptr) {
331       file_streams_[i]->close();
332     }
333   }
334   for (int i = static_cast<int>(file_streams_random_.size()) - 1; i >= 0; --i) {
335     for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
336       if (file_streams_random_[i][j] != nullptr) {
337         file_streams_random_[i][j]->close();
338       }
339     }
340   }
341   for (int i = static_cast<int>(database_paths_.size()) - 1; i >= 0; --i) {
342     if (database_paths_[i] != nullptr) {
343       auto sql_code = sqlite3_exec(database_paths_[i], "END TRANSACTION;", nullptr, nullptr, nullptr);
344       if (sql_code != SQLITE_OK) {
345         sqlite3_close(database_paths_[i]);
346         MS_LOG(ERROR) << "Execute SQL statement `END TRANSACTION;` failed, SQLite result code: "
347                       << std::to_string(sql_code);
348         continue;
349       }
350       auto ret = sqlite3_close(database_paths_[i]);
351       if (ret != SQLITE_OK) {
352         MS_LOG(ERROR) << "[Internal ERROR] Failed to close meta file, " << ret << ".";
353       }
354       database_paths_[i] = nullptr;
355     }
356   }
357 }
358 
~ShardReader()359 ShardReader::~ShardReader() { Close(); }
360 
Close()361 void ShardReader::Close() {
362   {
363     std::lock_guard<std::mutex> lck(mtx_delivery_);
364     interrupt_ = true;  // interrupt reading and stop threads
365   }
366   cv_delivery_.notify_all();
367 
368   // Wait for all threads to finish
369   for (auto &i_thread : thread_set_) {
370     if (i_thread.joinable()) {
371       i_thread.join();
372     }
373   }
374 
375   FileStreamsOperator();
376 }
377 
GetShardHeader() const378 std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
379 
GetShardColumn() const380 std::shared_ptr<ShardColumn> ShardReader::GetShardColumn() const { return shard_column_; }
381 
GetShardCount() const382 int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
383 
GetNumRows() const384 int64_t ShardReader::GetNumRows() const { return num_rows_; }
385 
GetNumRowsAfterSampling() const386 int64_t ShardReader::GetNumRowsAfterSampling() const { return tasks_.SizeAfterSampling(); }
387 
ReadRowGroupSummary()388 std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
389   std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
390   int shard_count = shard_header_->GetShardCount();
391   if (shard_count <= 0) {
392     return row_group_summary;
393   }
394 
395   uint32_t total_count = 0;
396   for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
397     // return -1 when page's size equals to 0.
398     auto last_page_id = shard_header_->GetLastPageId(shard_id);
399     if (static_cast<int>(last_page_id) == -1) {
400       // Empty mindrecord file which does not contain any samples
401       MS_LOG(WARNING) << "The mindrecord file: " << file_paths_[shard_id]
402                       << " does not contain any samples, pls remove it.";
403       row_group_summary.emplace_back(shard_id, 0, 0, 0);
404       shard_sample_count_.push_back(total_count);
405       continue;
406     }
407     for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
408       std::shared_ptr<Page> page_ptr;
409       (void)shard_header_->GetPage(shard_id, page_id, &page_ptr);
410       if (page_ptr->GetPageType() != kPageTypeBlob) {
411         continue;
412       }
413       uint64_t start_row_id = page_ptr->GetStartRowID();
414       if (start_row_id > page_ptr->GetEndRowID()) {
415         return std::vector<std::tuple<int, int, int, uint64_t>>();
416       }
417       uint64_t number_of_rows = page_ptr->GetEndRowID() - start_row_id;
418       total_count += number_of_rows;
419       row_group_summary.emplace_back(shard_id, page_ptr->GetPageTypeID(), start_row_id, number_of_rows);
420     }
421     shard_sample_count_.push_back(total_count);
422   }
423 
424   return row_group_summary;
425 }
426 
ConvertLabelToJson(const std::vector<std::vector<std::string>> & labels,std::shared_ptr<std::fstream> fs,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)427 Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
428                                        std::shared_ptr<std::fstream> fs,
429                                        std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
430                                        int shard_id, const std::vector<std::string> &columns,
431                                        std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
432   auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
433   for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
434     try {
435       uint64_t group_id = std::stoull(labels[i][0]);
436       uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
437       uint64_t offset_end = std::stoull(labels[i][2]);
438       CHECK_FAIL_RETURN_UNEXPECTED_MR(offset_end >= offset_start,
439                                       "The sample's end offset: " + std::to_string(offset_end) +
440                                         " should >= start offset: " + std::to_string(offset_start) + ", check fail.");
441       (*offset_ptr)[shard_id].emplace_back(
442         std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
443       if (!all_in_index_) {
444         int raw_page_id = std::stoi(labels[i][3]);
445         uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
446         uint64_t label_end = std::stoull(labels[i][5]);
447         CHECK_FAIL_RETURN_UNEXPECTED_MR(label_end >= label_start,
448                                         "The sample's end offset: " + std::to_string(label_end) +
449                                           " should >= start offset: " + std::to_string(label_start) + ", check fail.");
450         auto len = label_end - label_start;
451         auto label_raw = std::vector<uint8_t>(len);
452         auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
453         if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
454           fs->close();
455           RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
456         }
457         auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
458         if (!io_read.good() || io_read.fail() || io_read.bad()) {
459           fs->close();
460           RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
461         }
462         json label_json = json::from_msgpack(label_raw);
463         json tmp;
464         if (!columns.empty()) {
465           for (const auto &col : columns) {
466             if (label_json.find(col) != label_json.end()) {
467               tmp[col] = label_json[col];
468             }
469           }
470         } else {
471           tmp = label_json;
472         }
473         (*col_val_ptr)[shard_id].emplace_back(tmp);
474       } else {
475         json construct_json;
476         RETURN_IF_NOT_OK_MR(ConvertJsonValue(labels[i], columns, schema, &construct_json));
477         (*col_val_ptr)[shard_id].emplace_back(construct_json);
478       }
479     } catch (std::out_of_range &e) {
480       fs->close();
481       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
482                                   std::string(e.what()));
483     } catch (std::invalid_argument &e) {
484       fs->close();
485       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
486                                   std::string(e.what()));
487     } catch (...) {
488       fs->close();
489       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Unexpected exception raised in ConvertLabelToJson function.");
490     }
491   }
492 
493   return Status::OK();
494 }
495 
ConvertJsonValue(const std::vector<std::string> & label,const std::vector<std::string> & columns,const json & schema,json * value)496 Status ShardReader::ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns,
497                                      const json &schema, json *value) {
498   constexpr int64_t index = 3;
499   for (unsigned int j = 0; j < columns.size(); ++j) {
500     if (schema[columns[j]]["type"] == "int32") {
501       (*value)[columns[j]] = StringToNum<int32_t>(label[j + index]);
502     } else if (schema[columns[j]]["type"] == "int64") {
503       (*value)[columns[j]] = StringToNum<int64_t>(label[j + index]);
504     } else if (schema[columns[j]]["type"] == "float32") {
505       (*value)[columns[j]] = StringToNum<float>(label[j + index]);
506     } else if (schema[columns[j]]["type"] == "float64") {
507       (*value)[columns[j]] = StringToNum<double>(label[j + index]);
508     } else {
509       (*value)[columns[j]] = std::string(label[j + index]);
510     }
511   }
512   return Status::OK();
513 }
ReadAllRowsInShard(int shard_id,const int32_t & consumer_id,const std::string & sql,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)514 Status ShardReader::ReadAllRowsInShard(int shard_id, const int32_t &consumer_id, const std::string &sql,
515                                        const std::vector<std::string> &columns,
516                                        std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
517                                        std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
518   auto db = database_paths_[shard_id];
519   std::vector<std::vector<std::string>> labels;
520   char *errmsg = nullptr;
521   int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
522   if (rc != SQLITE_OK) {
523     std::ostringstream oss;
524     oss << "[Internal ERROR] Failed to execute the sql [ " << sql << " ] while reading meta file, " << errmsg;
525     sqlite3_free(errmsg);
526     sqlite3_close(db);
527     db = nullptr;
528     RETURN_STATUS_UNEXPECTED_MR(oss.str());
529   }
530   MS_LOG(DEBUG) << "Succeed to get " << labels.size() << " records from shard " << std::to_string(shard_id)
531                 << " index.";
532 
533   sqlite3_free(errmsg);
534   return ConvertLabelToJson(labels, file_streams_random_[consumer_id][shard_id], offset_ptr, shard_id, columns,
535                             col_val_ptr);
536 }
537 
GetAllClasses(const std::string & category_field,std::shared_ptr<std::set<std::string>> category_ptr)538 Status ShardReader::GetAllClasses(const std::string &category_field,
539                                   std::shared_ptr<std::set<std::string>> category_ptr) {
540   std::map<std::string, uint64_t> index_columns;
541   for (auto &field : GetShardHeader()->GetFields()) {
542     index_columns[field.second] = field.first;
543   }
544   CHECK_FAIL_RETURN_UNEXPECTED_MR(
545     index_columns.find(category_field) != index_columns.end(),
546     "Invalid data, 'class_column': " + category_field +
547       " can not found in fields of mindrecord files. Please check 'class_column' in PKSampler.");
548   std::shared_ptr<std::string> fn_ptr;
549   RETURN_IF_NOT_OK_MR(
550     ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
551   std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
552   std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
553   for (int x = 0; x < shard_count_; x++) {
554     threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
555   }
556 
557   for (int x = 0; x < shard_count_; x++) {
558     threads[x].join();
559   }
560   return Status::OK();
561 }
562 
GetClassesInShard(sqlite3 * db,int shard_id,const std::string & sql,std::shared_ptr<std::set<std::string>> category_ptr)563 void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
564                                     std::shared_ptr<std::set<std::string>> category_ptr) {
565   if (db == nullptr) {
566     return;
567   }
568 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
569   pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_id)).c_str());
570 #endif
571   std::vector<std::vector<std::string>> columns;
572   char *errmsg = nullptr;
573   int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg);
574   if (ret != SQLITE_OK) {
575     sqlite3_free(errmsg);
576     sqlite3_close(db);
577     db = nullptr;
578     MS_LOG(ERROR) << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql)
579                   << " ] while reading meta file, " << errmsg;
580     return;
581   }
582   MS_LOG(INFO) << "Succeed to get " << columns.size() << " records from shard " << std::to_string(shard_id)
583                << " index.";
584   std::lock_guard<std::mutex> lck(shard_locker_);
585   for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
586     category_ptr->emplace(columns[i][0]);
587   }
588   sqlite3_free(errmsg);
589 }
590 
ReadAllRowGroup(const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUPS> * row_group_ptr)591 Status ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns,
592                                     std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
593   RETURN_UNEXPECTED_IF_NULL_MR(row_group_ptr);
594   std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
595   auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
596     shard_count_, std::vector<std::vector<uint64_t>>{});
597   auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
598 
599   if (all_in_index_) {
600     for (unsigned int i = 0; i < columns.size(); ++i) {
601       fields += ',';
602       std::shared_ptr<std::string> fn_ptr;
603       RETURN_IF_NOT_OK_MR(
604         ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
605       fields += *fn_ptr;
606     }
607   } else {  // fetch raw data from Raw page while some field is not index.
608     fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
609   }
610 
611   std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;";
612 
613   std::vector<std::future<Status>> async_results;
614   auto status = Status::OK();
615   for (int x = 0; x < shard_count_; x++) {
616     async_results.push_back(std::async(std::launch::async, &ShardReader::ReadAllRowsInShard, this, x, 0, sql, columns,
617                                        offset_ptr, col_val_ptr));
618   }
619 
620   for (auto i = 0; i < async_results.size(); i++) {
621     auto res = async_results[i].get();
622     if (res.IsError() && status.IsOk()) {
623       status = res;
624     }
625   }
626   if (status.IsError()) {
627     return status;
628   }
629   *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
630   return Status::OK();
631 }
632 
ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> & columns,const uint32_t & shard_id,const int32_t & consumer_id,const uint32_t & sample_id,std::shared_ptr<ROW_GROUPS> * row_group_ptr)633 Status ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
634                                                      const int32_t &consumer_id, const uint32_t &sample_id,
635                                                      std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
636   RETURN_UNEXPECTED_IF_NULL_MR(row_group_ptr);
637   std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
638   auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
639     shard_count_, std::vector<std::vector<uint64_t>>{});
640   auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
641   if (all_in_index_) {
642     for (unsigned int i = 0; i < columns.size(); ++i) {
643       fields += ',';
644       std::shared_ptr<std::string> fn_ptr;
645       RETURN_IF_NOT_OK_MR(
646         ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
647       fields += *fn_ptr;
648     }
649   } else {  // fetch raw data from Raw page while some field is not index.
650     fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
651   }
652 
653   std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
654 
655   RETURN_IF_NOT_OK_MR(ReadAllRowsInShard(shard_id, consumer_id, sql, columns, offset_ptr, col_val_ptr));
656   *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
657   return Status::OK();
658 }
659 
ReadRowGroupBrief(int group_id,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)660 Status ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
661                                       std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
662   RETURN_UNEXPECTED_IF_NULL_MR(row_group_brief_ptr);
663   std::shared_ptr<Page> page_ptr;
664   RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
665   std::string file_name = file_paths_[shard_id];
666   uint64_t page_length = page_ptr->GetPageSize();
667   uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
668   std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id);
669   auto labels_ptr = std::make_shared<std::vector<json>>();
670   RETURN_IF_NOT_OK_MR(GetLabels(page_ptr->GetPageID(), shard_id, columns, {"", ""}, &labels_ptr));
671   *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
672                                                            std::move(*labels_ptr));
673   return Status::OK();
674 }
675 
ReadRowGroupCriteria(int group_id,int shard_id,const std::pair<std::string,std::string> & criteria,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)676 Status ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
677                                          const std::pair<std::string, std::string> &criteria,
678                                          const std::vector<std::string> &columns,
679                                          std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
680   RETURN_UNEXPECTED_IF_NULL_MR(row_group_brief_ptr);
681   std::shared_ptr<Page> page_ptr;
682   RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
683   vector<string> criteria_list{criteria.first};
684   RETURN_IF_NOT_OK_MR(CheckColumnList(criteria_list));
685   std::string file_name = file_paths_[shard_id];
686   uint64_t page_length = page_ptr->GetPageSize();
687   uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
688   std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id, criteria);
689   if (image_offset.empty()) {
690     *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>();
691   }
692   auto labels_ptr = std::make_shared<std::vector<json>>();
693   RETURN_IF_NOT_OK_MR(GetLabels(page_ptr->GetPageID(), shard_id, columns, criteria, &labels_ptr));
694   *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
695                                                            std::move(*labels_ptr));
696   return Status::OK();
697 }
698 
SelectCallback(void * p_data,int num_fields,char ** p_fields,char ** p_col_names)699 int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) {
700   auto *records = static_cast<std::vector<std::vector<std::string>> *>(p_data);
701   if (num_fields > 0 && num_fields <= kMaxFieldCount) {
702     for (int i = 0; i < num_fields; ++i) {
703       if (p_fields[i] == nullptr) {
704         p_fields[i] = const_cast<char *>("");
705       }
706     }
707   }
708   records->emplace_back(p_fields, p_fields + num_fields);
709   return 0;
710 }
711 
GetImageOffset(int page_id,int shard_id,const std::pair<std::string,std::string> & criteria)712 std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int shard_id,
713                                                                const std::pair<std::string, std::string> &criteria) {
714   auto db = database_paths_[shard_id];
715 
716   std::string sql = "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
717 
718   // whether use index search
719   if (!criteria.first.empty()) {
720     sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
721   }
722   sql += ";";
723   std::vector<std::vector<std::string>> image_offsets;
724 
725   sqlite3_stmt *stmt = nullptr;
726   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
727     MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to prepare statement [ " << sql << " ].";
728   }
729 
730   // bind the PAGE_ID_BLOB
731   int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
732   if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
733     (void)sqlite3_finalize(stmt);
734     MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to bind parameter of sql, key index: " << std::to_string(index)
735                       << ", value: " << std::to_string(page_id);
736   }
737 
738   // bind the criteria
739   if (!criteria.first.empty()) {
740     index = sqlite3_bind_parameter_index(stmt, ":criteria");
741     if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria.second), -1, SQLITE_STATIC) != SQLITE_OK) {
742       (void)sqlite3_finalize(stmt);
743       MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to bind parameter of sql, key index: " << std::to_string(index)
744                         << ", value: " + criteria.second;
745     }
746   }
747 
748   int rc = sqlite3_step(stmt);
749   while (rc != SQLITE_DONE) {
750     vector<string> tmp;
751     int ncols = sqlite3_column_count(stmt);
752     for (int i = 0; i < ncols; i++) {
753       tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
754     }
755     image_offsets.push_back(tmp);
756     rc = sqlite3_step(stmt);
757   }
758 
759   auto finalize = sqlite3_finalize(stmt);
760   if (finalize != SQLITE_OK) {
761     MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to finalize sql stmt, error code: " << std::to_string(finalize);
762   }
763 
764   MS_LOG(DEBUG) << "Succeed to get " << image_offsets.size() << " records from index.";
765 
766   std::vector<std::vector<uint64_t>> res;
767   for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) {
768     res.emplace_back(std::vector<uint64_t>{0, 0});
769   }
770   for (int i = 0; i < static_cast<int>(image_offsets.size()); i++) {
771     const auto &image_offset = image_offsets[i];
772     res[i][0] = std::stoull(image_offset[0]) + kInt64Len;
773     res[i][1] = std::stoull(image_offset[1]);
774     if (res[i][1] < res[i][0]) {
775       MS_LOG(EXCEPTION) << "The sample's end offset: " << std::to_string(res[i][1])
776                         << " should >= start offset: " << std::to_string(res[i][0]) << ", check fail.";
777     }
778   }
779   return res;
780 }
781 
GetPagesByCategory(int shard_id,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<uint64_t>> * pages_ptr)782 Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
783                                        std::shared_ptr<std::vector<uint64_t>> *pages_ptr) {
784   RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
785   auto db = database_paths_[shard_id];
786 
787   std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
788 
789   if (!criteria.first.empty()) {
790     sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
791   }
792   sql += ";";
793   std::vector<std::vector<std::string>> page_ids;
794 
795   sqlite3_stmt *stmt = nullptr;
796   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
797     (void)sqlite3_finalize(stmt);
798     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
799   }
800 
801   if (!criteria.first.empty()) {
802     int index = sqlite3_bind_parameter_index(stmt, ":criteria");
803     if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria.second), -1, SQLITE_STATIC) != SQLITE_OK) {
804       (void)sqlite3_finalize(stmt);
805       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
806                                   std::to_string(index) + ", value: " + criteria.second);
807     }
808   }
809 
810   int rc = sqlite3_step(stmt);
811   while (rc != SQLITE_DONE) {
812     vector<string> tmp;
813     int ncols = sqlite3_column_count(stmt);
814     for (int i = 0; i < ncols; i++) {
815       tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
816     }
817     page_ids.push_back(tmp);
818     rc = sqlite3_step(stmt);
819   }
820 
821   auto finalize = sqlite3_finalize(stmt);
822   if (finalize != SQLITE_OK) {
823     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
824                                 std::to_string(finalize));
825   }
826 
827   MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << " pages from index.";
828   for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
829     (*pages_ptr)->emplace_back(std::stoull(page_ids[i][0]));
830   }
831   return Status::OK();
832 }
833 
GetBlobFields()834 std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
835   std::vector<std::string> blob_fields;
836   for (auto &p : GetShardHeader()->GetSchemas()) {
837     // assume one schema
838     const auto &fields = p->GetBlobFields();
839     blob_fields.assign(fields.begin(), fields.end());
840     break;
841   }
842   return std::make_pair(kCV, blob_fields);
843 }
844 
CheckIfColumnInIndex(const std::vector<std::string> & columns)845 void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) {
846   // assume different schemas do not contain same key.
847   if (columns.empty()) {
848     all_in_index_ = false;
849     return;
850   }
851   for (auto &field : GetShardHeader()->GetFields()) {
852     column_schema_id_[field.second] = field.first;
853   }
854   for (auto &col : columns) {
855     if (column_schema_id_.find(col) == column_schema_id_.end()) {
856       all_in_index_ = false;
857       return;
858     }
859   }
860 }
861 
QueryWithPageIdBlobAndCriteria(sqlite3 * db,const string & sql,const int & page_id,const string & criteria,std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr)862 Status ShardReader::QueryWithPageIdBlobAndCriteria(sqlite3 *db, const string &sql, const int &page_id,
863                                                    const string &criteria,
864                                                    std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
865   sqlite3_stmt *stmt = nullptr;
866   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
867     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
868   }
869 
870   // bind the PAGE_ID_BLOB
871   int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
872   if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
873     (void)sqlite3_finalize(stmt);
874     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
875                                 std::to_string(index) + ", value: " + std::to_string(page_id));
876   }
877 
878   // bind the criteria
879   index = sqlite3_bind_parameter_index(stmt, ":criteria");
880   if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
881     (void)sqlite3_finalize(stmt);
882     RETURN_STATUS_UNEXPECTED_MR(
883       "[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) + ", value: " + criteria);
884   }
885   int rc = sqlite3_step(stmt);
886   while (rc != SQLITE_DONE) {
887     vector<string> tmp;
888     int ncols = sqlite3_column_count(stmt);
889     for (int i = 0; i < ncols; i++) {
890       tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
891     }
892     labels_ptr->push_back(tmp);
893     rc = sqlite3_step(stmt);
894   }
895 
896   auto finalize = sqlite3_finalize(stmt);
897   if (finalize != SQLITE_OK) {
898     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
899                                 std::to_string(finalize));
900   }
901   return Status::OK();
902 }
903 
GetLabelsFromBinaryFile(int shard_id,const std::vector<std::string> & columns,const std::vector<std::vector<std::string>> & label_offsets,std::shared_ptr<std::vector<json>> * labels_ptr)904 Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
905                                             const std::vector<std::vector<std::string>> &label_offsets,
906                                             std::shared_ptr<std::vector<json>> *labels_ptr) {
907   RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
908   std::shared_ptr<std::fstream> fs = file_streams_random_[0][shard_id];
909 
910   // init the return
911   for (unsigned int i = 0; i < label_offsets.size(); ++i) {
912     (*labels_ptr)->emplace_back(json{});
913   }
914 
915   for (unsigned int i = 0; i < label_offsets.size(); ++i) {
916     const auto &labelOffset = label_offsets[i];
917     if (labelOffset.size() < 3) {
918       fs->close();
919       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] 'labelOffset' size should be less than 3 but got: " +
920                                   std::to_string(labelOffset.size()) + ".");
921     }
922     uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len;
923     uint64_t label_end = std::stoull(labelOffset[2]);
924     CHECK_FAIL_RETURN_UNEXPECTED_MR(label_end >= label_start,
925                                     "The sample's end offset: " + std::to_string(label_end) +
926                                       " should >= start offset: " + std::to_string(label_start) + ", check fail.");
927     int raw_page_id = std::stoi(labelOffset[0]);
928     auto len = label_end - label_start;
929     auto label_raw = std::vector<uint8_t>(len);
930     auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
931     if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
932       fs->close();
933       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file, path: " + file_paths_[shard_id]);
934     }
935 
936     auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
937     if (!io_read.good() || io_read.fail() || io_read.bad()) {
938       fs->close();
939       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file, path: " + file_paths_[shard_id]);
940     }
941 
942     json label_json = json::from_msgpack(label_raw);
943     json tmp = label_json;
944     for (auto &col : columns) {
945       if (label_json.find(col) != label_json.end()) {
946         tmp[col] = label_json[col];
947       }
948     }
949     (*(*labels_ptr))[i] = tmp;
950   }
951   return Status::OK();
952 }
953 
GetLabelsFromPage(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)954 Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
955                                       const std::pair<std::string, std::string> &criteria,
956                                       std::shared_ptr<std::vector<json>> *labels_ptr) {
957   RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
958   // get page info from sqlite
959   auto db = database_paths_[shard_id];
960   std::string sql =
961     "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
962 
963   auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
964   if (!criteria.first.empty()) {
965     sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria;";
966     RETURN_IF_NOT_OK_MR(QueryWithPageIdBlobAndCriteria(db, sql, page_id, criteria.second, label_offset_ptr));
967   } else {
968     sql += ";";
969     sqlite3_stmt *stmt = nullptr;
970     if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
971       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
972     }
973 
974     // bind the PAGE_ID_BLOB
975     int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
976     if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
977       (void)sqlite3_finalize(stmt);
978       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
979                                   std::to_string(index) + ", value: " + std::to_string(page_id));
980     }
981 
982     int rc = sqlite3_step(stmt);
983     while (rc != SQLITE_DONE) {
984       vector<string> tmp;
985       int ncols = sqlite3_column_count(stmt);
986       for (int i = 0; i < ncols; i++) {
987         tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
988       }
989       label_offset_ptr->push_back(tmp);
990       rc = sqlite3_step(stmt);
991     }
992 
993     auto finalize = sqlite3_finalize(stmt);
994     if (finalize != SQLITE_OK) {
995       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
996                                   std::to_string(finalize));
997     }
998 
999     MS_LOG(DEBUG) << "Succeed to get " << label_offset_ptr->size() << " records from index.";
1000   }
1001   // get labels from binary file
1002   return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr, labels_ptr);
1003 }
1004 
GetLabels(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)1005 Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
1006                               const std::pair<std::string, std::string> &criteria,
1007                               std::shared_ptr<std::vector<json>> *labels_ptr) {
1008   RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
1009   if (all_in_index_) {
1010     auto db = database_paths_[shard_id];
1011     std::string fields;
1012     for (unsigned int i = 0; i < columns.size(); ++i) {
1013       if (i > 0) {
1014         fields += ',';
1015       }
1016       uint64_t schema_id = column_schema_id_[columns[i]];
1017       fields += columns[i] + "_" + std::to_string(schema_id);
1018     }
1019     if (fields.empty()) {
1020       fields = "*";
1021     }
1022     auto labels = std::make_shared<std::vector<std::vector<std::string>>>();
1023     std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
1024     if (!criteria.first.empty()) {
1025       sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria;";
1026       RETURN_IF_NOT_OK_MR(QueryWithPageIdBlobAndCriteria(db, sql, page_id, criteria.second, labels));
1027     } else {
1028       sql += ";";
1029       sqlite3_stmt *stmt = nullptr;
1030       if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
1031         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
1032       }
1033 
1034       // bind the PAGE_ID_BLOB
1035       int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
1036       if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
1037         (void)sqlite3_finalize(stmt);
1038         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
1039                                     std::to_string(index) + ", value: " + std::to_string(page_id));
1040       }
1041 
1042       int rc = sqlite3_step(stmt);
1043       while (rc != SQLITE_DONE) {
1044         vector<string> tmp;
1045         int ncols = sqlite3_column_count(stmt);
1046         for (int i = 0; i < ncols; i++) {
1047           tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
1048         }
1049         labels->push_back(tmp);
1050         rc = sqlite3_step(stmt);
1051       }
1052 
1053       auto finalize = sqlite3_finalize(stmt);
1054       if (finalize != SQLITE_OK) {
1055         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
1056                                     std::to_string(finalize));
1057       }
1058 
1059       MS_LOG(DEBUG) << "Succeed to get " << labels->size() << " records from index.";
1060     }
1061     for (unsigned int i = 0; i < labels->size(); ++i) {
1062       (*labels_ptr)->emplace_back(json{});
1063     }
1064     for (unsigned int i = 0; i < labels->size(); ++i) {
1065       json construct_json;
1066       for (unsigned int j = 0; j < columns.size(); ++j) {
1067         // construct json "f1": value
1068         auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
1069 
1070         // convert the string to base type by schema
1071         if (schema[columns[j]]["type"] == "int32") {
1072           construct_json[columns[j]] = StringToNum<int32_t>((*labels)[i][j]);
1073         } else if (schema[columns[j]]["type"] == "int64") {
1074           construct_json[columns[j]] = StringToNum<int64_t>((*labels)[i][j]);
1075         } else if (schema[columns[j]]["type"] == "float32") {
1076           construct_json[columns[j]] = StringToNum<float>((*labels)[i][j]);
1077         } else if (schema[columns[j]]["type"] == "float64") {
1078           construct_json[columns[j]] = StringToNum<double>((*labels)[i][j]);
1079         } else {
1080           construct_json[columns[j]] = std::string((*labels)[i][j]);
1081         }
1082       }
1083       (*(*labels_ptr))[i] = construct_json;
1084     }
1085     return Status::OK();
1086   }
1087   return GetLabelsFromPage(page_id, shard_id, columns, criteria, labels_ptr);
1088 }
1089 
ResortRowGroups(std::tuple<int,int,int,int> a,std::tuple<int,int,int,int> b)1090 bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, int> b) {
1091   return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
1092 }
1093 
GetNumClasses(const std::string & category_field)1094 int64_t ShardReader::GetNumClasses(const std::string &category_field) {
1095   auto shard_count = file_paths_.size();
1096   auto index_fields = shard_header_->GetFields();
1097 
1098   std::map<std::string, int64_t> map_schema_id_fields;
1099   for (auto &field : index_fields) {
1100     map_schema_id_fields[field.second] = field.first;
1101   }
1102 
1103   if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
1104     MS_LOG(ERROR) << "[Internal ERROR] 'category_field' " << category_field
1105                   << " can not found in index fields of mindrecord files.";
1106     return -1;
1107   }
1108   std::shared_ptr<std::string> fn_ptr;
1109   (void)ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field),
1110                                                &fn_ptr);
1111   std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
1112   std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
1113   auto category_ptr = std::make_shared<std::set<std::string>>();
1114   sqlite3 *db = nullptr;
1115   for (int x = 0; x < shard_count; x++) {
1116     std::string path_utf8 = "";
1117 #if defined(_WIN32) || defined(_WIN64)
1118     path_utf8 = FileUtils::GB2312ToUTF_8((file_paths_[x] + ".db").data());
1119 #endif
1120     if (path_utf8.empty()) {
1121       path_utf8 = file_paths_[x] + ".db";
1122     }
1123 
1124 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1125     // use "unix-none" to avoid flock and achieve better performance on shared storage platform
1126     int rc = sqlite3_open_v2(path_utf8.data(), &db, SQLITE_OPEN_READONLY, "unix-none");
1127 #else
1128     int rc = sqlite3_open_v2(path_utf8.data(), &db, SQLITE_OPEN_READONLY, nullptr);
1129 #endif
1130     if (SQLITE_OK != rc) {
1131       MS_LOG(ERROR) << "[Internal ERROR] Failed to open meta file: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
1132       return -1;
1133     }
1134 
1135     // starting a transaction during a read-only select operation can solve the problem of frequently
1136     // accessing *-journal / *-wal files.
1137     auto sql_code = sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
1138     if (sql_code != SQLITE_OK) {
1139       sqlite3_free(db);
1140       MS_LOG(ERROR) << "Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: "
1141                     << std::to_string(sql_code);
1142       return -1;
1143     }
1144     threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
1145   }
1146 
1147   for (int x = 0; x < shard_count; x++) {
1148     threads[x].join();
1149   }
1150   sqlite3_close(db);
1151   return category_ptr->size();
1152 }
1153 
CountTotalRows(const std::vector<std::string> & file_paths,bool load_dataset,const std::shared_ptr<ShardOperator> & ops,int64_t * count,const int64_t num_padded)1154 Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
1155                                    const std::shared_ptr<ShardOperator> &ops, int64_t *count,
1156                                    const int64_t num_padded) {
1157   RETURN_IF_NOT_OK_MR(Init(file_paths, load_dataset));
1158   int64_t num_samples = num_rows_;
1159   bool root = true;
1160   std::stack<std::shared_ptr<ShardOperator>> stack_ops;
1161   std::shared_ptr<ShardOperator> op(ops);
1162   while (op != nullptr) {
1163     stack_ops.push(op);
1164     op = op->GetChildOp();
1165   }
1166   while (!stack_ops.empty()) {
1167     op = stack_ops.top();
1168     stack_ops.pop();
1169     if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
1170       num_samples = op->GetNumSamples(num_samples, 0);
1171       if (num_padded > 0 && root == true) {
1172         num_samples += num_padded;
1173         root = false;
1174       }
1175     } else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1176       auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
1177       std::string category_field = category_op->GetCategoryField();
1178       auto num_classes = GetNumClasses(category_field);
1179       num_samples = category_op->GetNumSamples(num_samples, num_classes);
1180       if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
1181         auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
1182         if (tmp != 0 && num_samples != -1) {
1183           num_samples = std::min(num_samples, tmp);
1184         }
1185 
1186         CHECK_FAIL_RETURN_UNEXPECTED_MR(num_samples != -1,
1187                                         "Invalid data, 'num_samples': " + std::to_string(num_samples) +
1188                                           " is out of bound: " + std::to_string(std::numeric_limits<int64_t>::max()));
1189       }
1190     } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
1191       if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1192         auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
1193         if (root == true) {
1194           sampler_op->SetNumPaddedSamples(num_padded);
1195           num_samples = op->GetNumSamples(num_samples, 0);
1196           CHECK_FAIL_RETURN_UNEXPECTED_MR(
1197             num_samples != -1,
1198             "Invalid data, the size of dataset and padded samples: " + std::to_string(num_padded) +
1199               " can not be divisible by the value of 'num_shards'.\n Please adjust the value of 'num_padded'.");
1200           root = false;
1201         }
1202       } else {
1203         num_samples = op->GetNumSamples(num_samples, 0);
1204         num_samples += num_padded;
1205       }
1206     } else {
1207       if (num_padded > 0) {
1208         num_samples += num_padded;
1209       }
1210     }
1211   }
1212   *count = num_samples;
1213   return Status::OK();
1214 }
1215 
Open(const std::vector<std::string> & file_paths,bool load_dataset,int n_consumer,const std::vector<std::string> & selected_columns,const std::vector<std::shared_ptr<ShardOperator>> & operators,int64_t num_padded,LoadMode load_mode)1216 Status ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
1217                          const std::vector<std::string> &selected_columns,
1218                          const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded,
1219                          LoadMode load_mode) {
1220   load_mode_ = load_mode;
1221 
1222   // Open file and set header by ShardReader
1223   RETURN_IF_NOT_OK_MR(Init(file_paths, load_dataset));
1224   auto thread_limit = GetMaxThreadNum();
1225   if (n_consumer > thread_limit) {
1226     n_consumer = thread_limit;
1227   }
1228   if (n_consumer < kMinConsumerCount) {
1229     n_consumer = kMinConsumerCount;
1230   }
1231 
1232   selected_columns_ = selected_columns;
1233   RETURN_IF_NOT_OK_MR(CheckColumnList(selected_columns_));
1234 
1235   // Initialize argument
1236   shard_count_ = static_cast<int>(file_paths_.size());
1237   n_consumer_ = n_consumer;
1238   num_padded_ = num_padded;
1239 
1240   operators_ = operators;
1241   RETURN_IF_NOT_OK_MR(Open(n_consumer));
1242   return Status::OK();
1243 }
1244 
Launch(bool is_sample_read)1245 Status ShardReader::Launch(bool is_sample_read) {
1246   // Get all row groups' info
1247   auto row_group_summary = ReadRowGroupSummary();
1248 
1249   // Sort row group by (group_id, shard_id), prepare for parallel reading
1250   std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
1251   auto status = CreateTasks(row_group_summary, operators_);
1252   if (status.IsError()) {
1253     interrupt_ = true;
1254     return status;
1255   }
1256   if (is_sample_read) {
1257     return Status::OK();
1258   }
1259   // Start provider consumer threads
1260   thread_set_ = std::vector<std::thread>(n_consumer_);
1261   CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
1262                                   "Invalid data, 'num_parallel_workers' should be less than or equal to " +
1263                                     std::to_string(kMaxConsumerCount) + "but got: " + std::to_string(n_consumer_));
1264 
1265   for (int x = 0; x < n_consumer_; ++x) {
1266     thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
1267   }
1268 
1269   MS_LOG(INFO) << "Succeed to launch read thread.";
1270   return Status::OK();
1271 }
1272 
CreateTasksByCategory(const std::shared_ptr<ShardOperator> & op)1273 Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
1274   CheckIfColumnInIndex(selected_columns_);
1275   auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
1276   auto categories = category_op->GetCategories();
1277   int64_t num_elements = category_op->GetNumElements();
1278   int64_t num_samples = 0;
1279   if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
1280     num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
1281     CHECK_FAIL_RETURN_UNEXPECTED_MR(
1282       num_samples >= 0,
1283       "Invalid data, 'num_samples' should be greater than or equal to 0, but got: " + std::to_string(num_samples));
1284   }
1285   CHECK_FAIL_RETURN_UNEXPECTED_MR(
1286     num_elements > 0,
1287     "[Internal ERROR] 'num_elements' should be greater than 0, but got: " + std::to_string(num_elements));
1288   if (categories.empty() == true) {
1289     std::string category_field = category_op->GetCategoryField();
1290     int64_t num_categories = category_op->GetNumCategories();
1291     CHECK_FAIL_RETURN_UNEXPECTED_MR(
1292       num_categories > 0,
1293       "[Internal ERROR] 'num_categories' should be greater than 0, but got: " + std::to_string(num_categories));
1294     auto category_ptr = std::make_shared<std::set<std::string>>();
1295     RETURN_IF_NOT_OK_MR(GetAllClasses(category_field, category_ptr));
1296     int i = 0;
1297     for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
1298       categories.emplace_back(category_field, *it);
1299       i++;
1300     }
1301   }
1302   // Generate a vector of task lists.  Each catogory has a list of tasks.
1303   std::vector<ShardTaskList> categoryTasks(categories.size());
1304   for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
1305     int category_index = 0;
1306     for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
1307       auto pages_ptr = std::make_shared<std::vector<uint64_t>>();
1308       RETURN_IF_NOT_OK_MR(GetPagesByCategory(shard_id, categories[categoryNo], &pages_ptr));
1309       for (const auto &page_id : *pages_ptr) {
1310         if (category_index >= num_elements) {
1311           break;
1312         }
1313         std::shared_ptr<Page> page_ptr;
1314         RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, page_id, &page_ptr));
1315         auto group_id = page_ptr->GetPageTypeID();
1316         std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
1317         RETURN_IF_NOT_OK_MR(
1318           ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_, &row_group_brief_ptr));
1319         auto offsets = std::get<3>(*row_group_brief_ptr);
1320 
1321         auto number_of_rows = offsets.size();
1322         for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
1323           if (category_index < num_elements) {
1324             categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
1325                                                  std::get<3>(*row_group_brief_ptr)[iStart],
1326                                                  std::get<4>(*row_group_brief_ptr)[iStart]);
1327             category_index++;
1328           }
1329         }
1330         MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks.";
1331       }
1332     }
1333   }
1334   tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
1335 
1336   tasks_.InitSampleIds();
1337   RETURN_IF_NOT_OK_MR((*category_op)(tasks_));
1338   return Status::OK();
1339 }
1340 
CreateTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1341 Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1342                                      const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1343   CheckIfColumnInIndex(selected_columns_);
1344   std::shared_ptr<ROW_GROUPS> row_group_ptr;
1345   RETURN_IF_NOT_OK_MR(ReadAllRowGroup(selected_columns_, &row_group_ptr));
1346   auto &offsets = std::get<0>(*row_group_ptr);
1347   auto &local_columns = std::get<1>(*row_group_ptr);
1348   int sample_count = 0;
1349   for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
1350     sample_count += offsets[shard_id].size();
1351   }
1352   CHECK_FAIL_RETURN_UNEXPECTED_MR(sample_count == num_rows_, "Unequal number of index entries and data entries.");
1353   MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1354 
1355   // Init the tasks_ size
1356   tasks_.ResizeTask(sample_count);
1357 
1358   // Init the task threads, maybe use ThreadPool is better
1359   std::vector<std::thread> init_tasks_thread(shard_count_);
1360 
1361   uint32_t current_offset = 0;
1362   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1363     init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() {
1364 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1365       pthread_setname_np(pthread_self(), std::string("ParallelCreateTasks" + std::to_string(shard_id)).c_str());
1366 #endif
1367       auto offset = current_offset;
1368       for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
1369         tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
1370                           std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
1371                           local_columns[shard_id][i]);
1372         offset++;
1373       }
1374     });
1375     current_offset += offsets[shard_id].size();
1376   }
1377 
1378   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1379     init_tasks_thread[shard_id].join();
1380   }
1381   return Status::OK();
1382 }
1383 
CreateLazyTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1384 Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1385                                          const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1386   CheckIfColumnInIndex(selected_columns_);
1387   uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
1388   MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1389 
1390   // Init the tasks_ size
1391   tasks_.ResizeTask(sample_count);
1392 
1393   // Init the task threads, maybe use ThreadPool is better
1394   std::vector<std::thread> init_tasks_thread(shard_count_);
1395 
1396   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1397     // the offset indicate the shard start
1398     uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];
1399 
1400     // the count indicate the number of samples in the shard
1401     uint32_t shard_count =
1402       shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
1403     init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
1404 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1405       pthread_setname_np(pthread_self(), std::string("ParallelCreateLazyTasks" + std::to_string(shard_id)).c_str());
1406 #endif
1407       for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
1408         // here "i - current_offset" indicate the sample id in the shard
1409         tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
1410       }
1411     });
1412   }
1413 
1414   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1415     init_tasks_thread[shard_id].join();
1416   }
1417   return Status::OK();
1418 }
1419 
CreateSlowTasksByRow()1420 Status ShardReader::CreateSlowTasksByRow() {
1421   CheckIfColumnInIndex(selected_columns_);
1422   uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
1423   MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1424   tasks_.padded_sample_ = num_padded_;
1425   tasks_.SetShardSampleCount(shard_sample_count_);
1426   return Status::OK();
1427 }
1428 
CreateTasks(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1429 Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1430                                 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1431   int category_operator = -1;
1432   for (uint32_t i = 0; i < operators.size(); ++i) {
1433     const auto &op = operators[i];
1434     if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1435       category_operator = static_cast<int>(i);
1436       break;
1437     }
1438   }
1439 
1440   if (-1 == category_operator) {
1441     if (load_mode_ != LoadMode::kSlow) {
1442       if (load_mode_ == LoadMode::kLazy) {
1443         RETURN_IF_NOT_OK_MR(CreateLazyTasksByRow(row_group_summary, operators));
1444       } else {
1445         RETURN_IF_NOT_OK_MR(CreateTasksByRow(row_group_summary, operators));
1446       }
1447 
1448       // need padded sample to the task
1449       if (num_padded_ > 0) {
1450         for (auto i = 0; i < num_padded_; ++i) {
1451           tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
1452         }
1453       }
1454     } else {
1455       RETURN_IF_NOT_OK_MR(CreateSlowTasksByRow());
1456     }
1457   } else {
1458     RETURN_IF_NOT_OK_MR(CreateTasksByCategory(operators[category_operator]));
1459   }
1460 
1461   MS_LOG(DEBUG) << "Succeed to create " << tasks_.Size() << " initial task to start with before sampling.";
1462   if (load_mode_ != LoadMode::kSlow) {
1463     tasks_.InitSampleIds();
1464 
1465     for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
1466       const auto &op = operators[operator_no];
1467       if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1468         continue;
1469       }
1470 
1471       if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
1472         op->SetShardSampleCount(shard_sample_count_);
1473       }
1474       RETURN_IF_NOT_OK_MR((*op)(tasks_));
1475     }
1476 
1477     if (tasks_.permutation_.empty()) {
1478       tasks_.MakePerm();
1479     }
1480   } else {
1481     for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
1482       const auto &op = operators[operator_no];
1483       CHECK_FAIL_RETURN_UNEXPECTED_MR(
1484         !std::dynamic_pointer_cast<ShardCategory>(op),
1485         "[Internal ERROR] The retrieval function is not available when in slow loading mode.");
1486       if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
1487         op->SetShardSampleCount(shard_sample_count_);
1488       }
1489       RETURN_IF_NOT_OK_MR((*op)(tasks_));
1490     }
1491   }
1492 
1493   num_rows_ = tasks_.Size();
1494   MS_LOG(INFO) << "The total number of samples is " << num_rows_
1495                << ", the number of samples after sampling is: " << tasks_.SizeAfterSampling();
1496 
1497   return Status::OK();
1498 }
1499 
ConsumerOneTask(int64_t task_id,uint32_t consumer_id,std::shared_ptr<TASK_CONTENT> * task_content_ptr)1500 Status ShardReader::ConsumerOneTask(int64_t task_id, uint32_t consumer_id,
1501                                     std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
1502   RETURN_UNEXPECTED_IF_NULL_MR(task_content_ptr);
1503   if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
1504     // All tasks are done
1505     CHECK_FAIL_RETURN_UNEXPECTED_MR(task_id < tasks_.Size(), "[Internal ERROR] 'task_id': " + std::to_string(task_id) +
1506                                                                " is out of bound: " + std::to_string(tasks_.Size()));
1507   } else {
1508     CHECK_FAIL_RETURN_UNEXPECTED_MR(
1509       task_id < (num_padded_ + shard_sample_count_[shard_sample_count_.size() - 1]),
1510       "[Internal ERROR] 'task_id': " + std::to_string(task_id) +
1511         " is out of bound: " + std::to_string(num_padded_ + shard_sample_count_[shard_sample_count_.size() - 1]));
1512   }
1513 
1514   uint32_t shard_id = 0;
1515   uint32_t group_id = 0;
1516   uint32_t blob_start = 0;
1517   uint32_t blob_end = 0;
1518   json var_fields;
1519   // Pick up task from task list
1520   ShardTask task = tasks_.GetTaskByID(task_id);
1521 
1522   // check task type
1523   auto task_type = std::get<0>(task);
1524   if (task_type == TaskType::kPaddedTask) {
1525     *task_content_ptr =
1526       std::make_shared<TASK_CONTENT>(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1527     return Status::OK();
1528   }
1529 
1530   shard_id = std::get<0>(std::get<1>(task));  // shard id
1531 
1532   if (load_mode_ == LoadMode::kLazy || load_mode_ == LoadMode::kSlow) {
1533     // get scalar variable fields by sample id
1534     uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));
1535 
1536     // read the meta from index
1537     std::shared_ptr<ROW_GROUPS> row_group_ptr;
1538     RETURN_IF_NOT_OK_MR(
1539       ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, consumer_id, sample_id_in_shard, &row_group_ptr));
1540     auto &offsets = std::get<0>(*row_group_ptr);
1541     auto &local_columns = std::get<1>(*row_group_ptr);
1542 
1543     group_id = offsets[shard_id][0][1];       // group_id
1544     blob_start = offsets[shard_id][0][2];     // blob start
1545     blob_end = offsets[shard_id][0][3];       // blob end
1546     var_fields = local_columns[shard_id][0];  // scalar variable field
1547   } else {
1548     group_id = std::get<1>(std::get<1>(task));  // group id
1549     blob_start = std::get<2>(task)[0];          // blob start
1550     blob_end = std::get<2>(task)[1];            // blob end
1551     var_fields = std::get<3>(task);             // scalar variable field
1552   }
1553 
1554   // read the blob from data file
1555   std::shared_ptr<Page> page_ptr;
1556   RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
1557   MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;
1558 
1559   // Pack image list
1560   std::vector<uint8_t> images(blob_end - blob_start);
1561   auto file_offset = header_size_ + page_size_ * (page_ptr->GetPageID()) + blob_start;
1562 
1563   auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
1564   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
1565     file_streams_random_[consumer_id][shard_id]->close();
1566     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
1567   }
1568   auto &io_read =
1569     file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
1570   if (!io_read.good() || io_read.fail() || io_read.bad()) {
1571     file_streams_random_[consumer_id][shard_id]->close();
1572     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
1573   }
1574 
1575   // Deliver batch data to output map
1576   std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
1577   batch.emplace_back(std::move(images), std::move(var_fields));
1578 
1579   *task_content_ptr = std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::move(batch));
1580   return Status::OK();
1581 }
1582 
ConsumerByRow(int consumer_id)1583 void ShardReader::ConsumerByRow(int consumer_id) {
1584   // Set thread name
1585 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1586   pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(consumer_id)).c_str());
1587 #endif
1588 
1589   // Loop forever
1590   for (;;) {
1591     int64_t sample_id_pos = 0;
1592 
1593     // Get next task ID
1594     sample_id_pos = sample_id_position_++;
1595 
1596     auto task_content_ptr =
1597       std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1598     int64_t task_id = 0;
1599 
1600     if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
1601       // All tasks are done
1602       if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
1603         return;
1604       }
1605       task_id = tasks_.sample_ids_[sample_id_pos];
1606     } else {
1607       // task_id is not correct when slow load mode
1608       if (sample_id_pos >= shard_sample_count_[shard_sample_count_.size() - 1]) {
1609         return;
1610       }
1611       task_id = sample_id_pos;
1612     }
1613     if (ConsumerOneTask(task_id, consumer_id, &task_content_ptr).IsError()) {
1614       MS_LOG(ERROR) << "[Internal ERROR] Error raised in ConsumerOneTask function.";
1615       interrupt_ = true;
1616       cv_iterator_.notify_one();
1617       return;
1618     }
1619     const auto &batch = (*task_content_ptr).second;
1620     // Hanging if maximum map size exceeded
1621     //   otherwise, set batch data in map
1622     {
1623       std::unique_lock<std::mutex> lck(mtx_delivery_);
1624       cv_delivery_.wait(lck,
1625                         [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
1626       if (interrupt_) {
1627         return;
1628       }
1629       delivery_map_[sample_id_pos] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(batch);
1630     }
1631     cv_iterator_.notify_one();
1632   }
1633 }
1634 
GetNext()1635 std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> ShardReader::GetNext() {
1636   if (interrupt_) {
1637     return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1638   }
1639 
1640   if (deliver_id_ >= static_cast<int>(tasks_.SizeAfterSampling())) {
1641     return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1642   }
1643 
1644   std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>> res;
1645   {
1646     std::unique_lock<std::mutex> lck(mtx_delivery_);
1647     cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); });
1648     if (interrupt_) {
1649       return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1650     }
1651     res = delivery_map_[deliver_id_];
1652     delivery_map_.erase(deliver_id_++);
1653   }
1654 
1655   cv_delivery_.notify_all();
1656 
1657   // extract every blob field from blob data
1658   std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> res_with_blobs;
1659   for (auto iter = res->begin(); iter != res->end(); iter++) {
1660     std::map<std::string, std::vector<uint8_t>> key_with_blob_fields;
1661     auto shard_column = GetShardColumn();
1662     auto schema = shard_header_->GetSchemas();  // current, we only support 1 schema yet
1663     auto blob_fields = schema[0]->GetBlobFields();
1664     for (auto blob_field : blob_fields) {
1665       const unsigned char *data = nullptr;
1666       std::unique_ptr<unsigned char[]> data_ptr;
1667       uint64_t n_bytes = 0;
1668       mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
1669       uint64_t column_data_type_size = 1;
1670       std::vector<int64_t> column_shape;
1671       if (shard_column->GetColumnValueByName(blob_field, std::get<0>(*iter), std::get<1>(*iter), &data, &data_ptr,
1672                                              &n_bytes, &column_data_type, &column_data_type_size,
1673                                              &column_shape) != Status::OK()) {
1674         MS_LOG(ERROR) << "[Internal ERROR] Failed to extract blob fields from blob data";
1675         return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1676       }
1677       key_with_blob_fields[blob_field] = std::vector<uint8_t>(data, data + n_bytes);
1678     }
1679 
1680     res_with_blobs.emplace_back(std::move(key_with_blob_fields), std::move(std::get<1>(*iter)));
1681   }
1682 
1683   return res_with_blobs;
1684 }
1685 
GetNextById(const int64_t & task_id,const int32_t & consumer_id,std::shared_ptr<TASK_CONTENT> * task_content_ptr)1686 Status ShardReader::GetNextById(const int64_t &task_id, const int32_t &consumer_id,
1687                                 std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
1688   if (interrupt_) {
1689     return Status::OK();
1690   }
1691   RETURN_IF_NOT_OK_MR(ConsumerOneTask(task_id, consumer_id, task_content_ptr));
1692   return Status::OK();
1693 }
1694 
UnCompressBlob(const std::vector<uint8_t> & raw_blob_data,std::shared_ptr<std::vector<std::vector<uint8_t>>> * blob_data_ptr)1695 Status ShardReader::UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
1696                                    std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr) {
1697   RETURN_UNEXPECTED_IF_NULL_MR(blob_data_ptr);
1698   auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
1699   auto blob_fields = GetBlobFields().second;
1700   for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
1701     if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) {
1702       continue;
1703     }
1704     const unsigned char *data = nullptr;
1705     std::unique_ptr<unsigned char[]> data_ptr;
1706     uint64_t n_bytes = 0;
1707     RETURN_IF_NOT_OK_MR(
1708       shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes));
1709     if (data == nullptr) {
1710       data = reinterpret_cast<const unsigned char *>(data_ptr.get());
1711     }
1712     std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
1713     (*blob_data_ptr)->push_back(column);
1714   }
1715   return Status::OK();
1716 }
1717 
GetTotalBlobSize(int64_t * total_blob_size)1718 Status ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
1719   *total_blob_size = total_blob_size_;
1720   return Status::OK();
1721 }
1722 
Reset()1723 void ShardReader::Reset() {
1724   {
1725     std::lock_guard<std::mutex> lck(mtx_delivery_);
1726     sample_id_position_ = 0;
1727     deliver_id_ = 0;
1728   }
1729   cv_delivery_.notify_all();
1730 }
1731 
ShuffleTask()1732 void ShardReader::ShuffleTask() {
1733   // exist shuffle and distributed sampler in ops, skip shuffle
1734   bool has_sharding = false;
1735   for (const auto &op : operators_) {
1736     if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1737       has_sharding = true;
1738     }
1739   }
1740   for (const auto &op : operators_) {
1741     if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
1742       auto s = (*op)(tasks_);
1743       if (s.IsError()) {
1744         MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
1745       }
1746     } else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1747       auto s = (*op)(tasks_);
1748       if (s.IsError()) {
1749         MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
1750       }
1751     }
1752   }
1753   if (load_mode_ != kSlow) {
1754     if (tasks_.permutation_.empty()) {
1755       tasks_.MakePerm();
1756     }
1757   } else {
1758     tasks_.generator_ids_.ResetShardIndexAndID();
1759   }
1760 }
1761 
GetSampleIds()1762 const std::vector<int64_t> *ShardReader::GetSampleIds() {
1763   // return const reference to private sample id list.
1764   return &(this->tasks_.sample_ids_);
1765 }
1766 
GetLoadMode() const1767 LoadMode ShardReader::GetLoadMode() const { return load_mode_; }
1768 
GetNextSampleIds()1769 std::vector<int64_t> ShardReader::GetNextSampleIds() { return tasks_.GetNextSampleIds(); }
1770 }  // namespace mindrecord
1771 }  // namespace mindspore
1772