• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_reader.h"
18 
19 #include <algorithm>
20 #include <thread>
21 
22 #include "utils/file_utils.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "utils/ms_utils.h"
25 
26 using mindspore::LogStream;
27 using mindspore::ExceptionType::NoExceptionType;
28 using mindspore::MsLogLevel::DEBUG;
29 using mindspore::MsLogLevel::ERROR;
30 using mindspore::MsLogLevel::INFO;
31 
32 namespace mindspore {
33 namespace mindrecord {
34 template <class Type>
35 // convert the string to exactly number type (int32_t/int64_t/float/double)
StringToNum(const std::string & str)36 Type StringToNum(const std::string &str) {
37   std::istringstream iss(str);
38   Type num;
39   iss >> num;
40   return num;
41 }
42 
ShardReader()43 ShardReader::ShardReader()
44     : header_size_(0),
45       page_size_(0),
46       shard_count_(0),
47       n_consumer_(0),
48       num_padded_(0),
49       num_rows_(0),
50       total_blob_size_(0),
51       sample_id_position_(0),
52       deliver_id_(0),
53       lazy_load_(false),
54       shard_sample_count_() {}
55 
GetMeta(const std::string & file_path,std::shared_ptr<json> meta_data_ptr,std::shared_ptr<std::vector<std::string>> * addresses_ptr)56 Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
57                             std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
58   RETURN_UNEXPECTED_IF_NULL(addresses_ptr);
59   CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(file_path), "Invalid file, path: " + file_path);
60   std::shared_ptr<json> header_ptr;
61   RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path, &header_ptr));
62 
63   *meta_data_ptr = {{"header_size", (*header_ptr)["header_size"]}, {"page_size", (*header_ptr)["page_size"]},
64                     {"version", (*header_ptr)["version"]},         {"index_fields", (*header_ptr)["index_fields"]},
65                     {"schema", (*header_ptr)["schema"]},           {"blob_fields", (*header_ptr)["blob_fields"]}};
66   *addresses_ptr = std::make_shared<std::vector<std::string>>((*header_ptr)["shard_addresses"]);
67   return Status::OK();
68 }
69 
Init(const std::vector<std::string> & file_paths,bool load_dataset)70 Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
71   std::string file_path = file_paths[0];
72   auto first_meta_data_ptr = std::make_shared<json>();
73   std::shared_ptr<std::vector<std::string>> addresses_ptr;
74   RETURN_IF_NOT_OK(GetMeta(file_path, first_meta_data_ptr, &addresses_ptr));
75   if (file_paths.size() == 1 && load_dataset == true) {
76     auto ds = std::make_shared<std::vector<std::string>>();
77     RETURN_IF_NOT_OK(GetDatasetFiles(file_path, *addresses_ptr, &ds));
78     file_paths_ = *ds;
79   } else if (file_paths.size() >= 1 && load_dataset == false) {
80     file_paths_ = file_paths;
81   } else {
82     RETURN_STATUS_UNEXPECTED("Invalid data, number of MindRecord files [" + std::to_string(file_paths.size()) +
83                              "] or 'load_dataset' [" + std::to_string(load_dataset) + "]is invalid.");
84   }
85   for (const auto &file : file_paths_) {
86     auto meta_data_ptr = std::make_shared<json>();
87     RETURN_IF_NOT_OK(GetMeta(file, meta_data_ptr, &addresses_ptr));
88     CHECK_FAIL_RETURN_UNEXPECTED(*meta_data_ptr == *first_meta_data_ptr,
89                                  "Invalid data, MindRecord files meta data is not consistent.");
90     sqlite3 *db = nullptr;
91     RETURN_IF_NOT_OK(VerifyDataset(&db, file));
92     database_paths_.push_back(db);
93   }
94   ShardHeader sh = ShardHeader();
95   RETURN_IF_NOT_OK(sh.BuildDataset(file_paths_, load_dataset));
96   shard_header_ = std::make_shared<ShardHeader>(sh);
97   header_size_ = shard_header_->GetHeaderSize();
98   page_size_ = shard_header_->GetPageSize();
99   // version < 3.0
100   if ((*first_meta_data_ptr)["version"] < kVersion) {
101     shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
102   } else {
103     shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
104   }
105   num_rows_ = 0;
106   auto row_group_summary = ReadRowGroupSummary();
107 
108   // clear the shard_sample_count_, because it will be insert when Launch func
109   shard_sample_count_.clear();
110 
111   for (const auto &rg : row_group_summary) {
112     num_rows_ += std::get<3>(rg);
113   }
114 
115   if (num_rows_ > LAZY_LOAD_THRESHOLD) {
116     lazy_load_ = true;
117     MS_LOG(WARNING)
118       << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
119       << ", enable lazy load mode. If you want to speed up data loading, "
120       << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
121       << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
122       << "according to the current samples.";
123   }
124 
125   auto disk_size = page_size_ * row_group_summary.size();
126   auto compression_size = shard_header_->GetCompressionSize();
127   total_blob_size_ = disk_size + compression_size;
128   MS_LOG(INFO) << "Blob data size on disk: " << disk_size << " , additional uncompression size: " << compression_size
129                << " , Total blob size: " << total_blob_size_;
130 
131   MS_LOG(INFO) << "Succeed to get meta from mindrecord file & index file.";
132 
133   return Status::OK();
134 }
135 
VerifyDataset(sqlite3 ** db,const string & file)136 Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
137   // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
138   CHECK_FAIL_RETURN_UNEXPECTED(
139     sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
140     "Invalid database file, path: " + file + ".db, " + sqlite3_errmsg(*db));
141   MS_LOG(DEBUG) << "Succeed to Open database, path: " << file << ".db.";
142 
143   string sql = "SELECT NAME from SHARD_NAME;";
144   std::vector<std::vector<std::string>> name;
145   char *errmsg = nullptr;
146   if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
147     std::ostringstream oss;
148     oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
149     sqlite3_free(errmsg);
150     sqlite3_close(*db);
151     RETURN_STATUS_UNEXPECTED(oss.str());
152   } else {
153     MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(name.size()) << " records from index.";
154     std::shared_ptr<std::string> fn_ptr;
155     RETURN_IF_NOT_OK(GetFileName(file, &fn_ptr));
156     if (name.empty() || name[0][0] != *fn_ptr) {
157       sqlite3_free(errmsg);
158       sqlite3_close(*db);
159       RETURN_STATUS_UNEXPECTED("Invalid database file, shard name [" + *fn_ptr + "] can not match [" + name[0][0] +
160                                "].");
161     }
162   }
163   return Status::OK();
164 }
165 
CheckColumnList(const std::vector<std::string> & selected_columns)166 Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
167   vector<int> inSchema(selected_columns.size(), 0);
168   for (auto &p : GetShardHeader()->GetSchemas()) {
169     auto schema = p->GetSchema()["schema"];
170     for (unsigned int i = 0; i < selected_columns.size(); ++i) {
171       if (schema.find(selected_columns[i]) != schema.end()) {
172         inSchema[i] = 1;
173       }
174     }
175   }
176   CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; }),
177                                "Invalid data, column is not found in schema.");
178   return Status::OK();
179 }
180 
Open()181 Status ShardReader::Open() {
182   file_streams_.clear();
183   for (const auto &file : file_paths_) {
184     std::optional<std::string> dir = "";
185     std::optional<std::string> local_file_name = "";
186     FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
187     if (!dir.has_value()) {
188       dir = ".";
189     }
190 
191     auto realpath = FileUtils::GetRealPath(dir.value().data());
192     CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
193 
194     std::optional<std::string> whole_path = "";
195     FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
196 
197     std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
198     fs->open(whole_path.value(), std::ios::in | std::ios::binary);
199     if (!fs->good()) {
200       RETURN_STATUS_UNEXPECTED(
201         "Failed to open file: " + file +
202         ", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
203     }
204     MS_LOG(INFO) << "Succeed to open shard file.";
205     file_streams_.push_back(fs);
206   }
207   return Status::OK();
208 }
209 
Open(int n_consumer)210 Status ShardReader::Open(int n_consumer) {
211   file_streams_random_ =
212     std::vector<std::vector<std::shared_ptr<std::fstream>>>(n_consumer, std::vector<std::shared_ptr<std::fstream>>());
213   for (const auto &file : file_paths_) {
214     for (int j = 0; j < n_consumer; ++j) {
215       std::optional<std::string> dir = "";
216       std::optional<std::string> local_file_name = "";
217       FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
218       if (!dir.has_value()) {
219         dir = ".";
220       }
221 
222       auto realpath = FileUtils::GetRealPath(dir.value().data());
223       CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
224 
225       std::optional<std::string> whole_path = "";
226       FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
227 
228       std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
229       fs->open(whole_path.value(), std::ios::in | std::ios::binary);
230       if (!fs->good()) {
231         RETURN_STATUS_UNEXPECTED(
232           "Failed to open file: " + file +
233           ", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
234       }
235       file_streams_random_[j].push_back(fs);
236     }
237     MS_LOG(INFO) << "Succeed to open file, path: " << file;
238   }
239   return Status::OK();
240 }
241 
FileStreamsOperator()242 void ShardReader::FileStreamsOperator() {
243   for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; --i) {
244     if (file_streams_[i] != nullptr) {
245       file_streams_[i]->close();
246     }
247   }
248   for (int i = static_cast<int>(file_streams_random_.size()) - 1; i >= 0; --i) {
249     for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
250       if (file_streams_random_[i][j] != nullptr) {
251         file_streams_random_[i][j]->close();
252       }
253     }
254   }
255   for (int i = static_cast<int>(database_paths_.size()) - 1; i >= 0; --i) {
256     if (database_paths_[i] != nullptr) {
257       auto ret = sqlite3_close(database_paths_[i]);
258       if (ret != SQLITE_OK) {
259         MS_LOG(ERROR) << "Failed to close database, error code: " << ret << ".";
260       }
261       database_paths_[i] = nullptr;
262     }
263   }
264 }
265 
~ShardReader()266 ShardReader::~ShardReader() { Close(); }
267 
Close()268 void ShardReader::Close() {
269   {
270     std::lock_guard<std::mutex> lck(mtx_delivery_);
271     interrupt_ = true;  // interrupt reading and stop threads
272   }
273   cv_delivery_.notify_all();
274 
275   // Wait for all threads to finish
276   for (auto &i_thread : thread_set_) {
277     if (i_thread.joinable()) {
278       i_thread.join();
279     }
280   }
281 
282   FileStreamsOperator();
283 }
284 
GetShardHeader() const285 std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
286 
GetShardColumn() const287 std::shared_ptr<ShardColumn> ShardReader::GetShardColumn() const { return shard_column_; }
288 
GetShardCount() const289 int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
290 
GetNumRows() const291 int ShardReader::GetNumRows() const { return num_rows_; }
292 
ReadRowGroupSummary()293 std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
294   std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
295   int shard_count = shard_header_->GetShardCount();
296   if (shard_count <= 0) {
297     return row_group_summary;
298   }
299   if (shard_count <= kMaxFileCount) {
300     uint32_t total_count = 0;
301     for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
302       // return -1 when page's size equals to 0.
303       auto last_page_id = shard_header_->GetLastPageId(shard_id);
304       if (static_cast<int>(last_page_id) == -1) {
305         continue;
306       }
307       for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
308         std::shared_ptr<Page> page_ptr;
309         (void)shard_header_->GetPage(shard_id, page_id, &page_ptr);
310         if (page_ptr->GetPageType() != kPageTypeBlob) {
311           continue;
312         }
313         uint64_t start_row_id = page_ptr->GetStartRowID();
314         if (start_row_id > page_ptr->GetEndRowID()) {
315           return std::vector<std::tuple<int, int, int, uint64_t>>();
316         }
317         uint64_t number_of_rows = page_ptr->GetEndRowID() - start_row_id;
318         total_count += number_of_rows;
319         row_group_summary.emplace_back(shard_id, page_ptr->GetPageTypeID(), start_row_id, number_of_rows);
320       }
321       shard_sample_count_.push_back(total_count);
322     }
323   }
324 
325   return row_group_summary;
326 }
327 
ConvertLabelToJson(const std::vector<std::vector<std::string>> & labels,std::shared_ptr<std::fstream> fs,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)328 Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
329                                        std::shared_ptr<std::fstream> fs,
330                                        std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
331                                        int shard_id, const std::vector<std::string> &columns,
332                                        std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
333   for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
334     try {
335       uint64_t group_id = std::stoull(labels[i][0]);
336       uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
337       uint64_t offset_end = std::stoull(labels[i][2]);
338       (*offset_ptr)[shard_id].emplace_back(
339         std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
340       if (!all_in_index_) {
341         int raw_page_id = std::stoi(labels[i][3]);
342         uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
343         uint64_t label_end = std::stoull(labels[i][5]);
344         auto len = label_end - label_start;
345         auto label_raw = std::vector<uint8_t>(len);
346         auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
347         if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
348           fs->close();
349           RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
350         }
351         auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
352         if (!io_read.good() || io_read.fail() || io_read.bad()) {
353           fs->close();
354           RETURN_STATUS_UNEXPECTED("Failed to read file.");
355         }
356         json label_json = json::from_msgpack(label_raw);
357         json tmp;
358         if (!columns.empty()) {
359           for (auto &col : columns) {
360             if (label_json.find(col) != label_json.end()) {
361               tmp[col] = label_json[col];
362             }
363           }
364         } else {
365           tmp = label_json;
366         }
367         (*col_val_ptr)[shard_id].emplace_back(tmp);
368       } else {
369         json construct_json;
370         for (unsigned int j = 0; j < columns.size(); ++j) {
371           // construct json "f1": value
372           auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
373 
374           // convert the string to base type by schema
375           if (schema[columns[j]]["type"] == "int32") {
376             construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
377           } else if (schema[columns[j]]["type"] == "int64") {
378             construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
379           } else if (schema[columns[j]]["type"] == "float32") {
380             construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
381           } else if (schema[columns[j]]["type"] == "float64") {
382             construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
383           } else {
384             construct_json[columns[j]] = std::string(labels[i][j + 3]);
385           }
386         }
387         (*col_val_ptr)[shard_id].emplace_back(construct_json);
388       }
389     } catch (std::out_of_range &e) {
390       fs->close();
391       RETURN_STATUS_UNEXPECTED("Out of range exception raised in ConvertLabelToJson function, " +
392                                std::string(e.what()));
393     } catch (std::invalid_argument &e) {
394       fs->close();
395       RETURN_STATUS_UNEXPECTED("Invalid argument exception raised in ConvertLabelToJson function, " +
396                                std::string(e.what()));
397     } catch (...) {
398       fs->close();
399       RETURN_STATUS_UNEXPECTED("Unknown exception raised in ConvertLabelToJson function");
400     }
401   }
402 
403   fs->close();
404   return Status::OK();
405 }
406 
ReadAllRowsInShard(int shard_id,const std::string & sql,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)407 Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
408                                        std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
409                                        std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
410   auto db = database_paths_[shard_id];
411   std::vector<std::vector<std::string>> labels;
412   char *errmsg = nullptr;
413   int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
414   if (rc != SQLITE_OK) {
415     std::ostringstream oss;
416     oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
417     sqlite3_free(errmsg);
418     sqlite3_close(db);
419     db = nullptr;
420     RETURN_STATUS_UNEXPECTED(oss.str());
421   }
422   MS_LOG(INFO) << "Succeed to get " << static_cast<int>(labels.size()) << " records from shard "
423                << std::to_string(shard_id) << " index.";
424 
425   std::string file_name = file_paths_[shard_id];
426   auto realpath = FileUtils::GetRealPath(file_name.data());
427   if (!realpath.has_value()) {
428     sqlite3_free(errmsg);
429     sqlite3_close(db);
430     RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file_name);
431   }
432 
433   std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
434   if (!all_in_index_) {
435     fs->open(realpath.value(), std::ios::in | std::ios::binary);
436     if (!fs->good()) {
437       sqlite3_free(errmsg);
438       sqlite3_close(db);
439       RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file_name);
440     }
441   }
442   sqlite3_free(errmsg);
443   return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr);
444 }
445 
GetAllClasses(const std::string & category_field,std::shared_ptr<std::set<std::string>> category_ptr)446 Status ShardReader::GetAllClasses(const std::string &category_field,
447                                   std::shared_ptr<std::set<std::string>> category_ptr) {
448   std::map<std::string, uint64_t> index_columns;
449   for (auto &field : GetShardHeader()->GetFields()) {
450     index_columns[field.second] = field.first;
451   }
452   CHECK_FAIL_RETURN_UNEXPECTED(index_columns.find(category_field) != index_columns.end(),
453                                "Invalid data, index field " + category_field + " does not exist.");
454   std::shared_ptr<std::string> fn_ptr;
455   RETURN_IF_NOT_OK(
456     ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
457   std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
458   std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
459   for (int x = 0; x < shard_count_; x++) {
460     threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
461   }
462 
463   for (int x = 0; x < shard_count_; x++) {
464     threads[x].join();
465   }
466   return Status::OK();
467 }
468 
GetClassesInShard(sqlite3 * db,int shard_id,const std::string & sql,std::shared_ptr<std::set<std::string>> category_ptr)469 void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
470                                     std::shared_ptr<std::set<std::string>> category_ptr) {
471   if (db == nullptr) {
472     return;
473   }
474   std::vector<std::vector<std::string>> columns;
475   char *errmsg = nullptr;
476   int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg);
477   if (ret != SQLITE_OK) {
478     sqlite3_free(errmsg);
479     sqlite3_close(db);
480     db = nullptr;
481     MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
482     return;
483   }
484   MS_LOG(INFO) << "Succeed to get " << static_cast<int>(columns.size()) << " records from shard "
485                << std::to_string(shard_id) << " index.";
486   std::lock_guard<std::mutex> lck(shard_locker_);
487   for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
488     category_ptr->emplace(columns[i][0]);
489   }
490   sqlite3_free(errmsg);
491 }
492 
ReadAllRowGroup(const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUPS> * row_group_ptr)493 Status ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns,
494                                     std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
495   RETURN_UNEXPECTED_IF_NULL(row_group_ptr);
496   std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
497   auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
498     shard_count_, std::vector<std::vector<uint64_t>>{});
499   auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
500 
501   if (all_in_index_) {
502     for (unsigned int i = 0; i < columns.size(); ++i) {
503       fields += ',';
504       std::shared_ptr<std::string> fn_ptr;
505       RETURN_IF_NOT_OK(
506         ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
507       fields += *fn_ptr;
508     }
509   } else {  // fetch raw data from Raw page while some field is not index.
510     fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
511   }
512 
513   std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;";
514 
515   std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_);
516   for (int x = 0; x < shard_count_; x++) {
517     thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr);
518   }
519 
520   for (int x = 0; x < shard_count_; x++) {
521     thread_read_db[x].join();
522   }
523   *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
524   return Status::OK();
525 }
526 
ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> & columns,const uint32_t & shard_id,const uint32_t & sample_id,std::shared_ptr<ROW_GROUPS> * row_group_ptr)527 Status ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
528                                                      const uint32_t &sample_id,
529                                                      std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
530   RETURN_UNEXPECTED_IF_NULL(row_group_ptr);
531   std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
532   auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
533     shard_count_, std::vector<std::vector<uint64_t>>{});
534   auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
535   if (all_in_index_) {
536     for (unsigned int i = 0; i < columns.size(); ++i) {
537       fields += ',';
538       std::shared_ptr<std::string> fn_ptr;
539       RETURN_IF_NOT_OK(
540         ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
541       fields += *fn_ptr;
542     }
543   } else {  // fetch raw data from Raw page while some field is not index.
544     fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
545   }
546 
547   std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
548 
549   RETURN_IF_NOT_OK(ReadAllRowsInShard(shard_id, sql, columns, offset_ptr, col_val_ptr));
550   *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
551   return Status::OK();
552 }
553 
ReadRowGroupBrief(int group_id,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)554 Status ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
555                                       std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
556   RETURN_UNEXPECTED_IF_NULL(row_group_brief_ptr);
557   std::shared_ptr<Page> page_ptr;
558   RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
559   std::string file_name = file_paths_[shard_id];
560   uint64_t page_length = page_ptr->GetPageSize();
561   uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
562   std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id);
563   auto labels_ptr = std::make_shared<std::vector<json>>();
564   RETURN_IF_NOT_OK(GetLabels(page_ptr->GetPageID(), shard_id, columns, {"", ""}, &labels_ptr));
565   *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
566                                                            std::move(*labels_ptr));
567   return Status::OK();
568 }
569 
ReadRowGroupCriteria(int group_id,int shard_id,const std::pair<std::string,std::string> & criteria,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)570 Status ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
571                                          const std::pair<std::string, std::string> &criteria,
572                                          const std::vector<std::string> &columns,
573                                          std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
574   RETURN_UNEXPECTED_IF_NULL(row_group_brief_ptr);
575   std::shared_ptr<Page> page_ptr;
576   RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
577   vector<string> criteria_list{criteria.first};
578   RETURN_IF_NOT_OK(CheckColumnList(criteria_list));
579   std::string file_name = file_paths_[shard_id];
580   uint64_t page_length = page_ptr->GetPageSize();
581   uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
582   std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id, criteria);
583   if (image_offset.empty()) {
584     *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>();
585   }
586   auto labels_ptr = std::make_shared<std::vector<json>>();
587   RETURN_IF_NOT_OK(GetLabels(page_ptr->GetPageID(), shard_id, columns, criteria, &labels_ptr));
588   *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
589                                                            std::move(*labels_ptr));
590   return Status::OK();
591 }
592 
SelectCallback(void * p_data,int num_fields,char ** p_fields,char ** p_col_names)593 int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) {
594   auto *records = static_cast<std::vector<std::vector<std::string>> *>(p_data);
595   if (num_fields > 0 && num_fields <= kMaxFieldCount) {
596     for (int i = 0; i < num_fields; ++i)
597       if (p_fields[i] == nullptr) p_fields[i] = const_cast<char *>("");
598   }
599   records->emplace_back(p_fields, p_fields + num_fields);
600   return 0;
601 }
602 
GetImageOffset(int page_id,int shard_id,const std::pair<std::string,std::string> & criteria)603 std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int shard_id,
604                                                                const std::pair<std::string, std::string> &criteria) {
605   auto db = database_paths_[shard_id];
606 
607   std::string sql =
608     "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
609 
610   // whether use index search
611   if (!criteria.first.empty()) {
612     auto schema = shard_header_->GetSchemas()[0]->GetSchema();
613 
614     // not number field should add '' in sql
615     if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
616       sql +=
617         " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
618     } else {
619       sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
620              criteria.second + "'";
621     }
622   }
623   sql += ";";
624   std::vector<std::vector<std::string>> image_offsets;
625   char *errmsg = nullptr;
626   int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg);
627   if (rc != SQLITE_OK) {
628     MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
629     sqlite3_free(errmsg);
630     sqlite3_close(db);
631     db = nullptr;
632     return std::vector<std::vector<uint64_t>>();
633   } else {
634     MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(image_offsets.size()) << " records from index.";
635   }
636   std::vector<std::vector<uint64_t>> res;
637   for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
638   for (int i = 0; i < static_cast<int>(image_offsets.size()); i++) {
639     const auto &image_offset = image_offsets[i];
640     res[i][0] = std::stoull(image_offset[0]) + kInt64Len;
641     res[i][1] = std::stoull(image_offset[1]);
642   }
643   sqlite3_free(errmsg);
644   return res;
645 }
646 
GetPagesByCategory(int shard_id,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<uint64_t>> * pages_ptr)647 Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
648                                        std::shared_ptr<std::vector<uint64_t>> *pages_ptr) {
649   RETURN_UNEXPECTED_IF_NULL(pages_ptr);
650   auto db = database_paths_[shard_id];
651 
652   std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
653 
654   if (!criteria.first.empty()) {
655     auto schema = shard_header_->GetSchemas()[0]->GetSchema();
656     if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
657       sql +=
658         " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
659     } else {
660       sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
661              criteria.second + "'";
662     }
663   }
664   sql += ";";
665   std::vector<std::vector<std::string>> page_ids;
666   char *errmsg = nullptr;
667   int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg);
668   if (rc != SQLITE_OK) {
669     string ss(errmsg);
670     sqlite3_free(errmsg);
671     sqlite3_close(db);
672     db = nullptr;
673     RETURN_STATUS_UNEXPECTED(std::string("Failed to execute sql [") + common::SafeCStr(sql) + " ], " + ss);
674   } else {
675     MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << "pages from index.";
676   }
677   for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
678     (*pages_ptr)->emplace_back(std::stoull(page_ids[i][0]));
679   }
680   sqlite3_free(errmsg);
681   return Status::OK();
682 }
683 
GetBlobFields()684 std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
685   std::vector<std::string> blob_fields;
686   for (auto &p : GetShardHeader()->GetSchemas()) {
687     // assume one schema
688     const auto &fields = p->GetBlobFields();
689     blob_fields.assign(fields.begin(), fields.end());
690     break;
691   }
692   return std::make_pair(kCV, blob_fields);
693 }
694 
CheckIfColumnInIndex(const std::vector<std::string> & columns)695 void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) {
696   // assume different schemas do not contain same key.
697   if (columns.empty()) {
698     all_in_index_ = false;
699     return;
700   }
701   for (auto &field : GetShardHeader()->GetFields()) {
702     column_schema_id_[field.second] = field.first;
703   }
704   for (auto &col : columns) {
705     if (column_schema_id_.find(col) == column_schema_id_.end()) {
706       all_in_index_ = false;
707       return;
708     }
709   }
710 }
711 
QueryWithCriteria(sqlite3 * db,const string & sql,const string & criteria,std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr)712 Status ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
713                                       std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
714   sqlite3_stmt *stmt = nullptr;
715   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
716     RETURN_STATUS_UNEXPECTED("Failed to prepare statement sql [ " + sql + " ].");
717   }
718   int index = sqlite3_bind_parameter_index(stmt, ":criteria");
719   if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
720     RETURN_STATUS_UNEXPECTED("Failed to bind parameter of sql, index: " + std::to_string(index) +
721                              ", field value: " + criteria);
722   }
723   int rc = sqlite3_step(stmt);
724   while (rc != SQLITE_DONE) {
725     vector<string> tmp;
726     int ncols = sqlite3_column_count(stmt);
727     for (int i = 0; i < ncols; i++) {
728       tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
729     }
730     labels_ptr->push_back(tmp);
731     rc = sqlite3_step(stmt);
732   }
733   (void)sqlite3_finalize(stmt);
734   return Status::OK();
735 }
736 
GetLabelsFromBinaryFile(int shard_id,const std::vector<std::string> & columns,const std::vector<std::vector<std::string>> & label_offsets,std::shared_ptr<std::vector<json>> * labels_ptr)737 Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
738                                             const std::vector<std::vector<std::string>> &label_offsets,
739                                             std::shared_ptr<std::vector<json>> *labels_ptr) {
740   RETURN_UNEXPECTED_IF_NULL(labels_ptr);
741   std::string file_name = file_paths_[shard_id];
742   auto realpath = FileUtils::GetRealPath(file_name.data());
743   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path=" + file_name);
744 
745   std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
746   fs->open(realpath.value(), std::ios::in | std::ios::binary);
747   CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file, path: " + file_name);
748   // init the return
749   for (unsigned int i = 0; i < label_offsets.size(); ++i) {
750     (*labels_ptr)->emplace_back(json{});
751   }
752 
753   for (unsigned int i = 0; i < label_offsets.size(); ++i) {
754     const auto &labelOffset = label_offsets[i];
755     if (labelOffset.size() < 3) {
756       fs->close();
757       RETURN_STATUS_UNEXPECTED("Invalid data, labelOffset size: " + std::to_string(labelOffset.size()) +
758                                " is invalid.");
759     }
760     uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len;
761     uint64_t label_end = std::stoull(labelOffset[2]);
762     int raw_page_id = std::stoi(labelOffset[0]);
763     auto len = label_end - label_start;
764     auto label_raw = std::vector<uint8_t>(len);
765     auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
766     if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
767       fs->close();
768       RETURN_STATUS_UNEXPECTED("Failed to seekg file, path: " + file_name);
769     }
770 
771     auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
772     if (!io_read.good() || io_read.fail() || io_read.bad()) {
773       fs->close();
774       RETURN_STATUS_UNEXPECTED("Failed to read file, path: " + file_name);
775     }
776 
777     json label_json = json::from_msgpack(label_raw);
778     json tmp = label_json;
779     for (auto &col : columns) {
780       if (label_json.find(col) != label_json.end()) {
781         tmp[col] = label_json[col];
782       }
783     }
784     (*(*labels_ptr))[i] = tmp;
785   }
786   return Status::OK();
787 }
GetLabelsFromPage(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)788 Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
789                                       const std::pair<std::string, std::string> &criteria,
790                                       std::shared_ptr<std::vector<json>> *labels_ptr) {
791   RETURN_UNEXPECTED_IF_NULL(labels_ptr);
792   // get page info from sqlite
793   auto db = database_paths_[shard_id];
794   std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " +
795                     std::to_string(page_id);
796   auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
797   if (!criteria.first.empty()) {
798     sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
799     RETURN_IF_NOT_OK(QueryWithCriteria(db, sql, criteria.second, label_offset_ptr));
800   } else {
801     sql += ";";
802     char *errmsg = nullptr;
803     int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
804     if (rc != SQLITE_OK) {
805       std::ostringstream oss;
806       oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
807       sqlite3_free(errmsg);
808       sqlite3_close(db);
809       db = nullptr;
810       RETURN_STATUS_UNEXPECTED(oss.str());
811     }
812     MS_LOG(DEBUG) << "Succeed to get " << label_offset_ptr->size() << " records from index.";
813     sqlite3_free(errmsg);
814   }
815   // get labels from binary file
816   return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr, labels_ptr);
817 }
818 
GetLabels(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)819 Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
820                               const std::pair<std::string, std::string> &criteria,
821                               std::shared_ptr<std::vector<json>> *labels_ptr) {
822   RETURN_UNEXPECTED_IF_NULL(labels_ptr);
823   if (all_in_index_) {
824     auto db = database_paths_[shard_id];
825     std::string fields;
826     for (unsigned int i = 0; i < columns.size(); ++i) {
827       if (i > 0) fields += ',';
828       uint64_t schema_id = column_schema_id_[columns[i]];
829       fields += columns[i] + "_" + std::to_string(schema_id);
830     }
831     if (fields.empty()) {
832       fields = "*";
833     }
834     auto labels = std::make_shared<std::vector<std::vector<std::string>>>();
835     std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
836     if (!criteria.first.empty()) {
837       sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria";
838       RETURN_IF_NOT_OK(QueryWithCriteria(db, sql, criteria.second, labels));
839     } else {
840       sql += ";";
841       char *errmsg = nullptr;
842       int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels.get(), &errmsg);
843       if (rc != SQLITE_OK) {
844         std::ostringstream oss;
845         oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
846         sqlite3_free(errmsg);
847         sqlite3_close(db);
848         db = nullptr;
849         RETURN_STATUS_UNEXPECTED(oss.str());
850       } else {
851         MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(labels->size()) << " records from index.";
852       }
853       sqlite3_free(errmsg);
854     }
855     for (unsigned int i = 0; i < labels->size(); ++i) {
856       (*labels_ptr)->emplace_back(json{});
857     }
858     for (unsigned int i = 0; i < labels->size(); ++i) {
859       json construct_json;
860       for (unsigned int j = 0; j < columns.size(); ++j) {
861         // construct json "f1": value
862         auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
863 
864         // convert the string to base type by schema
865         if (schema[columns[j]]["type"] == "int32") {
866           construct_json[columns[j]] = StringToNum<int32_t>((*labels)[i][j]);
867         } else if (schema[columns[j]]["type"] == "int64") {
868           construct_json[columns[j]] = StringToNum<int64_t>((*labels)[i][j]);
869         } else if (schema[columns[j]]["type"] == "float32") {
870           construct_json[columns[j]] = StringToNum<float>((*labels)[i][j]);
871         } else if (schema[columns[j]]["type"] == "float64") {
872           construct_json[columns[j]] = StringToNum<double>((*labels)[i][j]);
873         } else {
874           construct_json[columns[j]] = std::string((*labels)[i][j]);
875         }
876       }
877       (*(*labels_ptr))[i] = construct_json;
878     }
879     return Status::OK();
880   }
881   return GetLabelsFromPage(page_id, shard_id, columns, criteria, labels_ptr);
882 }
883 
ResortRowGroups(std::tuple<int,int,int,int> a,std::tuple<int,int,int,int> b)884 bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, int> b) {
885   return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
886 }
887 
GetNumClasses(const std::string & category_field)888 int64_t ShardReader::GetNumClasses(const std::string &category_field) {
889   auto shard_count = file_paths_.size();
890   auto index_fields = shard_header_->GetFields();
891 
892   std::map<std::string, int64_t> map_schema_id_fields;
893   for (auto &field : index_fields) {
894     map_schema_id_fields[field.second] = field.first;
895   }
896 
897   if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
898     MS_LOG(ERROR) << "Invalid data, field " << category_field << " does not exist.";
899     return -1;
900   }
901   std::shared_ptr<std::string> fn_ptr;
902   (void)ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field),
903                                                &fn_ptr);
904   std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
905   std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
906   auto category_ptr = std::make_shared<std::set<std::string>>();
907   sqlite3 *db = nullptr;
908   for (int x = 0; x < shard_count; x++) {
909     int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
910     if (SQLITE_OK != rc) {
911       MS_LOG(ERROR) << "Failed to open database: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
912       return -1;
913     }
914     threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
915   }
916 
917   for (int x = 0; x < shard_count; x++) {
918     threads[x].join();
919   }
920   sqlite3_close(db);
921   return category_ptr->size();
922 }
923 
CountTotalRows(const std::vector<std::string> & file_paths,bool load_dataset,const std::shared_ptr<ShardOperator> & ops,int64_t * count,const int num_padded)924 Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
925                                    const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) {
926   RETURN_IF_NOT_OK(Init(file_paths, load_dataset));
927   int64_t num_samples = num_rows_;
928   bool root = true;
929   std::stack<std::shared_ptr<ShardOperator>> stack_ops;
930   std::shared_ptr<ShardOperator> op(ops);
931   while (op != nullptr) {
932     stack_ops.push(op);
933     op = op->GetChildOp();
934   }
935   while (!stack_ops.empty()) {
936     op = stack_ops.top();
937     stack_ops.pop();
938     if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
939       num_samples = op->GetNumSamples(num_samples, 0);
940       if (num_padded > 0 && root == true) {
941         num_samples += num_padded;
942         root = false;
943       }
944     } else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
945       auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
946       std::string category_field = category_op->GetCategoryField();
947       auto num_classes = GetNumClasses(category_field);
948       num_samples = category_op->GetNumSamples(num_samples, num_classes);
949       if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
950         auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
951         if (tmp != 0 && num_samples != -1) {
952           num_samples = std::min(num_samples, tmp);
953         }
954         CHECK_FAIL_RETURN_UNEXPECTED(
955           num_samples != -1, "Invalid input, number of samples: " + std::to_string(num_samples) +
956                                " exceeds the upper limit: " + std::to_string(std::numeric_limits<int64_t>::max()));
957       }
958     } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
959       if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
960         auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
961         if (root == true) {
962           sampler_op->SetNumPaddedSamples(num_padded);
963           num_samples = op->GetNumSamples(num_samples, 0);
964           CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1, "Invalid data, dataset size plus number of padded samples: " +
965                                                             std::to_string(num_samples) +
966                                                             " can not be divisible by number of shards.");
967           root = false;
968         }
969       } else {
970         num_samples = op->GetNumSamples(num_samples, 0);
971       }
972     } else {
973       if (num_padded > 0) {
974         num_samples += num_padded;
975       }
976     }
977   }
978   *count = num_samples;
979   return Status::OK();
980 }
981 
Open(const std::vector<std::string> & file_paths,bool load_dataset,int n_consumer,const std::vector<std::string> & selected_columns,const std::vector<std::shared_ptr<ShardOperator>> & operators,int num_padded,bool lazy_load)982 Status ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
983                          const std::vector<std::string> &selected_columns,
984                          const std::vector<std::shared_ptr<ShardOperator>> &operators, int num_padded, bool lazy_load) {
985   lazy_load_ = lazy_load;
986 
987   // Open file and set header by ShardReader
988   RETURN_IF_NOT_OK(Init(file_paths, load_dataset));
989   auto thread_limit = GetMaxThreadNum();
990   if (n_consumer > thread_limit) {
991     n_consumer = thread_limit;
992   }
993   if (n_consumer < kMinConsumerCount) {
994     n_consumer = kMinConsumerCount;
995   }
996 
997   selected_columns_ = selected_columns;
998   RETURN_IF_NOT_OK(CheckColumnList(selected_columns_));
999 
1000   // Initialize argument
1001   shard_count_ = static_cast<int>(file_paths_.size());
1002   n_consumer_ = n_consumer;
1003   num_padded_ = num_padded;
1004 
1005   operators_ = operators;
1006   RETURN_IF_NOT_OK(Open(n_consumer));
1007   return Status::OK();
1008 }
1009 
Launch(bool is_sample_read)1010 Status ShardReader::Launch(bool is_sample_read) {
1011   // Get all row groups' info
1012   auto row_group_summary = ReadRowGroupSummary();
1013 
1014   // Sort row group by (group_id, shard_id), prepare for parallel reading
1015   std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
1016   if (CreateTasks(row_group_summary, operators_).IsError()) {
1017     interrupt_ = true;
1018     RETURN_STATUS_UNEXPECTED("Failed to launch read threads.");
1019   }
1020   if (is_sample_read) {
1021     return Status::OK();
1022   }
1023   // Start provider consumer threads
1024   thread_set_ = std::vector<std::thread>(n_consumer_);
1025   CHECK_FAIL_RETURN_UNEXPECTED(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
1026                                "Invalid data, number of consumer: " + std::to_string(n_consumer_) +
1027                                  " exceeds the upper limit: " + std::to_string(kMaxConsumerCount));
1028 
1029   for (int x = 0; x < n_consumer_; ++x) {
1030     thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
1031   }
1032 
1033   MS_LOG(INFO) << "Succeed to launch read thread.";
1034   return Status::OK();
1035 }
1036 
CreateTasksByCategory(const std::shared_ptr<ShardOperator> & op)1037 Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
1038   CheckIfColumnInIndex(selected_columns_);
1039   auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
1040   auto categories = category_op->GetCategories();
1041   int64_t num_elements = category_op->GetNumElements();
1042   int64_t num_samples = 0;
1043   if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
1044     num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
1045     CHECK_FAIL_RETURN_UNEXPECTED(
1046       num_samples >= 0,
1047       "Invalid input, num_samples must be greater than or equal to 0, but got " + std::to_string(num_samples));
1048   }
1049   CHECK_FAIL_RETURN_UNEXPECTED(
1050     num_elements > 0, "Invalid input, num_elements must be greater than 0, but got " + std::to_string(num_elements));
1051   if (categories.empty() == true) {
1052     std::string category_field = category_op->GetCategoryField();
1053     int64_t num_categories = category_op->GetNumCategories();
1054     CHECK_FAIL_RETURN_UNEXPECTED(num_categories > 0, "Invalid input, num_categories must be greater than 0, but got " +
1055                                                        std::to_string(num_elements));
1056     auto category_ptr = std::make_shared<std::set<std::string>>();
1057     RETURN_IF_NOT_OK(GetAllClasses(category_field, category_ptr));
1058     int i = 0;
1059     for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
1060       categories.emplace_back(category_field, *it);
1061       i++;
1062     }
1063   }
1064   // Generate a vector of task lists.  Each catogory has a list of tasks.
1065   std::vector<ShardTaskList> categoryTasks(categories.size());
1066   for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
1067     int category_index = 0;
1068     for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
1069       auto pages_ptr = std::make_shared<std::vector<uint64_t>>();
1070       RETURN_IF_NOT_OK(GetPagesByCategory(shard_id, categories[categoryNo], &pages_ptr));
1071       for (const auto &page_id : *pages_ptr) {
1072         if (category_index >= num_elements) {
1073           break;
1074         }
1075         std::shared_ptr<Page> page_ptr;
1076         RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, page_id, &page_ptr));
1077         auto group_id = page_ptr->GetPageTypeID();
1078         std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
1079         RETURN_IF_NOT_OK(
1080           ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_, &row_group_brief_ptr));
1081         auto offsets = std::get<3>(*row_group_brief_ptr);
1082 
1083         auto number_of_rows = offsets.size();
1084         for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
1085           if (category_index < num_elements) {
1086             categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
1087                                                  std::get<3>(*row_group_brief_ptr)[iStart],
1088                                                  std::get<4>(*row_group_brief_ptr)[iStart]);
1089             category_index++;
1090           }
1091         }
1092         MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
1093       }
1094     }
1095   }
1096   tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
1097 
1098   tasks_.InitSampleIds();
1099   RETURN_IF_NOT_OK((*category_op)(tasks_));
1100   return Status::OK();
1101 }
1102 
CreateTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1103 Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1104                                      const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1105   CheckIfColumnInIndex(selected_columns_);
1106   std::shared_ptr<ROW_GROUPS> row_group_ptr;
1107   RETURN_IF_NOT_OK(ReadAllRowGroup(selected_columns_, &row_group_ptr));
1108   auto &offsets = std::get<0>(*row_group_ptr);
1109   auto &local_columns = std::get<1>(*row_group_ptr);
1110   CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
1111                                "Invalid data, number of shards: " + std::to_string(shard_count_) +
1112                                  " exceeds the upper limit: " + std::to_string(kMaxFileCount));
1113   int sample_count = 0;
1114   for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
1115     sample_count += offsets[shard_id].size();
1116   }
1117   MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1118 
1119   // Init the tasks_ size
1120   tasks_.ResizeTask(sample_count);
1121 
1122   // Init the task threads, maybe use ThreadPool is better
1123   std::vector<std::thread> init_tasks_thread(shard_count_);
1124 
1125   uint32_t current_offset = 0;
1126   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1127     init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() {
1128       auto offset = current_offset;
1129       for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
1130         tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
1131                           std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
1132                           local_columns[shard_id][i]);
1133         offset++;
1134       }
1135     });
1136     current_offset += offsets[shard_id].size();
1137   }
1138 
1139   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1140     init_tasks_thread[shard_id].join();
1141   }
1142   return Status::OK();
1143 }
1144 
CreateLazyTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1145 Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1146                                          const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1147   CheckIfColumnInIndex(selected_columns_);
1148   CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
1149                                "Invalid data, number of shards: " + std::to_string(shard_count_) +
1150                                  " exceeds the upper limit: " + std::to_string(kMaxFileCount));
1151   uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
1152   MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1153 
1154   // Init the tasks_ size
1155   tasks_.ResizeTask(sample_count);
1156 
1157   // Init the task threads, maybe use ThreadPool is better
1158   std::vector<std::thread> init_tasks_thread(shard_count_);
1159 
1160   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1161     // the offset indicate the shard start
1162     uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];
1163 
1164     // the count indicate the number of samples in the shard
1165     uint32_t shard_count =
1166       shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
1167     init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
1168       for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
1169         // here "i - current_offset" indicate the sample id in the shard
1170         tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
1171       }
1172     });
1173   }
1174 
1175   for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1176     init_tasks_thread[shard_id].join();
1177   }
1178   return Status::OK();
1179 }
1180 
CreateTasks(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1181 Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1182                                 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1183   int category_operator = -1;
1184   for (uint32_t i = 0; i < operators.size(); ++i) {
1185     const auto &op = operators[i];
1186     if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1187       category_operator = static_cast<int>(i);
1188       break;
1189     }
1190   }
1191 
1192   if (-1 == category_operator) {
1193     if (lazy_load_ == false) {
1194       RETURN_IF_NOT_OK(CreateTasksByRow(row_group_summary, operators));
1195     } else {
1196       RETURN_IF_NOT_OK(CreateLazyTasksByRow(row_group_summary, operators));
1197     }
1198 
1199     // need padded sample to the task
1200     if (num_padded_ > 0) {
1201       for (int i = 0; i < num_padded_; ++i) {
1202         tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
1203       }
1204     }
1205   } else {
1206     RETURN_IF_NOT_OK(CreateTasksByCategory(operators[category_operator]));
1207   }
1208   MS_LOG(DEBUG) << "Succeed to create " << tasks_.Size() << " initial task to start with before sampling.";
1209   tasks_.InitSampleIds();
1210 
1211   for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
1212     const auto &op = operators[operator_no];
1213     if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1214       continue;
1215     }
1216 
1217     if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
1218       op->SetShardSampleCount(shard_sample_count_);
1219     }
1220     RETURN_IF_NOT_OK((*op)(tasks_));
1221   }
1222 
1223   if (tasks_.permutation_.empty()) tasks_.MakePerm();
1224   num_rows_ = tasks_.Size();
1225   MS_LOG(INFO) << "The total number of samples is " << num_rows_
1226                << ", the number of samples after sampling is: " << tasks_.sample_ids_.size();
1227 
1228   return Status::OK();
1229 }
1230 
ConsumerOneTask(int task_id,uint32_t consumer_id,std::shared_ptr<TASK_CONTENT> * task_content_ptr)1231 Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
1232                                     std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
1233   RETURN_UNEXPECTED_IF_NULL(task_content_ptr);
1234   // All tasks are done
1235   CHECK_FAIL_RETURN_UNEXPECTED(
1236     task_id < static_cast<int>(tasks_.Size()),
1237     "Invalid data, task id: " + std::to_string(task_id) + " exceeds the upper limit: " + std::to_string(tasks_.Size()));
1238   uint32_t shard_id = 0;
1239   uint32_t group_id = 0;
1240   uint32_t blob_start = 0;
1241   uint32_t blob_end = 0;
1242   json var_fields;
1243   // Pick up task from task list
1244   ShardTask task = tasks_.GetTaskByID(task_id);
1245 
1246   // check task type
1247   auto task_type = std::get<0>(task);
1248   if (task_type == TaskType::kPaddedTask) {
1249     *task_content_ptr =
1250       std::make_shared<TASK_CONTENT>(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1251     return Status::OK();
1252   }
1253 
1254   shard_id = std::get<0>(std::get<1>(task));  // shard id
1255 
1256   if (lazy_load_ == false) {
1257     group_id = std::get<1>(std::get<1>(task));  // group id
1258     blob_start = std::get<2>(task)[0];          // blob start
1259     blob_end = std::get<2>(task)[1];            // blob end
1260     var_fields = std::get<3>(task);             // scalar variable field
1261   } else {
1262     // get scalar variable fields by sample id
1263     uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));
1264 
1265     // read the meta from index
1266     std::shared_ptr<ROW_GROUPS> row_group_ptr;
1267     RETURN_IF_NOT_OK(ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard, &row_group_ptr));
1268     auto &offsets = std::get<0>(*row_group_ptr);
1269     auto &local_columns = std::get<1>(*row_group_ptr);
1270 
1271     group_id = offsets[shard_id][0][1];       // group_id
1272     blob_start = offsets[shard_id][0][2];     // blob start
1273     blob_end = offsets[shard_id][0][3];       // blob end
1274     var_fields = local_columns[shard_id][0];  // scalar variable field
1275   }
1276 
1277   // read the blob from data file
1278   std::shared_ptr<Page> page_ptr;
1279   RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
1280   MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;
1281 
1282   // Pack image list
1283   std::vector<uint8_t> images(blob_end - blob_start);
1284   auto file_offset = header_size_ + page_size_ * (page_ptr->GetPageID()) + blob_start;
1285 
1286   auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
1287   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
1288     file_streams_random_[consumer_id][shard_id]->close();
1289     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
1290   }
1291   auto &io_read =
1292     file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
1293   if (!io_read.good() || io_read.fail() || io_read.bad()) {
1294     file_streams_random_[consumer_id][shard_id]->close();
1295     RETURN_STATUS_UNEXPECTED("Failed to read file.");
1296   }
1297 
1298   // Deliver batch data to output map
1299   std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
1300   batch.emplace_back(std::move(images), std::move(var_fields));
1301 
1302   *task_content_ptr = std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::move(batch));
1303   return Status::OK();
1304 }
1305 
ConsumerByRow(int consumer_id)1306 void ShardReader::ConsumerByRow(int consumer_id) {
1307   // Set thread name
1308 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1309   auto thread_id = kThreadName + std::to_string(consumer_id);
1310   prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0);
1311 #endif
1312 
1313   // Loop forever
1314   for (;;) {
1315     int sample_id_pos = 0;
1316 
1317     // Get next task ID
1318     sample_id_pos = sample_id_position_++;
1319 
1320     // All tasks are done
1321     if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
1322       return;
1323     }
1324     auto task_content_ptr =
1325       std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1326     if (ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id, &task_content_ptr).IsError()) {
1327       MS_LOG(ERROR) << "Error raised in ConsumerOneTask function.";
1328       return;
1329     }
1330     const auto &batch = (*task_content_ptr).second;
1331     // Hanging if maximum map size exceeded
1332     //   otherwise, set batch data in map
1333     {
1334       std::unique_lock<std::mutex> lck(mtx_delivery_);
1335       cv_delivery_.wait(lck,
1336                         [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
1337       if (interrupt_) {
1338         return;
1339       }
1340       delivery_map_[sample_id_pos] =
1341         std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch));
1342     }
1343     cv_iterator_.notify_one();
1344   }
1345 }
1346 
GetNext()1347 std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
1348   if (interrupt_) {
1349     return std::vector<std::tuple<std::vector<uint8_t>, json>>();
1350   }
1351   if (deliver_id_ >= static_cast<int>(tasks_.sample_ids_.size())) {
1352     return std::vector<std::tuple<std::vector<uint8_t>, json>>();
1353   }
1354 
1355   std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>> res;
1356   {
1357     std::unique_lock<std::mutex> lck(mtx_delivery_);
1358     cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); });
1359     if (interrupt_) {
1360       return std::vector<std::tuple<std::vector<uint8_t>, json>>();
1361     }
1362     res = delivery_map_[deliver_id_];
1363     delivery_map_.erase(deliver_id_++);
1364   }
1365 
1366   cv_delivery_.notify_all();
1367 
1368   return *res;
1369 }
1370 
GetNextById(const int64_t & task_id,const int32_t & consumer_id)1371 TASK_CONTENT ShardReader::GetNextById(const int64_t &task_id, const int32_t &consumer_id) {
1372   auto task_content_ptr =
1373     std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1374   if (interrupt_) {
1375     return *task_content_ptr;
1376   }
1377   (void)ConsumerOneTask(task_id, consumer_id, &task_content_ptr);
1378   return std::move(*task_content_ptr);
1379 }
1380 
UnCompressBlob(const std::vector<uint8_t> & raw_blob_data,std::shared_ptr<std::vector<std::vector<uint8_t>>> * blob_data_ptr)1381 Status ShardReader::UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
1382                                    std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr) {
1383   RETURN_UNEXPECTED_IF_NULL(blob_data_ptr);
1384   auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
1385   auto blob_fields = GetBlobFields().second;
1386   for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
1387     if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue;
1388     const unsigned char *data = nullptr;
1389     std::unique_ptr<unsigned char[]> data_ptr;
1390     uint64_t n_bytes = 0;
1391     RETURN_IF_NOT_OK(
1392       shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes));
1393     if (data == nullptr) {
1394       data = reinterpret_cast<const unsigned char *>(data_ptr.get());
1395     }
1396     std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
1397     (*blob_data_ptr)->push_back(column);
1398   }
1399   return Status::OK();
1400 }
1401 
GetTotalBlobSize(int64_t * total_blob_size)1402 Status ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
1403   *total_blob_size = total_blob_size_;
1404   return Status::OK();
1405 }
1406 
Reset()1407 void ShardReader::Reset() {
1408   {
1409     std::lock_guard<std::mutex> lck(mtx_delivery_);
1410     sample_id_position_ = 0;
1411     deliver_id_ = 0;
1412   }
1413   cv_delivery_.notify_all();
1414 }
1415 
ShuffleTask()1416 void ShardReader::ShuffleTask() {
1417   // exist shuffle and distributed sampler in ops, skip shuffle
1418   bool has_sharding = false;
1419   for (const auto &op : operators_) {
1420     if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1421       has_sharding = true;
1422     }
1423   }
1424   for (const auto &op : operators_) {
1425     if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
1426       auto s = (*op)(tasks_);
1427       if (s.IsError()) {
1428         MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
1429       }
1430     } else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1431       auto s = (*op)(tasks_);
1432       if (s.IsError()) {
1433         MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
1434       }
1435     }
1436   }
1437   if (tasks_.permutation_.empty()) tasks_.MakePerm();
1438 }
1439 
GetSampleIds()1440 const std::vector<int> *ShardReader::GetSampleIds() {
1441   // return const reference to private sample id list.
1442   return &(this->tasks_.sample_ids_);
1443 }
1444 
1445 }  // namespace mindrecord
1446 }  // namespace mindspore
1447