• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_segment.h"
18 #include "utils/ms_utils.h"
19 
20 #include "./securec.h"
21 #include "minddata/mindrecord/include/common/shard_utils.h"
22 #include "pybind11/pybind11.h"
23 
24 using mindspore::LogStream;
25 using mindspore::ExceptionType::NoExceptionType;
26 using mindspore::MsLogLevel::ERROR;
27 using mindspore::MsLogLevel::INFO;
28 
29 namespace mindspore {
30 namespace mindrecord {
ShardSegment()31 ShardSegment::ShardSegment() { SetAllInIndex(false); }
32 
GetCategoryFields(std::shared_ptr<vector<std::string>> * fields_ptr)33 Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) {
34   RETURN_UNEXPECTED_IF_NULL(fields_ptr);
35   // Skip if already populated
36   if (!candidate_category_fields_.empty()) {
37     *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
38     return Status::OK();
39   }
40 
41   std::string sql = "PRAGMA table_info(INDEXES);";
42   std::vector<std::vector<std::string>> field_names;
43 
44   char *errmsg = nullptr;
45   int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
46   if (rc != SQLITE_OK) {
47     std::ostringstream oss;
48     oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
49     sqlite3_free(errmsg);
50     sqlite3_close(database_paths_[0]);
51     database_paths_[0] = nullptr;
52     RETURN_STATUS_UNEXPECTED(oss.str());
53   } else {
54     MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_names.size()) << " records from index.";
55   }
56 
57   uint32_t idx = kStartFieldId;
58   while (idx < field_names.size()) {
59     if (field_names[idx].size() < 2) {
60       sqlite3_free(errmsg);
61       sqlite3_close(database_paths_[0]);
62       database_paths_[0] = nullptr;
63       RETURN_STATUS_UNEXPECTED("Invalid data, field_names size must be greater than 1, but got " +
64                                std::to_string(field_names[idx].size()));
65     }
66     candidate_category_fields_.push_back(field_names[idx][1]);
67     idx += 2;
68   }
69   sqlite3_free(errmsg);
70   *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
71   return Status::OK();
72 }
73 
SetCategoryField(std::string category_field)74 Status ShardSegment::SetCategoryField(std::string category_field) {
75   std::shared_ptr<vector<std::string>> fields_ptr;
76   RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr));
77   category_field = category_field + "_0";
78   if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_),
79                   [category_field](std::string x) { return x == category_field; })) {
80     current_category_field_ = category_field;
81     return Status::OK();
82   }
83   RETURN_STATUS_UNEXPECTED("Invalid data, field '" + category_field + "' is not a candidate category field.");
84 }
85 
ReadCategoryInfo(std::shared_ptr<std::string> * category_ptr)86 Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
87   RETURN_UNEXPECTED_IF_NULL(category_ptr);
88   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
89   RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
90   // Convert category info to json string
91   *category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));
92 
93   return Status::OK();
94 }
95 
WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> * category_info_ptr)96 Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) {
97   RETURN_UNEXPECTED_IF_NULL(category_info_ptr);
98   std::map<std::string, int> counter;
99   CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_),
100                                "Invalid data, field: " + current_category_field_ + "is invalid.");
101   std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
102                     ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";
103 
104   for (auto &db : database_paths_) {
105     std::vector<std::vector<std::string>> field_count;
106 
107     char *errmsg = nullptr;
108     if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
109       std::ostringstream oss;
110       oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
111       sqlite3_free(errmsg);
112       sqlite3_close(db);
113       db = nullptr;
114       RETURN_STATUS_UNEXPECTED(oss.str());
115     } else {
116       MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_count.size()) << " records from index.";
117     }
118 
119     for (const auto &field : field_count) {
120       counter[field[0]] += std::stoi(field[1]);
121     }
122     sqlite3_free(errmsg);
123   }
124 
125   int idx = 0;
126   (*category_info_ptr)->resize(counter.size());
127   (void)std::transform(
128     counter.begin(), counter.end(), (*category_info_ptr)->begin(),
129     [&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); });
130   return Status::OK();
131 }
132 
ToJsonForCategory(const CATEGORY_INFO & tri_vec)133 std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) {
134   std::vector<json> category_json_vec;
135   for (auto q : tri_vec) {
136     json j;
137     j["id"] = std::get<0>(q);
138     j["name"] = std::get<1>(q);
139     j["count"] = std::get<2>(q);
140 
141     category_json_vec.emplace_back(j);
142   }
143 
144   json j_vec(category_json_vec);
145   json category_info;
146   category_info["key"] = current_category_field_;
147   category_info["categories"] = j_vec;
148   return category_info.dump();
149 }
150 
ReadAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * page_ptr)151 Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
152                                     std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) {
153   RETURN_UNEXPECTED_IF_NULL(page_ptr);
154   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
155   RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
156   CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
157                                "Invalid data, category_id: " + std::to_string(category_id) +
158                                  " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
159   int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
160   // Quit if category not found or page number is out of range
161   CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
162                                  page_no * n_rows_of_page < total_rows_in_category,
163                                "Invalid data, page no: " + std::to_string(page_no) +
164                                  "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
165 
166   auto row_group_summary = ReadRowGroupSummary();
167 
168   uint64_t i_start = page_no * n_rows_of_page;
169   uint64_t i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
170   uint64_t idx = 0;
171   for (const auto &rg : row_group_summary) {
172     if (idx >= i_end) break;
173 
174     auto shard_id = std::get<0>(rg);
175     auto group_id = std::get<1>(rg);
176     std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
177     RETURN_IF_NOT_OK(ReadRowGroupCriteria(
178       group_id, shard_id,
179       std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
180       &row_group_brief_ptr));
181     auto offsets = std::get<3>(*row_group_brief_ptr);
182     uint64_t number_of_rows = offsets.size();
183     if (idx + number_of_rows < i_start) {
184       idx += number_of_rows;
185       continue;
186     }
187 
188     for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) {
189       if (idx >= i_start && idx < i_end) {
190         auto images_ptr = std::make_shared<std::vector<uint8_t>>();
191         RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
192         (*page_ptr)->push_back(std::move(*images_ptr));
193       }
194     }
195   }
196 
197   return Status::OK();
198 }
199 
PackImages(int group_id,int shard_id,std::vector<uint64_t> offset,std::shared_ptr<std::vector<uint8_t>> * images_ptr)200 Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
201                                 std::shared_ptr<std::vector<uint8_t>> *images_ptr) {
202   RETURN_UNEXPECTED_IF_NULL(images_ptr);
203   std::shared_ptr<Page> page_ptr;
204   RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
205   // Pack image list
206   (*images_ptr)->resize(offset[1] - offset[0]);
207 
208   auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0];
209   auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
210   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
211     file_streams_random_[0][shard_id]->close();
212     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
213   }
214 
215   auto &io_read =
216     file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]);
217   if (!io_read.good() || io_read.fail() || io_read.bad()) {
218     file_streams_random_[0][shard_id]->close();
219     RETURN_STATUS_UNEXPECTED("Failed to read file.");
220   }
221   return Status::OK();
222 }
223 
ReadAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * pages_ptr)224 Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
225                                       std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) {
226   RETURN_UNEXPECTED_IF_NULL(pages_ptr);
227   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
228   RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
229   for (const auto &categories : *category_info_ptr) {
230     if (std::get<1>(categories) == category_name) {
231       RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr));
232       return Status::OK();
233     }
234   }
235 
236   RETURN_STATUS_UNEXPECTED("category_name: " + category_name + " could not found.");
237 }
238 
ReadAllAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES> * pages_ptr)239 Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
240                                        std::shared_ptr<PAGES> *pages_ptr) {
241   RETURN_UNEXPECTED_IF_NULL(pages_ptr);
242   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
243   RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
244   CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()),
245                                "Invalid data, category_id: " + std::to_string(category_id) +
246                                  " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
247 
248   int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
249   // Quit if category not found or page number is out of range
250   CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
251                                  page_no * n_rows_of_page < total_rows_in_category,
252                                "Invalid data, page no: " + std::to_string(page_no) +
253                                  "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
254   auto row_group_summary = ReadRowGroupSummary();
255 
256   int i_start = page_no * n_rows_of_page;
257   int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
258   int idx = 0;
259   for (const auto &rg : row_group_summary) {
260     if (idx >= i_end) {
261       break;
262     }
263 
264     auto shard_id = std::get<0>(rg);
265     auto group_id = std::get<1>(rg);
266     std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
267     RETURN_IF_NOT_OK(ReadRowGroupCriteria(
268       group_id, shard_id,
269       std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
270       &row_group_brief_ptr));
271     auto offsets = std::get<3>(*row_group_brief_ptr);
272     auto labels = std::get<4>(*row_group_brief_ptr);
273 
274     int number_of_rows = offsets.size();
275     if (idx + number_of_rows < i_start) {
276       idx += number_of_rows;
277       continue;
278     }
279     CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()),
280                                  "Invalid data, number_of_rows: " + std::to_string(number_of_rows) + " is invalid.");
281     for (int i = 0; i < number_of_rows; ++i, ++idx) {
282       if (idx >= i_start && idx < i_end) {
283         auto images_ptr = std::make_shared<std::vector<uint8_t>>();
284         RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
285         (*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i]));
286       }
287     }
288   }
289   return Status::OK();
290 }
291 
ReadAllAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES> * pages_ptr)292 Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
293                                          std::shared_ptr<PAGES> *pages_ptr) {
294   RETURN_UNEXPECTED_IF_NULL(pages_ptr);
295   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
296   RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
297   // category_name to category_id
298   int64_t category_id = -1;
299   for (const auto &categories : *category_info_ptr) {
300     std::string categories_name = std::get<1>(categories);
301 
302     if (categories_name == category_name) {
303       category_id = std::get<0>(categories);
304       break;
305     }
306   }
307   CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "category_name: " + category_name + " could not found.");
308   return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
309 }
310 
GetBlobFields()311 std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
312   std::vector<std::string> blob_fields;
313   auto schema_list = GetShardHeader()->GetSchemas();
314   if (!schema_list.empty()) {
315     const auto &fields = schema_list[0]->GetBlobFields();
316     blob_fields.assign(fields.begin(), fields.end());
317   }
318   return std::make_pair(kCV, blob_fields);
319 }
320 
CleanUp(std::string field_name)321 std::string ShardSegment::CleanUp(std::string field_name) {
322   while (field_name.back() >= '0' && field_name.back() <= '9') {
323     field_name.pop_back();
324   }
325   field_name.pop_back();
326   return field_name;
327 }
328 }  // namespace mindrecord
329 }  // namespace mindspore
330