1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_header.h"
18
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "utils/file_utils.h"
26 #include "utils/ms_utils.h"
27 #include "minddata/mindrecord/include/shard_error.h"
28 #include "minddata/mindrecord/include/shard_page.h"
29
30 namespace mindspore {
31 namespace mindrecord {
32 std::atomic<bool> thread_status(false);
ShardHeader()33 ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
34 index_ = std::make_shared<Index>();
35 }
36
InitializeHeader(const std::vector<json> & headers,bool load_dataset)37 Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
38 shard_count_ = headers.size();
39 int shard_index = 0;
40 bool first = true;
41 for (const auto &header : headers) {
42 if (first) {
43 first = false;
44 RETURN_IF_NOT_OK_MR(ParseSchema(header["schema"]));
45 RETURN_IF_NOT_OK_MR(ParseIndexFields(header["index_fields"]));
46 RETURN_IF_NOT_OK_MR(ParseStatistics(header["statistics"]));
47 ParseShardAddress(header["shard_addresses"]);
48 header_size_ = header["header_size"].get<uint64_t>();
49 page_size_ = header["page_size"].get<uint64_t>();
50 compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
51 }
52 RETURN_IF_NOT_OK_MR(ParsePage(header["page"], shard_index, load_dataset));
53 shard_index++;
54 }
55 return Status::OK();
56 }
57
CheckFileStatus(const std::string & path)58 Status ShardHeader::CheckFileStatus(const std::string &path) {
59 auto realpath = FileUtils::GetRealPath(path.c_str());
60 CHECK_FAIL_RETURN_UNEXPECTED_MR(
61 realpath.has_value(),
62 "Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + path);
63 std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
64 CHECK_FAIL_RETURN_UNEXPECTED_MR(fin.is_open(),
65 "Invalid file, failed to open files for loading mindrecord files. Please check file "
66 "path, permission and open file limit: " +
67 path);
68 // fetch file size
69 auto &io_seekg = fin.seekg(0, std::ios::end);
70 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
71 fin.close();
72 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] failed to seekg file, file path: " + path);
73 }
74
75 size_t file_size = fin.tellg();
76 if (file_size < kMinFileSize) {
77 fin.close();
78 RETURN_STATUS_UNEXPECTED_MR("Invalid file, the size of mindrecord file: " + std::to_string(file_size) +
79 " is smaller than the lower limit: " + std::to_string(kMinFileSize) +
80 ".\n Please check file path: " + path +
81 " and use 'FileWriter' to generate valid mindrecord files.");
82 }
83 fin.close();
84 return Status::OK();
85 }
86
ValidateHeader(const std::string & path,std::shared_ptr<json> * header_ptr)87 Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) {
88 RETURN_UNEXPECTED_IF_NULL_MR(header_ptr);
89 RETURN_IF_NOT_OK_MR(CheckFileStatus(path));
90
91 auto realpath = FileUtils::GetRealPath(path.c_str());
92 CHECK_FAIL_RETURN_UNEXPECTED_MR(
93 realpath.has_value(),
94 "Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + path);
95
96 // read header size
97 json json_header;
98 std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
99 CHECK_FAIL_RETURN_UNEXPECTED_MR(fin.is_open(),
100 "Invalid file, failed to open files for loading mindrecord files. Please check file "
101 "path, permission and open file limit: " +
102 path);
103
104 uint64_t header_size = 0;
105 auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
106 if (!io_read.good() || io_read.fail() || io_read.bad()) {
107 fin.close();
108 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] failed to read file, file path: " + path);
109 }
110
111 if (header_size > kMaxHeaderSize) {
112 fin.close();
113 RETURN_STATUS_UNEXPECTED_MR(
114 "Invalid file, the size of mindrecord file header is larger than the upper limit. \n"
115 "The invalid mindrecord file is [" +
116 path + "]. \nPlease use 'FileWriter' to generate valid mindrecord files.");
117 }
118
119 // read header content
120 std::vector<uint8_t> header_content(header_size);
121 auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
122 if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
123 fin.close();
124 RETURN_STATUS_UNEXPECTED_MR("Invalid file, failed to read header content of file: " + path +
125 ", please check correction of MindRecord File");
126 }
127
128 fin.close();
129 std::string raw_header_content = std::string(header_content.begin(), header_content.end());
130 // parse json content
131 try {
132 json_header = json::parse(raw_header_content);
133 } catch (json::parse_error &e) {
134 RETURN_STATUS_UNEXPECTED_MR("Invalid file, failed to parse header content of file: " + path +
135 ", please check correction of MindRecord File");
136 }
137 *header_ptr = std::make_shared<json>(json_header);
138 return Status::OK();
139 }
140
BuildSingleHeader(const std::string & file_path,std::shared_ptr<json> * header_ptr)141 Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) {
142 RETURN_UNEXPECTED_IF_NULL_MR(header_ptr);
143 std::shared_ptr<json> raw_header;
144 RETURN_IF_NOT_OK_MR(ValidateHeader(file_path, &raw_header));
145 uint64_t compression_size =
146 raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0;
147 json header = {{"shard_addresses", (*raw_header)["shard_addresses"]},
148 {"header_size", (*raw_header)["header_size"]},
149 {"page_size", (*raw_header)["page_size"]},
150 {"compression_size", compression_size},
151 {"index_fields", (*raw_header)["index_fields"]},
152 {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]},
153 {"schema", (*raw_header)["schema"][0]["schema"]},
154 {"version", (*raw_header)["version"]}};
155 *header_ptr = std::make_shared<json>(header);
156 return Status::OK();
157 }
158
BuildDataset(const std::vector<std::string> & file_paths,bool load_dataset)159 Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
160 uint32_t thread_num = std::thread::hardware_concurrency();
161 if (thread_num == 0) {
162 thread_num = kThreadNumber;
163 }
164 uint32_t work_thread_num = 0;
165 uint32_t shard_count = file_paths.size();
166 int group_num = ceil(shard_count * 1.0 / thread_num);
167 std::vector<std::thread> thread_set(thread_num);
168 std::vector<json> headers(shard_count);
169 for (uint32_t x = 0; x < thread_num; ++x) {
170 int start_num = x * group_num;
171 int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
172 if (start_num >= end_num) {
173 continue;
174 }
175
176 thread_set[x] =
177 std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
178 work_thread_num++;
179 }
180
181 for (uint32_t x = 0; x < work_thread_num; ++x) {
182 thread_set[x].join();
183 }
184 if (thread_status) {
185 thread_status = false;
186 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Error raised in GetHeadersOneTask function.");
187 }
188 RETURN_IF_NOT_OK_MR(InitializeHeader(headers, load_dataset));
189 return Status::OK();
190 }
191
GetHeadersOneTask(int start,int end,std::vector<json> & headers,const vector<string> & realAddresses)192 void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers,
193 const vector<string> &realAddresses) {
194 if (thread_status || end > realAddresses.size()) {
195 return;
196 }
197 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
198 pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(start) + ":" + std::to_string(end)).c_str());
199 #endif
200 for (int x = start; x < end; ++x) {
201 std::shared_ptr<json> header;
202 auto status = ValidateHeader(realAddresses[x], &header);
203 if (status.IsError()) {
204 thread_status = true;
205 return;
206 }
207 (*header)["shard_addresses"] = realAddresses;
208 if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
209 kSupportedVersion.end()) {
210 MS_LOG(ERROR) << "Invalid file, the version of mindrecord files" << (*header)["version"].dump()
211 << " is not supported.\nPlease use 'FileWriter' to generate valid mindrecord files.";
212 thread_status = true;
213 return;
214 }
215 headers[x] = *header;
216 }
217 }
218
InitByFiles(const std::vector<std::string> & file_paths)219 Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
220 std::vector<std::string> file_names(file_paths.size());
221 std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string {
222 std::shared_ptr<std::string> fn;
223 return GetFileName(fp, &fn).IsOk() ? *fn : "";
224 });
225
226 shard_addresses_ = std::move(file_names);
227 shard_count_ = file_paths.size();
228 CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
229 "[Internal ERROR] 'shard_count_': " + std::to_string(shard_count_) +
230 "is not in range (0, " + std::to_string(kMaxShardCount) + "].");
231 pages_.resize(shard_count_);
232 return Status::OK();
233 }
234
ParseIndexFields(const json & index_fields)235 Status ShardHeader::ParseIndexFields(const json &index_fields) {
236 std::vector<std::pair<uint64_t, std::string>> parsed_index_fields;
237 for (auto &index_field : index_fields) {
238 auto schema_id = index_field["schema_id"].get<uint64_t>();
239 std::string field_name = index_field["index_field"].get<std::string>();
240 std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name);
241 parsed_index_fields.push_back(parsed_index_field);
242 }
243 RETURN_IF_NOT_OK_MR(AddIndexFields(parsed_index_fields));
244 return Status::OK();
245 }
246
ParsePage(const json & pages,int shard_index,bool load_dataset)247 Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
248 // set shard_index when load_dataset is false
249 if (pages_.empty()) {
250 pages_.resize(shard_count_);
251 }
252
253 for (auto &page : pages) {
254 int page_id = page["page_id"];
255 int shard_id = page["shard_id"];
256 std::string page_type = page["page_type"];
257 int page_type_id = page["page_type_id"];
258 auto start_row_id = page["start_row_id"].get<uint64_t>();
259 auto end_row_id = page["end_row_id"].get<uint64_t>();
260
261 std::vector<std::pair<int, uint64_t>> row_group_ids(page["row_group_ids"].size());
262 std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(),
263 [](json rg) { return std::make_pair(rg["id"], rg["offset"].get<uint64_t>()); });
264
265 auto page_size = page["page_size"].get<uint64_t>();
266
267 std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
268 end_row_id, row_group_ids, page_size);
269 if (load_dataset == true) {
270 pages_[shard_id].push_back(std::move(parsed_page));
271 } else {
272 pages_[shard_index].push_back(std::move(parsed_page));
273 }
274 }
275 return Status::OK();
276 }
277
ParseStatistics(const json & statistics)278 Status ShardHeader::ParseStatistics(const json &statistics) {
279 for (auto &statistic : statistics) {
280 CHECK_FAIL_RETURN_UNEXPECTED_MR(
281 statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
282 "[Internal ERROR] Failed to deserialize statistics: " + statistics.dump());
283 std::string statistic_description = statistic["desc"].get<std::string>();
284 json statistic_body = statistic["statistics"];
285 std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
286 RETURN_UNEXPECTED_IF_NULL_MR(parsed_statistic);
287 AddStatistic(parsed_statistic);
288 }
289 return Status::OK();
290 }
291
ParseSchema(const json & schemas)292 Status ShardHeader::ParseSchema(const json &schemas) {
293 for (auto &schema : schemas) {
294 // change how we get schemaBody once design is finalized
295 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
296 schema.find("schema") != schema.end(),
297 "[Internal ERROR] Failed to deserialize schema: " + schema.dump());
298 std::string schema_description = schema["desc"].get<std::string>();
299 std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
300 json schema_body = schema["schema"];
301 std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body);
302 CHECK_FAIL_RETURN_UNEXPECTED_MR(
303 parsed_schema != nullptr,
304 "[Internal ERROR] Failed to build schema, Please check the [ERROR] logs before for more details.");
305 AddSchema(parsed_schema);
306 }
307 return Status::OK();
308 }
309
ParseShardAddress(const json & address)310 void ShardHeader::ParseShardAddress(const json &address) {
311 (void)std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_));
312 }
313
SerializeHeader()314 std::vector<std::string> ShardHeader::SerializeHeader() {
315 std::vector<std::string> header;
316 auto index = SerializeIndexFields();
317 auto stats = SerializeStatistics();
318 auto schema = SerializeSchema();
319 auto pages = SerializePage();
320 auto address = SerializeShardAddress();
321 if (shard_count_ > static_cast<int>(pages.size())) {
322 return std::vector<string>{};
323 }
324 if (shard_count_ <= kMaxShardCount) {
325 for (int shardId = 0; shardId < shard_count_; shardId++) {
326 string s;
327 s += "{\"header_size\":" + std::to_string(header_size_) + ",";
328 s += "\"index_fields\":" + index + ",";
329 s += "\"page\":" + pages[shardId] + ",";
330 s += "\"page_size\":" + std::to_string(page_size_) + ",";
331 s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
332 s += "\"schema\":" + schema + ",";
333 s += "\"shard_addresses\":" + address + ",";
334 s += "\"shard_id\":" + std::to_string(shardId) + ",";
335 s += "\"statistics\":" + stats + ",";
336 s += "\"version\":\"" + std::string(kVersion) + "\"";
337 s += "}";
338 header.emplace_back(s);
339 }
340 }
341 return header;
342 }
343
SerializeIndexFields()344 std::string ShardHeader::SerializeIndexFields() {
345 json j;
346 auto fields = index_->GetFields();
347 (void)std::transform(fields.begin(), fields.end(), std::back_inserter(j),
348 [](const std::pair<uint64_t, std::string> &field) -> json {
349 return {{"schema_id", field.first}, {"index_field", field.second}};
350 });
351 return j.dump();
352 }
353
SerializePage()354 std::vector<std::string> ShardHeader::SerializePage() {
355 std::vector<string> pages;
356 for (auto &shard_pages : pages_) {
357 json j;
358 (void)std::transform(shard_pages.begin(), shard_pages.end(), std::back_inserter(j),
359 [](const std::shared_ptr<Page> &p) { return p->GetPage(); });
360 pages.emplace_back(j.dump());
361 }
362 return pages;
363 }
364
SerializeStatistics()365 std::string ShardHeader::SerializeStatistics() {
366 json j;
367 (void)std::transform(statistics_.begin(), statistics_.end(), std::back_inserter(j),
368 [](const std::shared_ptr<Statistics> &stats) { return stats->GetStatistics(); });
369 return j.dump();
370 }
371
SerializeSchema()372 std::string ShardHeader::SerializeSchema() {
373 json j;
374 (void)std::transform(schema_.begin(), schema_.end(), std::back_inserter(j),
375 [](const std::shared_ptr<Schema> &schema) { return schema->GetSchema(); });
376 return j.dump();
377 }
378
SerializeShardAddress()379 std::string ShardHeader::SerializeShardAddress() {
380 json j;
381 std::shared_ptr<std::string> fn_ptr;
382 for (const auto &addr : shard_addresses_) {
383 (void)GetFileName(addr, &fn_ptr);
384 (void)j.emplace_back(*fn_ptr);
385 }
386 return j.dump();
387 }
388
GetPage(const int & shard_id,const int & page_id,std::shared_ptr<Page> * page_ptr)389 Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) {
390 RETURN_UNEXPECTED_IF_NULL_MR(page_ptr);
391 if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
392 *page_ptr = pages_[shard_id][page_id];
393 return Status::OK();
394 }
395 page_ptr = nullptr;
396 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to get Page, 'page_id': " + std::to_string(page_id));
397 }
398
SetPage(const std::shared_ptr<Page> & new_page)399 Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
400 int shard_id = new_page->GetShardID();
401 int page_id = new_page->GetPageID();
402 if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
403 pages_[shard_id][page_id] = new_page;
404 return Status::OK();
405 }
406 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to set Page, 'page_id': " + std::to_string(page_id));
407 }
408
AddPage(const std::shared_ptr<Page> & new_page)409 Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
410 int shard_id = new_page->GetShardID();
411 int page_id = new_page->GetPageID();
412 if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
413 pages_[shard_id].push_back(new_page);
414 return Status::OK();
415 }
416 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to add Page, 'page_id': " + std::to_string(page_id));
417 }
418
GetLastPageId(const int & shard_id)419 int64_t ShardHeader::GetLastPageId(const int &shard_id) {
420 if (shard_id >= static_cast<int>(pages_.size())) {
421 return 0;
422 }
423 return pages_[shard_id].size() - 1;
424 }
425
GetLastPageIdByType(const int & shard_id,const std::string & page_type)426 int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) {
427 if (shard_id >= static_cast<int>(pages_.size())) {
428 return 0;
429 }
430 int last_page_id = -1;
431 for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
432 if (pages_[shard_id][i - 1]->GetPageType() == page_type) {
433 last_page_id = pages_[shard_id][i - 1]->GetPageID();
434 return last_page_id;
435 }
436 }
437 return last_page_id;
438 }
439
GetPageByGroupId(const int & group_id,const int & shard_id,std::shared_ptr<Page> * page_ptr)440 Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) {
441 RETURN_UNEXPECTED_IF_NULL_MR(page_ptr);
442 CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_id < static_cast<int>(pages_.size()),
443 "[Internal ERROR] 'shard_id': " + std::to_string(shard_id) +
444 " should be smaller than the size of 'pages_': " + std::to_string(pages_.size()) +
445 ".");
446 for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
447 auto page = pages_[shard_id][i - 1];
448 if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
449 *page_ptr = std::make_shared<Page>(*page);
450 return Status::OK();
451 }
452 }
453 page_ptr = nullptr;
454 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to get Page, 'group_id': " + std::to_string(group_id));
455 }
456
AddSchema(std::shared_ptr<Schema> schema)457 int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
458 if (schema == nullptr) {
459 MS_LOG(ERROR) << "[Internal ERROR] The pointer of schema is NULL.";
460 return -1;
461 }
462
463 if (!schema_.empty()) {
464 MS_LOG(ERROR) << "The schema is added repeatedly. Please remove the redundant 'add_schema' function.";
465 return -1;
466 }
467
468 int64_t schema_id = schema->GetSchemaID();
469 if (schema_id == -1) {
470 schema_id = schema_.size();
471 schema->SetSchemaID(schema_id);
472 }
473 schema_.push_back(schema);
474 return schema_id;
475 }
476
AddStatistic(std::shared_ptr<Statistics> statistic)477 void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) {
478 if (statistic) {
479 int64_t statistics_id = statistic->GetStatisticsID();
480 if (statistics_id == -1) {
481 statistics_id = statistics_.size();
482 statistic->SetStatisticsID(statistics_id);
483 }
484 statistics_.push_back(statistic);
485 }
486 }
487
InitIndexPtr()488 std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
489 std::shared_ptr<Index> index = index_;
490 if (!index_) {
491 index = std::make_shared<Index>();
492 index_ = index;
493 }
494 return index;
495 }
496
CheckIndexField(const std::string & field,const json & schema)497 Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
498 // check field name is or is not valid
499 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(field) != schema.end(),
500 "Invalid input, 'index_fields': " + field +
501 " can not found in schema: " + schema.dump() +
502 ".\n Please use 'add_index' function to add proper 'index_fields'.");
503 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema[field]["type"] != "Bytes",
504 "Invalid input, type of 'index_fields': " + field +
505 " is bytes and can not set as an 'index_fields'.\n Please use 'add_index' function "
506 "to add the other 'index_fields'.");
507 CHECK_FAIL_RETURN_UNEXPECTED_MR(
508 schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
509 "Invalid input, type of 'index_fields': " + field +
510 " is array and can not set as an 'index_fields'.\n Please use 'add_index' function "
511 "to add the other 'index_fields'.");
512 return Status::OK();
513 }
514
AddIndexFields(const std::vector<std::string> & fields)515 Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
516 if (fields.empty()) {
517 return Status::OK();
518 }
519 CHECK_FAIL_RETURN_UNEXPECTED_MR(
520 !GetSchemas().empty(), "Invalid data, schema is empty. Please use 'add_schema' function to add schema first.");
521 // create index Object
522 std::shared_ptr<Index> index = InitIndexPtr();
523 for (const auto &schemaPtr : schema_) {
524 std::shared_ptr<Schema> schema_ptr;
525 RETURN_IF_NOT_OK_MR(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr));
526 json schema = schema_ptr->GetSchema().at("schema");
527 // checkout and add fields for each schema
528 std::set<std::string> field_set;
529 for (const auto &item : index->GetFields()) {
530 field_set.insert(item.second);
531 }
532 for (const auto &field : fields) {
533 CHECK_FAIL_RETURN_UNEXPECTED_MR(
534 field_set.find(field) == field_set.end(),
535 "The 'index_fields': " + field + " is added repeatedly. Please remove the redundant 'add_index' function.");
536 // check field name is or is not valid
537 RETURN_IF_NOT_OK_MR(CheckIndexField(field, schema));
538 field_set.insert(field);
539 // add field into index
540 index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
541 }
542 }
543 index_ = index;
544 return Status::OK();
545 }
546
GetAllSchemaID(std::set<uint64_t> & bucket_count)547 Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
548 // get all schema id
549 for (const auto &schema : schema_) {
550 auto schema_id = schema->GetSchemaID();
551 CHECK_FAIL_RETURN_UNEXPECTED_MR(bucket_count.find(schema_id) == bucket_count.end(),
552 "[Internal ERROR] duplicate schema exist, schema id: " + std::to_string(schema_id));
553 bucket_count.insert(schema_id);
554 }
555 return Status::OK();
556 }
557
AddIndexFields(std::vector<std::pair<uint64_t,std::string>> fields)558 Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
559 if (fields.empty()) {
560 return Status::OK();
561 }
562 // create index Object
563 std::shared_ptr<Index> index = InitIndexPtr();
564 // get all schema id
565 std::set<uint64_t> bucket_count;
566 RETURN_IF_NOT_OK_MR(GetAllSchemaID(bucket_count));
567 // check and add fields for each schema
568 std::set<std::pair<uint64_t, std::string>> field_set;
569 for (const auto &item : index->GetFields()) {
570 field_set.insert(item);
571 }
572 for (const auto &field : fields) {
573 CHECK_FAIL_RETURN_UNEXPECTED_MR(field_set.find(field) == field_set.end(),
574 "The 'index_fields': " + field.second +
575 " is added repeatedly. Please remove the redundant 'add_index' function.");
576 uint64_t schema_id = field.first;
577 std::string field_name = field.second;
578
579 // check schemaId is or is not valid
580 CHECK_FAIL_RETURN_UNEXPECTED_MR(bucket_count.find(schema_id) != bucket_count.end(),
581 "[Internal ERROR] 'schema_id': " + std::to_string(schema_id) + " can not found.");
582 // check field name is or is not valid
583 std::shared_ptr<Schema> schema_ptr;
584 RETURN_IF_NOT_OK_MR(GetSchemaByID(schema_id, &schema_ptr));
585 json schema = schema_ptr->GetSchema().at("schema");
586 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(field_name) != schema.end(),
587 "Invalid input, 'index_fields': " + field_name +
588 " can not found in schema: " + schema.dump() +
589 ".\n Please use 'add_index' function to add proper 'index_fields'.");
590 RETURN_IF_NOT_OK_MR(CheckIndexField(field_name, schema));
591 field_set.insert(field);
592 // add field into index
593 index->AddIndexField(schema_id, field_name);
594 }
595 index_ = index;
596 return Status::OK();
597 }
598
GetShardAddressByID(int64_t shard_id)599 std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
600 if (shard_id >= shard_addresses_.size()) {
601 return "";
602 }
603 return shard_addresses_.at(shard_id);
604 }
605
GetSchemas()606 std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; }
607
GetStatistics()608 std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; }
609
GetFields()610 std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); }
611
GetIndex()612 std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
613
GetSchemaByID(int64_t schema_id,std::shared_ptr<Schema> * schema_ptr)614 Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
615 RETURN_UNEXPECTED_IF_NULL_MR(schema_ptr);
616 int64_t schema_size = schema_.size();
617 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_id >= 0 && schema_id < schema_size,
618 "[Internal ERROR] 'schema_id': " + std::to_string(schema_id) +
619 " is not in range [0, " + std::to_string(schema_size) + ").");
620 *schema_ptr = schema_.at(schema_id);
621 return Status::OK();
622 }
623
GetStatisticByID(int64_t statistic_id,std::shared_ptr<Statistics> * statistics_ptr)624 Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
625 RETURN_UNEXPECTED_IF_NULL_MR(statistics_ptr);
626 int64_t statistics_size = statistics_.size();
627 CHECK_FAIL_RETURN_UNEXPECTED_MR(statistic_id >= 0 && statistic_id < statistics_size,
628 "[Internal ERROR] 'statistic_id': " + std::to_string(statistic_id) +
629 " is not in range [0, " + std::to_string(statistics_size) + ").");
630 *statistics_ptr = statistics_.at(statistic_id);
631 return Status::OK();
632 }
633
PagesToFile(const std::string dump_file_name)634 Status ShardHeader::PagesToFile(const std::string dump_file_name) {
635 auto realpath = FileUtils::GetRealPath(dump_file_name.c_str());
636 CHECK_FAIL_RETURN_UNEXPECTED_MR(realpath.has_value(),
637 "[Internal ERROR] Failed to get the realpath of Pages file, path: " + dump_file_name);
638 // write header content to file, dump whatever is in the file before
639 std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
640 CHECK_FAIL_RETURN_UNEXPECTED_MR(page_out_handle.good(),
641 "[Internal ERROR] Failed to open Pages file, path: " + dump_file_name);
642 auto pages = SerializePage();
643 for (const auto &shard_pages : pages) {
644 page_out_handle << shard_pages << "\n";
645 }
646 page_out_handle.close();
647
648 ChangeFileMode(realpath.value(), S_IRUSR | S_IWUSR);
649 return Status::OK();
650 }
651
FileToPages(const std::string dump_file_name)652 Status ShardHeader::FileToPages(const std::string dump_file_name) {
653 for (auto &v : pages_) { // clean pages
654 v.clear();
655 }
656 auto realpath = FileUtils::GetRealPath(dump_file_name.c_str());
657 CHECK_FAIL_RETURN_UNEXPECTED_MR(realpath.has_value(),
658 "[Internal ERROR] Failed to get the realpath of Pages file, path: " + dump_file_name);
659 // attempt to open the file contains the page in json
660 std::ifstream page_in_handle(realpath.value(), std::ios::in);
661 CHECK_FAIL_RETURN_UNEXPECTED_MR(page_in_handle.good(),
662 "[Internal ERROR] Pages file does not exist, path: " + dump_file_name);
663 std::string line;
664 while (std::getline(page_in_handle, line)) {
665 auto s = ParsePage(json::parse(line), -1, true);
666 if (s != Status::OK()) {
667 page_in_handle.close();
668 return s;
669 }
670 }
671 page_in_handle.close();
672 return Status::OK();
673 }
674
Initialize(const std::shared_ptr<ShardHeader> * header_ptr,const json & schema,const std::vector<std::string> & index_fields,std::vector<std::string> & blob_fields,uint64_t & schema_id)675 Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
676 const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
677 uint64_t &schema_id) {
678 RETURN_UNEXPECTED_IF_NULL_MR(header_ptr);
679 auto schema_ptr = Schema::Build("mindrecord", schema);
680 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_ptr != nullptr, "[Internal ERROR] Failed to build schema: " + schema.dump() +
681 "." + "Check the [ERROR] logs before for more details.");
682 schema_id = (*header_ptr)->AddSchema(schema_ptr);
683 // create index
684 std::vector<std::pair<uint64_t, std::string>> id_index_fields;
685 if (!index_fields.empty()) {
686 (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
687 [schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
688 RETURN_IF_NOT_OK_MR((*header_ptr)->AddIndexFields(id_index_fields));
689 }
690
691 auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
692 blob_fields = build_schema_ptr->GetBlobFields();
693 return Status::OK();
694 }
695 } // namespace mindrecord
696 } // namespace mindspore
697