• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_segment.h"
18 #include "utils/ms_utils.h"
19 
20 #include "securec.h"
21 #include "minddata/mindrecord/include/common/shard_utils.h"
22 #include "pybind11/pybind11.h"
23 
24 namespace mindspore {
25 namespace mindrecord {
ShardSegment()26 ShardSegment::ShardSegment() { SetAllInIndex(false); }
27 
GetCategoryFields(std::shared_ptr<vector<std::string>> * fields_ptr)28 Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) {
29   RETURN_UNEXPECTED_IF_NULL_MR(fields_ptr);
30   // Skip if already populated
31   if (!candidate_category_fields_.empty()) {
32     *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
33     return Status::OK();
34   }
35 
36   std::string sql = "PRAGMA table_info(INDEXES);";
37   std::vector<std::vector<std::string>> field_names;
38 
39   char *errmsg = nullptr;
40   int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
41   if (rc != SQLITE_OK) {
42     std::ostringstream oss;
43     oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
44     sqlite3_free(errmsg);
45     sqlite3_close(database_paths_[0]);
46     database_paths_[0] = nullptr;
47     RETURN_STATUS_UNEXPECTED_MR(oss.str());
48   } else {
49     MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_names.size()) << " records from index.";
50   }
51 
52   uint32_t idx = kStartFieldId;
53   while (idx < field_names.size()) {
54     if (field_names[idx].size() < 2) {
55       sqlite3_free(errmsg);
56       sqlite3_close(database_paths_[0]);
57       database_paths_[0] = nullptr;
58       RETURN_STATUS_UNEXPECTED_MR("Invalid data, field_names size must be greater than 1, but got: " +
59                                   std::to_string(field_names[idx].size()));
60     }
61     candidate_category_fields_.push_back(field_names[idx][1]);
62     idx += 2;
63   }
64   sqlite3_free(errmsg);
65   *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
66   return Status::OK();
67 }
68 
SetCategoryField(std::string category_field)69 Status ShardSegment::SetCategoryField(std::string category_field) {
70   std::shared_ptr<vector<std::string>> fields_ptr;
71   RETURN_IF_NOT_OK_MR(GetCategoryFields(&fields_ptr));
72   category_field = category_field + "_0";
73   if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_),
74                   [category_field](std::string x) { return x == category_field; })) {
75     current_category_field_ = category_field;
76     return Status::OK();
77   }
78   RETURN_STATUS_UNEXPECTED_MR("Invalid data, field '" + category_field + "' is not a candidate category field.");
79 }
80 
ReadCategoryInfo(std::shared_ptr<std::string> * category_ptr)81 Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
82   RETURN_UNEXPECTED_IF_NULL_MR(category_ptr);
83   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
84   RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
85   // Convert category info to json string
86   *category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));
87 
88   return Status::OK();
89 }
90 
WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> * category_info_ptr)91 Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) {
92   RETURN_UNEXPECTED_IF_NULL_MR(category_info_ptr);
93   std::map<std::string, int> counter;
94   CHECK_FAIL_RETURN_UNEXPECTED_MR(ValidateFieldName(current_category_field_),
95                                   "Invalid data, field: " + current_category_field_ + "is invalid.");
96   std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
97                     ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";
98 
99   for (auto &db : database_paths_) {
100     std::vector<std::vector<std::string>> field_count;
101 
102     char *errmsg = nullptr;
103     if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
104       std::ostringstream oss;
105       oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
106       sqlite3_free(errmsg);
107       sqlite3_close(db);
108       db = nullptr;
109       RETURN_STATUS_UNEXPECTED_MR(oss.str());
110     } else {
111       MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_count.size()) << " records from index.";
112     }
113 
114     for (const auto &field : field_count) {
115       counter[field[0]] += std::stoi(field[1]);
116     }
117     sqlite3_free(errmsg);
118   }
119 
120   int idx = 0;
121   (*category_info_ptr)->resize(counter.size());
122   (void)std::transform(
123     counter.begin(), counter.end(), (*category_info_ptr)->begin(),
124     [&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); });
125   return Status::OK();
126 }
127 
ToJsonForCategory(const CATEGORY_INFO & tri_vec)128 std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) {
129   std::vector<json> category_json_vec;
130   for (auto q : tri_vec) {
131     json j;
132     j["id"] = std::get<0>(q);
133     j["name"] = std::get<1>(q);
134     j["count"] = std::get<2>(q);
135 
136     category_json_vec.emplace_back(j);
137   }
138 
139   json j_vec(category_json_vec);
140   json category_info;
141   category_info["key"] = current_category_field_;
142   category_info["categories"] = j_vec;
143   return category_info.dump();
144 }
145 
ReadAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * page_ptr)146 Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
147                                     std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) {
148   RETURN_UNEXPECTED_IF_NULL_MR(page_ptr);
149   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
150   RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
151   CHECK_FAIL_RETURN_UNEXPECTED_MR(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
152                                   "Invalid data, category_id: " + std::to_string(category_id) +
153                                     " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
154   int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
155   // Quit if category not found or page number is out of range
156   CHECK_FAIL_RETURN_UNEXPECTED_MR(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
157                                     page_no * n_rows_of_page < total_rows_in_category,
158                                   "Invalid data, page no: " + std::to_string(page_no) +
159                                     "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
160 
161   auto row_group_summary = ReadRowGroupSummary();
162 
163   uint64_t i_start = page_no * n_rows_of_page;
164   uint64_t i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
165   uint64_t idx = 0;
166   for (const auto &rg : row_group_summary) {
167     if (idx >= i_end) {
168       break;
169     }
170 
171     auto shard_id = std::get<0>(rg);
172     auto group_id = std::get<1>(rg);
173     std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
174     RETURN_IF_NOT_OK_MR(ReadRowGroupCriteria(
175       group_id, shard_id,
176       std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
177       &row_group_brief_ptr));
178     auto offsets = std::get<3>(*row_group_brief_ptr);
179     uint64_t number_of_rows = offsets.size();
180     if (idx + number_of_rows < i_start) {
181       idx += number_of_rows;
182       continue;
183     }
184 
185     for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) {
186       if (idx >= i_start && idx < i_end) {
187         auto images_ptr = std::make_shared<std::vector<uint8_t>>();
188         RETURN_IF_NOT_OK_MR(PackImages(group_id, shard_id, offsets[i], &images_ptr));
189         (*page_ptr)->push_back(std::move(*images_ptr));
190       }
191     }
192   }
193 
194   return Status::OK();
195 }
196 
PackImages(int group_id,int shard_id,std::vector<uint64_t> offset,std::shared_ptr<std::vector<uint8_t>> * images_ptr)197 Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
198                                 std::shared_ptr<std::vector<uint8_t>> *images_ptr) {
199   RETURN_UNEXPECTED_IF_NULL_MR(images_ptr);
200   std::shared_ptr<Page> page_ptr;
201   RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
202   // Pack image list
203   (*images_ptr)->resize(offset[1] - offset[0]);
204 
205   auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0];
206   auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
207   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
208     file_streams_random_[0][shard_id]->close();
209     RETURN_STATUS_UNEXPECTED_MR("Failed to seekg file.");
210   }
211 
212   auto &io_read =
213     file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]);
214   if (!io_read.good() || io_read.fail() || io_read.bad()) {
215     file_streams_random_[0][shard_id]->close();
216     RETURN_STATUS_UNEXPECTED_MR("Failed to read file.");
217   }
218   return Status::OK();
219 }
220 
ReadAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * pages_ptr)221 Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
222                                       std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) {
223   RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
224   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
225   RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
226   for (const auto &categories : *category_info_ptr) {
227     if (std::get<1>(categories) == category_name) {
228       RETURN_IF_NOT_OK_MR(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr));
229       return Status::OK();
230     }
231   }
232 
233   RETURN_STATUS_UNEXPECTED_MR("category_name: " + category_name + " could not found.");
234 }
235 
ReadAllAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES_WITH_BLOBS> * pages_ptr)236 Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
237                                        std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr) {
238   RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
239   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
240   RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
241   CHECK_FAIL_RETURN_UNEXPECTED_MR(category_id < static_cast<int64_t>(category_info_ptr->size()),
242                                   "Invalid data, category_id: " + std::to_string(category_id) +
243                                     " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
244 
245   int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
246   // Quit if category not found or page number is out of range
247   CHECK_FAIL_RETURN_UNEXPECTED_MR(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
248                                     page_no * n_rows_of_page < total_rows_in_category,
249                                   "Invalid data, page no: " + std::to_string(page_no) +
250                                     "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
251   auto row_group_summary = ReadRowGroupSummary();
252 
253   int i_start = page_no * n_rows_of_page;
254   int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
255   int idx = 0;
256   for (const auto &rg : row_group_summary) {
257     if (idx >= i_end) {
258       break;
259     }
260 
261     auto shard_id = std::get<0>(rg);
262     auto group_id = std::get<1>(rg);
263     std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
264     RETURN_IF_NOT_OK_MR(ReadRowGroupCriteria(
265       group_id, shard_id,
266       std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
267       &row_group_brief_ptr));
268     auto offsets = std::get<3>(*row_group_brief_ptr);
269     auto labels = std::get<4>(*row_group_brief_ptr);
270 
271     int number_of_rows = offsets.size();
272     if (idx + number_of_rows < i_start) {
273       idx += number_of_rows;
274       continue;
275     }
276     CHECK_FAIL_RETURN_UNEXPECTED_MR(number_of_rows <= static_cast<int>(labels.size()),
277                                     "Invalid data, number_of_rows: " + std::to_string(number_of_rows) + " is invalid.");
278     std::map<std::string, std::vector<uint8_t>> key_with_blob_fields;
279     for (int i = 0; i < number_of_rows; ++i, ++idx) {
280       if (idx >= i_start && idx < i_end) {
281         auto images_ptr = std::make_shared<std::vector<uint8_t>>();
282         RETURN_IF_NOT_OK_MR(PackImages(group_id, shard_id, offsets[i], &images_ptr));
283 
284         // extract every blob field from blob data
285         auto shard_column = GetShardColumn();
286         auto schema = shard_header_->GetSchemas();  // current, we only support 1 schema yet
287         auto blob_fields = schema[0]->GetBlobFields();
288         for (auto blob_field : blob_fields) {
289           const unsigned char *data = nullptr;
290           std::unique_ptr<unsigned char[]> data_ptr;
291           uint64_t n_bytes = 0;
292           mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
293           uint64_t column_data_type_size = 1;
294           std::vector<int64_t> column_shape;
295           RETURN_IF_NOT_OK_MR(shard_column->GetColumnValueByName(blob_field, *images_ptr, labels[i], &data, &data_ptr,
296                                                                  &n_bytes, &column_data_type, &column_data_type_size,
297                                                                  &column_shape));
298           key_with_blob_fields[blob_field] = std::vector<uint8_t>(data, data + n_bytes);
299         }
300 
301         (*pages_ptr)->emplace_back(std::move(key_with_blob_fields), std::move(labels[i]));
302       }
303     }
304   }
305   return Status::OK();
306 }
307 
ReadAllAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES_WITH_BLOBS> * pages_ptr)308 Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
309                                          std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr) {
310   RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
311   auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
312   RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
313   // category_name to category_id
314   int64_t category_id = -1;
315   for (const auto &categories : *category_info_ptr) {
316     std::string categories_name = std::get<1>(categories);
317 
318     if (categories_name == category_name) {
319       category_id = std::get<0>(categories);
320       break;
321     }
322   }
323   CHECK_FAIL_RETURN_UNEXPECTED_MR(category_id != -1, "category_name: " + category_name + " could not found.");
324   return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
325 }
326 
GetBlobFields()327 std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
328   std::vector<std::string> blob_fields;
329   auto schema_list = GetShardHeader()->GetSchemas();
330   if (!schema_list.empty()) {
331     const auto &fields = schema_list[0]->GetBlobFields();
332     blob_fields.assign(fields.begin(), fields.end());
333   }
334   return std::make_pair(kCV, blob_fields);
335 }
336 
CleanUp(std::string field_name)337 std::string ShardSegment::CleanUp(std::string field_name) {
338   while (field_name.back() >= '0' && field_name.back() <= '9') {
339     field_name.pop_back();
340   }
341   field_name.pop_back();
342   return field_name;
343 }
344 }  // namespace mindrecord
345 }  // namespace mindspore
346