1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_segment.h"
18 #include "utils/ms_utils.h"
19
20 #include "./securec.h"
21 #include "minddata/mindrecord/include/common/shard_utils.h"
22 #include "pybind11/pybind11.h"
23
24 using mindspore::LogStream;
25 using mindspore::ExceptionType::NoExceptionType;
26 using mindspore::MsLogLevel::ERROR;
27 using mindspore::MsLogLevel::INFO;
28
29 namespace mindspore {
30 namespace mindrecord {
ShardSegment()31 ShardSegment::ShardSegment() { SetAllInIndex(false); }
32
GetCategoryFields(std::shared_ptr<vector<std::string>> * fields_ptr)33 Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) {
34 RETURN_UNEXPECTED_IF_NULL(fields_ptr);
35 // Skip if already populated
36 if (!candidate_category_fields_.empty()) {
37 *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
38 return Status::OK();
39 }
40
41 std::string sql = "PRAGMA table_info(INDEXES);";
42 std::vector<std::vector<std::string>> field_names;
43
44 char *errmsg = nullptr;
45 int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
46 if (rc != SQLITE_OK) {
47 std::ostringstream oss;
48 oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
49 sqlite3_free(errmsg);
50 sqlite3_close(database_paths_[0]);
51 database_paths_[0] = nullptr;
52 RETURN_STATUS_UNEXPECTED(oss.str());
53 } else {
54 MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_names.size()) << " records from index.";
55 }
56
57 uint32_t idx = kStartFieldId;
58 while (idx < field_names.size()) {
59 if (field_names[idx].size() < 2) {
60 sqlite3_free(errmsg);
61 sqlite3_close(database_paths_[0]);
62 database_paths_[0] = nullptr;
63 RETURN_STATUS_UNEXPECTED("Invalid data, field_names size must be greater than 1, but got " +
64 std::to_string(field_names[idx].size()));
65 }
66 candidate_category_fields_.push_back(field_names[idx][1]);
67 idx += 2;
68 }
69 sqlite3_free(errmsg);
70 *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
71 return Status::OK();
72 }
73
SetCategoryField(std::string category_field)74 Status ShardSegment::SetCategoryField(std::string category_field) {
75 std::shared_ptr<vector<std::string>> fields_ptr;
76 RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr));
77 category_field = category_field + "_0";
78 if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_),
79 [category_field](std::string x) { return x == category_field; })) {
80 current_category_field_ = category_field;
81 return Status::OK();
82 }
83 RETURN_STATUS_UNEXPECTED("Invalid data, field '" + category_field + "' is not a candidate category field.");
84 }
85
ReadCategoryInfo(std::shared_ptr<std::string> * category_ptr)86 Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
87 RETURN_UNEXPECTED_IF_NULL(category_ptr);
88 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
89 RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
90 // Convert category info to json string
91 *category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));
92
93 return Status::OK();
94 }
95
WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> * category_info_ptr)96 Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) {
97 RETURN_UNEXPECTED_IF_NULL(category_info_ptr);
98 std::map<std::string, int> counter;
99 CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_),
100 "Invalid data, field: " + current_category_field_ + "is invalid.");
101 std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
102 ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";
103
104 for (auto &db : database_paths_) {
105 std::vector<std::vector<std::string>> field_count;
106
107 char *errmsg = nullptr;
108 if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
109 std::ostringstream oss;
110 oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
111 sqlite3_free(errmsg);
112 sqlite3_close(db);
113 db = nullptr;
114 RETURN_STATUS_UNEXPECTED(oss.str());
115 } else {
116 MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_count.size()) << " records from index.";
117 }
118
119 for (const auto &field : field_count) {
120 counter[field[0]] += std::stoi(field[1]);
121 }
122 sqlite3_free(errmsg);
123 }
124
125 int idx = 0;
126 (*category_info_ptr)->resize(counter.size());
127 (void)std::transform(
128 counter.begin(), counter.end(), (*category_info_ptr)->begin(),
129 [&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); });
130 return Status::OK();
131 }
132
ToJsonForCategory(const CATEGORY_INFO & tri_vec)133 std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) {
134 std::vector<json> category_json_vec;
135 for (auto q : tri_vec) {
136 json j;
137 j["id"] = std::get<0>(q);
138 j["name"] = std::get<1>(q);
139 j["count"] = std::get<2>(q);
140
141 category_json_vec.emplace_back(j);
142 }
143
144 json j_vec(category_json_vec);
145 json category_info;
146 category_info["key"] = current_category_field_;
147 category_info["categories"] = j_vec;
148 return category_info.dump();
149 }
150
ReadAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * page_ptr)151 Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
152 std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) {
153 RETURN_UNEXPECTED_IF_NULL(page_ptr);
154 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
155 RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
156 CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
157 "Invalid data, category_id: " + std::to_string(category_id) +
158 " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
159 int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
160 // Quit if category not found or page number is out of range
161 CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
162 page_no * n_rows_of_page < total_rows_in_category,
163 "Invalid data, page no: " + std::to_string(page_no) +
164 "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
165
166 auto row_group_summary = ReadRowGroupSummary();
167
168 uint64_t i_start = page_no * n_rows_of_page;
169 uint64_t i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
170 uint64_t idx = 0;
171 for (const auto &rg : row_group_summary) {
172 if (idx >= i_end) break;
173
174 auto shard_id = std::get<0>(rg);
175 auto group_id = std::get<1>(rg);
176 std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
177 RETURN_IF_NOT_OK(ReadRowGroupCriteria(
178 group_id, shard_id,
179 std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
180 &row_group_brief_ptr));
181 auto offsets = std::get<3>(*row_group_brief_ptr);
182 uint64_t number_of_rows = offsets.size();
183 if (idx + number_of_rows < i_start) {
184 idx += number_of_rows;
185 continue;
186 }
187
188 for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) {
189 if (idx >= i_start && idx < i_end) {
190 auto images_ptr = std::make_shared<std::vector<uint8_t>>();
191 RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
192 (*page_ptr)->push_back(std::move(*images_ptr));
193 }
194 }
195 }
196
197 return Status::OK();
198 }
199
PackImages(int group_id,int shard_id,std::vector<uint64_t> offset,std::shared_ptr<std::vector<uint8_t>> * images_ptr)200 Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
201 std::shared_ptr<std::vector<uint8_t>> *images_ptr) {
202 RETURN_UNEXPECTED_IF_NULL(images_ptr);
203 std::shared_ptr<Page> page_ptr;
204 RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
205 // Pack image list
206 (*images_ptr)->resize(offset[1] - offset[0]);
207
208 auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0];
209 auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
210 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
211 file_streams_random_[0][shard_id]->close();
212 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
213 }
214
215 auto &io_read =
216 file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]);
217 if (!io_read.good() || io_read.fail() || io_read.bad()) {
218 file_streams_random_[0][shard_id]->close();
219 RETURN_STATUS_UNEXPECTED("Failed to read file.");
220 }
221 return Status::OK();
222 }
223
ReadAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * pages_ptr)224 Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
225 std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) {
226 RETURN_UNEXPECTED_IF_NULL(pages_ptr);
227 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
228 RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
229 for (const auto &categories : *category_info_ptr) {
230 if (std::get<1>(categories) == category_name) {
231 RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr));
232 return Status::OK();
233 }
234 }
235
236 RETURN_STATUS_UNEXPECTED("category_name: " + category_name + " could not found.");
237 }
238
ReadAllAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES> * pages_ptr)239 Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
240 std::shared_ptr<PAGES> *pages_ptr) {
241 RETURN_UNEXPECTED_IF_NULL(pages_ptr);
242 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
243 RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
244 CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()),
245 "Invalid data, category_id: " + std::to_string(category_id) +
246 " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
247
248 int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
249 // Quit if category not found or page number is out of range
250 CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
251 page_no * n_rows_of_page < total_rows_in_category,
252 "Invalid data, page no: " + std::to_string(page_no) +
253 "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
254 auto row_group_summary = ReadRowGroupSummary();
255
256 int i_start = page_no * n_rows_of_page;
257 int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
258 int idx = 0;
259 for (const auto &rg : row_group_summary) {
260 if (idx >= i_end) {
261 break;
262 }
263
264 auto shard_id = std::get<0>(rg);
265 auto group_id = std::get<1>(rg);
266 std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
267 RETURN_IF_NOT_OK(ReadRowGroupCriteria(
268 group_id, shard_id,
269 std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
270 &row_group_brief_ptr));
271 auto offsets = std::get<3>(*row_group_brief_ptr);
272 auto labels = std::get<4>(*row_group_brief_ptr);
273
274 int number_of_rows = offsets.size();
275 if (idx + number_of_rows < i_start) {
276 idx += number_of_rows;
277 continue;
278 }
279 CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()),
280 "Invalid data, number_of_rows: " + std::to_string(number_of_rows) + " is invalid.");
281 for (int i = 0; i < number_of_rows; ++i, ++idx) {
282 if (idx >= i_start && idx < i_end) {
283 auto images_ptr = std::make_shared<std::vector<uint8_t>>();
284 RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
285 (*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i]));
286 }
287 }
288 }
289 return Status::OK();
290 }
291
ReadAllAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES> * pages_ptr)292 Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
293 std::shared_ptr<PAGES> *pages_ptr) {
294 RETURN_UNEXPECTED_IF_NULL(pages_ptr);
295 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
296 RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
297 // category_name to category_id
298 int64_t category_id = -1;
299 for (const auto &categories : *category_info_ptr) {
300 std::string categories_name = std::get<1>(categories);
301
302 if (categories_name == category_name) {
303 category_id = std::get<0>(categories);
304 break;
305 }
306 }
307 CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "category_name: " + category_name + " could not found.");
308 return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
309 }
310
GetBlobFields()311 std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
312 std::vector<std::string> blob_fields;
313 auto schema_list = GetShardHeader()->GetSchemas();
314 if (!schema_list.empty()) {
315 const auto &fields = schema_list[0]->GetBlobFields();
316 blob_fields.assign(fields.begin(), fields.end());
317 }
318 return std::make_pair(kCV, blob_fields);
319 }
320
CleanUp(std::string field_name)321 std::string ShardSegment::CleanUp(std::string field_name) {
322 while (field_name.back() >= '0' && field_name.back() <= '9') {
323 field_name.pop_back();
324 }
325 field_name.pop_back();
326 return field_name;
327 }
328 } // namespace mindrecord
329 } // namespace mindspore
330