• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/mindrecord/include/shard_index_generator.h"
17 
18 #include "utils/file_utils.h"
19 #include "utils/ms_utils.h"
20 
21 using mindspore::LogStream;
22 using mindspore::ExceptionType::NoExceptionType;
23 using mindspore::MsLogLevel::DEBUG;
24 using mindspore::MsLogLevel::ERROR;
25 using mindspore::MsLogLevel::INFO;
26 
27 namespace mindspore {
28 namespace mindrecord {
ShardIndexGenerator(const std::string & file_path,bool append)29 ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
30     : file_path_(file_path),
31       append_(append),
32       page_size_(0),
33       header_size_(0),
34       schema_count_(0),
35       task_(0),
36       write_success_(true) {}
37 
Build()38 Status ShardIndexGenerator::Build() {
39   std::shared_ptr<json> header_ptr;
40   RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr));
41   auto ds = std::make_shared<std::vector<std::string>>();
42   RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds));
43   ShardHeader header = ShardHeader();
44   RETURN_IF_NOT_OK(header.BuildDataset(*ds));
45   shard_header_ = header;
46   MS_LOG(INFO) << "Initialize header from mindrecord file for index successfully.";
47   return Status::OK();
48 }
49 
GetValueByField(const string & field,json input,std::shared_ptr<std::string> * value)50 Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) {
51   RETURN_UNEXPECTED_IF_NULL(value);
52   CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty.");
53   CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty.");
54 
55   // parameter input does not contain the field
56   CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(),
57                                "The field " + field + " is not found in json " + input.dump());
58 
59   // schema does not contain the field
60   auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
61   CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
62                                "The field " + field + " is not found in schema " + schema.dump());
63 
64   // field should be scalar type
65   CHECK_FAIL_RETURN_UNEXPECTED(
66     kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
67     "The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");
68 
69   if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
70     auto schema_field_options = schema[field];
71     CHECK_FAIL_RETURN_UNEXPECTED(
72       schema_field_options.find("shape") == schema_field_options.end(),
73       "The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable.");
74     *value = std::make_shared<std::string>(input[field].dump());
75   } else {
76     // the field type is string in here
77     *value = std::make_shared<std::string>(input[field].get<std::string>());
78   }
79   return Status::OK();
80 }
81 
TakeFieldType(const string & field_path,json schema)82 std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
83   std::vector<std::string> field_name = StringSplit(field_path, kPoint);
84   for (uint64_t i = 0; i < field_name.size(); ++i) {
85     try {
86       if (i != field_name.size() - 1) {
87         // Get type information from json schema
88         schema = schema.at(field_name[i]);
89         schema = schema.at("properties");
90       } else {
91         // standard root layer exist "properties" if type is "object"
92         if (schema.find("properties") != schema.end()) {
93           schema = schema.at("properties");
94         }
95         schema = schema.at(field_name[i]);
96         std::string field_type = schema.at("type").dump();
97         if (field_type.length() <= 2) {
98           return "";
99         } else {
100           return field_type.substr(1, field_type.length() - 2);
101         }
102       }
103     } catch (...) {
104       MS_LOG(WARNING) << "Exception occurred while get field type.";
105       return "";
106     }
107   }
108   return "";
109 }
110 
ConvertJsonToSQL(const std::string & json)111 std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) {
112   if (kDbJsonMap.find(json) != kDbJsonMap.end()) {
113     return kDbJsonMap.at(json);
114   } else {
115     return "TEXT";
116   }
117 }
118 
Callback(void * not_used,int argc,char ** argv,char ** az_col_name)119 int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) {
120   for (auto i = 0; i < argc; i++) {
121     if (argv[i] != nullptr) {
122       MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr");
123     }
124   }
125   MS_LOG(INFO) << "\n";
126   return 0;
127 }
128 
ExecuteSQL(const std::string & sql,sqlite3 * db,const std::string & success_msg)129 Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
130   char *z_err_msg = nullptr;
131   int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
132   if (rc != SQLITE_OK) {
133     std::ostringstream oss;
134     oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg;
135     MS_LOG(DEBUG) << oss.str();
136     sqlite3_free(z_err_msg);
137     sqlite3_close(db);
138     RETURN_STATUS_UNEXPECTED(oss.str());
139   } else {
140     if (!success_msg.empty()) {
141       MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg;
142     }
143     sqlite3_free(z_err_msg);
144     return Status::OK();
145   }
146 }
147 
GenerateFieldName(const std::pair<uint64_t,std::string> & field,std::shared_ptr<std::string> * fn_ptr)148 Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field,
149                                               std::shared_ptr<std::string> *fn_ptr) {
150   RETURN_UNEXPECTED_IF_NULL(fn_ptr);
151   // Replaces dots and dashes with underscores for SQL use
152   std::string field_name = field.second;
153   // white list to avoid sql injection
154   std::replace_if(
155     field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_');
156   auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
157     return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
158   });
159   CHECK_FAIL_RETURN_UNEXPECTED(
160     pos == field_name.end(),
161     "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name);
162   *fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
163   return Status::OK();
164 }
165 
CheckDatabase(const std::string & shard_address,sqlite3 ** db)166 Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) {
167   std::optional<std::string> dir = "";
168   std::optional<std::string> local_file_name = "";
169   FileUtils::SplitDirAndFileName(shard_address, &dir, &local_file_name);
170   if (!dir.has_value()) {
171     dir = ".";
172   }
173 
174   auto realpath = FileUtils::GetRealPath(dir.value().data());
175   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
176 
177   std::optional<std::string> whole_path = "";
178   FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
179 
180   std::ifstream fin(whole_path.value());
181   if (!append_ && fin.good()) {
182     fin.close();
183     RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address);
184   }
185   fin.close();
186   if (sqlite3_open_v2(common::SafeCStr(whole_path.value()), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
187     RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" +
188                              std::string(sqlite3_errmsg(*db)));
189   }
190   MS_LOG(DEBUG) << "Opened database successfully";
191   return Status::OK();
192 }
193 
CreateShardNameTable(sqlite3 * db,const std::string & shard_name)194 Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
195   // create shard_name table
196   std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;";
197   RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully."));
198   sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);";
199   RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully."));
200   sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);";
201   sqlite3_stmt *stmt = nullptr;
202   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
203     if (stmt != nullptr) {
204       (void)sqlite3_finalize(stmt);
205     }
206     sqlite3_close(db);
207     RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
208   }
209 
210   int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
211   if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
212     (void)sqlite3_finalize(stmt);
213     sqlite3_close(db);
214     RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
215                              ", field value: " + std::string(shard_name));
216   }
217 
218   if (sqlite3_step(stmt) != SQLITE_DONE) {
219     (void)sqlite3_finalize(stmt);
220     RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
221   }
222   (void)sqlite3_finalize(stmt);
223   return Status::OK();
224 }
225 
CreateDatabase(int shard_no,sqlite3 ** db)226 Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
227   std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
228   CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no);
229   std::shared_ptr<std::string> fn_ptr;
230   RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr));
231   shard_address += ".db";
232   RETURN_IF_NOT_OK(CheckDatabase(shard_address, db));
233   std::string sql = "DROP TABLE IF EXISTS INDEXES;";
234   RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully."));
235   sql =
236     "CREATE TABLE INDEXES("
237     "  ROW_ID               INT  NOT NULL, PAGE_ID_RAW          INT  NOT NULL"
238     ", PAGE_OFFSET_RAW      INT  NOT NULL, PAGE_OFFSET_RAW_END  INT  NOT NULL"
239     ", ROW_GROUP_ID         INT  NOT NULL, PAGE_ID_BLOB         INT  NOT NULL"
240     ", PAGE_OFFSET_BLOB     INT  NOT NULL, PAGE_OFFSET_BLOB_END INT  NOT NULL";
241 
242   int field_no = 0;
243   std::shared_ptr<std::string> field_ptr;
244   for (const auto &field : fields_) {
245     uint64_t schema_id = field.first;
246     std::shared_ptr<Schema> schema_ptr;
247     RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr));
248     json json_schema = (schema_ptr->GetSchema())["schema"];
249     std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema));
250     RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr));
251     sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type;
252   }
253   sql += ", PRIMARY KEY(ROW_ID";
254   for (uint64_t i = 0; i < fields_.size(); ++i) {
255     sql += ",INC_" + std::to_string(i);
256   }
257   sql += "));";
258   RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully."));
259   RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr));
260   return Status::OK();
261 }
262 
GetSchemaDetails(const std::vector<uint64_t> & schema_lens,std::fstream & in,std::shared_ptr<std::vector<json>> * detail_ptr)263 Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
264                                              std::shared_ptr<std::vector<json>> *detail_ptr) {
265   RETURN_UNEXPECTED_IF_NULL(detail_ptr);
266   if (schema_count_ <= kMaxSchemaCount) {
267     for (int sc = 0; sc < schema_count_; ++sc) {
268       std::vector<char> schema_detail(schema_lens[sc]);
269       auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
270       if (!io_read.good() || io_read.fail() || io_read.bad()) {
271         in.close();
272         RETURN_STATUS_UNEXPECTED("Failed to read file.");
273       }
274       auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
275       (*detail_ptr)->emplace_back(j);
276     }
277   }
278   return Status::OK();
279 }
280 
GenerateRawSQL(const std::vector<std::pair<uint64_t,std::string>> & fields,std::shared_ptr<std::string> * sql_ptr)281 Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
282                                            std::shared_ptr<std::string> *sql_ptr) {
283   std::string sql =
284     "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END,"
285     "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END";
286 
287   int field_no = 0;
288   for (const auto &field : fields) {
289     std::shared_ptr<std::string> fn_ptr;
290     RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
291     sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr;
292   }
293   sql +=
294     ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB,"
295     ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END";
296   field_no = 0;
297   for (const auto &field : fields) {
298     std::shared_ptr<std::string> fn_ptr;
299     RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
300     sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr;
301   }
302   sql += " )";
303 
304   *sql_ptr = std::make_shared<std::string>(sql);
305   return Status::OK();
306 }
307 
BindParameterExecuteSQL(sqlite3 * db,const std::string & sql,const ROW_DATA & data)308 Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) {
309   sqlite3_stmt *stmt = nullptr;
310   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
311     if (stmt != nullptr) {
312       (void)sqlite3_finalize(stmt);
313     }
314     sqlite3_close(db);
315     RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
316   }
317   for (auto &row : data) {
318     for (auto &field : row) {
319       const auto &place_holder = std::get<0>(field);
320       const auto &field_type = std::get<1>(field);
321       const auto &field_value = std::get<2>(field);
322 
323       int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
324       if (field_type == "INTEGER") {
325         if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
326           (void)sqlite3_finalize(stmt);
327           sqlite3_close(db);
328           RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
329                                    ", field value: " + std::string(field_value));
330         }
331       } else if (field_type == "NUMERIC") {
332         if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
333           (void)sqlite3_finalize(stmt);
334           sqlite3_close(db);
335           RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
336                                    ", field value: " + std::string(field_value));
337         }
338       } else if (field_type == "NULL") {
339         if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
340           (void)sqlite3_finalize(stmt);
341 
342           sqlite3_close(db);
343           RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
344                                    ", field value: NULL");
345         }
346       } else {
347         if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
348           (void)sqlite3_finalize(stmt);
349           sqlite3_close(db);
350           RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
351                                    ", field value: " + std::string(field_value));
352         }
353       }
354     }
355     if (sqlite3_step(stmt) != SQLITE_DONE) {
356       (void)sqlite3_finalize(stmt);
357       RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
358     }
359     (void)sqlite3_reset(stmt);
360   }
361   (void)sqlite3_finalize(stmt);
362   return Status::OK();
363 }
364 
AddBlobPageInfo(std::vector<std::tuple<std::string,std::string,std::string>> & row_data,const std::shared_ptr<Page> cur_blob_page,uint64_t & cur_blob_page_offset,std::fstream & in)365 Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
366                                             const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
367                                             std::fstream &in) {
368   row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));
369 
370   // blob data start
371   row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset));
372   auto &io_seekg_blob =
373     in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
374   if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
375     in.close();
376     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
377   }
378   uint64_t image_size = 0;
379   auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
380   if (!io_read.good() || io_read.fail() || io_read.bad()) {
381     MS_LOG(ERROR) << "File read failed";
382     in.close();
383     RETURN_STATUS_UNEXPECTED("Failed to read file.");
384   }
385 
386   cur_blob_page_offset += (kInt64Len + image_size);
387   row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset));
388 
389   return Status::OK();
390 }
391 
AddIndexFieldByRawData(const std::vector<json> & schema_detail,std::vector<std::tuple<std::string,std::string,std::string>> & row_data)392 Status ShardIndexGenerator::AddIndexFieldByRawData(
393   const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) {
394   auto index_fields_ptr = std::make_shared<INDEX_FIELDS>();
395   RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr));
396   int index = 0;
397   for (const auto &field : *index_fields_ptr) {
398     // assume simple field: string , number etc.
399     row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
400     row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
401   }
402   return Status::OK();
403 }
404 
GenerateRowData(int shard_no,const std::map<int,int> & blob_id_to_page_id,int raw_page_id,std::fstream & in,std::shared_ptr<ROW_DATA> * row_data_ptr)405 Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
406                                             std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) {
407   RETURN_UNEXPECTED_IF_NULL(row_data_ptr);
408   // current raw data page
409   std::shared_ptr<Page> page_ptr;
410   RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr));
411   // related blob page
412   vector<pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds();
413 
414   // pair: row_group id, offset in raw data page
415   for (pair<int, int> blob_ids : row_group_list) {
416     // get blob data page according to row_group id
417     auto iter = blob_id_to_page_id.find(blob_ids.first);
418     CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id.");
419     std::shared_ptr<Page> blob_page_ptr;
420     RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
421     // offset in current raw data page
422     auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
423     uint64_t cur_blob_page_offset = 0;
424     for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) {
425       std::vector<std::tuple<std::string, std::string, std::string>> row_data;
426       row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
427       row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID()));
428       row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID()));
429 
430       // raw data start
431       row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));
432 
433       // calculate raw data end
434       auto &io_seekg =
435         in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
436       if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
437         in.close();
438         RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
439       }
440       std::vector<uint64_t> schema_lens;
441       if (schema_count_ <= kMaxSchemaCount) {
442         for (int sc = 0; sc < schema_count_; sc++) {
443           uint64_t schema_size = 0;
444 
445           auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
446           if (!io_read.good() || io_read.fail() || io_read.bad()) {
447             in.close();
448             RETURN_STATUS_UNEXPECTED("Failed to read file.");
449           }
450 
451           cur_raw_page_offset += (kInt64Len + schema_size);
452           schema_lens.push_back(schema_size);
453         }
454       }
455       row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset));
456 
457       // Getting schema for getting data for fields
458       auto detail_ptr = std::make_shared<std::vector<json>>();
459       RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr));
460       // start blob page info
461       RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in));
462 
463       // start index field
464       AddIndexFieldByRawData(*detail_ptr, row_data);
465       (*row_data_ptr)->push_back(std::move(row_data));
466     }
467   }
468   return Status::OK();
469 }
470 
GenerateIndexFields(const std::vector<json> & schema_detail,std::shared_ptr<INDEX_FIELDS> * index_fields_ptr)471 Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail,
472                                                 std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) {
473   RETURN_UNEXPECTED_IF_NULL(index_fields_ptr);
474   // index fields
475   std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
476   for (const auto &field : index_fields) {
477     CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range.");
478     std::shared_ptr<std::string> field_val_ptr;
479     RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
480     std::shared_ptr<Schema> schema_ptr;
481     RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr));
482     std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"]));
483     std::shared_ptr<std::string> fn_ptr;
484     RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
485     (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr);
486   }
487   return Status::OK();
488 }
489 
ExecuteTransaction(const int & shard_no,sqlite3 * db,const std::vector<int> & raw_page_ids,const std::map<int,int> & blob_id_to_page_id)490 Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
491                                                const std::map<int, int> &blob_id_to_page_id) {
492   // Add index data to database
493   std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
494   CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty.");
495 
496   auto realpath = FileUtils::GetRealPath(shard_address.data());
497   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
498   std::fstream in;
499   in.open(realpath.value(), std::ios::in | std::ios::binary);
500   if (!in.good()) {
501     in.close();
502     RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address);
503   }
504   (void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
505   for (int raw_page_id : raw_page_ids) {
506     std::shared_ptr<std::string> sql_ptr;
507     RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in);
508     auto row_data_ptr = std::make_shared<ROW_DATA>();
509     RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in);
510     RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in);
511     MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db.";
512   }
513   (void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
514   in.close();
515 
516   // Close database
517   sqlite3_close(db);
518   db = nullptr;
519   return Status::OK();
520 }
521 
WriteToDatabase()522 Status ShardIndexGenerator::WriteToDatabase() {
523   fields_ = shard_header_.GetFields();
524   page_size_ = shard_header_.GetPageSize();
525   header_size_ = shard_header_.GetHeaderSize();
526   schema_count_ = shard_header_.GetSchemaCount();
527   CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount,
528                                "num shards: " + std::to_string(shard_header_.GetShardCount()) +
529                                  " exceeds max count:" + std::to_string(kMaxSchemaCount));
530 
531   task_ = 0;  // set two atomic vars to initial value
532   write_success_ = true;
533 
534   // spawn half the physical threads or total number of shards whichever is smaller
535   const unsigned int num_workers =
536     std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount()));
537 
538   std::vector<std::thread> threads;
539   threads.reserve(num_workers);
540 
541   for (size_t t = 0; t < threads.capacity(); t++) {
542     threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
543   }
544 
545   for (size_t t = 0; t < threads.capacity(); t++) {
546     threads[t].join();
547   }
548   CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db.");
549   return Status::OK();
550 }
551 
DatabaseWriter()552 void ShardIndexGenerator::DatabaseWriter() {
553   int shard_no = task_++;
554   while (shard_no < shard_header_.GetShardCount()) {
555     sqlite3 *db = nullptr;
556     if (CreateDatabase(shard_no, &db).IsError()) {
557       MS_LOG(ERROR) << "Failed to create Generate database.";
558       write_success_ = false;
559       return;
560     }
561     MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
562     // Pre-processing page information
563     auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
564 
565     std::map<int, int> blob_id_to_page_id;
566     std::vector<int> raw_page_ids;
567     for (uint64_t i = 0; i < total_pages; ++i) {
568       std::shared_ptr<Page> page_ptr;
569       if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) {
570         MS_LOG(ERROR) << "Failed to get page.";
571         write_success_ = false;
572         return;
573       }
574       if (page_ptr->GetPageType() == "RAW_DATA") {
575         raw_page_ids.push_back(i);
576       } else if (page_ptr->GetPageType() == "BLOB_DATA") {
577         blob_id_to_page_id[page_ptr->GetPageTypeID()] = i;
578       }
579     }
580 
581     if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) {
582       MS_LOG(ERROR) << "Failed to execute transaction.";
583       write_success_ = false;
584       return;
585     }
586     MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
587     shard_no = task_++;
588   }
589 }
Finalize(const std::vector<std::string> file_names)590 Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
591   CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty.");
592   ShardIndexGenerator sg{file_names[0]};
593   RETURN_IF_NOT_OK(sg.Build());
594   RETURN_IF_NOT_OK(sg.WriteToDatabase());
595   return Status::OK();
596 }
597 }  // namespace mindrecord
598 }  // namespace mindspore
599