• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_
19 
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <tuple>
26 #include <utility>
27 #include <vector>
28 #include "minddata/mindrecord/include/shard_header.h"
29 #include "./sqlite3.h"
30 
31 namespace mindspore {
32 namespace mindrecord {
33 using INDEX_FIELDS = std::vector<std::tuple<std::string, std::string, std::string>>;
34 using ROW_DATA = std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>;
35 class __attribute__((visibility("default"))) ShardIndexGenerator {
36  public:
37   explicit ShardIndexGenerator(const std::string &file_path, bool append = false);
38 
39   Status Build();
40 
41   static Status GenerateFieldName(const std::pair<uint64_t, std::string> &field, std::shared_ptr<std::string> *fn_ptr);
42 
~ShardIndexGenerator()43   ~ShardIndexGenerator() {}
44 
45   /// \brief fetch value in json by field name
46   /// \param[in] field
47   /// \param[in] input
48   /// \param[in] value
49   /// \return Status
50   Status GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value);
51 
52   /// \brief fetch field type in schema n by field path
53   /// \param[in] field_path
54   /// \param[in] schema
55   /// \return the type of field
56   static std::string TakeFieldType(const std::string &field_path, json schema);
57 
58   /// \brief create databases for indexes
59   Status WriteToDatabase();
60 
61   static Status Finalize(const std::vector<std::string> file_names);
62 
63  private:
64   static int Callback(void *not_used, int argc, char **argv, char **az_col_name);
65 
66   static Status ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = "");
67 
68   static std::string ConvertJsonToSQL(const std::string &json);
69 
70   Status CreateDatabase(int shard_no, sqlite3 **db);
71 
72   Status GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
73                           std::shared_ptr<std::vector<json>> *detail_ptr);
74 
75   static Status GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
76                                std::shared_ptr<std::string> *sql_ptr);
77 
78   Status CheckDatabase(const std::string &shard_address, sqlite3 **db);
79 
80   ///
81   /// \param shard_no
82   /// \param blob_id_to_page_id
83   /// \param raw_page_id
84   /// \param in
85   /// \return Status
86   Status GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, std::fstream &in,
87                          std::shared_ptr<ROW_DATA> *row_data_ptr);
88   ///
89   /// \param db
90   /// \param sql
91   /// \param data
92   /// \return
93   Status BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data);
94 
95   Status GenerateIndexFields(const std::vector<json> &schema_detail, std::shared_ptr<INDEX_FIELDS> *index_fields_ptr);
96 
97   Status ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
98                             const std::map<int, int> &blob_id_to_page_id);
99 
100   Status CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
101 
102   Status AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
103                          const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in);
104 
105   Status AddIndexFieldByRawData(const std::vector<json> &schema_detail,
106                                 std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
107 
108   void DatabaseWriter();  // worker thread
109 
110   std::string file_path_;
111   bool append_;
112   ShardHeader shard_header_;
113   uint64_t page_size_;
114   uint64_t header_size_;
115   int schema_count_;
116   std::atomic_int task_;
117   std::atomic_bool write_success_;
118   std::vector<std::pair<uint64_t, std::string>> fields_;
119 };
120 }  // namespace mindrecord
121 }  // namespace mindspore
122 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_
123