1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/mindrecord/include/shard_index_generator.h"
17
18 #include "utils/file_utils.h"
19 #include "utils/ms_utils.h"
20
21 namespace mindspore {
22 namespace mindrecord {
ShardIndexGenerator(const std::string & file_path,bool append)23 ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
24 : file_path_(file_path),
25 append_(append),
26 page_size_(0),
27 header_size_(0),
28 schema_count_(0),
29 task_(0),
30 write_success_(true) {}
31
Build()32 Status ShardIndexGenerator::Build() {
33 std::shared_ptr<json> header_ptr;
34 RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(file_path_, &header_ptr));
35 auto ds = std::make_shared<std::vector<std::string>>();
36 RETURN_IF_NOT_OK_MR(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds));
37 ShardHeader header = ShardHeader();
38 RETURN_IF_NOT_OK_MR(header.BuildDataset(*ds));
39 shard_header_ = header;
40 MS_LOG(INFO) << "Initialize header from mindrecord file for index successfully.";
41 return Status::OK();
42 }
43
GetValueByField(const string & field,const json & input,std::shared_ptr<std::string> * value)44 Status ShardIndexGenerator::GetValueByField(const string &field, const json &input,
45 std::shared_ptr<std::string> *value) {
46 RETURN_UNEXPECTED_IF_NULL_MR(value);
47
48 // parameter input does not contain the field
49 CHECK_FAIL_RETURN_UNEXPECTED_MR(input.find(field) != input.end(),
50 "[Internal ERROR] 'field': " + field + " can not found in raw data: " + input.dump());
51
52 // schema does not contain the field
53 auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
54 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(field) != schema.end(),
55 "[Internal ERROR] 'field': " + field + " can not found in schema: " + schema.dump());
56
57 // field should be scalar type
58 CHECK_FAIL_RETURN_UNEXPECTED_MR(
59 kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
60 "[Internal ERROR] 'field': " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");
61
62 if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
63 auto schema_field_options = schema[field];
64 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_field_options.find("shape") == schema_field_options.end(),
65 "[Internal ERROR] 'field': " + field + " shape is " +
66 schema[field]["shape"].dump() + " which is not retrievable.");
67 *value = std::make_shared<std::string>(input[field].dump());
68 } else {
69 // the field type is string in here
70 *value = std::make_shared<std::string>(input[field].get<std::string>());
71 }
72 return Status::OK();
73 }
74
TakeFieldType(const string & field_path,json & schema)75 std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json &schema) {
76 std::vector<std::string> field_name = StringSplit(field_path, kPoint);
77 for (uint64_t i = 0; i < field_name.size(); ++i) {
78 try {
79 if (i != field_name.size() - 1) {
80 // Get type information from json schema
81 schema = schema.at(field_name[i]);
82 schema = schema.at("properties");
83 } else {
84 // standard root layer exist "properties" if type is "object"
85 if (schema.find("properties") != schema.end()) {
86 schema = schema.at("properties");
87 }
88 schema = schema.at(field_name[i]);
89 std::string field_type = schema.at("type").dump();
90 if (field_type.length() <= 2) {
91 return "";
92 } else {
93 return field_type.substr(1, field_type.length() - 2);
94 }
95 }
96 } catch (...) {
97 MS_LOG(WARNING) << "Exception occurred while get field type.";
98 return "";
99 }
100 }
101 return "";
102 }
103
ConvertJsonToSQL(const std::string & json)104 std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) {
105 if (kDbJsonMap.find(json) != kDbJsonMap.end()) {
106 return kDbJsonMap.at(json);
107 } else {
108 return "TEXT";
109 }
110 }
111
Callback(void * not_used,int argc,char ** argv,char ** az_col_name)112 int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) {
113 for (auto i = 0; i < argc; i++) {
114 if (argv[i] != nullptr) {
115 MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr");
116 }
117 }
118 MS_LOG(INFO) << "\n";
119 return 0;
120 }
121
ExecuteSQL(const std::string & sql,sqlite3 * db,const std::string & success_msg)122 Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
123 char *z_err_msg = nullptr;
124 int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
125 if (rc != SQLITE_OK) {
126 std::ostringstream oss;
127 oss << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql) << " ], " << z_err_msg;
128 MS_LOG(DEBUG) << oss.str();
129 sqlite3_free(z_err_msg);
130 sqlite3_close(db);
131 RETURN_STATUS_UNEXPECTED_MR(oss.str());
132 } else {
133 if (!success_msg.empty()) {
134 MS_LOG(DEBUG) << "Suceess to execute the sql [ " << common::SafeCStr(sql) << " ], " << success_msg;
135 }
136 sqlite3_free(z_err_msg);
137 return Status::OK();
138 }
139 }
140
GenerateFieldName(const std::pair<uint64_t,std::string> & field,std::shared_ptr<std::string> * fn_ptr)141 Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field,
142 std::shared_ptr<std::string> *fn_ptr) {
143 RETURN_UNEXPECTED_IF_NULL_MR(fn_ptr);
144 // Replaces dots and dashes with underscores for SQL use
145 std::string field_name = field.second;
146 // white list to avoid sql injection
147 std::replace_if(
148 field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_');
149 auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
150 return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
151 });
152 CHECK_FAIL_RETURN_UNEXPECTED_MR(pos == field_name.end(), "Invalid data, field name: " + field_name +
153 "is not composed of '0-9' or 'a-z' or 'A-Z' or '_'.");
154 *fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
155 return Status::OK();
156 }
157
CheckDatabase(const std::string & shard_address,sqlite3 ** db)158 Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) {
159 std::optional<std::string> dir = "";
160 std::optional<std::string> local_file_name = "";
161 FileUtils::SplitDirAndFileName(shard_address, &dir, &local_file_name);
162 if (!dir.has_value()) {
163 dir = ".";
164 }
165
166 auto realpath = FileUtils::GetRealPath(dir.value().c_str());
167 CHECK_FAIL_RETURN_UNEXPECTED_MR(
168 realpath.has_value(),
169 "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + shard_address);
170
171 std::optional<std::string> whole_path = "";
172 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
173
174 std::ifstream fin(whole_path.value(), std::ios::in);
175 if (!append_ && fin.good()) {
176 fin.close();
177 RETURN_STATUS_UNEXPECTED_MR(
178 "Invalid file, mindrecord meta files already exist. Please check file path: " + shard_address +
179 ".\nIf you do not want to keep the files, set the 'overwrite' parameter to True and try again.");
180 }
181 fin.close();
182
183 std::string whole_path_utf8 = "";
184 #if defined(_WIN32) || defined(_WIN64)
185 whole_path_utf8 = FileUtils::GB2312ToUTF_8(whole_path.value().data());
186 #endif
187 if (whole_path_utf8.empty()) {
188 whole_path_utf8 = whole_path.value();
189 }
190
191 if (sqlite3_open_v2(whole_path_utf8.data(), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
192 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to open mindrecord meta file: " + shard_address + ", " +
193 sqlite3_errmsg(*db));
194 }
195 MS_LOG(DEBUG) << "Open meta file successfully";
196 return Status::OK();
197 }
198
CreateShardNameTable(sqlite3 * db,const std::string & shard_name)199 Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
200 // create shard_name table
201 std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;";
202 RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, db, "drop table successfully."));
203 sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);";
204 RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, db, "create table successfully."));
205 sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);";
206 sqlite3_stmt *stmt = nullptr;
207 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
208 if (stmt != nullptr) {
209 (void)sqlite3_finalize(stmt);
210 }
211 sqlite3_close(db);
212 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
213 }
214
215 int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
216 if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
217 (void)sqlite3_finalize(stmt);
218 sqlite3_close(db);
219 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
220 std::to_string(index) + ", value: " + shard_name);
221 }
222
223 if (sqlite3_step(stmt) != SQLITE_DONE) {
224 (void)sqlite3_finalize(stmt);
225 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to step execute stmt.");
226 }
227 (void)sqlite3_finalize(stmt);
228 return Status::OK();
229 }
230
CreateDatabase(int shard_no,sqlite3 ** db)231 Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
232 std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
233 std::shared_ptr<std::string> fn_ptr;
234 RETURN_IF_NOT_OK_MR(GetFileName(shard_address, &fn_ptr));
235 shard_address += ".db";
236 RETURN_IF_NOT_OK_MR(CheckDatabase(shard_address, db));
237 std::string sql = "DROP TABLE IF EXISTS INDEXES;";
238 RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, *db, "drop table successfully."));
239 sql =
240 "CREATE TABLE INDEXES("
241 " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL"
242 ", PAGE_OFFSET_RAW INT NOT NULL, PAGE_OFFSET_RAW_END INT NOT NULL"
243 ", ROW_GROUP_ID INT NOT NULL, PAGE_ID_BLOB INT NOT NULL"
244 ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL";
245
246 int field_no = 0;
247 std::shared_ptr<std::string> field_ptr;
248 for (const auto &field : fields_) {
249 uint64_t schema_id = field.first;
250 std::shared_ptr<Schema> schema_ptr;
251 RETURN_IF_NOT_OK_MR(shard_header_.GetSchemaByID(schema_id, &schema_ptr));
252 json json_schema = (schema_ptr->GetSchema())["schema"];
253 std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema));
254 RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &field_ptr));
255 sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type;
256 }
257 sql += ", PRIMARY KEY(ROW_ID";
258 for (uint64_t i = 0; i < fields_.size(); ++i) {
259 sql += ",INC_" + std::to_string(i);
260 }
261 sql += "));";
262 RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, *db, "create table successfully."));
263 RETURN_IF_NOT_OK_MR(CreateShardNameTable(*db, *fn_ptr));
264 return Status::OK();
265 }
266
GetSchemaDetails(const std::vector<uint64_t> & schema_lens,std::fstream & in,std::shared_ptr<std::vector<json>> * detail_ptr)267 Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
268 std::shared_ptr<std::vector<json>> *detail_ptr) {
269 RETURN_UNEXPECTED_IF_NULL_MR(detail_ptr);
270 if (schema_count_ <= kMaxSchemaCount) {
271 for (int sc = 0; sc < schema_count_; ++sc) {
272 std::vector<char> schema_detail(schema_lens[sc]);
273 auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
274 if (!io_read.good() || io_read.fail() || io_read.bad()) {
275 in.close();
276 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
277 }
278 auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
279 (*detail_ptr)->emplace_back(j);
280 }
281 }
282 return Status::OK();
283 }
284
GenerateRawSQL(const std::vector<std::pair<uint64_t,std::string>> & fields,std::shared_ptr<std::string> * sql_ptr)285 Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
286 std::shared_ptr<std::string> *sql_ptr) {
287 std::string sql =
288 "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END,"
289 "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END";
290
291 int field_no = 0;
292 for (const auto &field : fields) {
293 std::shared_ptr<std::string> fn_ptr;
294 RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &fn_ptr));
295 sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr;
296 }
297 sql +=
298 ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB,"
299 ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END";
300 field_no = 0;
301 for (const auto &field : fields) {
302 std::shared_ptr<std::string> fn_ptr;
303 RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &fn_ptr));
304 sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr;
305 }
306 sql += " )";
307
308 *sql_ptr = std::make_shared<std::string>(sql);
309 return Status::OK();
310 }
311
BindParameterExecuteSQL(sqlite3 * db,const std::string & sql,const ROW_DATA & data)312 Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) {
313 sqlite3_stmt *stmt = nullptr;
314 if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
315 if (stmt != nullptr) {
316 (void)sqlite3_finalize(stmt);
317 }
318 sqlite3_close(db);
319 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
320 }
321 for (auto &row : data) {
322 for (auto &field : row) {
323 const auto &place_holder = std::get<0>(field);
324 const auto &field_type = std::get<1>(field);
325 const auto &field_value = std::get<2>(field);
326
327 int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
328 if (field_type == "INTEGER") {
329 if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
330 (void)sqlite3_finalize(stmt);
331 sqlite3_close(db);
332 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
333 std::to_string(index) + ", value: " + field_value);
334 }
335 } else if (field_type == "REAL") {
336 if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
337 (void)sqlite3_finalize(stmt);
338 sqlite3_close(db);
339 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
340 std::to_string(index) + ", value: " + field_value);
341 }
342 } else if (field_type == "NULL") {
343 if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
344 (void)sqlite3_finalize(stmt);
345
346 sqlite3_close(db);
347 RETURN_STATUS_UNEXPECTED_MR(
348 "[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) + ", value: NULL");
349 }
350 } else {
351 if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
352 (void)sqlite3_finalize(stmt);
353 sqlite3_close(db);
354 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
355 std::to_string(index) + ", value: " + field_value);
356 }
357 }
358 }
359 if (sqlite3_step(stmt) != SQLITE_DONE) {
360 (void)sqlite3_finalize(stmt);
361 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to step execute stmt.");
362 }
363 (void)sqlite3_reset(stmt);
364 }
365 (void)sqlite3_finalize(stmt);
366 return Status::OK();
367 }
368
AddBlobPageInfo(std::vector<std::tuple<std::string,std::string,std::string>> & row_data,const std::shared_ptr<Page> cur_blob_page,uint64_t & cur_blob_page_offset,std::fstream & in)369 Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
370 const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
371 std::fstream &in) {
372 row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));
373
374 // blob data start
375 row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset));
376 auto &io_seekg_blob =
377 in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
378 if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
379 in.close();
380 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
381 }
382 uint64_t image_size = 0;
383 auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
384 if (!io_read.good() || io_read.fail() || io_read.bad()) {
385 in.close();
386 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
387 }
388
389 cur_blob_page_offset += (kInt64Len + image_size);
390 row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset));
391
392 return Status::OK();
393 }
394
AddIndexFieldByRawData(const std::vector<json> & schema_detail,std::vector<std::tuple<std::string,std::string,std::string>> & row_data)395 Status ShardIndexGenerator::AddIndexFieldByRawData(
396 const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) {
397 auto index_fields_ptr = std::make_shared<INDEX_FIELDS>();
398 RETURN_IF_NOT_OK_MR(GenerateIndexFields(schema_detail, &index_fields_ptr));
399 int index = 0;
400 for (const auto &field : *index_fields_ptr) {
401 // assume simple field: string , number etc.
402 row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
403 row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
404 }
405 return Status::OK();
406 }
407
GenerateRowData(int shard_no,const std::map<int,int> & blob_id_to_page_id,int raw_page_id,std::fstream & in,std::shared_ptr<ROW_DATA> * row_data_ptr)408 Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
409 std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) {
410 RETURN_UNEXPECTED_IF_NULL_MR(row_data_ptr);
411 // current raw data page
412 std::shared_ptr<Page> page_ptr;
413 RETURN_IF_NOT_OK_MR(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr));
414 // related blob page
415 std::vector<std::pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds();
416
417 // pair: row_group id, offset in raw data page
418 for (pair<int, int> blob_ids : row_group_list) {
419 // get blob data page according to row_group id
420 auto iter = blob_id_to_page_id.find(blob_ids.first);
421 CHECK_FAIL_RETURN_UNEXPECTED_MR(iter != blob_id_to_page_id.end(),
422 "[Internal ERROR] Failed to get page id from blob id.");
423 std::shared_ptr<Page> blob_page_ptr;
424 RETURN_IF_NOT_OK_MR(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
425 // offset in current raw data page
426 auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
427 uint64_t cur_blob_page_offset = 0;
428 for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) {
429 std::vector<std::tuple<std::string, std::string, std::string>> row_data;
430 row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
431 row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID()));
432 row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID()));
433
434 // raw data start
435 row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));
436
437 // calculate raw data end
438 auto &io_seekg =
439 in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
440 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
441 in.close();
442 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
443 }
444 std::vector<uint64_t> schema_lens;
445 if (schema_count_ <= kMaxSchemaCount) {
446 for (int sc = 0; sc < schema_count_; sc++) {
447 uint64_t schema_size = 0;
448
449 auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
450 if (!io_read.good() || io_read.fail() || io_read.bad()) {
451 in.close();
452 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
453 }
454
455 cur_raw_page_offset += (kInt64Len + schema_size);
456 schema_lens.push_back(schema_size);
457 }
458 }
459 row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset));
460
461 // Getting schema for getting data for fields
462 auto detail_ptr = std::make_shared<std::vector<json>>();
463 RETURN_IF_NOT_OK_MR(GetSchemaDetails(schema_lens, in, &detail_ptr));
464 // start blob page info
465 RETURN_IF_NOT_OK_MR(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in));
466
467 // start index field
468 RETURN_IF_NOT_OK_MR(AddIndexFieldByRawData(*detail_ptr, row_data));
469 (*row_data_ptr)->push_back(std::move(row_data));
470 }
471 }
472 return Status::OK();
473 }
474
GenerateIndexFields(const std::vector<json> & schema_detail,std::shared_ptr<INDEX_FIELDS> * index_fields_ptr)475 Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail,
476 std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) {
477 RETURN_UNEXPECTED_IF_NULL_MR(index_fields_ptr);
478 // index fields
479 std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
480 for (const auto &field : index_fields) {
481 CHECK_FAIL_RETURN_UNEXPECTED_MR(
482 field.first < schema_detail.size(),
483 "[Internal ERROR] 'field': " + field.second + " is out of bound:" + std::to_string(schema_detail.size()));
484 std::shared_ptr<std::string> field_val_ptr;
485 RETURN_IF_NOT_OK_MR(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
486 std::shared_ptr<Schema> schema_ptr;
487 RETURN_IF_NOT_OK_MR(shard_header_.GetSchemaByID(field.first, &schema_ptr));
488 std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"]));
489 std::shared_ptr<std::string> fn_ptr;
490 RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &fn_ptr));
491 (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr);
492 }
493 return Status::OK();
494 }
495
ExecuteTransaction(const int & shard_no,sqlite3 * db,const std::vector<int> & raw_page_ids,const std::map<int,int> & blob_id_to_page_id)496 Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
497 const std::map<int, int> &blob_id_to_page_id) {
498 // Add index data to database
499 std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
500
501 auto realpath = FileUtils::GetRealPath(shard_address.c_str());
502 CHECK_FAIL_RETURN_UNEXPECTED_MR(
503 realpath.has_value(),
504 "Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + shard_address);
505 std::fstream in;
506 in.open(realpath.value(), std::ios::in | std::ios::binary);
507 if (!in.good()) {
508 in.close();
509 RETURN_STATUS_UNEXPECTED_MR(
510 "Invalid file, failed to open mindrecord files. Please check file path, permission and open files limit(ulimit "
511 "-a): " +
512 shard_address);
513 }
514 auto sql_code = sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
515 if (sql_code != SQLITE_OK) {
516 in.close();
517 RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: " +
518 std::to_string(sql_code));
519 }
520 for (int raw_page_id : raw_page_ids) {
521 std::shared_ptr<std::string> sql_ptr;
522 RELEASE_AND_RETURN_IF_NOT_OK_MR(GenerateRawSQL(fields_, &sql_ptr), db, in);
523 auto row_data_ptr = std::make_shared<ROW_DATA>();
524 RELEASE_AND_RETURN_IF_NOT_OK_MR(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db,
525 in);
526 RELEASE_AND_RETURN_IF_NOT_OK_MR(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in);
527 MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db.";
528 }
529 sql_code = sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
530 if (sql_code != SQLITE_OK) {
531 in.close();
532 sqlite3_close(db);
533 RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `END TRANSACTION;` failed, SQLite result code: " +
534 std::to_string(sql_code));
535 }
536 in.close();
537
538 // Close database
539 sqlite3_close(db);
540 db = nullptr;
541 return Status::OK();
542 }
543
WriteToDatabase()544 Status ShardIndexGenerator::WriteToDatabase() {
545 fields_ = shard_header_.GetFields();
546 for (auto &field : fields_) {
547 // column name can't start with a number
548 CHECK_FAIL_RETURN_UNEXPECTED_MR(field.second[0] < '0' || field.second[0] > '9',
549 "Index field: " + field.second + " can't start with a number.");
550 }
551
552 page_size_ = shard_header_.GetPageSize();
553 header_size_ = shard_header_.GetHeaderSize();
554 schema_count_ = shard_header_.GetSchemaCount();
555 CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_header_.GetShardCount() <= kMaxShardCount,
556 "[Internal ERROR] 'shard_count': " + std::to_string(shard_header_.GetShardCount()) +
557 "is not in range (0, " + std::to_string(kMaxShardCount) + "].");
558
559 task_ = 0; // set two atomic vars to initial value
560 write_success_ = true;
561
562 // spawn half the physical threads or total number of shards whichever is smaller
563 const unsigned int num_workers =
564 std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount()));
565
566 std::vector<std::thread> threads;
567 threads.reserve(num_workers);
568
569 for (size_t t = 0; t < threads.capacity(); t++) {
570 threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
571 }
572
573 for (size_t t = 0; t < threads.capacity(); t++) {
574 threads[t].join();
575 }
576 CHECK_FAIL_RETURN_UNEXPECTED_MR(write_success_, "[Internal ERROR] Failed to write mindrecord meta files.");
577 return Status::OK();
578 }
579
DatabaseWriter()580 void ShardIndexGenerator::DatabaseWriter() {
581 int shard_no = task_++;
582 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
583 pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_no)).c_str());
584 #endif
585 while (shard_no < shard_header_.GetShardCount()) {
586 sqlite3 *db = nullptr;
587 Status st = CreateDatabase(shard_no, &db);
588 if (st.IsError()) {
589 // in thread, so we print error here
590 MS_LOG(ERROR) << st.ToString();
591 write_success_ = false;
592 return;
593 }
594 MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
595 // Pre-processing page information
596 auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
597
598 std::map<int, int> blob_id_to_page_id;
599 std::vector<int> raw_page_ids;
600 for (uint64_t i = 0; i < total_pages; ++i) {
601 std::shared_ptr<Page> page_ptr;
602 Status get_page_status = shard_header_.GetPage(shard_no, i, &page_ptr);
603 if (get_page_status.IsError()) {
604 MS_LOG(ERROR) << get_page_status.ToString();
605 write_success_ = false;
606 return;
607 }
608 if (page_ptr->GetPageType() == "RAW_DATA") {
609 raw_page_ids.push_back(i);
610 } else if (page_ptr->GetPageType() == "BLOB_DATA") {
611 blob_id_to_page_id[page_ptr->GetPageTypeID()] = i;
612 }
613 }
614 Status transaction_status = ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id);
615 if (transaction_status.IsError()) {
616 MS_LOG(ERROR) << transaction_status.ToString();
617 write_success_ = false;
618 return;
619 }
620 MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
621 shard_no = task_++;
622 }
623 }
624
Finalize(const std::vector<std::string> file_names)625 Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
626 CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_names.empty(), "[Internal ERROR] the size of mindrecord files is 0.");
627 ShardIndexGenerator sg{file_names[0]};
628 RETURN_IF_NOT_OK_MR(sg.Build());
629 RETURN_IF_NOT_OK_MR(sg.WriteToDatabase());
630 return Status::OK();
631 }
632 } // namespace mindrecord
633 } // namespace mindspore
634