1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_segment.h"
18 #include "utils/ms_utils.h"
19
20 #include "securec.h"
21 #include "minddata/mindrecord/include/common/shard_utils.h"
22 #include "pybind11/pybind11.h"
23
24 namespace mindspore {
25 namespace mindrecord {
ShardSegment()26 ShardSegment::ShardSegment() { SetAllInIndex(false); }
27
GetCategoryFields(std::shared_ptr<vector<std::string>> * fields_ptr)28 Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) {
29 RETURN_UNEXPECTED_IF_NULL_MR(fields_ptr);
30 // Skip if already populated
31 if (!candidate_category_fields_.empty()) {
32 *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
33 return Status::OK();
34 }
35
36 std::string sql = "PRAGMA table_info(INDEXES);";
37 std::vector<std::vector<std::string>> field_names;
38
39 char *errmsg = nullptr;
40 int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
41 if (rc != SQLITE_OK) {
42 std::ostringstream oss;
43 oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
44 sqlite3_free(errmsg);
45 sqlite3_close(database_paths_[0]);
46 database_paths_[0] = nullptr;
47 RETURN_STATUS_UNEXPECTED_MR(oss.str());
48 } else {
49 MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_names.size()) << " records from index.";
50 }
51
52 uint32_t idx = kStartFieldId;
53 while (idx < field_names.size()) {
54 if (field_names[idx].size() < 2) {
55 sqlite3_free(errmsg);
56 sqlite3_close(database_paths_[0]);
57 database_paths_[0] = nullptr;
58 RETURN_STATUS_UNEXPECTED_MR("Invalid data, field_names size must be greater than 1, but got: " +
59 std::to_string(field_names[idx].size()));
60 }
61 candidate_category_fields_.push_back(field_names[idx][1]);
62 idx += 2;
63 }
64 sqlite3_free(errmsg);
65 *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
66 return Status::OK();
67 }
68
SetCategoryField(std::string category_field)69 Status ShardSegment::SetCategoryField(std::string category_field) {
70 std::shared_ptr<vector<std::string>> fields_ptr;
71 RETURN_IF_NOT_OK_MR(GetCategoryFields(&fields_ptr));
72 category_field = category_field + "_0";
73 if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_),
74 [category_field](std::string x) { return x == category_field; })) {
75 current_category_field_ = category_field;
76 return Status::OK();
77 }
78 RETURN_STATUS_UNEXPECTED_MR("Invalid data, field '" + category_field + "' is not a candidate category field.");
79 }
80
ReadCategoryInfo(std::shared_ptr<std::string> * category_ptr)81 Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
82 RETURN_UNEXPECTED_IF_NULL_MR(category_ptr);
83 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
84 RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
85 // Convert category info to json string
86 *category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));
87
88 return Status::OK();
89 }
90
WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> * category_info_ptr)91 Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) {
92 RETURN_UNEXPECTED_IF_NULL_MR(category_info_ptr);
93 std::map<std::string, int> counter;
94 CHECK_FAIL_RETURN_UNEXPECTED_MR(ValidateFieldName(current_category_field_),
95 "Invalid data, field: " + current_category_field_ + "is invalid.");
96 std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
97 ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";
98
99 for (auto &db : database_paths_) {
100 std::vector<std::vector<std::string>> field_count;
101
102 char *errmsg = nullptr;
103 if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
104 std::ostringstream oss;
105 oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
106 sqlite3_free(errmsg);
107 sqlite3_close(db);
108 db = nullptr;
109 RETURN_STATUS_UNEXPECTED_MR(oss.str());
110 } else {
111 MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_count.size()) << " records from index.";
112 }
113
114 for (const auto &field : field_count) {
115 counter[field[0]] += std::stoi(field[1]);
116 }
117 sqlite3_free(errmsg);
118 }
119
120 int idx = 0;
121 (*category_info_ptr)->resize(counter.size());
122 (void)std::transform(
123 counter.begin(), counter.end(), (*category_info_ptr)->begin(),
124 [&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); });
125 return Status::OK();
126 }
127
ToJsonForCategory(const CATEGORY_INFO & tri_vec)128 std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) {
129 std::vector<json> category_json_vec;
130 for (auto q : tri_vec) {
131 json j;
132 j["id"] = std::get<0>(q);
133 j["name"] = std::get<1>(q);
134 j["count"] = std::get<2>(q);
135
136 category_json_vec.emplace_back(j);
137 }
138
139 json j_vec(category_json_vec);
140 json category_info;
141 category_info["key"] = current_category_field_;
142 category_info["categories"] = j_vec;
143 return category_info.dump();
144 }
145
ReadAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * page_ptr)146 Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
147 std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) {
148 RETURN_UNEXPECTED_IF_NULL_MR(page_ptr);
149 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
150 RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
151 CHECK_FAIL_RETURN_UNEXPECTED_MR(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
152 "Invalid data, category_id: " + std::to_string(category_id) +
153 " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
154 int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
155 // Quit if category not found or page number is out of range
156 CHECK_FAIL_RETURN_UNEXPECTED_MR(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
157 page_no * n_rows_of_page < total_rows_in_category,
158 "Invalid data, page no: " + std::to_string(page_no) +
159 "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
160
161 auto row_group_summary = ReadRowGroupSummary();
162
163 uint64_t i_start = page_no * n_rows_of_page;
164 uint64_t i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
165 uint64_t idx = 0;
166 for (const auto &rg : row_group_summary) {
167 if (idx >= i_end) {
168 break;
169 }
170
171 auto shard_id = std::get<0>(rg);
172 auto group_id = std::get<1>(rg);
173 std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
174 RETURN_IF_NOT_OK_MR(ReadRowGroupCriteria(
175 group_id, shard_id,
176 std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
177 &row_group_brief_ptr));
178 auto offsets = std::get<3>(*row_group_brief_ptr);
179 uint64_t number_of_rows = offsets.size();
180 if (idx + number_of_rows < i_start) {
181 idx += number_of_rows;
182 continue;
183 }
184
185 for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) {
186 if (idx >= i_start && idx < i_end) {
187 auto images_ptr = std::make_shared<std::vector<uint8_t>>();
188 RETURN_IF_NOT_OK_MR(PackImages(group_id, shard_id, offsets[i], &images_ptr));
189 (*page_ptr)->push_back(std::move(*images_ptr));
190 }
191 }
192 }
193
194 return Status::OK();
195 }
196
PackImages(int group_id,int shard_id,std::vector<uint64_t> offset,std::shared_ptr<std::vector<uint8_t>> * images_ptr)197 Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
198 std::shared_ptr<std::vector<uint8_t>> *images_ptr) {
199 RETURN_UNEXPECTED_IF_NULL_MR(images_ptr);
200 std::shared_ptr<Page> page_ptr;
201 RETURN_IF_NOT_OK_MR(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
202 // Pack image list
203 (*images_ptr)->resize(offset[1] - offset[0]);
204
205 auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0];
206 auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
207 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
208 file_streams_random_[0][shard_id]->close();
209 RETURN_STATUS_UNEXPECTED_MR("Failed to seekg file.");
210 }
211
212 auto &io_read =
213 file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]);
214 if (!io_read.good() || io_read.fail() || io_read.bad()) {
215 file_streams_random_[0][shard_id]->close();
216 RETURN_STATUS_UNEXPECTED_MR("Failed to read file.");
217 }
218 return Status::OK();
219 }
220
ReadAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<std::vector<std::vector<uint8_t>>> * pages_ptr)221 Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
222 std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) {
223 RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
224 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
225 RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
226 for (const auto &categories : *category_info_ptr) {
227 if (std::get<1>(categories) == category_name) {
228 RETURN_IF_NOT_OK_MR(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr));
229 return Status::OK();
230 }
231 }
232
233 RETURN_STATUS_UNEXPECTED_MR("category_name: " + category_name + " could not found.");
234 }
235
ReadAllAtPageById(int64_t category_id,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES_WITH_BLOBS> * pages_ptr)236 Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
237 std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr) {
238 RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
239 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
240 RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
241 CHECK_FAIL_RETURN_UNEXPECTED_MR(category_id < static_cast<int64_t>(category_info_ptr->size()),
242 "Invalid data, category_id: " + std::to_string(category_id) +
243 " must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
244
245 int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
246 // Quit if category not found or page number is out of range
247 CHECK_FAIL_RETURN_UNEXPECTED_MR(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
248 page_no * n_rows_of_page < total_rows_in_category,
249 "Invalid data, page no: " + std::to_string(page_no) +
250 "or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
251 auto row_group_summary = ReadRowGroupSummary();
252
253 int i_start = page_no * n_rows_of_page;
254 int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
255 int idx = 0;
256 for (const auto &rg : row_group_summary) {
257 if (idx >= i_end) {
258 break;
259 }
260
261 auto shard_id = std::get<0>(rg);
262 auto group_id = std::get<1>(rg);
263 std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
264 RETURN_IF_NOT_OK_MR(ReadRowGroupCriteria(
265 group_id, shard_id,
266 std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
267 &row_group_brief_ptr));
268 auto offsets = std::get<3>(*row_group_brief_ptr);
269 auto labels = std::get<4>(*row_group_brief_ptr);
270
271 int number_of_rows = offsets.size();
272 if (idx + number_of_rows < i_start) {
273 idx += number_of_rows;
274 continue;
275 }
276 CHECK_FAIL_RETURN_UNEXPECTED_MR(number_of_rows <= static_cast<int>(labels.size()),
277 "Invalid data, number_of_rows: " + std::to_string(number_of_rows) + " is invalid.");
278 std::map<std::string, std::vector<uint8_t>> key_with_blob_fields;
279 for (int i = 0; i < number_of_rows; ++i, ++idx) {
280 if (idx >= i_start && idx < i_end) {
281 auto images_ptr = std::make_shared<std::vector<uint8_t>>();
282 RETURN_IF_NOT_OK_MR(PackImages(group_id, shard_id, offsets[i], &images_ptr));
283
284 // extract every blob field from blob data
285 auto shard_column = GetShardColumn();
286 auto schema = shard_header_->GetSchemas(); // current, we only support 1 schema yet
287 auto blob_fields = schema[0]->GetBlobFields();
288 for (auto blob_field : blob_fields) {
289 const unsigned char *data = nullptr;
290 std::unique_ptr<unsigned char[]> data_ptr;
291 uint64_t n_bytes = 0;
292 mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType;
293 uint64_t column_data_type_size = 1;
294 std::vector<int64_t> column_shape;
295 RETURN_IF_NOT_OK_MR(shard_column->GetColumnValueByName(blob_field, *images_ptr, labels[i], &data, &data_ptr,
296 &n_bytes, &column_data_type, &column_data_type_size,
297 &column_shape));
298 key_with_blob_fields[blob_field] = std::vector<uint8_t>(data, data + n_bytes);
299 }
300
301 (*pages_ptr)->emplace_back(std::move(key_with_blob_fields), std::move(labels[i]));
302 }
303 }
304 }
305 return Status::OK();
306 }
307
ReadAllAtPageByName(std::string category_name,int64_t page_no,int64_t n_rows_of_page,std::shared_ptr<PAGES_WITH_BLOBS> * pages_ptr)308 Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
309 std::shared_ptr<PAGES_WITH_BLOBS> *pages_ptr) {
310 RETURN_UNEXPECTED_IF_NULL_MR(pages_ptr);
311 auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
312 RETURN_IF_NOT_OK_MR(WrapCategoryInfo(&category_info_ptr));
313 // category_name to category_id
314 int64_t category_id = -1;
315 for (const auto &categories : *category_info_ptr) {
316 std::string categories_name = std::get<1>(categories);
317
318 if (categories_name == category_name) {
319 category_id = std::get<0>(categories);
320 break;
321 }
322 }
323 CHECK_FAIL_RETURN_UNEXPECTED_MR(category_id != -1, "category_name: " + category_name + " could not found.");
324 return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
325 }
326
GetBlobFields()327 std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
328 std::vector<std::string> blob_fields;
329 auto schema_list = GetShardHeader()->GetSchemas();
330 if (!schema_list.empty()) {
331 const auto &fields = schema_list[0]->GetBlobFields();
332 blob_fields.assign(fields.begin(), fields.end());
333 }
334 return std::make_pair(kCV, blob_fields);
335 }
336
CleanUp(std::string field_name)337 std::string ShardSegment::CleanUp(std::string field_name) {
338 while (field_name.back() >= '0' && field_name.back() <= '9') {
339 field_name.pop_back();
340 }
341 field_name.pop_back();
342 return field_name;
343 }
344 } // namespace mindrecord
345 } // namespace mindspore
346