1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_reader.h"
18
19 #include <algorithm>
20 #include <thread>
21
22 #include "utils/file_utils.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "utils/ms_utils.h"
25
26 namespace mindspore {
27 namespace mindrecord {
28 template <class Type>
29 // convert the string to exactly number type (int32_t/int64_t/float/double)
StringToNum(const std::string & str)30 Type StringToNum(const std::string &str) {
31 std::istringstream iss(str);
32 Type num;
33 iss >> num;
34 return num;
35 }
36
ShardReader()37 ShardReader::ShardReader()
38 : header_size_(0),
39 page_size_(0),
40 shard_count_(0),
41 n_consumer_(0),
42 num_padded_(0),
43 num_rows_(0),
44 total_blob_size_(0),
45 sample_id_position_(0),
46 deliver_id_(0),
47 load_mode_(LoadMode::kFast),
48 shard_sample_count_() {}
49
GetMeta(const std::string & file_path,std::shared_ptr<json> meta_data_ptr,std::shared_ptr<std::vector<std::string>> * addresses_ptr)50 Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
51 std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
52 RETURN_UNEXPECTED_IF_NULL_MR(addresses_ptr);
53 RETURN_IF_NOT_OK_MR(CheckFile(file_path));
54 std::shared_ptr<json> header_ptr;
55 RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(file_path, &header_ptr));
56
57 *meta_data_ptr = {{"header_size", (*header_ptr)["header_size"]}, {"page_size", (*header_ptr)["page_size"]},
58 {"version", (*header_ptr)["version"]}, {"index_fields", (*header_ptr)["index_fields"]},
59 {"schema", (*header_ptr)["schema"]}, {"blob_fields", (*header_ptr)["blob_fields"]}};
60 std::vector<std::string> addresses_vec = (*header_ptr)["shard_addresses"];
61 *addresses_ptr = std::make_shared<std::vector<std::string>>(addresses_vec);
62 return Status::OK();
63 }
64
Init(const std::vector<std::string> & file_paths,bool load_dataset)65 Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
66 std::string file_path = file_paths[0];
67 auto first_meta_data_ptr = std::make_shared<json>();
68 std::shared_ptr<std::vector<std::string>> addresses_ptr;
69 RETURN_IF_NOT_OK_MR(GetMeta(file_path, first_meta_data_ptr, &addresses_ptr));
70 if (file_paths.size() == 1 && load_dataset == true) {
71 auto ds = std::make_shared<std::vector<std::string>>();
72 RETURN_IF_NOT_OK_MR(GetDatasetFiles(file_path, *addresses_ptr, &ds));
73 file_paths_ = *ds; // load files according to shard_addresses
74 } else if (file_paths.size() >= 1 && load_dataset == false) {
75 file_paths_ = file_paths; // load files according to the input
76 } else {
77 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] The values of 'load_dataset' and 'file_paths' are not as expected.");
78 }
79 for (const auto &file : file_paths_) {
80 auto meta_data_ptr = std::make_shared<json>();
81 RETURN_IF_NOT_OK_MR(GetMeta(file, meta_data_ptr, &addresses_ptr));
82 CHECK_FAIL_RETURN_UNEXPECTED_MR(
83 *meta_data_ptr == *first_meta_data_ptr,
84 "Invalid file, the metadata of mindrecord file: " + file +
85 " is different from others, please make sure all the mindrecord files generated by the same script.");
86 sqlite3 *db = nullptr;
87 RETURN_IF_NOT_OK_MR(VerifyDataset(&db, file));
88 database_paths_.push_back(db);
89 }
90 ShardHeader sh = ShardHeader();
91 RETURN_IF_NOT_OK_MR(sh.BuildDataset(file_paths_, load_dataset));
92 shard_header_ = std::make_shared<ShardHeader>(sh);
93 header_size_ = shard_header_->GetHeaderSize();
94 page_size_ = shard_header_->GetPageSize();
95 // version < 3.0
96 if ((*first_meta_data_ptr)["version"] < kVersion) {
97 shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
98 } else {
99 shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
100 }
101 num_rows_ = 0;
102 auto row_group_summary = ReadRowGroupSummary();
103
104 // clear the shard_sample_count_, because it will be insert when Launch func
105 shard_sample_count_.clear();
106
107 constexpr int64_t get_index = 3;
108 for (const auto &rg : row_group_summary) {
109 num_rows_ += std::get<get_index>(rg);
110 }
111
112 if (num_rows_ > SLOW_LOAD_THRESHOLD) {
113 load_mode_ = LoadMode::kSlow;
114 tasks_.load_mode_ = LoadMode::kSlow;
115 MS_LOG(INFO) << "The number of samples is larger than " << SLOW_LOAD_THRESHOLD
116 << ", enable slow load mode. If you want to speed up data loading, "
117 << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
118 << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
119 << "according to the current samples.";
120 } else if (num_rows_ > LAZY_LOAD_THRESHOLD) {
121 load_mode_ = LoadMode::kLazy;
122 tasks_.load_mode_ = LoadMode::kLazy;
123 MS_LOG(INFO) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
124 << ", enable lazy load mode. If you want to speed up data loading, "
125 << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
126 << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
127 << "according to the current samples.";
128 } else {
129 load_mode_ = LoadMode::kFast;
130 tasks_.load_mode_ = LoadMode::kFast;
131 }
132
133 auto disk_size = page_size_ * row_group_summary.size();
134 auto compression_size = shard_header_->GetCompressionSize();
135 total_blob_size_ = disk_size + compression_size;
136 MS_LOG(INFO) << "The size of blob data on disk: " << disk_size
137 << " , additional uncompression size: " << compression_size
138 << " , total blob size: " << total_blob_size_;
139
140 MS_LOG(INFO) << "Succeed to get metadata from mindrecord files";
141
142 return Status::OK();
143 }
144
VerifyDataset(sqlite3 ** db,const string & file)145 Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
146 std::string path_utf8 = "";
147 #if defined(_WIN32) || defined(_WIN64)
148 path_utf8 = FileUtils::GB2312ToUTF_8((file + ".db").data());
149 #endif
150 if (path_utf8.empty()) {
151 path_utf8 = file + ".db";
152 }
153
154 // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
155 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
156 // use "unix-none" to avoid flock and achieve better performance on shared storage platform
157 CHECK_FAIL_RETURN_UNEXPECTED_MR(
158 sqlite3_open_v2(path_utf8.data(), db, SQLITE_OPEN_READONLY, "unix-none") == SQLITE_OK,
159 "Invalid file, failed to open mindrecord meta file. Please check whether the meta file: " + file +
160 ".db exists and do not rename the mindrecord file and meta file.");
161 #else
162 CHECK_FAIL_RETURN_UNEXPECTED_MR(
163 sqlite3_open_v2(path_utf8.data(), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
164 "Invalid file, failed to open mindrecord meta file. Please check whether the meta file: " + file +
165 ".db exists and do not rename the mindrecord file and meta file.");
166 #endif
167 MS_LOG(DEBUG) << "Succeed to open meta file, path: " << file << ".db.";
168
169 // starting a transaction during a read-only select operation can solve the problem of frequently
170 // accessing *-journal / *-wal files.
171 auto sql_code = sqlite3_exec(*db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
172 if (sql_code != SQLITE_OK) {
173 sqlite3_free(*db);
174 RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: " +
175 std::to_string(sql_code));
176 }
177
178 string sql = "SELECT NAME from SHARD_NAME;";
179 std::vector<std::vector<std::string>> name;
180 char *errmsg = nullptr;
181 if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
182 std::ostringstream oss;
183 oss << "Failed to execute the sql [ " << sql << " ] while verifying meta file, " << errmsg
184 << ".\nPlease check the meta file: " + file + ".db";
185 sqlite3_free(errmsg);
186 sqlite3_close(*db);
187 RETURN_STATUS_UNEXPECTED_MR(oss.str());
188 } else {
189 std::shared_ptr<std::string> fn_ptr;
190 RETURN_IF_NOT_OK_MR(GetFileName(file, &fn_ptr));
191 if (name.empty() || name[0][0] != *fn_ptr) {
192 sqlite3_free(errmsg);
193 sqlite3_close(*db);
194 RETURN_STATUS_UNEXPECTED_MR("Invalid file, mindrecord meta file: " + file + ".db and mindrecord file: " + file +
195 " can not match. Please do not rename the mindrecord file or meta file.");
196 }
197 }
198 return Status::OK();
199 }
200
CheckColumnList(const std::vector<std::string> & selected_columns)201 Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
202 auto schema_ptr = GetShardHeader()->GetSchemas()[0];
203 auto schema = schema_ptr->GetSchema()["schema"];
204 for (auto i = 0; i < selected_columns.size(); ++i) {
205 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(selected_columns[i]) != schema.end(),
206 "Invalid data, column name: " + selected_columns[i] +
207 " can not found in schema. Please check the 'column_list'.");
208 }
209 return Status::OK();
210 }
211
Open(int n_consumer)212 Status ShardReader::Open(int n_consumer) {
213 file_streams_random_ =
214 std::vector<std::vector<std::shared_ptr<std::fstream>>>(n_consumer, std::vector<std::shared_ptr<std::fstream>>());
215 for (const auto &file : file_paths_) {
216 for (int j = 0; j < n_consumer; ++j) {
217 std::optional<std::string> dir = "";
218 std::optional<std::string> local_file_name = "";
219 FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
220 if (!dir.has_value()) {
221 dir = ".";
222 }
223
224 auto realpath = FileUtils::GetRealPath(dir.value().c_str());
225 CHECK_FAIL_RETURN_UNEXPECTED_MR(
226 realpath.has_value(),
227 "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
228
229 std::optional<std::string> whole_path = "";
230 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
231
232 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
233 fs->open(whole_path.value(), std::ios::in | std::ios::binary);
234 if (!fs->good()) {
235 fs->close();
236 RETURN_STATUS_UNEXPECTED_MR(
237 "Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
238 "open files limit(ulimit -a): " +
239 file);
240 }
241 file_streams_random_[j].push_back(fs);
242 }
243 MS_LOG(INFO) << "Succeed to open file, path: " << file;
244 }
245 return Status::OK();
246 }
247
ExtendRandomFileStreams(const int n_new_consumers)248 Status ShardReader::ExtendRandomFileStreams(const int n_new_consumers) {
249 CHECK_FAIL_RETURN_UNEXPECTED_MR(n_new_consumers > 0,
250 "n_new_consumers must be a positive number. Got: " + std::to_string(n_new_consumers));
251 CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_streams_random_.empty(),
252 "ExtendRandomFileStreams() must not be called prior to calling Open()");
253 // make sure we won't exceed the number of allowed threads.
254 uint32_t thread_limit = GetMaxThreadNum();
255 CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ + n_new_consumers <= thread_limit,
256 "Requested increase in number of consumers will cause it to be above the number of "
257 "allowed threads. n_new_consumers: " +
258 std::to_string(n_new_consumers) +
259 ", new n_consumers: " + std::to_string(n_consumer_ + n_new_consumers));
260
261 for (int i = 0; i < n_new_consumers; i++) {
262 (void)file_streams_random_.emplace_back(std::vector<std::shared_ptr<std::fstream>>());
263 }
264
265 for (const auto &file : file_paths_) {
266 std::optional<std::string> dir = "";
267 std::optional<std::string> local_file_name = "";
268 FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
269 if (!dir.has_value()) {
270 dir = ".";
271 }
272
273 auto realpath = FileUtils::GetRealPath(dir.value().data());
274 CHECK_FAIL_RETURN_UNEXPECTED_MR(
275 realpath.has_value(), "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
276
277 std::optional<std::string> whole_path = "";
278 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
279
280 for (int j = n_consumer_; j < n_consumer_ + n_new_consumers; ++j) {
281 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
282 fs->open(whole_path.value(), std::ios::in | std::ios::binary);
283 if (!fs->good()) {
284 fs->close();
285 RETURN_STATUS_UNEXPECTED_MR(
286 "Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
287 "open files limit(ulimit -a): " +
288 file);
289 }
290 file_streams_random_[j].push_back(fs);
291 }
292 MS_LOG(INFO) << "Succeed to open file, path: " << file;
293 }
294 n_consumer_ += n_new_consumers;
295 MS_LOG(INFO) << "n_consumer_ is increased by " + std::to_string(n_new_consumers) + " to " +
296 std::to_string(n_consumer_);
297
298 return Status::OK();
299 }
300
ShrinkRandomFileStreams(const int n_remove_consumers)301 Status ShardReader::ShrinkRandomFileStreams(const int n_remove_consumers) {
302 CHECK_FAIL_RETURN_UNEXPECTED_MR(
303 n_remove_consumers > 0, "n_remove_consumers must be a positive number. Got: " + std::to_string(n_remove_consumers));
304 CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_streams_random_.empty(),
305 "ShrinkRandomFileStreams() must not be called prior to calling Open()");
306 // make sure we won't go below the number of allowed threads.
307 CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ - n_remove_consumers >= kMinConsumerCount,
308 "Requested decrease in number of consumers will cause it to be below the number of "
309 "allowed threads. n_remove_consumers: " +
310 std::to_string(n_remove_consumers) +
311 ", new n_consumers: " + std::to_string(n_consumer_ - n_remove_consumers));
312
313 for (int i = n_consumer_ - 1; i >= n_consumer_ - n_remove_consumers; i--) {
314 for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
315 if (file_streams_random_[i][j] != nullptr) {
316 file_streams_random_[i][j]->close();
317 }
318 }
319 file_streams_random_.pop_back();
320 }
321 n_consumer_ -= n_remove_consumers;
322 MS_LOG(INFO) << "n_consumer_ is decreased by " + std::to_string(n_remove_consumers) + " to " +
323 std::to_string(n_consumer_);
324
325 return Status::OK();
326 }
327
FileStreamsOperator()328 void ShardReader::FileStreamsOperator() {
329 for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; --i) {
330 if (file_streams_[i] != nullptr) {
331 file_streams_[i]->close();
332 }
333 }
334 for (int i = static_cast<int>(file_streams_random_.size()) - 1; i >= 0; --i) {
335 for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
336 if (file_streams_random_[i][j] != nullptr) {
337 file_streams_random_[i][j]->close();
338 }
339 }
340 }
341 for (int i = static_cast<int>(database_paths_.size()) - 1; i >= 0; --i) {
342 if (database_paths_[i] != nullptr) {
343 auto sql_code = sqlite3_exec(database_paths_[i], "END TRANSACTION;", nullptr, nullptr, nullptr);
344 if (sql_code != SQLITE_OK) {
345 sqlite3_close(database_paths_[i]);
346 MS_LOG(ERROR) << "Execute SQL statement `END TRANSACTION;` failed, SQLite result code: "
347 << std::to_string(sql_code);
348 continue;
349 }
350 auto ret = sqlite3_close(database_paths_[i]);
351 if (ret != SQLITE_OK) {
352 MS_LOG(ERROR) << "[Internal ERROR] Failed to close meta file, " << ret << ".";
353 }
354 database_paths_[i] = nullptr;
355 }
356 }
357 }
358
~ShardReader()359 ShardReader::~ShardReader() { Close(); }
360
Close()361 void ShardReader::Close() {
362 {
363 std::lock_guard<std::mutex> lck(mtx_delivery_);
364 interrupt_ = true; // interrupt reading and stop threads
365 }
366 cv_delivery_.notify_all();
367
368 // Wait for all threads to finish
369 for (auto &i_thread : thread_set_) {
370 if (i_thread.joinable()) {
371 i_thread.join();
372 }
373 }
374
375 FileStreamsOperator();
376 }
377
GetShardHeader() const378 std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
379
GetShardColumn() const380 std::shared_ptr<ShardColumn> ShardReader::GetShardColumn() const { return shard_column_; }
381
GetShardCount() const382 int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
383
GetNumRows() const384 int64_t ShardReader::GetNumRows() const { return num_rows_; }
385
GetNumRowsAfterSampling() const386 int64_t ShardReader::GetNumRowsAfterSampling() const { return tasks_.SizeAfterSampling(); }
387
ReadRowGroupSummary()388 std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
389 std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
390 int shard_count = shard_header_->GetShardCount();
391 if (shard_count <= 0) {
392 return row_group_summary;
393 }
394
395 uint32_t total_count = 0;
396 for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
397 // return -1 when page's size equals to 0.
398 auto last_page_id = shard_header_->GetLastPageId(shard_id);
399 if (static_cast<int>(last_page_id) == -1) {
400 // Empty mindrecord file which does not contain any samples
401 MS_LOG(WARNING) << "The mindrecord file: " << file_paths_[shard_id]
402 << " does not contain any samples, pls remove it.";
403 row_group_summary.emplace_back(shard_id, 0, 0, 0);
404 shard_sample_count_.push_back(total_count);
405 continue;
406 }
407 for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
408 std::shared_ptr<Page> page_ptr;
409 (void)shard_header_->GetPage(shard_id, page_id, &page_ptr);
410 if (page_ptr->GetPageType() != kPageTypeBlob) {
411 continue;
412 }
413 uint64_t start_row_id = page_ptr->GetStartRowID();
414 if (start_row_id > page_ptr->GetEndRowID()) {
415 return std::vector<std::tuple<int, int, int, uint64_t>>();
416 }
417 uint64_t number_of_rows = page_ptr->GetEndRowID() - start_row_id;
418 total_count += number_of_rows;
419 row_group_summary.emplace_back(shard_id, page_ptr->GetPageTypeID(), start_row_id, number_of_rows);
420 }
421 shard_sample_count_.push_back(total_count);
422 }
423
424 return row_group_summary;
425 }
426
ConvertLabelToJson(const std::vector<std::vector<std::string>> & labels,std::shared_ptr<std::fstream> fs,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)427 Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
428 std::shared_ptr<std::fstream> fs,
429 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
430 int shard_id, const std::vector<std::string> &columns,
431 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
432 auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
433 for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
434 try {
435 uint64_t group_id = std::stoull(labels[i][0]);
436 uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
437 uint64_t offset_end = std::stoull(labels[i][2]);
438 CHECK_FAIL_RETURN_UNEXPECTED_MR(offset_end >= offset_start,
439 "The sample's end offset: " + std::to_string(offset_end) +
440 " should >= start offset: " + std::to_string(offset_start) + ", check fail.");
441 (*offset_ptr)[shard_id].emplace_back(
442 std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
443 if (!all_in_index_) {
444 int raw_page_id = std::stoi(labels[i][3]);
445 uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
446 uint64_t label_end = std::stoull(labels[i][5]);
447 CHECK_FAIL_RETURN_UNEXPECTED_MR(label_end >= label_start,
448 "The sample's end offset: " + std::to_string(label_end) +
449 " should >= start offset: " + std::to_string(label_start) + ", check fail.");
450 auto len = label_end - label_start;
451 auto label_raw = std::vector<uint8_t>(len);
452 auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
453 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
454 fs->close();
455 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
456 }
457 auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
458 if (!io_read.good() || io_read.fail() || io_read.bad()) {
459 fs->close();
460 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
461 }
462 json label_json = json::from_msgpack(label_raw);
463 json tmp;
464 if (!columns.empty()) {
465 for (const auto &col : columns) {
466 if (label_json.find(col) != label_json.end()) {
467 tmp[col] = label_json[col];
468 }
469 }
470 } else {
471 tmp = label_json;
472 }
473 (*col_val_ptr)[shard_id].emplace_back(tmp);
474 } else {
475 json construct_json;
476 RETURN_IF_NOT_OK_MR(ConvertJsonValue(labels[i], columns, schema, &construct_json));
477 (*col_val_ptr)[shard_id].emplace_back(construct_json);
478 }
479 } catch (std::out_of_range &e) {
480 fs->close();
481 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
482 std::string(e.what()));
483 } catch (std::invalid_argument &e) {
484 fs->close();
485 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
486 std::string(e.what()));
487 } catch (...) {
488 fs->close();
489 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Unexpected exception raised in ConvertLabelToJson function.");
490 }
491 }
492
493 return Status::OK();
494 }
495
ConvertJsonValue(const std::vector<std::string> & label,const std::vector<std::string> & columns,const json & schema,json * value)496 Status ShardReader::ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns,
497 const json &schema, json *value) {
498 constexpr int64_t index = 3;
499 for (unsigned int j = 0; j < columns.size(); ++j) {
500 if (schema[columns[j]]["type"] == "int32") {
501 (*value)[columns[j]] = StringToNum<int32_t>(label[j + index]);
502 } else if (schema[columns[j]]["type"] == "int64") {
503 (*value)[columns[j]] = StringToNum<int64_t>(label[j + index]);
504 } else if (schema[columns[j]]["type"] == "float32") {
505 (*value)[columns[j]] = StringToNum<float>(label[j + index]);
506 } else if (schema[columns[j]]["type"] == "float64") {
507 (*value)[columns[j]] = StringToNum<double>(label[j + index]);
508 } else {
509 (*value)[columns[j]] = std::string(label[j + index]);
510 }
511 }
512 return Status::OK();
513 }
ReadAllRowsInShard(int shard_id,const int32_t & consumer_id,const std::string & sql,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)514 Status ShardReader::ReadAllRowsInShard(int shard_id, const int32_t &consumer_id, const std::string &sql,
515 const std::vector<std::string> &columns,
516 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
517 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
518 auto db = database_paths_[shard_id];
519 std::vector<std::vector<std::string>> labels;
520 char *errmsg = nullptr;
521 int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
522 if (rc != SQLITE_OK) {
523 std::ostringstream oss;
524 oss << "[Internal ERROR] Failed to execute the sql [ " << sql << " ] while reading meta file, " << errmsg;
525 sqlite3_free(errmsg);
526 sqlite3_close(db);
527 db = nullptr;
528 RETURN_STATUS_UNEXPECTED_MR(oss.str());
529 }
530 MS_LOG(DEBUG) << "Succeed to get " << labels.size() << " records from shard " << std::to_string(shard_id)
531 << " index.";
532
533 sqlite3_free(errmsg);
534 return ConvertLabelToJson(labels, file_streams_random_[consumer_id][shard_id], offset_ptr, shard_id, columns,
535 col_val_ptr);
536 }
537
GetAllClasses(const std::string & category_field,std::shared_ptr<std::set<std::string>> category_ptr)538 Status ShardReader::GetAllClasses(const std::string &category_field,
539 std::shared_ptr<std::set<std::string>> category_ptr) {
540 std::map<std::string, uint64_t> index_columns;
541 for (auto &field : GetShardHeader()->GetFields()) {
542 index_columns[field.second] = field.first;
543 }
544 CHECK_FAIL_RETURN_UNEXPECTED_MR(
545 index_columns.find(category_field) != index_columns.end(),
546 "Invalid data, 'class_column': " + category_field +
547 " can not found in fields of mindrecord files. Please check 'class_column' in PKSampler.");
548 std::shared_ptr<std::string> fn_ptr;
549 RETURN_IF_NOT_OK_MR(
550 ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
551 std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
552 std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
553 for (int x = 0; x < shard_count_; x++) {
554 threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
555 }
556
557 for (int x = 0; x < shard_count_; x++) {
558 threads[x].join();
559 }
560 return Status::OK();
561 }
562
GetClassesInShard(sqlite3 * db,int shard_id,const std::string & sql,std::shared_ptr<std::set<std::string>> category_ptr)563 void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
564 std::shared_ptr<std::set<std::string>> category_ptr) {
565 if (db == nullptr) {
566 return;
567 }
568 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
569 pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_id)).c_str());
570 #endif
571 std::vector<std::vector<std::string>> columns;
572 char *errmsg = nullptr;
573 int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg);
574 if (ret != SQLITE_OK) {
575 sqlite3_free(errmsg);
576 sqlite3_close(db);
577 db = nullptr;
578 MS_LOG(ERROR) << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql)
579 << " ] while reading meta file, " << errmsg;
580 return;
581 }
582 MS_LOG(INFO) << "Succeed to get " << columns.size() << " records from shard " << std::to_string(shard_id)
583 << " index.";
584 std::lock_guard<std::mutex> lck(shard_locker_);
585 for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
586 category_ptr->emplace(columns[i][0]);
587 }
588 sqlite3_free(errmsg);
589 }
590
ReadAllRowGroup(const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUPS> * row_group_ptr)591 Status ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns,
592 std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
593 RETURN_UNEXPECTED_IF_NULL_MR(row_group_ptr);
594 std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
595 auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
596 shard_count_, std::vector<std::vector<uint64_t>>{});
597 auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
598
599 if (all_in_index_) {
600 for (unsigned int i = 0; i < columns.size(); ++i) {
601 fields += ',';
602 std::shared_ptr<std::string> fn_ptr;
603 RETURN_IF_NOT_OK_MR(
604 ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
605 fields += *fn_ptr;
606 }
607 } else { // fetch raw data from Raw page while some field is not index.
608 fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
609 }
610
611 std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;";
612
613 std::vector<std::future<Status>> async_results;
614 auto status = Status::OK();
615 for (int x = 0; x < shard_count_; x++) {
616 async_results.push_back(std::async(std::launch::async, &ShardReader::ReadAllRowsInShard, this, x, 0, sql, columns,
617 offset_ptr, col_val_ptr));
618 }
619
620 for (auto i = 0; i < async_results.size(); i++) {
621 auto res = async_results[i].get();
622 if (res.IsError() && status.IsOk()) {
623 status = res;
624 }
625 }
626 if (status.IsError()) {
627 return status;
628 }
629 *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
630 return Status::OK();
631 }
632
ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> & columns,const uint32_t & shard_id,const int32_t & consumer_id,const uint32_t & sample_id,std::shared_ptr<ROW_GROUPS> * row_group_ptr)633 Status ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
634 const int32_t &consumer_id, const uint32_t &sample_id,
635 std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
636 RETURN_UNEXPECTED_IF_NULL_MR(row_group_ptr);
637 std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
638 auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
639 shard_count_, std::vector<std::vector<uint64_t>>{});
640 auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
641 if (all_in_index_) {
642 for (unsigned int i = 0; i < columns.size(); ++i) {
643 fields += ',';
644 std::shared_ptr<std::string> fn_ptr;
645 RETURN_IF_NOT_OK_MR(
646 ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
647 fields += *fn_ptr;
648 }
649 } else { // fetch raw data from Raw page while some field is not index.
650 fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
651 }
652
653 std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
654
655 RETURN_IF_NOT_OK_MR(ReadAllRowsInShard(shard_id, consumer_id, sql, columns, offset_ptr, col_val_ptr));
656 *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
657 return Status::OK();
658 }
659
ReadRowGroupBrief(int group_id,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)660 Status ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
661 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
662 RETURN_UNEXPECTED_IF_NULL_MR(row_group_brief_ptr);
663 std::shared_ptr<Page> page_ptr;
664 RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
665 std::string file_name = file_paths_[shard_id];
666 uint64_t page_length = page_ptr->GetPageSize();
667 uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
668 std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id);
669 auto labels_ptr = std::make_shared<std::vector<json>>();
670 RETURN_IF_NOT_OK_MR(GetLabels(page_ptr->GetPageID(), shard_id, columns, {"", ""}, &labels_ptr));
671 *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
672 std::move(*labels_ptr));
673 return Status::OK();
674 }
675
ReadRowGroupCriteria(int group_id,int shard_id,const std::pair<std::string,std::string> & criteria,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)676 Status ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
677 const std::pair<std::string, std::string> &criteria,
678 const std::vector<std::string> &columns,
679 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
680 RETURN_UNEXPECTED_IF_NULL_MR(row_group_brief_ptr);
681 std::shared_ptr<Page> page_ptr;
682 RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
683 vector<string> criteria_list{criteria.first};
684 RETURN_IF_NOT_OK_MR(CheckColumnList(criteria_list));
685 std::string file_name = file_paths_[shard_id];
686 uint64_t page_length = page_ptr->GetPageSize();
687 uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
688 std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id, criteria);
689 if (image_offset.empty()) {
690 *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>();
691 }
692 auto labels_ptr = std::make_shared<std::vector<json>>();
693 RETURN_IF_NOT_OK_MR(GetLabels(page_ptr->GetPageID(), shard_id, columns, criteria, &labels_ptr));
694 *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
695 std::move(*labels_ptr));
696 return Status::OK();
697 }
698
SelectCallback(void * p_data,int num_fields,char ** p_fields,char ** p_col_names)699 int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) {
700 auto *records = static_cast<std::vector<std::vector<std::string>> *>(p_data);
701 if (num_fields > 0 && num_fields <= kMaxFieldCount) {
702 for (int i = 0; i < num_fields; ++i) {
703 if (p_fields[i] == nullptr) {
704 p_fields[i] = const_cast<char *>("");
705 }
706 }
707 }
708 records->emplace_back(p_fields, p_fields + num_fields);
709 return 0;
710 }
711
GetImageOffset(int page_id,int shard_id,const std::pair<std::string,std::string> & criteria)712 std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int shard_id,
713 const std::pair<std::string, std::string> &criteria) {
714 auto db = database_paths_[shard_id];
715
716 std::string sql = "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
717
718 // whether use index search
719 if (!criteria.first.empty()) {
720 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
721 }
722 sql += ";";
723 std::vector<std::vector<std::string>> image_offsets;
724
725 sqlite3_stmt *stmt = nullptr;
726 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
727 MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to prepare statement [ " << sql << " ].";
728 }
729
730 // bind the PAGE_ID_BLOB
731 int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
732 if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
733 (void)sqlite3_finalize(stmt);
734 MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to bind parameter of sql, key index: " << std::to_string(index)
735 << ", value: " << std::to_string(page_id);
736 }
737
738 // bind the criteria
739 if (!criteria.first.empty()) {
740 index = sqlite3_bind_parameter_index(stmt, ":criteria");
741 if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria.second), -1, SQLITE_STATIC) != SQLITE_OK) {
742 (void)sqlite3_finalize(stmt);
743 MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to bind parameter of sql, key index: " << std::to_string(index)
744 << ", value: " + criteria.second;
745 }
746 }
747
748 int rc = sqlite3_step(stmt);
749 while (rc != SQLITE_DONE) {
750 vector<string> tmp;
751 int ncols = sqlite3_column_count(stmt);
752 for (int i = 0; i < ncols; i++) {
753 tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
754 }
755 image_offsets.push_back(tmp);
756 rc = sqlite3_step(stmt);
757 }
758
759 auto finalize = sqlite3_finalize(stmt);
760 if (finalize != SQLITE_OK) {
761 MS_LOG(EXCEPTION) << "[Internal ERROR] Failed to finalize sql stmt, error code: " << std::to_string(finalize);
762 }
763
764 MS_LOG(DEBUG) << "Succeed to get " << image_offsets.size() << " records from index.";
765
766 std::vector<std::vector<uint64_t>> res;
767 for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) {
768 res.emplace_back(std::vector<uint64_t>{0, 0});
769 }
770 for (int i = 0; i < static_cast<int>(image_offsets.size()); i++) {
771 const auto &image_offset = image_offsets[i];
772 res[i][0] = std::stoull(image_offset[0]) + kInt64Len;
773 res[i][1] = std::stoull(image_offset[1]);
774 if (res[i][1] < res[i][0]) {
775 MS_LOG(EXCEPTION) << "The sample's end offset: " << std::to_string(res[i][1])
776 << " should >= start offset: " << std::to_string(res[i][0]) << ", check fail.";
777 }
778 }
779 return res;
780 }
781
GetPagesByCategory(int shard_id,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<uint64_t>> * pages_ptr)782 Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
783 std::shared_ptr<std::vector<uint64_t>> *pages_ptr) {
784 RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
785 auto db = database_paths_[shard_id];
786
787 std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
788
789 if (!criteria.first.empty()) {
790 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
791 }
792 sql += ";";
793 std::vector<std::vector<std::string>> page_ids;
794
795 sqlite3_stmt *stmt = nullptr;
796 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
797 (void)sqlite3_finalize(stmt);
798 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
799 }
800
801 if (!criteria.first.empty()) {
802 int index = sqlite3_bind_parameter_index(stmt, ":criteria");
803 if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria.second), -1, SQLITE_STATIC) != SQLITE_OK) {
804 (void)sqlite3_finalize(stmt);
805 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
806 std::to_string(index) + ", value: " + criteria.second);
807 }
808 }
809
810 int rc = sqlite3_step(stmt);
811 while (rc != SQLITE_DONE) {
812 vector<string> tmp;
813 int ncols = sqlite3_column_count(stmt);
814 for (int i = 0; i < ncols; i++) {
815 tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
816 }
817 page_ids.push_back(tmp);
818 rc = sqlite3_step(stmt);
819 }
820
821 auto finalize = sqlite3_finalize(stmt);
822 if (finalize != SQLITE_OK) {
823 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
824 std::to_string(finalize));
825 }
826
827 MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << " pages from index.";
828 for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
829 (*pages_ptr)->emplace_back(std::stoull(page_ids[i][0]));
830 }
831 return Status::OK();
832 }
833
GetBlobFields()834 std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
835 std::vector<std::string> blob_fields;
836 for (auto &p : GetShardHeader()->GetSchemas()) {
837 // assume one schema
838 const auto &fields = p->GetBlobFields();
839 blob_fields.assign(fields.begin(), fields.end());
840 break;
841 }
842 return std::make_pair(kCV, blob_fields);
843 }
844
CheckIfColumnInIndex(const std::vector<std::string> & columns)845 void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) {
846 // assume different schemas do not contain same key.
847 if (columns.empty()) {
848 all_in_index_ = false;
849 return;
850 }
851 for (auto &field : GetShardHeader()->GetFields()) {
852 column_schema_id_[field.second] = field.first;
853 }
854 for (auto &col : columns) {
855 if (column_schema_id_.find(col) == column_schema_id_.end()) {
856 all_in_index_ = false;
857 return;
858 }
859 }
860 }
861
QueryWithPageIdBlobAndCriteria(sqlite3 * db,const string & sql,const int & page_id,const string & criteria,std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr)862 Status ShardReader::QueryWithPageIdBlobAndCriteria(sqlite3 *db, const string &sql, const int &page_id,
863 const string &criteria,
864 std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
865 sqlite3_stmt *stmt = nullptr;
866 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
867 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
868 }
869
870 // bind the PAGE_ID_BLOB
871 int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
872 if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
873 (void)sqlite3_finalize(stmt);
874 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
875 std::to_string(index) + ", value: " + std::to_string(page_id));
876 }
877
878 // bind the criteria
879 index = sqlite3_bind_parameter_index(stmt, ":criteria");
880 if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
881 (void)sqlite3_finalize(stmt);
882 RETURN_STATUS_UNEXPECTED_MR(
883 "[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) + ", value: " + criteria);
884 }
885 int rc = sqlite3_step(stmt);
886 while (rc != SQLITE_DONE) {
887 vector<string> tmp;
888 int ncols = sqlite3_column_count(stmt);
889 for (int i = 0; i < ncols; i++) {
890 tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
891 }
892 labels_ptr->push_back(tmp);
893 rc = sqlite3_step(stmt);
894 }
895
896 auto finalize = sqlite3_finalize(stmt);
897 if (finalize != SQLITE_OK) {
898 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
899 std::to_string(finalize));
900 }
901 return Status::OK();
902 }
903
GetLabelsFromBinaryFile(int shard_id,const std::vector<std::string> & columns,const std::vector<std::vector<std::string>> & label_offsets,std::shared_ptr<std::vector<json>> * labels_ptr)904 Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
905 const std::vector<std::vector<std::string>> &label_offsets,
906 std::shared_ptr<std::vector<json>> *labels_ptr) {
907 RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
908 std::shared_ptr<std::fstream> fs = file_streams_random_[0][shard_id];
909
910 // init the return
911 for (unsigned int i = 0; i < label_offsets.size(); ++i) {
912 (*labels_ptr)->emplace_back(json{});
913 }
914
915 for (unsigned int i = 0; i < label_offsets.size(); ++i) {
916 const auto &labelOffset = label_offsets[i];
917 if (labelOffset.size() < 3) {
918 fs->close();
919 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] 'labelOffset' size should be less than 3 but got: " +
920 std::to_string(labelOffset.size()) + ".");
921 }
922 uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len;
923 uint64_t label_end = std::stoull(labelOffset[2]);
924 CHECK_FAIL_RETURN_UNEXPECTED_MR(label_end >= label_start,
925 "The sample's end offset: " + std::to_string(label_end) +
926 " should >= start offset: " + std::to_string(label_start) + ", check fail.");
927 int raw_page_id = std::stoi(labelOffset[0]);
928 auto len = label_end - label_start;
929 auto label_raw = std::vector<uint8_t>(len);
930 auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
931 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
932 fs->close();
933 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file, path: " + file_paths_[shard_id]);
934 }
935
936 auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
937 if (!io_read.good() || io_read.fail() || io_read.bad()) {
938 fs->close();
939 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file, path: " + file_paths_[shard_id]);
940 }
941
942 json label_json = json::from_msgpack(label_raw);
943 json tmp = label_json;
944 for (auto &col : columns) {
945 if (label_json.find(col) != label_json.end()) {
946 tmp[col] = label_json[col];
947 }
948 }
949 (*(*labels_ptr))[i] = tmp;
950 }
951 return Status::OK();
952 }
953
GetLabelsFromPage(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)954 Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
955 const std::pair<std::string, std::string> &criteria,
956 std::shared_ptr<std::vector<json>> *labels_ptr) {
957 RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
958 // get page info from sqlite
959 auto db = database_paths_[shard_id];
960 std::string sql =
961 "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
962
963 auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
964 if (!criteria.first.empty()) {
965 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria;";
966 RETURN_IF_NOT_OK_MR(QueryWithPageIdBlobAndCriteria(db, sql, page_id, criteria.second, label_offset_ptr));
967 } else {
968 sql += ";";
969 sqlite3_stmt *stmt = nullptr;
970 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
971 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
972 }
973
974 // bind the PAGE_ID_BLOB
975 int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
976 if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
977 (void)sqlite3_finalize(stmt);
978 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
979 std::to_string(index) + ", value: " + std::to_string(page_id));
980 }
981
982 int rc = sqlite3_step(stmt);
983 while (rc != SQLITE_DONE) {
984 vector<string> tmp;
985 int ncols = sqlite3_column_count(stmt);
986 for (int i = 0; i < ncols; i++) {
987 tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
988 }
989 label_offset_ptr->push_back(tmp);
990 rc = sqlite3_step(stmt);
991 }
992
993 auto finalize = sqlite3_finalize(stmt);
994 if (finalize != SQLITE_OK) {
995 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
996 std::to_string(finalize));
997 }
998
999 MS_LOG(DEBUG) << "Succeed to get " << label_offset_ptr->size() << " records from index.";
1000 }
1001 // get labels from binary file
1002 return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr, labels_ptr);
1003 }
1004
GetLabels(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)1005 Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
1006 const std::pair<std::string, std::string> &criteria,
1007 std::shared_ptr<std::vector<json>> *labels_ptr) {
1008 RETURN_UNEXPECTED_IF_NULL_MR(labels_ptr);
1009 if (all_in_index_) {
1010 auto db = database_paths_[shard_id];
1011 std::string fields;
1012 for (unsigned int i = 0; i < columns.size(); ++i) {
1013 if (i > 0) {
1014 fields += ',';
1015 }
1016 uint64_t schema_id = column_schema_id_[columns[i]];
1017 fields += columns[i] + "_" + std::to_string(schema_id);
1018 }
1019 if (fields.empty()) {
1020 fields = "*";
1021 }
1022 auto labels = std::make_shared<std::vector<std::vector<std::string>>>();
1023 std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = :page_id_blob";
1024 if (!criteria.first.empty()) {
1025 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria;";
1026 RETURN_IF_NOT_OK_MR(QueryWithPageIdBlobAndCriteria(db, sql, page_id, criteria.second, labels));
1027 } else {
1028 sql += ";";
1029 sqlite3_stmt *stmt = nullptr;
1030 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
1031 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
1032 }
1033
1034 // bind the PAGE_ID_BLOB
1035 int index = sqlite3_bind_parameter_index(stmt, ":page_id_blob");
1036 if (sqlite3_bind_int64(stmt, index, page_id) != SQLITE_OK) {
1037 (void)sqlite3_finalize(stmt);
1038 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
1039 std::to_string(index) + ", value: " + std::to_string(page_id));
1040 }
1041
1042 int rc = sqlite3_step(stmt);
1043 while (rc != SQLITE_DONE) {
1044 vector<string> tmp;
1045 int ncols = sqlite3_column_count(stmt);
1046 for (int i = 0; i < ncols; i++) {
1047 tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
1048 }
1049 labels->push_back(tmp);
1050 rc = sqlite3_step(stmt);
1051 }
1052
1053 auto finalize = sqlite3_finalize(stmt);
1054 if (finalize != SQLITE_OK) {
1055 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to finalize sql stmt, error code: " +
1056 std::to_string(finalize));
1057 }
1058
1059 MS_LOG(DEBUG) << "Succeed to get " << labels->size() << " records from index.";
1060 }
1061 for (unsigned int i = 0; i < labels->size(); ++i) {
1062 (*labels_ptr)->emplace_back(json{});
1063 }
1064 for (unsigned int i = 0; i < labels->size(); ++i) {
1065 json construct_json;
1066 for (unsigned int j = 0; j < columns.size(); ++j) {
1067 // construct json "f1": value
1068 auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
1069
1070 // convert the string to base type by schema
1071 if (schema[columns[j]]["type"] == "int32") {
1072 construct_json[columns[j]] = StringToNum<int32_t>((*labels)[i][j]);
1073 } else if (schema[columns[j]]["type"] == "int64") {
1074 construct_json[columns[j]] = StringToNum<int64_t>((*labels)[i][j]);
1075 } else if (schema[columns[j]]["type"] == "float32") {
1076 construct_json[columns[j]] = StringToNum<float>((*labels)[i][j]);
1077 } else if (schema[columns[j]]["type"] == "float64") {
1078 construct_json[columns[j]] = StringToNum<double>((*labels)[i][j]);
1079 } else {
1080 construct_json[columns[j]] = std::string((*labels)[i][j]);
1081 }
1082 }
1083 (*(*labels_ptr))[i] = construct_json;
1084 }
1085 return Status::OK();
1086 }
1087 return GetLabelsFromPage(page_id, shard_id, columns, criteria, labels_ptr);
1088 }
1089
ResortRowGroups(std::tuple<int,int,int,int> a,std::tuple<int,int,int,int> b)1090 bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, int> b) {
1091 return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
1092 }
1093
GetNumClasses(const std::string & category_field)1094 int64_t ShardReader::GetNumClasses(const std::string &category_field) {
1095 auto shard_count = file_paths_.size();
1096 auto index_fields = shard_header_->GetFields();
1097
1098 std::map<std::string, int64_t> map_schema_id_fields;
1099 for (auto &field : index_fields) {
1100 map_schema_id_fields[field.second] = field.first;
1101 }
1102
1103 if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
1104 MS_LOG(ERROR) << "[Internal ERROR] 'category_field' " << category_field
1105 << " can not found in index fields of mindrecord files.";
1106 return -1;
1107 }
1108 std::shared_ptr<std::string> fn_ptr;
1109 (void)ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field),
1110 &fn_ptr);
1111 std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
1112 std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
1113 auto category_ptr = std::make_shared<std::set<std::string>>();
1114 sqlite3 *db = nullptr;
1115 for (int x = 0; x < shard_count; x++) {
1116 std::string path_utf8 = "";
1117 #if defined(_WIN32) || defined(_WIN64)
1118 path_utf8 = FileUtils::GB2312ToUTF_8((file_paths_[x] + ".db").data());
1119 #endif
1120 if (path_utf8.empty()) {
1121 path_utf8 = file_paths_[x] + ".db";
1122 }
1123
1124 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1125 // use "unix-none" to avoid flock and achieve better performance on shared storage platform
1126 int rc = sqlite3_open_v2(path_utf8.data(), &db, SQLITE_OPEN_READONLY, "unix-none");
1127 #else
1128 int rc = sqlite3_open_v2(path_utf8.data(), &db, SQLITE_OPEN_READONLY, nullptr);
1129 #endif
1130 if (SQLITE_OK != rc) {
1131 MS_LOG(ERROR) << "[Internal ERROR] Failed to open meta file: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
1132 return -1;
1133 }
1134
1135 // starting a transaction during a read-only select operation can solve the problem of frequently
1136 // accessing *-journal / *-wal files.
1137 auto sql_code = sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
1138 if (sql_code != SQLITE_OK) {
1139 sqlite3_free(db);
1140 MS_LOG(ERROR) << "Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: "
1141 << std::to_string(sql_code);
1142 return -1;
1143 }
1144 threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
1145 }
1146
1147 for (int x = 0; x < shard_count; x++) {
1148 threads[x].join();
1149 }
1150 sqlite3_close(db);
1151 return category_ptr->size();
1152 }
1153
CountTotalRows(const std::vector<std::string> & file_paths,bool load_dataset,const std::shared_ptr<ShardOperator> & ops,int64_t * count,const int64_t num_padded)1154 Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
1155 const std::shared_ptr<ShardOperator> &ops, int64_t *count,
1156 const int64_t num_padded) {
1157 RETURN_IF_NOT_OK_MR(Init(file_paths, load_dataset));
1158 int64_t num_samples = num_rows_;
1159 bool root = true;
1160 std::stack<std::shared_ptr<ShardOperator>> stack_ops;
1161 std::shared_ptr<ShardOperator> op(ops);
1162 while (op != nullptr) {
1163 stack_ops.push(op);
1164 op = op->GetChildOp();
1165 }
1166 while (!stack_ops.empty()) {
1167 op = stack_ops.top();
1168 stack_ops.pop();
1169 if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
1170 num_samples = op->GetNumSamples(num_samples, 0);
1171 if (num_padded > 0 && root == true) {
1172 num_samples += num_padded;
1173 root = false;
1174 }
1175 } else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1176 auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
1177 std::string category_field = category_op->GetCategoryField();
1178 auto num_classes = GetNumClasses(category_field);
1179 num_samples = category_op->GetNumSamples(num_samples, num_classes);
1180 if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
1181 auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
1182 if (tmp != 0 && num_samples != -1) {
1183 num_samples = std::min(num_samples, tmp);
1184 }
1185
1186 CHECK_FAIL_RETURN_UNEXPECTED_MR(num_samples != -1,
1187 "Invalid data, 'num_samples': " + std::to_string(num_samples) +
1188 " is out of bound: " + std::to_string(std::numeric_limits<int64_t>::max()));
1189 }
1190 } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
1191 if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1192 auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
1193 if (root == true) {
1194 sampler_op->SetNumPaddedSamples(num_padded);
1195 num_samples = op->GetNumSamples(num_samples, 0);
1196 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1197 num_samples != -1,
1198 "Invalid data, the size of dataset and padded samples: " + std::to_string(num_padded) +
1199 " can not be divisible by the value of 'num_shards'.\n Please adjust the value of 'num_padded'.");
1200 root = false;
1201 }
1202 } else {
1203 num_samples = op->GetNumSamples(num_samples, 0);
1204 num_samples += num_padded;
1205 }
1206 } else {
1207 if (num_padded > 0) {
1208 num_samples += num_padded;
1209 }
1210 }
1211 }
1212 *count = num_samples;
1213 return Status::OK();
1214 }
1215
Open(const std::vector<std::string> & file_paths,bool load_dataset,int n_consumer,const std::vector<std::string> & selected_columns,const std::vector<std::shared_ptr<ShardOperator>> & operators,int64_t num_padded,LoadMode load_mode)1216 Status ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
1217 const std::vector<std::string> &selected_columns,
1218 const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded,
1219 LoadMode load_mode) {
1220 load_mode_ = load_mode;
1221
1222 // Open file and set header by ShardReader
1223 RETURN_IF_NOT_OK_MR(Init(file_paths, load_dataset));
1224 auto thread_limit = GetMaxThreadNum();
1225 if (n_consumer > thread_limit) {
1226 n_consumer = thread_limit;
1227 }
1228 if (n_consumer < kMinConsumerCount) {
1229 n_consumer = kMinConsumerCount;
1230 }
1231
1232 selected_columns_ = selected_columns;
1233 RETURN_IF_NOT_OK_MR(CheckColumnList(selected_columns_));
1234
1235 // Initialize argument
1236 shard_count_ = static_cast<int>(file_paths_.size());
1237 n_consumer_ = n_consumer;
1238 num_padded_ = num_padded;
1239
1240 operators_ = operators;
1241 RETURN_IF_NOT_OK_MR(Open(n_consumer));
1242 return Status::OK();
1243 }
1244
Launch(bool is_sample_read)1245 Status ShardReader::Launch(bool is_sample_read) {
1246 // Get all row groups' info
1247 auto row_group_summary = ReadRowGroupSummary();
1248
1249 // Sort row group by (group_id, shard_id), prepare for parallel reading
1250 std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
1251 auto status = CreateTasks(row_group_summary, operators_);
1252 if (status.IsError()) {
1253 interrupt_ = true;
1254 return status;
1255 }
1256 if (is_sample_read) {
1257 return Status::OK();
1258 }
1259 // Start provider consumer threads
1260 thread_set_ = std::vector<std::thread>(n_consumer_);
1261 CHECK_FAIL_RETURN_UNEXPECTED_MR(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
1262 "Invalid data, 'num_parallel_workers' should be less than or equal to " +
1263 std::to_string(kMaxConsumerCount) + "but got: " + std::to_string(n_consumer_));
1264
1265 for (int x = 0; x < n_consumer_; ++x) {
1266 thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
1267 }
1268
1269 MS_LOG(INFO) << "Succeed to launch read thread.";
1270 return Status::OK();
1271 }
1272
CreateTasksByCategory(const std::shared_ptr<ShardOperator> & op)1273 Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
1274 CheckIfColumnInIndex(selected_columns_);
1275 auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
1276 auto categories = category_op->GetCategories();
1277 int64_t num_elements = category_op->GetNumElements();
1278 int64_t num_samples = 0;
1279 if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
1280 num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
1281 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1282 num_samples >= 0,
1283 "Invalid data, 'num_samples' should be greater than or equal to 0, but got: " + std::to_string(num_samples));
1284 }
1285 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1286 num_elements > 0,
1287 "[Internal ERROR] 'num_elements' should be greater than 0, but got: " + std::to_string(num_elements));
1288 if (categories.empty() == true) {
1289 std::string category_field = category_op->GetCategoryField();
1290 int64_t num_categories = category_op->GetNumCategories();
1291 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1292 num_categories > 0,
1293 "[Internal ERROR] 'num_categories' should be greater than 0, but got: " + std::to_string(num_categories));
1294 auto category_ptr = std::make_shared<std::set<std::string>>();
1295 RETURN_IF_NOT_OK_MR(GetAllClasses(category_field, category_ptr));
1296 int i = 0;
1297 for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
1298 categories.emplace_back(category_field, *it);
1299 i++;
1300 }
1301 }
1302 // Generate a vector of task lists. Each catogory has a list of tasks.
1303 std::vector<ShardTaskList> categoryTasks(categories.size());
1304 for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
1305 int category_index = 0;
1306 for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
1307 auto pages_ptr = std::make_shared<std::vector<uint64_t>>();
1308 RETURN_IF_NOT_OK_MR(GetPagesByCategory(shard_id, categories[categoryNo], &pages_ptr));
1309 for (const auto &page_id : *pages_ptr) {
1310 if (category_index >= num_elements) {
1311 break;
1312 }
1313 std::shared_ptr<Page> page_ptr;
1314 RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, page_id, &page_ptr));
1315 auto group_id = page_ptr->GetPageTypeID();
1316 std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
1317 RETURN_IF_NOT_OK_MR(
1318 ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_, &row_group_brief_ptr));
1319 auto offsets = std::get<3>(*row_group_brief_ptr);
1320
1321 auto number_of_rows = offsets.size();
1322 for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
1323 if (category_index < num_elements) {
1324 categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
1325 std::get<3>(*row_group_brief_ptr)[iStart],
1326 std::get<4>(*row_group_brief_ptr)[iStart]);
1327 category_index++;
1328 }
1329 }
1330 MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks.";
1331 }
1332 }
1333 }
1334 tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
1335
1336 tasks_.InitSampleIds();
1337 RETURN_IF_NOT_OK_MR((*category_op)(tasks_));
1338 return Status::OK();
1339 }
1340
CreateTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1341 Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1342 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1343 CheckIfColumnInIndex(selected_columns_);
1344 std::shared_ptr<ROW_GROUPS> row_group_ptr;
1345 RETURN_IF_NOT_OK_MR(ReadAllRowGroup(selected_columns_, &row_group_ptr));
1346 auto &offsets = std::get<0>(*row_group_ptr);
1347 auto &local_columns = std::get<1>(*row_group_ptr);
1348 int sample_count = 0;
1349 for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
1350 sample_count += offsets[shard_id].size();
1351 }
1352 CHECK_FAIL_RETURN_UNEXPECTED_MR(sample_count == num_rows_, "Unequal number of index entries and data entries.");
1353 MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1354
1355 // Init the tasks_ size
1356 tasks_.ResizeTask(sample_count);
1357
1358 // Init the task threads, maybe use ThreadPool is better
1359 std::vector<std::thread> init_tasks_thread(shard_count_);
1360
1361 uint32_t current_offset = 0;
1362 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1363 init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() {
1364 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1365 pthread_setname_np(pthread_self(), std::string("ParallelCreateTasks" + std::to_string(shard_id)).c_str());
1366 #endif
1367 auto offset = current_offset;
1368 for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
1369 tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
1370 std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
1371 local_columns[shard_id][i]);
1372 offset++;
1373 }
1374 });
1375 current_offset += offsets[shard_id].size();
1376 }
1377
1378 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1379 init_tasks_thread[shard_id].join();
1380 }
1381 return Status::OK();
1382 }
1383
CreateLazyTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1384 Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1385 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1386 CheckIfColumnInIndex(selected_columns_);
1387 uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
1388 MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1389
1390 // Init the tasks_ size
1391 tasks_.ResizeTask(sample_count);
1392
1393 // Init the task threads, maybe use ThreadPool is better
1394 std::vector<std::thread> init_tasks_thread(shard_count_);
1395
1396 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1397 // the offset indicate the shard start
1398 uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];
1399
1400 // the count indicate the number of samples in the shard
1401 uint32_t shard_count =
1402 shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
1403 init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
1404 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1405 pthread_setname_np(pthread_self(), std::string("ParallelCreateLazyTasks" + std::to_string(shard_id)).c_str());
1406 #endif
1407 for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
1408 // here "i - current_offset" indicate the sample id in the shard
1409 tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
1410 }
1411 });
1412 }
1413
1414 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1415 init_tasks_thread[shard_id].join();
1416 }
1417 return Status::OK();
1418 }
1419
CreateSlowTasksByRow()1420 Status ShardReader::CreateSlowTasksByRow() {
1421 CheckIfColumnInIndex(selected_columns_);
1422 uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
1423 MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1424 tasks_.padded_sample_ = num_padded_;
1425 tasks_.SetShardSampleCount(shard_sample_count_);
1426 return Status::OK();
1427 }
1428
CreateTasks(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1429 Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1430 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1431 int category_operator = -1;
1432 for (uint32_t i = 0; i < operators.size(); ++i) {
1433 const auto &op = operators[i];
1434 if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1435 category_operator = static_cast<int>(i);
1436 break;
1437 }
1438 }
1439
1440 if (-1 == category_operator) {
1441 if (load_mode_ != LoadMode::kSlow) {
1442 if (load_mode_ == LoadMode::kLazy) {
1443 RETURN_IF_NOT_OK_MR(CreateLazyTasksByRow(row_group_summary, operators));
1444 } else {
1445 RETURN_IF_NOT_OK_MR(CreateTasksByRow(row_group_summary, operators));
1446 }
1447
1448 // need padded sample to the task
1449 if (num_padded_ > 0) {
1450 for (auto i = 0; i < num_padded_; ++i) {
1451 tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
1452 }
1453 }
1454 } else {
1455 RETURN_IF_NOT_OK_MR(CreateSlowTasksByRow());
1456 }
1457 } else {
1458 RETURN_IF_NOT_OK_MR(CreateTasksByCategory(operators[category_operator]));
1459 }
1460
1461 MS_LOG(DEBUG) << "Succeed to create " << tasks_.Size() << " initial task to start with before sampling.";
1462 if (load_mode_ != LoadMode::kSlow) {
1463 tasks_.InitSampleIds();
1464
1465 for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
1466 const auto &op = operators[operator_no];
1467 if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1468 continue;
1469 }
1470
1471 if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
1472 op->SetShardSampleCount(shard_sample_count_);
1473 }
1474 RETURN_IF_NOT_OK_MR((*op)(tasks_));
1475 }
1476
1477 if (tasks_.permutation_.empty()) {
1478 tasks_.MakePerm();
1479 }
1480 } else {
1481 for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
1482 const auto &op = operators[operator_no];
1483 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1484 !std::dynamic_pointer_cast<ShardCategory>(op),
1485 "[Internal ERROR] The retrieval function is not available when in slow loading mode.");
1486 if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
1487 op->SetShardSampleCount(shard_sample_count_);
1488 }
1489 RETURN_IF_NOT_OK_MR((*op)(tasks_));
1490 }
1491 }
1492
1493 num_rows_ = tasks_.Size();
1494 MS_LOG(INFO) << "The total number of samples is " << num_rows_
1495 << ", the number of samples after sampling is: " << tasks_.SizeAfterSampling();
1496
1497 return Status::OK();
1498 }
1499
ConsumerOneTask(int64_t task_id,uint32_t consumer_id,std::shared_ptr<TASK_CONTENT> * task_content_ptr)1500 Status ShardReader::ConsumerOneTask(int64_t task_id, uint32_t consumer_id,
1501 std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
1502 RETURN_UNEXPECTED_IF_NULL_MR(task_content_ptr);
1503 if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
1504 // All tasks are done
1505 CHECK_FAIL_RETURN_UNEXPECTED_MR(task_id < tasks_.Size(), "[Internal ERROR] 'task_id': " + std::to_string(task_id) +
1506 " is out of bound: " + std::to_string(tasks_.Size()));
1507 } else {
1508 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1509 task_id < (num_padded_ + shard_sample_count_[shard_sample_count_.size() - 1]),
1510 "[Internal ERROR] 'task_id': " + std::to_string(task_id) +
1511 " is out of bound: " + std::to_string(num_padded_ + shard_sample_count_[shard_sample_count_.size() - 1]));
1512 }
1513
1514 uint32_t shard_id = 0;
1515 uint32_t group_id = 0;
1516 uint32_t blob_start = 0;
1517 uint32_t blob_end = 0;
1518 json var_fields;
1519 // Pick up task from task list
1520 ShardTask task = tasks_.GetTaskByID(task_id);
1521
1522 // check task type
1523 auto task_type = std::get<0>(task);
1524 if (task_type == TaskType::kPaddedTask) {
1525 *task_content_ptr =
1526 std::make_shared<TASK_CONTENT>(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1527 return Status::OK();
1528 }
1529
1530 shard_id = std::get<0>(std::get<1>(task)); // shard id
1531
1532 if (load_mode_ == LoadMode::kLazy || load_mode_ == LoadMode::kSlow) {
1533 // get scalar variable fields by sample id
1534 uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));
1535
1536 // read the meta from index
1537 std::shared_ptr<ROW_GROUPS> row_group_ptr;
1538 RETURN_IF_NOT_OK_MR(
1539 ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, consumer_id, sample_id_in_shard, &row_group_ptr));
1540 auto &offsets = std::get<0>(*row_group_ptr);
1541 auto &local_columns = std::get<1>(*row_group_ptr);
1542
1543 group_id = offsets[shard_id][0][1]; // group_id
1544 blob_start = offsets[shard_id][0][2]; // blob start
1545 blob_end = offsets[shard_id][0][3]; // blob end
1546 var_fields = local_columns[shard_id][0]; // scalar variable field
1547 } else {
1548 group_id = std::get<1>(std::get<1>(task)); // group id
1549 blob_start = std::get<2>(task)[0]; // blob start
1550 blob_end = std::get<2>(task)[1]; // blob end
1551 var_fields = std::get<3>(task); // scalar variable field
1552 }
1553
1554 // read the blob from data file
1555 std::shared_ptr<Page> page_ptr;
1556 RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
1557 MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;
1558
1559 // Pack image list
1560 std::vector<uint8_t> images(blob_end - blob_start);
1561 auto file_offset = header_size_ + page_size_ * (page_ptr->GetPageID()) + blob_start;
1562
1563 auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
1564 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
1565 file_streams_random_[consumer_id][shard_id]->close();
1566 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
1567 }
1568 auto &io_read =
1569 file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
1570 if (!io_read.good() || io_read.fail() || io_read.bad()) {
1571 file_streams_random_[consumer_id][shard_id]->close();
1572 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
1573 }
1574
1575 // Deliver batch data to output map
1576 std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
1577 batch.emplace_back(std::move(images), std::move(var_fields));
1578
1579 *task_content_ptr = std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::move(batch));
1580 return Status::OK();
1581 }
1582
ConsumerByRow(int consumer_id)1583 void ShardReader::ConsumerByRow(int consumer_id) {
1584 // Set thread name
1585 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1586 pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(consumer_id)).c_str());
1587 #endif
1588
1589 // Loop forever
1590 for (;;) {
1591 int64_t sample_id_pos = 0;
1592
1593 // Get next task ID
1594 sample_id_pos = sample_id_position_++;
1595
1596 auto task_content_ptr =
1597 std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1598 int64_t task_id = 0;
1599
1600 if (load_mode_ == LoadMode::kFast || load_mode_ == LoadMode::kLazy) {
1601 // All tasks are done
1602 if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
1603 return;
1604 }
1605 task_id = tasks_.sample_ids_[sample_id_pos];
1606 } else {
1607 // task_id is not correct when slow load mode
1608 if (sample_id_pos >= shard_sample_count_[shard_sample_count_.size() - 1]) {
1609 return;
1610 }
1611 task_id = sample_id_pos;
1612 }
1613 if (ConsumerOneTask(task_id, consumer_id, &task_content_ptr).IsError()) {
1614 MS_LOG(ERROR) << "[Internal ERROR] Error raised in ConsumerOneTask function.";
1615 interrupt_ = true;
1616 cv_iterator_.notify_one();
1617 return;
1618 }
1619 const auto &batch = (*task_content_ptr).second;
1620 // Hanging if maximum map size exceeded
1621 // otherwise, set batch data in map
1622 {
1623 std::unique_lock<std::mutex> lck(mtx_delivery_);
1624 cv_delivery_.wait(lck,
1625 [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
1626 if (interrupt_) {
1627 return;
1628 }
1629 delivery_map_[sample_id_pos] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(batch);
1630 }
1631 cv_iterator_.notify_one();
1632 }
1633 }
1634
GetNext()1635 std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> ShardReader::GetNext() {
1636 if (interrupt_) {
1637 return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1638 }
1639
1640 if (deliver_id_ >= static_cast<int>(tasks_.SizeAfterSampling())) {
1641 return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1642 }
1643
1644 std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>> res;
1645 {
1646 std::unique_lock<std::mutex> lck(mtx_delivery_);
1647 cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); });
1648 if (interrupt_) {
1649 return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1650 }
1651 res = delivery_map_[deliver_id_];
1652 delivery_map_.erase(deliver_id_++);
1653 }
1654
1655 cv_delivery_.notify_all();
1656
1657 // extract every blob field from blob data
1658 std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> res_with_blobs;
1659 for (auto iter = res->begin(); iter != res->end(); iter++) {
1660 std::map<std::string, std::vector<uint8_t>> key_with_blob_fields;
1661 auto shard_column = GetShardColumn();
1662 auto schema = shard_header_->GetSchemas(); // current, we only support 1 schema yet
1663 auto blob_fields = schema[0]->GetBlobFields();
1664 for (auto blob_field : blob_fields) {
1665 const unsigned char *data = nullptr;
1666 std::unique_ptr<unsigned char[]> data_ptr;
1667 uint64_t n_bytes = 0;
1668 mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
1669 uint64_t column_data_type_size = 1;
1670 std::vector<int64_t> column_shape;
1671 if (shard_column->GetColumnValueByName(blob_field, std::get<0>(*iter), std::get<1>(*iter), &data, &data_ptr,
1672 &n_bytes, &column_data_type, &column_data_type_size,
1673 &column_shape) != Status::OK()) {
1674 MS_LOG(ERROR) << "[Internal ERROR] Failed to extract blob fields from blob data";
1675 return std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>>();
1676 }
1677 key_with_blob_fields[blob_field] = std::vector<uint8_t>(data, data + n_bytes);
1678 }
1679
1680 res_with_blobs.emplace_back(std::move(key_with_blob_fields), std::move(std::get<1>(*iter)));
1681 }
1682
1683 return res_with_blobs;
1684 }
1685
GetNextById(const int64_t & task_id,const int32_t & consumer_id,std::shared_ptr<TASK_CONTENT> * task_content_ptr)1686 Status ShardReader::GetNextById(const int64_t &task_id, const int32_t &consumer_id,
1687 std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
1688 if (interrupt_) {
1689 return Status::OK();
1690 }
1691 RETURN_IF_NOT_OK_MR(ConsumerOneTask(task_id, consumer_id, task_content_ptr));
1692 return Status::OK();
1693 }
1694
UnCompressBlob(const std::vector<uint8_t> & raw_blob_data,std::shared_ptr<std::vector<std::vector<uint8_t>>> * blob_data_ptr)1695 Status ShardReader::UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
1696 std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr) {
1697 RETURN_UNEXPECTED_IF_NULL_MR(blob_data_ptr);
1698 auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
1699 auto blob_fields = GetBlobFields().second;
1700 for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
1701 if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) {
1702 continue;
1703 }
1704 const unsigned char *data = nullptr;
1705 std::unique_ptr<unsigned char[]> data_ptr;
1706 uint64_t n_bytes = 0;
1707 RETURN_IF_NOT_OK_MR(
1708 shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes));
1709 if (data == nullptr) {
1710 data = reinterpret_cast<const unsigned char *>(data_ptr.get());
1711 }
1712 std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
1713 (*blob_data_ptr)->push_back(column);
1714 }
1715 return Status::OK();
1716 }
1717
GetTotalBlobSize(int64_t * total_blob_size)1718 Status ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
1719 *total_blob_size = total_blob_size_;
1720 return Status::OK();
1721 }
1722
Reset()1723 void ShardReader::Reset() {
1724 {
1725 std::lock_guard<std::mutex> lck(mtx_delivery_);
1726 sample_id_position_ = 0;
1727 deliver_id_ = 0;
1728 }
1729 cv_delivery_.notify_all();
1730 }
1731
ShuffleTask()1732 void ShardReader::ShuffleTask() {
1733 // exist shuffle and distributed sampler in ops, skip shuffle
1734 bool has_sharding = false;
1735 for (const auto &op : operators_) {
1736 if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1737 has_sharding = true;
1738 }
1739 }
1740 for (const auto &op : operators_) {
1741 if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
1742 auto s = (*op)(tasks_);
1743 if (s.IsError()) {
1744 MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
1745 }
1746 } else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1747 auto s = (*op)(tasks_);
1748 if (s.IsError()) {
1749 MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
1750 }
1751 }
1752 }
1753 if (load_mode_ != kSlow) {
1754 if (tasks_.permutation_.empty()) {
1755 tasks_.MakePerm();
1756 }
1757 } else {
1758 tasks_.generator_ids_.ResetShardIndexAndID();
1759 }
1760 }
1761
GetSampleIds()1762 const std::vector<int64_t> *ShardReader::GetSampleIds() {
1763 // return const reference to private sample id list.
1764 return &(this->tasks_.sample_ids_);
1765 }
1766
GetLoadMode() const1767 LoadMode ShardReader::GetLoadMode() const { return load_mode_; }
1768
GetNextSampleIds()1769 std::vector<int64_t> ShardReader::GetNextSampleIds() { return tasks_.GetNextSampleIds(); }
1770 } // namespace mindrecord
1771 } // namespace mindspore
1772