• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/mindrecord/include/shard_index_generator.h"
17 
18 #include "utils/file_utils.h"
19 #include "utils/ms_utils.h"
20 
21 namespace mindspore {
22 namespace mindrecord {
ShardIndexGenerator(const std::string & file_path,bool append)23 ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
24     : file_path_(file_path),
25       append_(append),
26       page_size_(0),
27       header_size_(0),
28       schema_count_(0),
29       task_(0),
30       write_success_(true) {}
31 
Build()32 Status ShardIndexGenerator::Build() {
33   std::shared_ptr<json> header_ptr;
34   RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(file_path_, &header_ptr));
35   auto ds = std::make_shared<std::vector<std::string>>();
36   RETURN_IF_NOT_OK_MR(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds));
37   ShardHeader header = ShardHeader();
38   RETURN_IF_NOT_OK_MR(header.BuildDataset(*ds));
39   shard_header_ = header;
40   MS_LOG(INFO) << "Initialize header from mindrecord file for index successfully.";
41   return Status::OK();
42 }
43 
GetValueByField(const string & field,const json & input,std::shared_ptr<std::string> * value)44 Status ShardIndexGenerator::GetValueByField(const string &field, const json &input,
45                                             std::shared_ptr<std::string> *value) {
46   RETURN_UNEXPECTED_IF_NULL_MR(value);
47 
48   // parameter input does not contain the field
49   CHECK_FAIL_RETURN_UNEXPECTED_MR(input.find(field) != input.end(),
50                                   "[Internal ERROR] 'field': " + field + " can not found in raw data: " + input.dump());
51 
52   // schema does not contain the field
53   auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
54   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(field) != schema.end(),
55                                   "[Internal ERROR] 'field': " + field + " can not found in schema: " + schema.dump());
56 
57   // field should be scalar type
58   CHECK_FAIL_RETURN_UNEXPECTED_MR(
59     kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
60     "[Internal ERROR] 'field': " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");
61 
62   if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
63     auto schema_field_options = schema[field];
64     CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_field_options.find("shape") == schema_field_options.end(),
65                                     "[Internal ERROR] 'field': " + field + " shape is " +
66                                       schema[field]["shape"].dump() + " which is not retrievable.");
67     *value = std::make_shared<std::string>(input[field].dump());
68   } else {
69     // the field type is string in here
70     *value = std::make_shared<std::string>(input[field].get<std::string>());
71   }
72   return Status::OK();
73 }
74 
TakeFieldType(const string & field_path,json & schema)75 std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json &schema) {
76   std::vector<std::string> field_name = StringSplit(field_path, kPoint);
77   for (uint64_t i = 0; i < field_name.size(); ++i) {
78     try {
79       if (i != field_name.size() - 1) {
80         // Get type information from json schema
81         schema = schema.at(field_name[i]);
82         schema = schema.at("properties");
83       } else {
84         // standard root layer exist "properties" if type is "object"
85         if (schema.find("properties") != schema.end()) {
86           schema = schema.at("properties");
87         }
88         schema = schema.at(field_name[i]);
89         std::string field_type = schema.at("type").dump();
90         if (field_type.length() <= 2) {
91           return "";
92         } else {
93           return field_type.substr(1, field_type.length() - 2);
94         }
95       }
96     } catch (...) {
97       MS_LOG(WARNING) << "Exception occurred while get field type.";
98       return "";
99     }
100   }
101   return "";
102 }
103 
ConvertJsonToSQL(const std::string & json)104 std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) {
105   if (kDbJsonMap.find(json) != kDbJsonMap.end()) {
106     return kDbJsonMap.at(json);
107   } else {
108     return "TEXT";
109   }
110 }
111 
Callback(void * not_used,int argc,char ** argv,char ** az_col_name)112 int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) {
113   for (auto i = 0; i < argc; i++) {
114     if (argv[i] != nullptr) {
115       MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr");
116     }
117   }
118   MS_LOG(INFO) << "\n";
119   return 0;
120 }
121 
ExecuteSQL(const std::string & sql,sqlite3 * db,const std::string & success_msg)122 Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
123   char *z_err_msg = nullptr;
124   int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
125   if (rc != SQLITE_OK) {
126     std::ostringstream oss;
127     oss << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql) << " ], " << z_err_msg;
128     MS_LOG(DEBUG) << oss.str();
129     sqlite3_free(z_err_msg);
130     sqlite3_close(db);
131     RETURN_STATUS_UNEXPECTED_MR(oss.str());
132   } else {
133     if (!success_msg.empty()) {
134       MS_LOG(DEBUG) << "Suceess to execute the sql [ " << common::SafeCStr(sql) << " ], " << success_msg;
135     }
136     sqlite3_free(z_err_msg);
137     return Status::OK();
138   }
139 }
140 
GenerateFieldName(const std::pair<uint64_t,std::string> & field,std::shared_ptr<std::string> * fn_ptr)141 Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field,
142                                               std::shared_ptr<std::string> *fn_ptr) {
143   RETURN_UNEXPECTED_IF_NULL_MR(fn_ptr);
144   // Replaces dots and dashes with underscores for SQL use
145   std::string field_name = field.second;
146   // white list to avoid sql injection
147   std::replace_if(
148     field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_');
149   auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
150     return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
151   });
152   CHECK_FAIL_RETURN_UNEXPECTED_MR(pos == field_name.end(), "Invalid data, field name: " + field_name +
153                                                              "is not composed of '0-9' or 'a-z' or 'A-Z' or '_'.");
154   *fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
155   return Status::OK();
156 }
157 
CheckDatabase(const std::string & shard_address,sqlite3 ** db)158 Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) {
159   std::optional<std::string> dir = "";
160   std::optional<std::string> local_file_name = "";
161   FileUtils::SplitDirAndFileName(shard_address, &dir, &local_file_name);
162   if (!dir.has_value()) {
163     dir = ".";
164   }
165 
166   auto realpath = FileUtils::GetRealPath(dir.value().c_str());
167   CHECK_FAIL_RETURN_UNEXPECTED_MR(
168     realpath.has_value(),
169     "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + shard_address);
170 
171   std::optional<std::string> whole_path = "";
172   FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
173 
174   std::ifstream fin(whole_path.value(), std::ios::in);
175   if (!append_ && fin.good()) {
176     fin.close();
177     RETURN_STATUS_UNEXPECTED_MR(
178       "Invalid file, mindrecord meta files already exist. Please check file path: " + shard_address +
179       ".\nIf you do not want to keep the files, set the 'overwrite' parameter to True and try again.");
180   }
181   fin.close();
182 
183   std::string whole_path_utf8 = "";
184 #if defined(_WIN32) || defined(_WIN64)
185   whole_path_utf8 = FileUtils::GB2312ToUTF_8(whole_path.value().data());
186 #endif
187   if (whole_path_utf8.empty()) {
188     whole_path_utf8 = whole_path.value();
189   }
190 
191   if (sqlite3_open_v2(whole_path_utf8.data(), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
192     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to open mindrecord meta file: " + shard_address + ", " +
193                                 sqlite3_errmsg(*db));
194   }
195   MS_LOG(DEBUG) << "Open meta file successfully";
196   return Status::OK();
197 }
198 
CreateShardNameTable(sqlite3 * db,const std::string & shard_name)199 Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
200   // create shard_name table
201   std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;";
202   RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, db, "drop table successfully."));
203   sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);";
204   RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, db, "create table successfully."));
205   sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);";
206   sqlite3_stmt *stmt = nullptr;
207   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
208     if (stmt != nullptr) {
209       (void)sqlite3_finalize(stmt);
210     }
211     sqlite3_close(db);
212     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
213   }
214 
215   int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
216   if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
217     (void)sqlite3_finalize(stmt);
218     sqlite3_close(db);
219     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
220                                 std::to_string(index) + ", value: " + shard_name);
221   }
222 
223   if (sqlite3_step(stmt) != SQLITE_DONE) {
224     (void)sqlite3_finalize(stmt);
225     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to step execute stmt.");
226   }
227   (void)sqlite3_finalize(stmt);
228   return Status::OK();
229 }
230 
CreateDatabase(int shard_no,sqlite3 ** db)231 Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
232   std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
233   std::shared_ptr<std::string> fn_ptr;
234   RETURN_IF_NOT_OK_MR(GetFileName(shard_address, &fn_ptr));
235   shard_address += ".db";
236   RETURN_IF_NOT_OK_MR(CheckDatabase(shard_address, db));
237   std::string sql = "DROP TABLE IF EXISTS INDEXES;";
238   RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, *db, "drop table successfully."));
239   sql =
240     "CREATE TABLE INDEXES("
241     "  ROW_ID               INT  NOT NULL, PAGE_ID_RAW          INT  NOT NULL"
242     ", PAGE_OFFSET_RAW      INT  NOT NULL, PAGE_OFFSET_RAW_END  INT  NOT NULL"
243     ", ROW_GROUP_ID         INT  NOT NULL, PAGE_ID_BLOB         INT  NOT NULL"
244     ", PAGE_OFFSET_BLOB     INT  NOT NULL, PAGE_OFFSET_BLOB_END INT  NOT NULL";
245 
246   int field_no = 0;
247   std::shared_ptr<std::string> field_ptr;
248   for (const auto &field : fields_) {
249     uint64_t schema_id = field.first;
250     std::shared_ptr<Schema> schema_ptr;
251     RETURN_IF_NOT_OK_MR(shard_header_.GetSchemaByID(schema_id, &schema_ptr));
252     json json_schema = (schema_ptr->GetSchema())["schema"];
253     std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema));
254     RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &field_ptr));
255     sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type;
256   }
257   sql += ", PRIMARY KEY(ROW_ID";
258   for (uint64_t i = 0; i < fields_.size(); ++i) {
259     sql += ",INC_" + std::to_string(i);
260   }
261   sql += "));";
262   RETURN_IF_NOT_OK_MR(ExecuteSQL(sql, *db, "create table successfully."));
263   RETURN_IF_NOT_OK_MR(CreateShardNameTable(*db, *fn_ptr));
264   return Status::OK();
265 }
266 
GetSchemaDetails(const std::vector<uint64_t> & schema_lens,std::fstream & in,std::shared_ptr<std::vector<json>> * detail_ptr)267 Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
268                                              std::shared_ptr<std::vector<json>> *detail_ptr) {
269   RETURN_UNEXPECTED_IF_NULL_MR(detail_ptr);
270   if (schema_count_ <= kMaxSchemaCount) {
271     for (int sc = 0; sc < schema_count_; ++sc) {
272       std::vector<char> schema_detail(schema_lens[sc]);
273       auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
274       if (!io_read.good() || io_read.fail() || io_read.bad()) {
275         in.close();
276         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
277       }
278       auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
279       (*detail_ptr)->emplace_back(j);
280     }
281   }
282   return Status::OK();
283 }
284 
GenerateRawSQL(const std::vector<std::pair<uint64_t,std::string>> & fields,std::shared_ptr<std::string> * sql_ptr)285 Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
286                                            std::shared_ptr<std::string> *sql_ptr) {
287   std::string sql =
288     "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END,"
289     "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END";
290 
291   int field_no = 0;
292   for (const auto &field : fields) {
293     std::shared_ptr<std::string> fn_ptr;
294     RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &fn_ptr));
295     sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr;
296   }
297   sql +=
298     ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB,"
299     ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END";
300   field_no = 0;
301   for (const auto &field : fields) {
302     std::shared_ptr<std::string> fn_ptr;
303     RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &fn_ptr));
304     sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr;
305   }
306   sql += " )";
307 
308   *sql_ptr = std::make_shared<std::string>(sql);
309   return Status::OK();
310 }
311 
BindParameterExecuteSQL(sqlite3 * db,const std::string & sql,const ROW_DATA & data)312 Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) {
313   sqlite3_stmt *stmt = nullptr;
314   if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
315     if (stmt != nullptr) {
316       (void)sqlite3_finalize(stmt);
317     }
318     sqlite3_close(db);
319     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
320   }
321   for (auto &row : data) {
322     for (auto &field : row) {
323       const auto &place_holder = std::get<0>(field);
324       const auto &field_type = std::get<1>(field);
325       const auto &field_value = std::get<2>(field);
326 
327       int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
328       if (field_type == "INTEGER") {
329         if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
330           (void)sqlite3_finalize(stmt);
331           sqlite3_close(db);
332           RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
333                                       std::to_string(index) + ", value: " + field_value);
334         }
335       } else if (field_type == "REAL") {
336         if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
337           (void)sqlite3_finalize(stmt);
338           sqlite3_close(db);
339           RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
340                                       std::to_string(index) + ", value: " + field_value);
341         }
342       } else if (field_type == "NULL") {
343         if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
344           (void)sqlite3_finalize(stmt);
345 
346           sqlite3_close(db);
347           RETURN_STATUS_UNEXPECTED_MR(
348             "[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) + ", value: NULL");
349         }
350       } else {
351         if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
352           (void)sqlite3_finalize(stmt);
353           sqlite3_close(db);
354           RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to bind parameter of sql, key index: " +
355                                       std::to_string(index) + ", value: " + field_value);
356         }
357       }
358     }
359     if (sqlite3_step(stmt) != SQLITE_DONE) {
360       (void)sqlite3_finalize(stmt);
361       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to step execute stmt.");
362     }
363     (void)sqlite3_reset(stmt);
364   }
365   (void)sqlite3_finalize(stmt);
366   return Status::OK();
367 }
368 
AddBlobPageInfo(std::vector<std::tuple<std::string,std::string,std::string>> & row_data,const std::shared_ptr<Page> cur_blob_page,uint64_t & cur_blob_page_offset,std::fstream & in)369 Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
370                                             const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
371                                             std::fstream &in) {
372   row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));
373 
374   // blob data start
375   row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset));
376   auto &io_seekg_blob =
377     in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
378   if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
379     in.close();
380     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
381   }
382   uint64_t image_size = 0;
383   auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
384   if (!io_read.good() || io_read.fail() || io_read.bad()) {
385     in.close();
386     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
387   }
388 
389   cur_blob_page_offset += (kInt64Len + image_size);
390   row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset));
391 
392   return Status::OK();
393 }
394 
AddIndexFieldByRawData(const std::vector<json> & schema_detail,std::vector<std::tuple<std::string,std::string,std::string>> & row_data)395 Status ShardIndexGenerator::AddIndexFieldByRawData(
396   const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) {
397   auto index_fields_ptr = std::make_shared<INDEX_FIELDS>();
398   RETURN_IF_NOT_OK_MR(GenerateIndexFields(schema_detail, &index_fields_ptr));
399   int index = 0;
400   for (const auto &field : *index_fields_ptr) {
401     // assume simple field: string , number etc.
402     row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
403     row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
404   }
405   return Status::OK();
406 }
407 
GenerateRowData(int shard_no,const std::map<int,int> & blob_id_to_page_id,int raw_page_id,std::fstream & in,std::shared_ptr<ROW_DATA> * row_data_ptr)408 Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
409                                             std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) {
410   RETURN_UNEXPECTED_IF_NULL_MR(row_data_ptr);
411   // current raw data page
412   std::shared_ptr<Page> page_ptr;
413   RETURN_IF_NOT_OK_MR(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr));
414   // related blob page
415   std::vector<std::pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds();
416 
417   // pair: row_group id, offset in raw data page
418   for (pair<int, int> blob_ids : row_group_list) {
419     // get blob data page according to row_group id
420     auto iter = blob_id_to_page_id.find(blob_ids.first);
421     CHECK_FAIL_RETURN_UNEXPECTED_MR(iter != blob_id_to_page_id.end(),
422                                     "[Internal ERROR] Failed to get page id from blob id.");
423     std::shared_ptr<Page> blob_page_ptr;
424     RETURN_IF_NOT_OK_MR(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
425     // offset in current raw data page
426     auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
427     uint64_t cur_blob_page_offset = 0;
428     for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) {
429       std::vector<std::tuple<std::string, std::string, std::string>> row_data;
430       row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
431       row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID()));
432       row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID()));
433 
434       // raw data start
435       row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));
436 
437       // calculate raw data end
438       auto &io_seekg =
439         in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
440       if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
441         in.close();
442         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
443       }
444       std::vector<uint64_t> schema_lens;
445       if (schema_count_ <= kMaxSchemaCount) {
446         for (int sc = 0; sc < schema_count_; sc++) {
447           uint64_t schema_size = 0;
448 
449           auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
450           if (!io_read.good() || io_read.fail() || io_read.bad()) {
451             in.close();
452             RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
453           }
454 
455           cur_raw_page_offset += (kInt64Len + schema_size);
456           schema_lens.push_back(schema_size);
457         }
458       }
459       row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset));
460 
461       // Getting schema for getting data for fields
462       auto detail_ptr = std::make_shared<std::vector<json>>();
463       RETURN_IF_NOT_OK_MR(GetSchemaDetails(schema_lens, in, &detail_ptr));
464       // start blob page info
465       RETURN_IF_NOT_OK_MR(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in));
466 
467       // start index field
468       RETURN_IF_NOT_OK_MR(AddIndexFieldByRawData(*detail_ptr, row_data));
469       (*row_data_ptr)->push_back(std::move(row_data));
470     }
471   }
472   return Status::OK();
473 }
474 
GenerateIndexFields(const std::vector<json> & schema_detail,std::shared_ptr<INDEX_FIELDS> * index_fields_ptr)475 Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail,
476                                                 std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) {
477   RETURN_UNEXPECTED_IF_NULL_MR(index_fields_ptr);
478   // index fields
479   std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
480   for (const auto &field : index_fields) {
481     CHECK_FAIL_RETURN_UNEXPECTED_MR(
482       field.first < schema_detail.size(),
483       "[Internal ERROR] 'field': " + field.second + " is out of bound:" + std::to_string(schema_detail.size()));
484     std::shared_ptr<std::string> field_val_ptr;
485     RETURN_IF_NOT_OK_MR(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
486     std::shared_ptr<Schema> schema_ptr;
487     RETURN_IF_NOT_OK_MR(shard_header_.GetSchemaByID(field.first, &schema_ptr));
488     std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"]));
489     std::shared_ptr<std::string> fn_ptr;
490     RETURN_IF_NOT_OK_MR(GenerateFieldName(field, &fn_ptr));
491     (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr);
492   }
493   return Status::OK();
494 }
495 
ExecuteTransaction(const int & shard_no,sqlite3 * db,const std::vector<int> & raw_page_ids,const std::map<int,int> & blob_id_to_page_id)496 Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
497                                                const std::map<int, int> &blob_id_to_page_id) {
498   // Add index data to database
499   std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
500 
501   auto realpath = FileUtils::GetRealPath(shard_address.c_str());
502   CHECK_FAIL_RETURN_UNEXPECTED_MR(
503     realpath.has_value(),
504     "Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + shard_address);
505   std::fstream in;
506   in.open(realpath.value(), std::ios::in | std::ios::binary);
507   if (!in.good()) {
508     in.close();
509     RETURN_STATUS_UNEXPECTED_MR(
510       "Invalid file, failed to open mindrecord files. Please check file path, permission and open files limit(ulimit "
511       "-a): " +
512       shard_address);
513   }
514   auto sql_code = sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
515   if (sql_code != SQLITE_OK) {
516     in.close();
517     RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `BEGIN TRANSACTION;` failed, SQLite result code: " +
518                                 std::to_string(sql_code));
519   }
520   for (int raw_page_id : raw_page_ids) {
521     std::shared_ptr<std::string> sql_ptr;
522     RELEASE_AND_RETURN_IF_NOT_OK_MR(GenerateRawSQL(fields_, &sql_ptr), db, in);
523     auto row_data_ptr = std::make_shared<ROW_DATA>();
524     RELEASE_AND_RETURN_IF_NOT_OK_MR(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db,
525                                     in);
526     RELEASE_AND_RETURN_IF_NOT_OK_MR(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in);
527     MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db.";
528   }
529   sql_code = sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
530   if (sql_code != SQLITE_OK) {
531     in.close();
532     sqlite3_close(db);
533     RETURN_STATUS_UNEXPECTED_MR("Execute SQL statement `END TRANSACTION;` failed, SQLite result code: " +
534                                 std::to_string(sql_code));
535   }
536   in.close();
537 
538   // Close database
539   sqlite3_close(db);
540   db = nullptr;
541   return Status::OK();
542 }
543 
WriteToDatabase()544 Status ShardIndexGenerator::WriteToDatabase() {
545   fields_ = shard_header_.GetFields();
546   for (auto &field : fields_) {
547     // column name can't start with a number
548     CHECK_FAIL_RETURN_UNEXPECTED_MR(field.second[0] < '0' || field.second[0] > '9',
549                                     "Index field: " + field.second + " can't start with a number.");
550   }
551 
552   page_size_ = shard_header_.GetPageSize();
553   header_size_ = shard_header_.GetHeaderSize();
554   schema_count_ = shard_header_.GetSchemaCount();
555   CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_header_.GetShardCount() <= kMaxShardCount,
556                                   "[Internal ERROR] 'shard_count': " + std::to_string(shard_header_.GetShardCount()) +
557                                     "is not in range (0, " + std::to_string(kMaxShardCount) + "].");
558 
559   task_ = 0;  // set two atomic vars to initial value
560   write_success_ = true;
561 
562   // spawn half the physical threads or total number of shards whichever is smaller
563   const unsigned int num_workers =
564     std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount()));
565 
566   std::vector<std::thread> threads;
567   threads.reserve(num_workers);
568 
569   for (size_t t = 0; t < threads.capacity(); t++) {
570     threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
571   }
572 
573   for (size_t t = 0; t < threads.capacity(); t++) {
574     threads[t].join();
575   }
576   CHECK_FAIL_RETURN_UNEXPECTED_MR(write_success_, "[Internal ERROR] Failed to write mindrecord meta files.");
577   return Status::OK();
578 }
579 
DatabaseWriter()580 void ShardIndexGenerator::DatabaseWriter() {
581   int shard_no = task_++;
582 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
583   pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_no)).c_str());
584 #endif
585   while (shard_no < shard_header_.GetShardCount()) {
586     sqlite3 *db = nullptr;
587     Status st = CreateDatabase(shard_no, &db);
588     if (st.IsError()) {
589       // in thread, so we print error here
590       MS_LOG(ERROR) << st.ToString();
591       write_success_ = false;
592       return;
593     }
594     MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
595     // Pre-processing page information
596     auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
597 
598     std::map<int, int> blob_id_to_page_id;
599     std::vector<int> raw_page_ids;
600     for (uint64_t i = 0; i < total_pages; ++i) {
601       std::shared_ptr<Page> page_ptr;
602       Status get_page_status = shard_header_.GetPage(shard_no, i, &page_ptr);
603       if (get_page_status.IsError()) {
604         MS_LOG(ERROR) << get_page_status.ToString();
605         write_success_ = false;
606         return;
607       }
608       if (page_ptr->GetPageType() == "RAW_DATA") {
609         raw_page_ids.push_back(i);
610       } else if (page_ptr->GetPageType() == "BLOB_DATA") {
611         blob_id_to_page_id[page_ptr->GetPageTypeID()] = i;
612       }
613     }
614     Status transaction_status = ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id);
615     if (transaction_status.IsError()) {
616       MS_LOG(ERROR) << transaction_status.ToString();
617       write_success_ = false;
618       return;
619     }
620     MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
621     shard_no = task_++;
622   }
623 }
624 
Finalize(const std::vector<std::string> file_names)625 Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
626   CHECK_FAIL_RETURN_UNEXPECTED_MR(!file_names.empty(), "[Internal ERROR] the size of mindrecord files is 0.");
627   ShardIndexGenerator sg{file_names[0]};
628   RETURN_IF_NOT_OK_MR(sg.Build());
629   RETURN_IF_NOT_OK_MR(sg.WriteToDatabase());
630   return Status::OK();
631 }
632 }  // namespace mindrecord
633 }  // namespace mindspore
634