1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/mindrecord/include/shard_index_generator.h"
17
18 #include "utils/file_utils.h"
19 #include "utils/ms_utils.h"
20
21 using mindspore::LogStream;
22 using mindspore::ExceptionType::NoExceptionType;
23 using mindspore::MsLogLevel::DEBUG;
24 using mindspore::MsLogLevel::ERROR;
25 using mindspore::MsLogLevel::INFO;
26
27 namespace mindspore {
28 namespace mindrecord {
ShardIndexGenerator(const std::string & file_path,bool append)29 ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
30 : file_path_(file_path),
31 append_(append),
32 page_size_(0),
33 header_size_(0),
34 schema_count_(0),
35 task_(0),
36 write_success_(true) {}
37
Build()38 Status ShardIndexGenerator::Build() {
39 std::shared_ptr<json> header_ptr;
40 RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr));
41 auto ds = std::make_shared<std::vector<std::string>>();
42 RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds));
43 ShardHeader header = ShardHeader();
44 RETURN_IF_NOT_OK(header.BuildDataset(*ds));
45 shard_header_ = header;
46 MS_LOG(INFO) << "Initialize header from mindrecord file for index successfully.";
47 return Status::OK();
48 }
49
GetValueByField(const string & field,json input,std::shared_ptr<std::string> * value)50 Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) {
51 RETURN_UNEXPECTED_IF_NULL(value);
52 CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty.");
53 CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty.");
54
55 // parameter input does not contain the field
56 CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(),
57 "The field " + field + " is not found in json " + input.dump());
58
59 // schema does not contain the field
60 auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
61 CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
62 "The field " + field + " is not found in schema " + schema.dump());
63
64 // field should be scalar type
65 CHECK_FAIL_RETURN_UNEXPECTED(
66 kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
67 "The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");
68
69 if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
70 auto schema_field_options = schema[field];
71 CHECK_FAIL_RETURN_UNEXPECTED(
72 schema_field_options.find("shape") == schema_field_options.end(),
73 "The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable.");
74 *value = std::make_shared<std::string>(input[field].dump());
75 } else {
76 // the field type is string in here
77 *value = std::make_shared<std::string>(input[field].get<std::string>());
78 }
79 return Status::OK();
80 }
81
TakeFieldType(const string & field_path,json schema)82 std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
83 std::vector<std::string> field_name = StringSplit(field_path, kPoint);
84 for (uint64_t i = 0; i < field_name.size(); ++i) {
85 try {
86 if (i != field_name.size() - 1) {
87 // Get type information from json schema
88 schema = schema.at(field_name[i]);
89 schema = schema.at("properties");
90 } else {
91 // standard root layer exist "properties" if type is "object"
92 if (schema.find("properties") != schema.end()) {
93 schema = schema.at("properties");
94 }
95 schema = schema.at(field_name[i]);
96 std::string field_type = schema.at("type").dump();
97 if (field_type.length() <= 2) {
98 return "";
99 } else {
100 return field_type.substr(1, field_type.length() - 2);
101 }
102 }
103 } catch (...) {
104 MS_LOG(WARNING) << "Exception occurred while get field type.";
105 return "";
106 }
107 }
108 return "";
109 }
110
ConvertJsonToSQL(const std::string & json)111 std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) {
112 if (kDbJsonMap.find(json) != kDbJsonMap.end()) {
113 return kDbJsonMap.at(json);
114 } else {
115 return "TEXT";
116 }
117 }
118
Callback(void * not_used,int argc,char ** argv,char ** az_col_name)119 int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) {
120 for (auto i = 0; i < argc; i++) {
121 if (argv[i] != nullptr) {
122 MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr");
123 }
124 }
125 MS_LOG(INFO) << "\n";
126 return 0;
127 }
128
ExecuteSQL(const std::string & sql,sqlite3 * db,const std::string & success_msg)129 Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
130 char *z_err_msg = nullptr;
131 int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
132 if (rc != SQLITE_OK) {
133 std::ostringstream oss;
134 oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg;
135 MS_LOG(DEBUG) << oss.str();
136 sqlite3_free(z_err_msg);
137 sqlite3_close(db);
138 RETURN_STATUS_UNEXPECTED(oss.str());
139 } else {
140 if (!success_msg.empty()) {
141 MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg;
142 }
143 sqlite3_free(z_err_msg);
144 return Status::OK();
145 }
146 }
147
GenerateFieldName(const std::pair<uint64_t,std::string> & field,std::shared_ptr<std::string> * fn_ptr)148 Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field,
149 std::shared_ptr<std::string> *fn_ptr) {
150 RETURN_UNEXPECTED_IF_NULL(fn_ptr);
151 // Replaces dots and dashes with underscores for SQL use
152 std::string field_name = field.second;
153 // white list to avoid sql injection
154 std::replace_if(
155 field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_');
156 auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
157 return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
158 });
159 CHECK_FAIL_RETURN_UNEXPECTED(
160 pos == field_name.end(),
161 "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name);
162 *fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
163 return Status::OK();
164 }
165
CheckDatabase(const std::string & shard_address,sqlite3 ** db)166 Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) {
167 std::optional<std::string> dir = "";
168 std::optional<std::string> local_file_name = "";
169 FileUtils::SplitDirAndFileName(shard_address, &dir, &local_file_name);
170 if (!dir.has_value()) {
171 dir = ".";
172 }
173
174 auto realpath = FileUtils::GetRealPath(dir.value().data());
175 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
176
177 std::optional<std::string> whole_path = "";
178 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
179
180 std::ifstream fin(whole_path.value());
181 if (!append_ && fin.good()) {
182 fin.close();
183 RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address);
184 }
185 fin.close();
186 if (sqlite3_open_v2(common::SafeCStr(whole_path.value()), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
187 RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" +
188 std::string(sqlite3_errmsg(*db)));
189 }
190 MS_LOG(DEBUG) << "Opened database successfully";
191 return Status::OK();
192 }
193
CreateShardNameTable(sqlite3 * db,const std::string & shard_name)194 Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
195 // create shard_name table
196 std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;";
197 RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully."));
198 sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);";
199 RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully."));
200 sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);";
201 sqlite3_stmt *stmt = nullptr;
202 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
203 if (stmt != nullptr) {
204 (void)sqlite3_finalize(stmt);
205 }
206 sqlite3_close(db);
207 RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
208 }
209
210 int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
211 if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
212 (void)sqlite3_finalize(stmt);
213 sqlite3_close(db);
214 RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
215 ", field value: " + std::string(shard_name));
216 }
217
218 if (sqlite3_step(stmt) != SQLITE_DONE) {
219 (void)sqlite3_finalize(stmt);
220 RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
221 }
222 (void)sqlite3_finalize(stmt);
223 return Status::OK();
224 }
225
CreateDatabase(int shard_no,sqlite3 ** db)226 Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
227 std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
228 CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no);
229 std::shared_ptr<std::string> fn_ptr;
230 RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr));
231 shard_address += ".db";
232 RETURN_IF_NOT_OK(CheckDatabase(shard_address, db));
233 std::string sql = "DROP TABLE IF EXISTS INDEXES;";
234 RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully."));
235 sql =
236 "CREATE TABLE INDEXES("
237 " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL"
238 ", PAGE_OFFSET_RAW INT NOT NULL, PAGE_OFFSET_RAW_END INT NOT NULL"
239 ", ROW_GROUP_ID INT NOT NULL, PAGE_ID_BLOB INT NOT NULL"
240 ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL";
241
242 int field_no = 0;
243 std::shared_ptr<std::string> field_ptr;
244 for (const auto &field : fields_) {
245 uint64_t schema_id = field.first;
246 std::shared_ptr<Schema> schema_ptr;
247 RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr));
248 json json_schema = (schema_ptr->GetSchema())["schema"];
249 std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema));
250 RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr));
251 sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type;
252 }
253 sql += ", PRIMARY KEY(ROW_ID";
254 for (uint64_t i = 0; i < fields_.size(); ++i) {
255 sql += ",INC_" + std::to_string(i);
256 }
257 sql += "));";
258 RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully."));
259 RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr));
260 return Status::OK();
261 }
262
GetSchemaDetails(const std::vector<uint64_t> & schema_lens,std::fstream & in,std::shared_ptr<std::vector<json>> * detail_ptr)263 Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
264 std::shared_ptr<std::vector<json>> *detail_ptr) {
265 RETURN_UNEXPECTED_IF_NULL(detail_ptr);
266 if (schema_count_ <= kMaxSchemaCount) {
267 for (int sc = 0; sc < schema_count_; ++sc) {
268 std::vector<char> schema_detail(schema_lens[sc]);
269 auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
270 if (!io_read.good() || io_read.fail() || io_read.bad()) {
271 in.close();
272 RETURN_STATUS_UNEXPECTED("Failed to read file.");
273 }
274 auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
275 (*detail_ptr)->emplace_back(j);
276 }
277 }
278 return Status::OK();
279 }
280
GenerateRawSQL(const std::vector<std::pair<uint64_t,std::string>> & fields,std::shared_ptr<std::string> * sql_ptr)281 Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
282 std::shared_ptr<std::string> *sql_ptr) {
283 std::string sql =
284 "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END,"
285 "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END";
286
287 int field_no = 0;
288 for (const auto &field : fields) {
289 std::shared_ptr<std::string> fn_ptr;
290 RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
291 sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr;
292 }
293 sql +=
294 ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB,"
295 ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END";
296 field_no = 0;
297 for (const auto &field : fields) {
298 std::shared_ptr<std::string> fn_ptr;
299 RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
300 sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr;
301 }
302 sql += " )";
303
304 *sql_ptr = std::make_shared<std::string>(sql);
305 return Status::OK();
306 }
307
BindParameterExecuteSQL(sqlite3 * db,const std::string & sql,const ROW_DATA & data)308 Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) {
309 sqlite3_stmt *stmt = nullptr;
310 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
311 if (stmt != nullptr) {
312 (void)sqlite3_finalize(stmt);
313 }
314 sqlite3_close(db);
315 RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
316 }
317 for (auto &row : data) {
318 for (auto &field : row) {
319 const auto &place_holder = std::get<0>(field);
320 const auto &field_type = std::get<1>(field);
321 const auto &field_value = std::get<2>(field);
322
323 int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
324 if (field_type == "INTEGER") {
325 if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
326 (void)sqlite3_finalize(stmt);
327 sqlite3_close(db);
328 RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
329 ", field value: " + std::string(field_value));
330 }
331 } else if (field_type == "NUMERIC") {
332 if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
333 (void)sqlite3_finalize(stmt);
334 sqlite3_close(db);
335 RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
336 ", field value: " + std::string(field_value));
337 }
338 } else if (field_type == "NULL") {
339 if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
340 (void)sqlite3_finalize(stmt);
341
342 sqlite3_close(db);
343 RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
344 ", field value: NULL");
345 }
346 } else {
347 if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
348 (void)sqlite3_finalize(stmt);
349 sqlite3_close(db);
350 RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
351 ", field value: " + std::string(field_value));
352 }
353 }
354 }
355 if (sqlite3_step(stmt) != SQLITE_DONE) {
356 (void)sqlite3_finalize(stmt);
357 RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
358 }
359 (void)sqlite3_reset(stmt);
360 }
361 (void)sqlite3_finalize(stmt);
362 return Status::OK();
363 }
364
AddBlobPageInfo(std::vector<std::tuple<std::string,std::string,std::string>> & row_data,const std::shared_ptr<Page> cur_blob_page,uint64_t & cur_blob_page_offset,std::fstream & in)365 Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
366 const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
367 std::fstream &in) {
368 row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));
369
370 // blob data start
371 row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset));
372 auto &io_seekg_blob =
373 in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
374 if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
375 in.close();
376 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
377 }
378 uint64_t image_size = 0;
379 auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
380 if (!io_read.good() || io_read.fail() || io_read.bad()) {
381 MS_LOG(ERROR) << "File read failed";
382 in.close();
383 RETURN_STATUS_UNEXPECTED("Failed to read file.");
384 }
385
386 cur_blob_page_offset += (kInt64Len + image_size);
387 row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset));
388
389 return Status::OK();
390 }
391
AddIndexFieldByRawData(const std::vector<json> & schema_detail,std::vector<std::tuple<std::string,std::string,std::string>> & row_data)392 Status ShardIndexGenerator::AddIndexFieldByRawData(
393 const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) {
394 auto index_fields_ptr = std::make_shared<INDEX_FIELDS>();
395 RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr));
396 int index = 0;
397 for (const auto &field : *index_fields_ptr) {
398 // assume simple field: string , number etc.
399 row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
400 row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
401 }
402 return Status::OK();
403 }
404
GenerateRowData(int shard_no,const std::map<int,int> & blob_id_to_page_id,int raw_page_id,std::fstream & in,std::shared_ptr<ROW_DATA> * row_data_ptr)405 Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
406 std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) {
407 RETURN_UNEXPECTED_IF_NULL(row_data_ptr);
408 // current raw data page
409 std::shared_ptr<Page> page_ptr;
410 RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr));
411 // related blob page
412 vector<pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds();
413
414 // pair: row_group id, offset in raw data page
415 for (pair<int, int> blob_ids : row_group_list) {
416 // get blob data page according to row_group id
417 auto iter = blob_id_to_page_id.find(blob_ids.first);
418 CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id.");
419 std::shared_ptr<Page> blob_page_ptr;
420 RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
421 // offset in current raw data page
422 auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
423 uint64_t cur_blob_page_offset = 0;
424 for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) {
425 std::vector<std::tuple<std::string, std::string, std::string>> row_data;
426 row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
427 row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID()));
428 row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID()));
429
430 // raw data start
431 row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));
432
433 // calculate raw data end
434 auto &io_seekg =
435 in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
436 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
437 in.close();
438 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
439 }
440 std::vector<uint64_t> schema_lens;
441 if (schema_count_ <= kMaxSchemaCount) {
442 for (int sc = 0; sc < schema_count_; sc++) {
443 uint64_t schema_size = 0;
444
445 auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
446 if (!io_read.good() || io_read.fail() || io_read.bad()) {
447 in.close();
448 RETURN_STATUS_UNEXPECTED("Failed to read file.");
449 }
450
451 cur_raw_page_offset += (kInt64Len + schema_size);
452 schema_lens.push_back(schema_size);
453 }
454 }
455 row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset));
456
457 // Getting schema for getting data for fields
458 auto detail_ptr = std::make_shared<std::vector<json>>();
459 RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr));
460 // start blob page info
461 RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in));
462
463 // start index field
464 AddIndexFieldByRawData(*detail_ptr, row_data);
465 (*row_data_ptr)->push_back(std::move(row_data));
466 }
467 }
468 return Status::OK();
469 }
470
GenerateIndexFields(const std::vector<json> & schema_detail,std::shared_ptr<INDEX_FIELDS> * index_fields_ptr)471 Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail,
472 std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) {
473 RETURN_UNEXPECTED_IF_NULL(index_fields_ptr);
474 // index fields
475 std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
476 for (const auto &field : index_fields) {
477 CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range.");
478 std::shared_ptr<std::string> field_val_ptr;
479 RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
480 std::shared_ptr<Schema> schema_ptr;
481 RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr));
482 std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"]));
483 std::shared_ptr<std::string> fn_ptr;
484 RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
485 (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr);
486 }
487 return Status::OK();
488 }
489
ExecuteTransaction(const int & shard_no,sqlite3 * db,const std::vector<int> & raw_page_ids,const std::map<int,int> & blob_id_to_page_id)490 Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
491 const std::map<int, int> &blob_id_to_page_id) {
492 // Add index data to database
493 std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
494 CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty.");
495
496 auto realpath = FileUtils::GetRealPath(shard_address.data());
497 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
498 std::fstream in;
499 in.open(realpath.value(), std::ios::in | std::ios::binary);
500 if (!in.good()) {
501 in.close();
502 RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address);
503 }
504 (void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
505 for (int raw_page_id : raw_page_ids) {
506 std::shared_ptr<std::string> sql_ptr;
507 RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in);
508 auto row_data_ptr = std::make_shared<ROW_DATA>();
509 RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in);
510 RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in);
511 MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db.";
512 }
513 (void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
514 in.close();
515
516 // Close database
517 sqlite3_close(db);
518 db = nullptr;
519 return Status::OK();
520 }
521
WriteToDatabase()522 Status ShardIndexGenerator::WriteToDatabase() {
523 fields_ = shard_header_.GetFields();
524 page_size_ = shard_header_.GetPageSize();
525 header_size_ = shard_header_.GetHeaderSize();
526 schema_count_ = shard_header_.GetSchemaCount();
527 CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount,
528 "num shards: " + std::to_string(shard_header_.GetShardCount()) +
529 " exceeds max count:" + std::to_string(kMaxSchemaCount));
530
531 task_ = 0; // set two atomic vars to initial value
532 write_success_ = true;
533
534 // spawn half the physical threads or total number of shards whichever is smaller
535 const unsigned int num_workers =
536 std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount()));
537
538 std::vector<std::thread> threads;
539 threads.reserve(num_workers);
540
541 for (size_t t = 0; t < threads.capacity(); t++) {
542 threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
543 }
544
545 for (size_t t = 0; t < threads.capacity(); t++) {
546 threads[t].join();
547 }
548 CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db.");
549 return Status::OK();
550 }
551
DatabaseWriter()552 void ShardIndexGenerator::DatabaseWriter() {
553 int shard_no = task_++;
554 while (shard_no < shard_header_.GetShardCount()) {
555 sqlite3 *db = nullptr;
556 if (CreateDatabase(shard_no, &db).IsError()) {
557 MS_LOG(ERROR) << "Failed to create Generate database.";
558 write_success_ = false;
559 return;
560 }
561 MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
562 // Pre-processing page information
563 auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
564
565 std::map<int, int> blob_id_to_page_id;
566 std::vector<int> raw_page_ids;
567 for (uint64_t i = 0; i < total_pages; ++i) {
568 std::shared_ptr<Page> page_ptr;
569 if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) {
570 MS_LOG(ERROR) << "Failed to get page.";
571 write_success_ = false;
572 return;
573 }
574 if (page_ptr->GetPageType() == "RAW_DATA") {
575 raw_page_ids.push_back(i);
576 } else if (page_ptr->GetPageType() == "BLOB_DATA") {
577 blob_id_to_page_id[page_ptr->GetPageTypeID()] = i;
578 }
579 }
580
581 if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) {
582 MS_LOG(ERROR) << "Failed to execute transaction.";
583 write_success_ = false;
584 return;
585 }
586 MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
587 shard_no = task_++;
588 }
589 }
Finalize(const std::vector<std::string> file_names)590 Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
591 CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty.");
592 ShardIndexGenerator sg{file_names[0]};
593 RETURN_IF_NOT_OK(sg.Build());
594 RETURN_IF_NOT_OK(sg.WriteToDatabase());
595 return Status::OK();
596 }
597 } // namespace mindrecord
598 } // namespace mindspore
599