1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_reader.h"
18
19 #include <algorithm>
20 #include <thread>
21
22 #include "utils/file_utils.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "utils/ms_utils.h"
25
26 using mindspore::LogStream;
27 using mindspore::ExceptionType::NoExceptionType;
28 using mindspore::MsLogLevel::DEBUG;
29 using mindspore::MsLogLevel::ERROR;
30 using mindspore::MsLogLevel::INFO;
31
32 namespace mindspore {
33 namespace mindrecord {
34 template <class Type>
35 // convert the string to exactly number type (int32_t/int64_t/float/double)
StringToNum(const std::string & str)36 Type StringToNum(const std::string &str) {
37 std::istringstream iss(str);
38 Type num;
39 iss >> num;
40 return num;
41 }
42
ShardReader()43 ShardReader::ShardReader()
44 : header_size_(0),
45 page_size_(0),
46 shard_count_(0),
47 n_consumer_(0),
48 num_padded_(0),
49 num_rows_(0),
50 total_blob_size_(0),
51 sample_id_position_(0),
52 deliver_id_(0),
53 lazy_load_(false),
54 shard_sample_count_() {}
55
GetMeta(const std::string & file_path,std::shared_ptr<json> meta_data_ptr,std::shared_ptr<std::vector<std::string>> * addresses_ptr)56 Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
57 std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
58 RETURN_UNEXPECTED_IF_NULL(addresses_ptr);
59 CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(file_path), "Invalid file, path: " + file_path);
60 std::shared_ptr<json> header_ptr;
61 RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path, &header_ptr));
62
63 *meta_data_ptr = {{"header_size", (*header_ptr)["header_size"]}, {"page_size", (*header_ptr)["page_size"]},
64 {"version", (*header_ptr)["version"]}, {"index_fields", (*header_ptr)["index_fields"]},
65 {"schema", (*header_ptr)["schema"]}, {"blob_fields", (*header_ptr)["blob_fields"]}};
66 *addresses_ptr = std::make_shared<std::vector<std::string>>((*header_ptr)["shard_addresses"]);
67 return Status::OK();
68 }
69
Init(const std::vector<std::string> & file_paths,bool load_dataset)70 Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
71 std::string file_path = file_paths[0];
72 auto first_meta_data_ptr = std::make_shared<json>();
73 std::shared_ptr<std::vector<std::string>> addresses_ptr;
74 RETURN_IF_NOT_OK(GetMeta(file_path, first_meta_data_ptr, &addresses_ptr));
75 if (file_paths.size() == 1 && load_dataset == true) {
76 auto ds = std::make_shared<std::vector<std::string>>();
77 RETURN_IF_NOT_OK(GetDatasetFiles(file_path, *addresses_ptr, &ds));
78 file_paths_ = *ds;
79 } else if (file_paths.size() >= 1 && load_dataset == false) {
80 file_paths_ = file_paths;
81 } else {
82 RETURN_STATUS_UNEXPECTED("Invalid data, number of MindRecord files [" + std::to_string(file_paths.size()) +
83 "] or 'load_dataset' [" + std::to_string(load_dataset) + "]is invalid.");
84 }
85 for (const auto &file : file_paths_) {
86 auto meta_data_ptr = std::make_shared<json>();
87 RETURN_IF_NOT_OK(GetMeta(file, meta_data_ptr, &addresses_ptr));
88 CHECK_FAIL_RETURN_UNEXPECTED(*meta_data_ptr == *first_meta_data_ptr,
89 "Invalid data, MindRecord files meta data is not consistent.");
90 sqlite3 *db = nullptr;
91 RETURN_IF_NOT_OK(VerifyDataset(&db, file));
92 database_paths_.push_back(db);
93 }
94 ShardHeader sh = ShardHeader();
95 RETURN_IF_NOT_OK(sh.BuildDataset(file_paths_, load_dataset));
96 shard_header_ = std::make_shared<ShardHeader>(sh);
97 header_size_ = shard_header_->GetHeaderSize();
98 page_size_ = shard_header_->GetPageSize();
99 // version < 3.0
100 if ((*first_meta_data_ptr)["version"] < kVersion) {
101 shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
102 } else {
103 shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
104 }
105 num_rows_ = 0;
106 auto row_group_summary = ReadRowGroupSummary();
107
108 // clear the shard_sample_count_, because it will be insert when Launch func
109 shard_sample_count_.clear();
110
111 for (const auto &rg : row_group_summary) {
112 num_rows_ += std::get<3>(rg);
113 }
114
115 if (num_rows_ > LAZY_LOAD_THRESHOLD) {
116 lazy_load_ = true;
117 MS_LOG(WARNING)
118 << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
119 << ", enable lazy load mode. If you want to speed up data loading, "
120 << "it is recommended that you save multiple samples into one record when creating MindRecord files,"
121 << " so that you can enable fast loading mode, and don't forget to adjust your batch size "
122 << "according to the current samples.";
123 }
124
125 auto disk_size = page_size_ * row_group_summary.size();
126 auto compression_size = shard_header_->GetCompressionSize();
127 total_blob_size_ = disk_size + compression_size;
128 MS_LOG(INFO) << "Blob data size on disk: " << disk_size << " , additional uncompression size: " << compression_size
129 << " , Total blob size: " << total_blob_size_;
130
131 MS_LOG(INFO) << "Succeed to get meta from mindrecord file & index file.";
132
133 return Status::OK();
134 }
135
VerifyDataset(sqlite3 ** db,const string & file)136 Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
137 // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
138 CHECK_FAIL_RETURN_UNEXPECTED(
139 sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
140 "Invalid database file, path: " + file + ".db, " + sqlite3_errmsg(*db));
141 MS_LOG(DEBUG) << "Succeed to Open database, path: " << file << ".db.";
142
143 string sql = "SELECT NAME from SHARD_NAME;";
144 std::vector<std::vector<std::string>> name;
145 char *errmsg = nullptr;
146 if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
147 std::ostringstream oss;
148 oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
149 sqlite3_free(errmsg);
150 sqlite3_close(*db);
151 RETURN_STATUS_UNEXPECTED(oss.str());
152 } else {
153 MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(name.size()) << " records from index.";
154 std::shared_ptr<std::string> fn_ptr;
155 RETURN_IF_NOT_OK(GetFileName(file, &fn_ptr));
156 if (name.empty() || name[0][0] != *fn_ptr) {
157 sqlite3_free(errmsg);
158 sqlite3_close(*db);
159 RETURN_STATUS_UNEXPECTED("Invalid database file, shard name [" + *fn_ptr + "] can not match [" + name[0][0] +
160 "].");
161 }
162 }
163 return Status::OK();
164 }
165
CheckColumnList(const std::vector<std::string> & selected_columns)166 Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
167 vector<int> inSchema(selected_columns.size(), 0);
168 for (auto &p : GetShardHeader()->GetSchemas()) {
169 auto schema = p->GetSchema()["schema"];
170 for (unsigned int i = 0; i < selected_columns.size(); ++i) {
171 if (schema.find(selected_columns[i]) != schema.end()) {
172 inSchema[i] = 1;
173 }
174 }
175 }
176 CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; }),
177 "Invalid data, column is not found in schema.");
178 return Status::OK();
179 }
180
Open()181 Status ShardReader::Open() {
182 file_streams_.clear();
183 for (const auto &file : file_paths_) {
184 std::optional<std::string> dir = "";
185 std::optional<std::string> local_file_name = "";
186 FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
187 if (!dir.has_value()) {
188 dir = ".";
189 }
190
191 auto realpath = FileUtils::GetRealPath(dir.value().data());
192 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
193
194 std::optional<std::string> whole_path = "";
195 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
196
197 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
198 fs->open(whole_path.value(), std::ios::in | std::ios::binary);
199 if (!fs->good()) {
200 RETURN_STATUS_UNEXPECTED(
201 "Failed to open file: " + file +
202 ", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
203 }
204 MS_LOG(INFO) << "Succeed to open shard file.";
205 file_streams_.push_back(fs);
206 }
207 return Status::OK();
208 }
209
Open(int n_consumer)210 Status ShardReader::Open(int n_consumer) {
211 file_streams_random_ =
212 std::vector<std::vector<std::shared_ptr<std::fstream>>>(n_consumer, std::vector<std::shared_ptr<std::fstream>>());
213 for (const auto &file : file_paths_) {
214 for (int j = 0; j < n_consumer; ++j) {
215 std::optional<std::string> dir = "";
216 std::optional<std::string> local_file_name = "";
217 FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
218 if (!dir.has_value()) {
219 dir = ".";
220 }
221
222 auto realpath = FileUtils::GetRealPath(dir.value().data());
223 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
224
225 std::optional<std::string> whole_path = "";
226 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
227
228 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
229 fs->open(whole_path.value(), std::ios::in | std::ios::binary);
230 if (!fs->good()) {
231 RETURN_STATUS_UNEXPECTED(
232 "Failed to open file: " + file +
233 ", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
234 }
235 file_streams_random_[j].push_back(fs);
236 }
237 MS_LOG(INFO) << "Succeed to open file, path: " << file;
238 }
239 return Status::OK();
240 }
241
FileStreamsOperator()242 void ShardReader::FileStreamsOperator() {
243 for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; --i) {
244 if (file_streams_[i] != nullptr) {
245 file_streams_[i]->close();
246 }
247 }
248 for (int i = static_cast<int>(file_streams_random_.size()) - 1; i >= 0; --i) {
249 for (int j = static_cast<int>(file_streams_random_[i].size()) - 1; j >= 0; --j) {
250 if (file_streams_random_[i][j] != nullptr) {
251 file_streams_random_[i][j]->close();
252 }
253 }
254 }
255 for (int i = static_cast<int>(database_paths_.size()) - 1; i >= 0; --i) {
256 if (database_paths_[i] != nullptr) {
257 auto ret = sqlite3_close(database_paths_[i]);
258 if (ret != SQLITE_OK) {
259 MS_LOG(ERROR) << "Failed to close database, error code: " << ret << ".";
260 }
261 database_paths_[i] = nullptr;
262 }
263 }
264 }
265
~ShardReader()266 ShardReader::~ShardReader() { Close(); }
267
Close()268 void ShardReader::Close() {
269 {
270 std::lock_guard<std::mutex> lck(mtx_delivery_);
271 interrupt_ = true; // interrupt reading and stop threads
272 }
273 cv_delivery_.notify_all();
274
275 // Wait for all threads to finish
276 for (auto &i_thread : thread_set_) {
277 if (i_thread.joinable()) {
278 i_thread.join();
279 }
280 }
281
282 FileStreamsOperator();
283 }
284
GetShardHeader() const285 std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
286
GetShardColumn() const287 std::shared_ptr<ShardColumn> ShardReader::GetShardColumn() const { return shard_column_; }
288
GetShardCount() const289 int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
290
GetNumRows() const291 int ShardReader::GetNumRows() const { return num_rows_; }
292
ReadRowGroupSummary()293 std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
294 std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
295 int shard_count = shard_header_->GetShardCount();
296 if (shard_count <= 0) {
297 return row_group_summary;
298 }
299 if (shard_count <= kMaxFileCount) {
300 uint32_t total_count = 0;
301 for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
302 // return -1 when page's size equals to 0.
303 auto last_page_id = shard_header_->GetLastPageId(shard_id);
304 if (static_cast<int>(last_page_id) == -1) {
305 continue;
306 }
307 for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
308 std::shared_ptr<Page> page_ptr;
309 (void)shard_header_->GetPage(shard_id, page_id, &page_ptr);
310 if (page_ptr->GetPageType() != kPageTypeBlob) {
311 continue;
312 }
313 uint64_t start_row_id = page_ptr->GetStartRowID();
314 if (start_row_id > page_ptr->GetEndRowID()) {
315 return std::vector<std::tuple<int, int, int, uint64_t>>();
316 }
317 uint64_t number_of_rows = page_ptr->GetEndRowID() - start_row_id;
318 total_count += number_of_rows;
319 row_group_summary.emplace_back(shard_id, page_ptr->GetPageTypeID(), start_row_id, number_of_rows);
320 }
321 shard_sample_count_.push_back(total_count);
322 }
323 }
324
325 return row_group_summary;
326 }
327
ConvertLabelToJson(const std::vector<std::vector<std::string>> & labels,std::shared_ptr<std::fstream> fs,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)328 Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
329 std::shared_ptr<std::fstream> fs,
330 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
331 int shard_id, const std::vector<std::string> &columns,
332 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
333 for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
334 try {
335 uint64_t group_id = std::stoull(labels[i][0]);
336 uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
337 uint64_t offset_end = std::stoull(labels[i][2]);
338 (*offset_ptr)[shard_id].emplace_back(
339 std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
340 if (!all_in_index_) {
341 int raw_page_id = std::stoi(labels[i][3]);
342 uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
343 uint64_t label_end = std::stoull(labels[i][5]);
344 auto len = label_end - label_start;
345 auto label_raw = std::vector<uint8_t>(len);
346 auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
347 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
348 fs->close();
349 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
350 }
351 auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
352 if (!io_read.good() || io_read.fail() || io_read.bad()) {
353 fs->close();
354 RETURN_STATUS_UNEXPECTED("Failed to read file.");
355 }
356 json label_json = json::from_msgpack(label_raw);
357 json tmp;
358 if (!columns.empty()) {
359 for (auto &col : columns) {
360 if (label_json.find(col) != label_json.end()) {
361 tmp[col] = label_json[col];
362 }
363 }
364 } else {
365 tmp = label_json;
366 }
367 (*col_val_ptr)[shard_id].emplace_back(tmp);
368 } else {
369 json construct_json;
370 for (unsigned int j = 0; j < columns.size(); ++j) {
371 // construct json "f1": value
372 auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
373
374 // convert the string to base type by schema
375 if (schema[columns[j]]["type"] == "int32") {
376 construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
377 } else if (schema[columns[j]]["type"] == "int64") {
378 construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
379 } else if (schema[columns[j]]["type"] == "float32") {
380 construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
381 } else if (schema[columns[j]]["type"] == "float64") {
382 construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
383 } else {
384 construct_json[columns[j]] = std::string(labels[i][j + 3]);
385 }
386 }
387 (*col_val_ptr)[shard_id].emplace_back(construct_json);
388 }
389 } catch (std::out_of_range &e) {
390 fs->close();
391 RETURN_STATUS_UNEXPECTED("Out of range exception raised in ConvertLabelToJson function, " +
392 std::string(e.what()));
393 } catch (std::invalid_argument &e) {
394 fs->close();
395 RETURN_STATUS_UNEXPECTED("Invalid argument exception raised in ConvertLabelToJson function, " +
396 std::string(e.what()));
397 } catch (...) {
398 fs->close();
399 RETURN_STATUS_UNEXPECTED("Unknown exception raised in ConvertLabelToJson function");
400 }
401 }
402
403 fs->close();
404 return Status::OK();
405 }
406
ReadAllRowsInShard(int shard_id,const std::string & sql,const std::vector<std::string> & columns,std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr)407 Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
408 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
409 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
410 auto db = database_paths_[shard_id];
411 std::vector<std::vector<std::string>> labels;
412 char *errmsg = nullptr;
413 int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
414 if (rc != SQLITE_OK) {
415 std::ostringstream oss;
416 oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
417 sqlite3_free(errmsg);
418 sqlite3_close(db);
419 db = nullptr;
420 RETURN_STATUS_UNEXPECTED(oss.str());
421 }
422 MS_LOG(INFO) << "Succeed to get " << static_cast<int>(labels.size()) << " records from shard "
423 << std::to_string(shard_id) << " index.";
424
425 std::string file_name = file_paths_[shard_id];
426 auto realpath = FileUtils::GetRealPath(file_name.data());
427 if (!realpath.has_value()) {
428 sqlite3_free(errmsg);
429 sqlite3_close(db);
430 RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file_name);
431 }
432
433 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
434 if (!all_in_index_) {
435 fs->open(realpath.value(), std::ios::in | std::ios::binary);
436 if (!fs->good()) {
437 sqlite3_free(errmsg);
438 sqlite3_close(db);
439 RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file_name);
440 }
441 }
442 sqlite3_free(errmsg);
443 return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr);
444 }
445
GetAllClasses(const std::string & category_field,std::shared_ptr<std::set<std::string>> category_ptr)446 Status ShardReader::GetAllClasses(const std::string &category_field,
447 std::shared_ptr<std::set<std::string>> category_ptr) {
448 std::map<std::string, uint64_t> index_columns;
449 for (auto &field : GetShardHeader()->GetFields()) {
450 index_columns[field.second] = field.first;
451 }
452 CHECK_FAIL_RETURN_UNEXPECTED(index_columns.find(category_field) != index_columns.end(),
453 "Invalid data, index field " + category_field + " does not exist.");
454 std::shared_ptr<std::string> fn_ptr;
455 RETURN_IF_NOT_OK(
456 ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
457 std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
458 std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
459 for (int x = 0; x < shard_count_; x++) {
460 threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
461 }
462
463 for (int x = 0; x < shard_count_; x++) {
464 threads[x].join();
465 }
466 return Status::OK();
467 }
468
GetClassesInShard(sqlite3 * db,int shard_id,const std::string & sql,std::shared_ptr<std::set<std::string>> category_ptr)469 void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
470 std::shared_ptr<std::set<std::string>> category_ptr) {
471 if (db == nullptr) {
472 return;
473 }
474 std::vector<std::vector<std::string>> columns;
475 char *errmsg = nullptr;
476 int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg);
477 if (ret != SQLITE_OK) {
478 sqlite3_free(errmsg);
479 sqlite3_close(db);
480 db = nullptr;
481 MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
482 return;
483 }
484 MS_LOG(INFO) << "Succeed to get " << static_cast<int>(columns.size()) << " records from shard "
485 << std::to_string(shard_id) << " index.";
486 std::lock_guard<std::mutex> lck(shard_locker_);
487 for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
488 category_ptr->emplace(columns[i][0]);
489 }
490 sqlite3_free(errmsg);
491 }
492
ReadAllRowGroup(const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUPS> * row_group_ptr)493 Status ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns,
494 std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
495 RETURN_UNEXPECTED_IF_NULL(row_group_ptr);
496 std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
497 auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
498 shard_count_, std::vector<std::vector<uint64_t>>{});
499 auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
500
501 if (all_in_index_) {
502 for (unsigned int i = 0; i < columns.size(); ++i) {
503 fields += ',';
504 std::shared_ptr<std::string> fn_ptr;
505 RETURN_IF_NOT_OK(
506 ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
507 fields += *fn_ptr;
508 }
509 } else { // fetch raw data from Raw page while some field is not index.
510 fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
511 }
512
513 std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;";
514
515 std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_);
516 for (int x = 0; x < shard_count_; x++) {
517 thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr);
518 }
519
520 for (int x = 0; x < shard_count_; x++) {
521 thread_read_db[x].join();
522 }
523 *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
524 return Status::OK();
525 }
526
ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> & columns,const uint32_t & shard_id,const uint32_t & sample_id,std::shared_ptr<ROW_GROUPS> * row_group_ptr)527 Status ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
528 const uint32_t &sample_id,
529 std::shared_ptr<ROW_GROUPS> *row_group_ptr) {
530 RETURN_UNEXPECTED_IF_NULL(row_group_ptr);
531 std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
532 auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
533 shard_count_, std::vector<std::vector<uint64_t>>{});
534 auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
535 if (all_in_index_) {
536 for (unsigned int i = 0; i < columns.size(); ++i) {
537 fields += ',';
538 std::shared_ptr<std::string> fn_ptr;
539 RETURN_IF_NOT_OK(
540 ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr));
541 fields += *fn_ptr;
542 }
543 } else { // fetch raw data from Raw page while some field is not index.
544 fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
545 }
546
547 std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
548
549 RETURN_IF_NOT_OK(ReadAllRowsInShard(shard_id, sql, columns, offset_ptr, col_val_ptr));
550 *row_group_ptr = std::make_shared<ROW_GROUPS>(std::move(*offset_ptr), std::move(*col_val_ptr));
551 return Status::OK();
552 }
553
ReadRowGroupBrief(int group_id,int shard_id,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)554 Status ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
555 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
556 RETURN_UNEXPECTED_IF_NULL(row_group_brief_ptr);
557 std::shared_ptr<Page> page_ptr;
558 RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
559 std::string file_name = file_paths_[shard_id];
560 uint64_t page_length = page_ptr->GetPageSize();
561 uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
562 std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id);
563 auto labels_ptr = std::make_shared<std::vector<json>>();
564 RETURN_IF_NOT_OK(GetLabels(page_ptr->GetPageID(), shard_id, columns, {"", ""}, &labels_ptr));
565 *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
566 std::move(*labels_ptr));
567 return Status::OK();
568 }
569
ReadRowGroupCriteria(int group_id,int shard_id,const std::pair<std::string,std::string> & criteria,const std::vector<std::string> & columns,std::shared_ptr<ROW_GROUP_BRIEF> * row_group_brief_ptr)570 Status ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
571 const std::pair<std::string, std::string> &criteria,
572 const std::vector<std::string> &columns,
573 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr) {
574 RETURN_UNEXPECTED_IF_NULL(row_group_brief_ptr);
575 std::shared_ptr<Page> page_ptr;
576 RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
577 vector<string> criteria_list{criteria.first};
578 RETURN_IF_NOT_OK(CheckColumnList(criteria_list));
579 std::string file_name = file_paths_[shard_id];
580 uint64_t page_length = page_ptr->GetPageSize();
581 uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_;
582 std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id, criteria);
583 if (image_offset.empty()) {
584 *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>();
585 }
586 auto labels_ptr = std::make_shared<std::vector<json>>();
587 RETURN_IF_NOT_OK(GetLabels(page_ptr->GetPageID(), shard_id, columns, criteria, &labels_ptr));
588 *row_group_brief_ptr = std::make_shared<ROW_GROUP_BRIEF>(file_name, page_length, page_offset, std::move(image_offset),
589 std::move(*labels_ptr));
590 return Status::OK();
591 }
592
SelectCallback(void * p_data,int num_fields,char ** p_fields,char ** p_col_names)593 int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) {
594 auto *records = static_cast<std::vector<std::vector<std::string>> *>(p_data);
595 if (num_fields > 0 && num_fields <= kMaxFieldCount) {
596 for (int i = 0; i < num_fields; ++i)
597 if (p_fields[i] == nullptr) p_fields[i] = const_cast<char *>("");
598 }
599 records->emplace_back(p_fields, p_fields + num_fields);
600 return 0;
601 }
602
GetImageOffset(int page_id,int shard_id,const std::pair<std::string,std::string> & criteria)603 std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int shard_id,
604 const std::pair<std::string, std::string> &criteria) {
605 auto db = database_paths_[shard_id];
606
607 std::string sql =
608 "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
609
610 // whether use index search
611 if (!criteria.first.empty()) {
612 auto schema = shard_header_->GetSchemas()[0]->GetSchema();
613
614 // not number field should add '' in sql
615 if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
616 sql +=
617 " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
618 } else {
619 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
620 criteria.second + "'";
621 }
622 }
623 sql += ";";
624 std::vector<std::vector<std::string>> image_offsets;
625 char *errmsg = nullptr;
626 int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg);
627 if (rc != SQLITE_OK) {
628 MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
629 sqlite3_free(errmsg);
630 sqlite3_close(db);
631 db = nullptr;
632 return std::vector<std::vector<uint64_t>>();
633 } else {
634 MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(image_offsets.size()) << " records from index.";
635 }
636 std::vector<std::vector<uint64_t>> res;
637 for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
638 for (int i = 0; i < static_cast<int>(image_offsets.size()); i++) {
639 const auto &image_offset = image_offsets[i];
640 res[i][0] = std::stoull(image_offset[0]) + kInt64Len;
641 res[i][1] = std::stoull(image_offset[1]);
642 }
643 sqlite3_free(errmsg);
644 return res;
645 }
646
GetPagesByCategory(int shard_id,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<uint64_t>> * pages_ptr)647 Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
648 std::shared_ptr<std::vector<uint64_t>> *pages_ptr) {
649 RETURN_UNEXPECTED_IF_NULL(pages_ptr);
650 auto db = database_paths_[shard_id];
651
652 std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
653
654 if (!criteria.first.empty()) {
655 auto schema = shard_header_->GetSchemas()[0]->GetSchema();
656 if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
657 sql +=
658 " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
659 } else {
660 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
661 criteria.second + "'";
662 }
663 }
664 sql += ";";
665 std::vector<std::vector<std::string>> page_ids;
666 char *errmsg = nullptr;
667 int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg);
668 if (rc != SQLITE_OK) {
669 string ss(errmsg);
670 sqlite3_free(errmsg);
671 sqlite3_close(db);
672 db = nullptr;
673 RETURN_STATUS_UNEXPECTED(std::string("Failed to execute sql [") + common::SafeCStr(sql) + " ], " + ss);
674 } else {
675 MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << "pages from index.";
676 }
677 for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
678 (*pages_ptr)->emplace_back(std::stoull(page_ids[i][0]));
679 }
680 sqlite3_free(errmsg);
681 return Status::OK();
682 }
683
GetBlobFields()684 std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
685 std::vector<std::string> blob_fields;
686 for (auto &p : GetShardHeader()->GetSchemas()) {
687 // assume one schema
688 const auto &fields = p->GetBlobFields();
689 blob_fields.assign(fields.begin(), fields.end());
690 break;
691 }
692 return std::make_pair(kCV, blob_fields);
693 }
694
CheckIfColumnInIndex(const std::vector<std::string> & columns)695 void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) {
696 // assume different schemas do not contain same key.
697 if (columns.empty()) {
698 all_in_index_ = false;
699 return;
700 }
701 for (auto &field : GetShardHeader()->GetFields()) {
702 column_schema_id_[field.second] = field.first;
703 }
704 for (auto &col : columns) {
705 if (column_schema_id_.find(col) == column_schema_id_.end()) {
706 all_in_index_ = false;
707 return;
708 }
709 }
710 }
711
QueryWithCriteria(sqlite3 * db,const string & sql,const string & criteria,std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr)712 Status ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
713 std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
714 sqlite3_stmt *stmt = nullptr;
715 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
716 RETURN_STATUS_UNEXPECTED("Failed to prepare statement sql [ " + sql + " ].");
717 }
718 int index = sqlite3_bind_parameter_index(stmt, ":criteria");
719 if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
720 RETURN_STATUS_UNEXPECTED("Failed to bind parameter of sql, index: " + std::to_string(index) +
721 ", field value: " + criteria);
722 }
723 int rc = sqlite3_step(stmt);
724 while (rc != SQLITE_DONE) {
725 vector<string> tmp;
726 int ncols = sqlite3_column_count(stmt);
727 for (int i = 0; i < ncols; i++) {
728 tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
729 }
730 labels_ptr->push_back(tmp);
731 rc = sqlite3_step(stmt);
732 }
733 (void)sqlite3_finalize(stmt);
734 return Status::OK();
735 }
736
GetLabelsFromBinaryFile(int shard_id,const std::vector<std::string> & columns,const std::vector<std::vector<std::string>> & label_offsets,std::shared_ptr<std::vector<json>> * labels_ptr)737 Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
738 const std::vector<std::vector<std::string>> &label_offsets,
739 std::shared_ptr<std::vector<json>> *labels_ptr) {
740 RETURN_UNEXPECTED_IF_NULL(labels_ptr);
741 std::string file_name = file_paths_[shard_id];
742 auto realpath = FileUtils::GetRealPath(file_name.data());
743 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path=" + file_name);
744
745 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
746 fs->open(realpath.value(), std::ios::in | std::ios::binary);
747 CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file, path: " + file_name);
748 // init the return
749 for (unsigned int i = 0; i < label_offsets.size(); ++i) {
750 (*labels_ptr)->emplace_back(json{});
751 }
752
753 for (unsigned int i = 0; i < label_offsets.size(); ++i) {
754 const auto &labelOffset = label_offsets[i];
755 if (labelOffset.size() < 3) {
756 fs->close();
757 RETURN_STATUS_UNEXPECTED("Invalid data, labelOffset size: " + std::to_string(labelOffset.size()) +
758 " is invalid.");
759 }
760 uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len;
761 uint64_t label_end = std::stoull(labelOffset[2]);
762 int raw_page_id = std::stoi(labelOffset[0]);
763 auto len = label_end - label_start;
764 auto label_raw = std::vector<uint8_t>(len);
765 auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
766 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
767 fs->close();
768 RETURN_STATUS_UNEXPECTED("Failed to seekg file, path: " + file_name);
769 }
770
771 auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
772 if (!io_read.good() || io_read.fail() || io_read.bad()) {
773 fs->close();
774 RETURN_STATUS_UNEXPECTED("Failed to read file, path: " + file_name);
775 }
776
777 json label_json = json::from_msgpack(label_raw);
778 json tmp = label_json;
779 for (auto &col : columns) {
780 if (label_json.find(col) != label_json.end()) {
781 tmp[col] = label_json[col];
782 }
783 }
784 (*(*labels_ptr))[i] = tmp;
785 }
786 return Status::OK();
787 }
GetLabelsFromPage(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)788 Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
789 const std::pair<std::string, std::string> &criteria,
790 std::shared_ptr<std::vector<json>> *labels_ptr) {
791 RETURN_UNEXPECTED_IF_NULL(labels_ptr);
792 // get page info from sqlite
793 auto db = database_paths_[shard_id];
794 std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " +
795 std::to_string(page_id);
796 auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
797 if (!criteria.first.empty()) {
798 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
799 RETURN_IF_NOT_OK(QueryWithCriteria(db, sql, criteria.second, label_offset_ptr));
800 } else {
801 sql += ";";
802 char *errmsg = nullptr;
803 int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
804 if (rc != SQLITE_OK) {
805 std::ostringstream oss;
806 oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
807 sqlite3_free(errmsg);
808 sqlite3_close(db);
809 db = nullptr;
810 RETURN_STATUS_UNEXPECTED(oss.str());
811 }
812 MS_LOG(DEBUG) << "Succeed to get " << label_offset_ptr->size() << " records from index.";
813 sqlite3_free(errmsg);
814 }
815 // get labels from binary file
816 return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr, labels_ptr);
817 }
818
GetLabels(int page_id,int shard_id,const std::vector<std::string> & columns,const std::pair<std::string,std::string> & criteria,std::shared_ptr<std::vector<json>> * labels_ptr)819 Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
820 const std::pair<std::string, std::string> &criteria,
821 std::shared_ptr<std::vector<json>> *labels_ptr) {
822 RETURN_UNEXPECTED_IF_NULL(labels_ptr);
823 if (all_in_index_) {
824 auto db = database_paths_[shard_id];
825 std::string fields;
826 for (unsigned int i = 0; i < columns.size(); ++i) {
827 if (i > 0) fields += ',';
828 uint64_t schema_id = column_schema_id_[columns[i]];
829 fields += columns[i] + "_" + std::to_string(schema_id);
830 }
831 if (fields.empty()) {
832 fields = "*";
833 }
834 auto labels = std::make_shared<std::vector<std::vector<std::string>>>();
835 std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
836 if (!criteria.first.empty()) {
837 sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria";
838 RETURN_IF_NOT_OK(QueryWithCriteria(db, sql, criteria.second, labels));
839 } else {
840 sql += ";";
841 char *errmsg = nullptr;
842 int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels.get(), &errmsg);
843 if (rc != SQLITE_OK) {
844 std::ostringstream oss;
845 oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
846 sqlite3_free(errmsg);
847 sqlite3_close(db);
848 db = nullptr;
849 RETURN_STATUS_UNEXPECTED(oss.str());
850 } else {
851 MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(labels->size()) << " records from index.";
852 }
853 sqlite3_free(errmsg);
854 }
855 for (unsigned int i = 0; i < labels->size(); ++i) {
856 (*labels_ptr)->emplace_back(json{});
857 }
858 for (unsigned int i = 0; i < labels->size(); ++i) {
859 json construct_json;
860 for (unsigned int j = 0; j < columns.size(); ++j) {
861 // construct json "f1": value
862 auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
863
864 // convert the string to base type by schema
865 if (schema[columns[j]]["type"] == "int32") {
866 construct_json[columns[j]] = StringToNum<int32_t>((*labels)[i][j]);
867 } else if (schema[columns[j]]["type"] == "int64") {
868 construct_json[columns[j]] = StringToNum<int64_t>((*labels)[i][j]);
869 } else if (schema[columns[j]]["type"] == "float32") {
870 construct_json[columns[j]] = StringToNum<float>((*labels)[i][j]);
871 } else if (schema[columns[j]]["type"] == "float64") {
872 construct_json[columns[j]] = StringToNum<double>((*labels)[i][j]);
873 } else {
874 construct_json[columns[j]] = std::string((*labels)[i][j]);
875 }
876 }
877 (*(*labels_ptr))[i] = construct_json;
878 }
879 return Status::OK();
880 }
881 return GetLabelsFromPage(page_id, shard_id, columns, criteria, labels_ptr);
882 }
883
ResortRowGroups(std::tuple<int,int,int,int> a,std::tuple<int,int,int,int> b)884 bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, int> b) {
885 return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
886 }
887
GetNumClasses(const std::string & category_field)888 int64_t ShardReader::GetNumClasses(const std::string &category_field) {
889 auto shard_count = file_paths_.size();
890 auto index_fields = shard_header_->GetFields();
891
892 std::map<std::string, int64_t> map_schema_id_fields;
893 for (auto &field : index_fields) {
894 map_schema_id_fields[field.second] = field.first;
895 }
896
897 if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
898 MS_LOG(ERROR) << "Invalid data, field " << category_field << " does not exist.";
899 return -1;
900 }
901 std::shared_ptr<std::string> fn_ptr;
902 (void)ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field),
903 &fn_ptr);
904 std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES";
905 std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
906 auto category_ptr = std::make_shared<std::set<std::string>>();
907 sqlite3 *db = nullptr;
908 for (int x = 0; x < shard_count; x++) {
909 int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
910 if (SQLITE_OK != rc) {
911 MS_LOG(ERROR) << "Failed to open database: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
912 return -1;
913 }
914 threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
915 }
916
917 for (int x = 0; x < shard_count; x++) {
918 threads[x].join();
919 }
920 sqlite3_close(db);
921 return category_ptr->size();
922 }
923
CountTotalRows(const std::vector<std::string> & file_paths,bool load_dataset,const std::shared_ptr<ShardOperator> & ops,int64_t * count,const int num_padded)924 Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
925 const std::shared_ptr<ShardOperator> &ops, int64_t *count, const int num_padded) {
926 RETURN_IF_NOT_OK(Init(file_paths, load_dataset));
927 int64_t num_samples = num_rows_;
928 bool root = true;
929 std::stack<std::shared_ptr<ShardOperator>> stack_ops;
930 std::shared_ptr<ShardOperator> op(ops);
931 while (op != nullptr) {
932 stack_ops.push(op);
933 op = op->GetChildOp();
934 }
935 while (!stack_ops.empty()) {
936 op = stack_ops.top();
937 stack_ops.pop();
938 if (std::dynamic_pointer_cast<ShardShuffle>(op)) {
939 num_samples = op->GetNumSamples(num_samples, 0);
940 if (num_padded > 0 && root == true) {
941 num_samples += num_padded;
942 root = false;
943 }
944 } else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
945 auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
946 std::string category_field = category_op->GetCategoryField();
947 auto num_classes = GetNumClasses(category_field);
948 num_samples = category_op->GetNumSamples(num_samples, num_classes);
949 if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
950 auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
951 if (tmp != 0 && num_samples != -1) {
952 num_samples = std::min(num_samples, tmp);
953 }
954 CHECK_FAIL_RETURN_UNEXPECTED(
955 num_samples != -1, "Invalid input, number of samples: " + std::to_string(num_samples) +
956 " exceeds the upper limit: " + std::to_string(std::numeric_limits<int64_t>::max()));
957 }
958 } else if (std::dynamic_pointer_cast<ShardSample>(op)) {
959 if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
960 auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
961 if (root == true) {
962 sampler_op->SetNumPaddedSamples(num_padded);
963 num_samples = op->GetNumSamples(num_samples, 0);
964 CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1, "Invalid data, dataset size plus number of padded samples: " +
965 std::to_string(num_samples) +
966 " can not be divisible by number of shards.");
967 root = false;
968 }
969 } else {
970 num_samples = op->GetNumSamples(num_samples, 0);
971 }
972 } else {
973 if (num_padded > 0) {
974 num_samples += num_padded;
975 }
976 }
977 }
978 *count = num_samples;
979 return Status::OK();
980 }
981
Open(const std::vector<std::string> & file_paths,bool load_dataset,int n_consumer,const std::vector<std::string> & selected_columns,const std::vector<std::shared_ptr<ShardOperator>> & operators,int num_padded,bool lazy_load)982 Status ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
983 const std::vector<std::string> &selected_columns,
984 const std::vector<std::shared_ptr<ShardOperator>> &operators, int num_padded, bool lazy_load) {
985 lazy_load_ = lazy_load;
986
987 // Open file and set header by ShardReader
988 RETURN_IF_NOT_OK(Init(file_paths, load_dataset));
989 auto thread_limit = GetMaxThreadNum();
990 if (n_consumer > thread_limit) {
991 n_consumer = thread_limit;
992 }
993 if (n_consumer < kMinConsumerCount) {
994 n_consumer = kMinConsumerCount;
995 }
996
997 selected_columns_ = selected_columns;
998 RETURN_IF_NOT_OK(CheckColumnList(selected_columns_));
999
1000 // Initialize argument
1001 shard_count_ = static_cast<int>(file_paths_.size());
1002 n_consumer_ = n_consumer;
1003 num_padded_ = num_padded;
1004
1005 operators_ = operators;
1006 RETURN_IF_NOT_OK(Open(n_consumer));
1007 return Status::OK();
1008 }
1009
Launch(bool is_sample_read)1010 Status ShardReader::Launch(bool is_sample_read) {
1011 // Get all row groups' info
1012 auto row_group_summary = ReadRowGroupSummary();
1013
1014 // Sort row group by (group_id, shard_id), prepare for parallel reading
1015 std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
1016 if (CreateTasks(row_group_summary, operators_).IsError()) {
1017 interrupt_ = true;
1018 RETURN_STATUS_UNEXPECTED("Failed to launch read threads.");
1019 }
1020 if (is_sample_read) {
1021 return Status::OK();
1022 }
1023 // Start provider consumer threads
1024 thread_set_ = std::vector<std::thread>(n_consumer_);
1025 CHECK_FAIL_RETURN_UNEXPECTED(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
1026 "Invalid data, number of consumer: " + std::to_string(n_consumer_) +
1027 " exceeds the upper limit: " + std::to_string(kMaxConsumerCount));
1028
1029 for (int x = 0; x < n_consumer_; ++x) {
1030 thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
1031 }
1032
1033 MS_LOG(INFO) << "Succeed to launch read thread.";
1034 return Status::OK();
1035 }
1036
CreateTasksByCategory(const std::shared_ptr<ShardOperator> & op)1037 Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
1038 CheckIfColumnInIndex(selected_columns_);
1039 auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
1040 auto categories = category_op->GetCategories();
1041 int64_t num_elements = category_op->GetNumElements();
1042 int64_t num_samples = 0;
1043 if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
1044 num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
1045 CHECK_FAIL_RETURN_UNEXPECTED(
1046 num_samples >= 0,
1047 "Invalid input, num_samples must be greater than or equal to 0, but got " + std::to_string(num_samples));
1048 }
1049 CHECK_FAIL_RETURN_UNEXPECTED(
1050 num_elements > 0, "Invalid input, num_elements must be greater than 0, but got " + std::to_string(num_elements));
1051 if (categories.empty() == true) {
1052 std::string category_field = category_op->GetCategoryField();
1053 int64_t num_categories = category_op->GetNumCategories();
1054 CHECK_FAIL_RETURN_UNEXPECTED(num_categories > 0, "Invalid input, num_categories must be greater than 0, but got " +
1055 std::to_string(num_elements));
1056 auto category_ptr = std::make_shared<std::set<std::string>>();
1057 RETURN_IF_NOT_OK(GetAllClasses(category_field, category_ptr));
1058 int i = 0;
1059 for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
1060 categories.emplace_back(category_field, *it);
1061 i++;
1062 }
1063 }
1064 // Generate a vector of task lists. Each catogory has a list of tasks.
1065 std::vector<ShardTaskList> categoryTasks(categories.size());
1066 for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
1067 int category_index = 0;
1068 for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
1069 auto pages_ptr = std::make_shared<std::vector<uint64_t>>();
1070 RETURN_IF_NOT_OK(GetPagesByCategory(shard_id, categories[categoryNo], &pages_ptr));
1071 for (const auto &page_id : *pages_ptr) {
1072 if (category_index >= num_elements) {
1073 break;
1074 }
1075 std::shared_ptr<Page> page_ptr;
1076 RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, page_id, &page_ptr));
1077 auto group_id = page_ptr->GetPageTypeID();
1078 std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
1079 RETURN_IF_NOT_OK(
1080 ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_, &row_group_brief_ptr));
1081 auto offsets = std::get<3>(*row_group_brief_ptr);
1082
1083 auto number_of_rows = offsets.size();
1084 for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
1085 if (category_index < num_elements) {
1086 categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
1087 std::get<3>(*row_group_brief_ptr)[iStart],
1088 std::get<4>(*row_group_brief_ptr)[iStart]);
1089 category_index++;
1090 }
1091 }
1092 MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
1093 }
1094 }
1095 }
1096 tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
1097
1098 tasks_.InitSampleIds();
1099 RETURN_IF_NOT_OK((*category_op)(tasks_));
1100 return Status::OK();
1101 }
1102
CreateTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1103 Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1104 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1105 CheckIfColumnInIndex(selected_columns_);
1106 std::shared_ptr<ROW_GROUPS> row_group_ptr;
1107 RETURN_IF_NOT_OK(ReadAllRowGroup(selected_columns_, &row_group_ptr));
1108 auto &offsets = std::get<0>(*row_group_ptr);
1109 auto &local_columns = std::get<1>(*row_group_ptr);
1110 CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
1111 "Invalid data, number of shards: " + std::to_string(shard_count_) +
1112 " exceeds the upper limit: " + std::to_string(kMaxFileCount));
1113 int sample_count = 0;
1114 for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
1115 sample_count += offsets[shard_id].size();
1116 }
1117 MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1118
1119 // Init the tasks_ size
1120 tasks_.ResizeTask(sample_count);
1121
1122 // Init the task threads, maybe use ThreadPool is better
1123 std::vector<std::thread> init_tasks_thread(shard_count_);
1124
1125 uint32_t current_offset = 0;
1126 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1127 init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() {
1128 auto offset = current_offset;
1129 for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
1130 tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
1131 std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
1132 local_columns[shard_id][i]);
1133 offset++;
1134 }
1135 });
1136 current_offset += offsets[shard_id].size();
1137 }
1138
1139 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1140 init_tasks_thread[shard_id].join();
1141 }
1142 return Status::OK();
1143 }
1144
CreateLazyTasksByRow(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1145 Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1146 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1147 CheckIfColumnInIndex(selected_columns_);
1148 CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
1149 "Invalid data, number of shards: " + std::to_string(shard_count_) +
1150 " exceeds the upper limit: " + std::to_string(kMaxFileCount));
1151 uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
1152 MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
1153
1154 // Init the tasks_ size
1155 tasks_.ResizeTask(sample_count);
1156
1157 // Init the task threads, maybe use ThreadPool is better
1158 std::vector<std::thread> init_tasks_thread(shard_count_);
1159
1160 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1161 // the offset indicate the shard start
1162 uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];
1163
1164 // the count indicate the number of samples in the shard
1165 uint32_t shard_count =
1166 shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
1167 init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
1168 for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
1169 // here "i - current_offset" indicate the sample id in the shard
1170 tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
1171 }
1172 });
1173 }
1174
1175 for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
1176 init_tasks_thread[shard_id].join();
1177 }
1178 return Status::OK();
1179 }
1180
CreateTasks(const std::vector<std::tuple<int,int,int,uint64_t>> & row_group_summary,const std::vector<std::shared_ptr<ShardOperator>> & operators)1181 Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
1182 const std::vector<std::shared_ptr<ShardOperator>> &operators) {
1183 int category_operator = -1;
1184 for (uint32_t i = 0; i < operators.size(); ++i) {
1185 const auto &op = operators[i];
1186 if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1187 category_operator = static_cast<int>(i);
1188 break;
1189 }
1190 }
1191
1192 if (-1 == category_operator) {
1193 if (lazy_load_ == false) {
1194 RETURN_IF_NOT_OK(CreateTasksByRow(row_group_summary, operators));
1195 } else {
1196 RETURN_IF_NOT_OK(CreateLazyTasksByRow(row_group_summary, operators));
1197 }
1198
1199 // need padded sample to the task
1200 if (num_padded_ > 0) {
1201 for (int i = 0; i < num_padded_; ++i) {
1202 tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
1203 }
1204 }
1205 } else {
1206 RETURN_IF_NOT_OK(CreateTasksByCategory(operators[category_operator]));
1207 }
1208 MS_LOG(DEBUG) << "Succeed to create " << tasks_.Size() << " initial task to start with before sampling.";
1209 tasks_.InitSampleIds();
1210
1211 for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
1212 const auto &op = operators[operator_no];
1213 if (std::dynamic_pointer_cast<ShardCategory>(op)) {
1214 continue;
1215 }
1216
1217 if (std::dynamic_pointer_cast<ShardDistributedSample>(op) || std::dynamic_pointer_cast<ShardShuffle>(op)) {
1218 op->SetShardSampleCount(shard_sample_count_);
1219 }
1220 RETURN_IF_NOT_OK((*op)(tasks_));
1221 }
1222
1223 if (tasks_.permutation_.empty()) tasks_.MakePerm();
1224 num_rows_ = tasks_.Size();
1225 MS_LOG(INFO) << "The total number of samples is " << num_rows_
1226 << ", the number of samples after sampling is: " << tasks_.sample_ids_.size();
1227
1228 return Status::OK();
1229 }
1230
ConsumerOneTask(int task_id,uint32_t consumer_id,std::shared_ptr<TASK_CONTENT> * task_content_ptr)1231 Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
1232 std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
1233 RETURN_UNEXPECTED_IF_NULL(task_content_ptr);
1234 // All tasks are done
1235 CHECK_FAIL_RETURN_UNEXPECTED(
1236 task_id < static_cast<int>(tasks_.Size()),
1237 "Invalid data, task id: " + std::to_string(task_id) + " exceeds the upper limit: " + std::to_string(tasks_.Size()));
1238 uint32_t shard_id = 0;
1239 uint32_t group_id = 0;
1240 uint32_t blob_start = 0;
1241 uint32_t blob_end = 0;
1242 json var_fields;
1243 // Pick up task from task list
1244 ShardTask task = tasks_.GetTaskByID(task_id);
1245
1246 // check task type
1247 auto task_type = std::get<0>(task);
1248 if (task_type == TaskType::kPaddedTask) {
1249 *task_content_ptr =
1250 std::make_shared<TASK_CONTENT>(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1251 return Status::OK();
1252 }
1253
1254 shard_id = std::get<0>(std::get<1>(task)); // shard id
1255
1256 if (lazy_load_ == false) {
1257 group_id = std::get<1>(std::get<1>(task)); // group id
1258 blob_start = std::get<2>(task)[0]; // blob start
1259 blob_end = std::get<2>(task)[1]; // blob end
1260 var_fields = std::get<3>(task); // scalar variable field
1261 } else {
1262 // get scalar variable fields by sample id
1263 uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));
1264
1265 // read the meta from index
1266 std::shared_ptr<ROW_GROUPS> row_group_ptr;
1267 RETURN_IF_NOT_OK(ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard, &row_group_ptr));
1268 auto &offsets = std::get<0>(*row_group_ptr);
1269 auto &local_columns = std::get<1>(*row_group_ptr);
1270
1271 group_id = offsets[shard_id][0][1]; // group_id
1272 blob_start = offsets[shard_id][0][2]; // blob start
1273 blob_end = offsets[shard_id][0][3]; // blob end
1274 var_fields = local_columns[shard_id][0]; // scalar variable field
1275 }
1276
1277 // read the blob from data file
1278 std::shared_ptr<Page> page_ptr;
1279 RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
1280 MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;
1281
1282 // Pack image list
1283 std::vector<uint8_t> images(blob_end - blob_start);
1284 auto file_offset = header_size_ + page_size_ * (page_ptr->GetPageID()) + blob_start;
1285
1286 auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
1287 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
1288 file_streams_random_[consumer_id][shard_id]->close();
1289 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
1290 }
1291 auto &io_read =
1292 file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
1293 if (!io_read.good() || io_read.fail() || io_read.bad()) {
1294 file_streams_random_[consumer_id][shard_id]->close();
1295 RETURN_STATUS_UNEXPECTED("Failed to read file.");
1296 }
1297
1298 // Deliver batch data to output map
1299 std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
1300 batch.emplace_back(std::move(images), std::move(var_fields));
1301
1302 *task_content_ptr = std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::move(batch));
1303 return Status::OK();
1304 }
1305
ConsumerByRow(int consumer_id)1306 void ShardReader::ConsumerByRow(int consumer_id) {
1307 // Set thread name
1308 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
1309 auto thread_id = kThreadName + std::to_string(consumer_id);
1310 prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0);
1311 #endif
1312
1313 // Loop forever
1314 for (;;) {
1315 int sample_id_pos = 0;
1316
1317 // Get next task ID
1318 sample_id_pos = sample_id_position_++;
1319
1320 // All tasks are done
1321 if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
1322 return;
1323 }
1324 auto task_content_ptr =
1325 std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1326 if (ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id, &task_content_ptr).IsError()) {
1327 MS_LOG(ERROR) << "Error raised in ConsumerOneTask function.";
1328 return;
1329 }
1330 const auto &batch = (*task_content_ptr).second;
1331 // Hanging if maximum map size exceeded
1332 // otherwise, set batch data in map
1333 {
1334 std::unique_lock<std::mutex> lck(mtx_delivery_);
1335 cv_delivery_.wait(lck,
1336 [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
1337 if (interrupt_) {
1338 return;
1339 }
1340 delivery_map_[sample_id_pos] =
1341 std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch));
1342 }
1343 cv_iterator_.notify_one();
1344 }
1345 }
1346
GetNext()1347 std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
1348 if (interrupt_) {
1349 return std::vector<std::tuple<std::vector<uint8_t>, json>>();
1350 }
1351 if (deliver_id_ >= static_cast<int>(tasks_.sample_ids_.size())) {
1352 return std::vector<std::tuple<std::vector<uint8_t>, json>>();
1353 }
1354
1355 std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>> res;
1356 {
1357 std::unique_lock<std::mutex> lck(mtx_delivery_);
1358 cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); });
1359 if (interrupt_) {
1360 return std::vector<std::tuple<std::vector<uint8_t>, json>>();
1361 }
1362 res = delivery_map_[deliver_id_];
1363 delivery_map_.erase(deliver_id_++);
1364 }
1365
1366 cv_delivery_.notify_all();
1367
1368 return *res;
1369 }
1370
GetNextById(const int64_t & task_id,const int32_t & consumer_id)1371 TASK_CONTENT ShardReader::GetNextById(const int64_t &task_id, const int32_t &consumer_id) {
1372 auto task_content_ptr =
1373 std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
1374 if (interrupt_) {
1375 return *task_content_ptr;
1376 }
1377 (void)ConsumerOneTask(task_id, consumer_id, &task_content_ptr);
1378 return std::move(*task_content_ptr);
1379 }
1380
UnCompressBlob(const std::vector<uint8_t> & raw_blob_data,std::shared_ptr<std::vector<std::vector<uint8_t>>> * blob_data_ptr)1381 Status ShardReader::UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
1382 std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr) {
1383 RETURN_UNEXPECTED_IF_NULL(blob_data_ptr);
1384 auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
1385 auto blob_fields = GetBlobFields().second;
1386 for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
1387 if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue;
1388 const unsigned char *data = nullptr;
1389 std::unique_ptr<unsigned char[]> data_ptr;
1390 uint64_t n_bytes = 0;
1391 RETURN_IF_NOT_OK(
1392 shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes));
1393 if (data == nullptr) {
1394 data = reinterpret_cast<const unsigned char *>(data_ptr.get());
1395 }
1396 std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
1397 (*blob_data_ptr)->push_back(column);
1398 }
1399 return Status::OK();
1400 }
1401
GetTotalBlobSize(int64_t * total_blob_size)1402 Status ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
1403 *total_blob_size = total_blob_size_;
1404 return Status::OK();
1405 }
1406
Reset()1407 void ShardReader::Reset() {
1408 {
1409 std::lock_guard<std::mutex> lck(mtx_delivery_);
1410 sample_id_position_ = 0;
1411 deliver_id_ = 0;
1412 }
1413 cv_delivery_.notify_all();
1414 }
1415
ShuffleTask()1416 void ShardReader::ShuffleTask() {
1417 // exist shuffle and distributed sampler in ops, skip shuffle
1418 bool has_sharding = false;
1419 for (const auto &op : operators_) {
1420 if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1421 has_sharding = true;
1422 }
1423 }
1424 for (const auto &op : operators_) {
1425 if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
1426 auto s = (*op)(tasks_);
1427 if (s.IsError()) {
1428 MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
1429 }
1430 } else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
1431 auto s = (*op)(tasks_);
1432 if (s.IsError()) {
1433 MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
1434 }
1435 }
1436 }
1437 if (tasks_.permutation_.empty()) tasks_.MakePerm();
1438 }
1439
GetSampleIds()1440 const std::vector<int> *ShardReader::GetSampleIds() {
1441 // return const reference to private sample id list.
1442 return &(this->tasks_.sample_ids_);
1443 }
1444
1445 } // namespace mindrecord
1446 } // namespace mindspore
1447