1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_header.h"
18
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "utils/file_utils.h"
26 #include "utils/ms_utils.h"
27 #include "minddata/mindrecord/include/shard_error.h"
28 #include "minddata/mindrecord/include/shard_page.h"
29
30 using mindspore::LogStream;
31 using mindspore::ExceptionType::NoExceptionType;
32 using mindspore::MsLogLevel::ERROR;
33
34 namespace mindspore {
35 namespace mindrecord {
36 std::atomic<bool> thread_status(false);
ShardHeader()37 ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
38 index_ = std::make_shared<Index>();
39 }
40
InitializeHeader(const std::vector<json> & headers,bool load_dataset)41 Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
42 shard_count_ = headers.size();
43 int shard_index = 0;
44 bool first = true;
45 for (const auto &header : headers) {
46 if (first) {
47 first = false;
48 RETURN_IF_NOT_OK(ParseSchema(header["schema"]));
49 RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"]));
50 RETURN_IF_NOT_OK(ParseStatistics(header["statistics"]));
51 ParseShardAddress(header["shard_addresses"]);
52 header_size_ = header["header_size"].get<uint64_t>();
53 page_size_ = header["page_size"].get<uint64_t>();
54 compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
55 }
56 RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset));
57 shard_index++;
58 }
59 return Status::OK();
60 }
61
CheckFileStatus(const std::string & path)62 Status ShardHeader::CheckFileStatus(const std::string &path) {
63 auto realpath = FileUtils::GetRealPath(path.data());
64 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + path);
65 std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
66 CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file, file path: " + path);
67 // fetch file size
68 auto &io_seekg = fin.seekg(0, std::ios::end);
69 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
70 fin.close();
71 RETURN_STATUS_UNEXPECTED("Failed to seekg file, file path: " + path);
72 }
73
74 size_t file_size = fin.tellg();
75 if (file_size < kMinFileSize) {
76 fin.close();
77 RETURN_STATUS_UNEXPECTED("Invalid file content, file " + path + " size is smaller than the lower limit.");
78 }
79 fin.close();
80 return Status::OK();
81 }
82
ValidateHeader(const std::string & path,std::shared_ptr<json> * header_ptr)83 Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) {
84 RETURN_UNEXPECTED_IF_NULL(header_ptr);
85 RETURN_IF_NOT_OK(CheckFileStatus(path));
86 // read header size
87 json json_header;
88 std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary);
89 CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file, file path: " + path);
90
91 uint64_t header_size = 0;
92 auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
93 if (!io_read.good() || io_read.fail() || io_read.bad()) {
94 fin.close();
95 RETURN_STATUS_UNEXPECTED("Failed to read file, file path: " + path);
96 }
97
98 if (header_size > kMaxHeaderSize) {
99 fin.close();
100 RETURN_STATUS_UNEXPECTED("Invalid file content, incorrect file or file header is exceeds the upper limit.");
101 }
102
103 // read header content
104 std::vector<uint8_t> header_content(header_size);
105 auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
106 if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
107 fin.close();
108 RETURN_STATUS_UNEXPECTED("Failed to read file, file path: " + path);
109 }
110
111 fin.close();
112 std::string raw_header_content = std::string(header_content.begin(), header_content.end());
113 // parse json content
114 try {
115 json_header = json::parse(raw_header_content);
116 } catch (json::parse_error &e) {
117 RETURN_STATUS_UNEXPECTED("Json parse failed: " + std::string(e.what()));
118 }
119 *header_ptr = std::make_shared<json>(json_header);
120 return Status::OK();
121 }
122
BuildSingleHeader(const std::string & file_path,std::shared_ptr<json> * header_ptr)123 Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) {
124 RETURN_UNEXPECTED_IF_NULL(header_ptr);
125 std::shared_ptr<json> raw_header;
126 RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header));
127 uint64_t compression_size =
128 raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0;
129 json header = {{"shard_addresses", (*raw_header)["shard_addresses"]},
130 {"header_size", (*raw_header)["header_size"]},
131 {"page_size", (*raw_header)["page_size"]},
132 {"compression_size", compression_size},
133 {"index_fields", (*raw_header)["index_fields"]},
134 {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]},
135 {"schema", (*raw_header)["schema"][0]["schema"]},
136 {"version", (*raw_header)["version"]}};
137 *header_ptr = std::make_shared<json>(header);
138 return Status::OK();
139 }
140
BuildDataset(const std::vector<std::string> & file_paths,bool load_dataset)141 Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
142 uint32_t thread_num = std::thread::hardware_concurrency();
143 if (thread_num == 0) {
144 thread_num = kThreadNumber;
145 }
146 uint32_t work_thread_num = 0;
147 uint32_t shard_count = file_paths.size();
148 int group_num = ceil(shard_count * 1.0 / thread_num);
149 std::vector<std::thread> thread_set(thread_num);
150 std::vector<json> headers(shard_count);
151 for (uint32_t x = 0; x < thread_num; ++x) {
152 int start_num = x * group_num;
153 int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
154 if (start_num >= end_num) {
155 continue;
156 }
157
158 thread_set[x] =
159 std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
160 work_thread_num++;
161 }
162
163 for (uint32_t x = 0; x < work_thread_num; ++x) {
164 thread_set[x].join();
165 }
166 if (thread_status) {
167 thread_status = false;
168 RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread.");
169 }
170 RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset));
171 return Status::OK();
172 }
173
GetHeadersOneTask(int start,int end,std::vector<json> & headers,const vector<string> & realAddresses)174 void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers,
175 const vector<string> &realAddresses) {
176 if (thread_status || end > realAddresses.size()) {
177 return;
178 }
179 for (int x = start; x < end; ++x) {
180 std::shared_ptr<json> header;
181 auto status = ValidateHeader(realAddresses[x], &header);
182 if (status.IsError()) {
183 thread_status = true;
184 return;
185 }
186 (*header)["shard_addresses"] = realAddresses;
187 if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
188 kSupportedVersion.end()) {
189 MS_LOG(ERROR) << "Invalid version, file version " << (*header)["version"].dump() << " can not match lib version "
190 << kVersion << ".";
191 thread_status = true;
192 return;
193 }
194 headers[x] = *header;
195 }
196 }
197
InitByFiles(const std::vector<std::string> & file_paths)198 Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
199 std::vector<std::string> file_names(file_paths.size());
200 std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string {
201 std::shared_ptr<std::string> fn;
202 return GetFileName(fp, &fn).IsOk() ? *fn : "";
203 });
204
205 shard_addresses_ = std::move(file_names);
206 shard_count_ = file_paths.size();
207 CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
208 "Invalid input, The number of MindRecord files " + std::to_string(shard_count_) +
209 "is not int range (0, " + std::to_string(kMaxShardCount) + "].");
210 pages_.resize(shard_count_);
211 return Status::OK();
212 }
213
ParseIndexFields(const json & index_fields)214 Status ShardHeader::ParseIndexFields(const json &index_fields) {
215 std::vector<std::pair<uint64_t, std::string>> parsed_index_fields;
216 for (auto &index_field : index_fields) {
217 auto schema_id = index_field["schema_id"].get<uint64_t>();
218 std::string field_name = index_field["index_field"].get<std::string>();
219 std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name);
220 parsed_index_fields.push_back(parsed_index_field);
221 }
222 RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields));
223 return Status::OK();
224 }
225
ParsePage(const json & pages,int shard_index,bool load_dataset)226 Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
227 // set shard_index when load_dataset is false
228 CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "Invalid input, The number of MindRecord files " +
229 std::to_string(shard_count_) + "is not int range (0, " +
230 std::to_string(kMaxFileCount) + "].");
231 if (pages_.empty()) {
232 pages_.resize(shard_count_);
233 }
234
235 for (auto &page : pages) {
236 int page_id = page["page_id"];
237 int shard_id = page["shard_id"];
238 std::string page_type = page["page_type"];
239 int page_type_id = page["page_type_id"];
240 auto start_row_id = page["start_row_id"].get<uint64_t>();
241 auto end_row_id = page["end_row_id"].get<uint64_t>();
242
243 std::vector<std::pair<int, uint64_t>> row_group_ids(page["row_group_ids"].size());
244 std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(),
245 [](json rg) { return std::make_pair(rg["id"], rg["offset"].get<uint64_t>()); });
246
247 auto page_size = page["page_size"].get<uint64_t>();
248
249 std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
250 end_row_id, row_group_ids, page_size);
251 if (load_dataset == true) {
252 pages_[shard_id].push_back(std::move(parsed_page));
253 } else {
254 pages_[shard_index].push_back(std::move(parsed_page));
255 }
256 }
257 return Status::OK();
258 }
259
ParseStatistics(const json & statistics)260 Status ShardHeader::ParseStatistics(const json &statistics) {
261 for (auto &statistic : statistics) {
262 CHECK_FAIL_RETURN_UNEXPECTED(
263 statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
264 "Failed to deserialize statistics, statistic info: " + statistics.dump());
265 std::string statistic_description = statistic["desc"].get<std::string>();
266 json statistic_body = statistic["statistics"];
267 std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
268 RETURN_UNEXPECTED_IF_NULL(parsed_statistic);
269 AddStatistic(parsed_statistic);
270 }
271 return Status::OK();
272 }
273
ParseSchema(const json & schemas)274 Status ShardHeader::ParseSchema(const json &schemas) {
275 for (auto &schema : schemas) {
276 // change how we get schemaBody once design is finalized
277 CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
278 schema.find("schema") != schema.end(),
279 "Failed to deserialize schema, schema info: " + schema.dump());
280 std::string schema_description = schema["desc"].get<std::string>();
281 std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
282 json schema_body = schema["schema"];
283 std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body);
284 RETURN_UNEXPECTED_IF_NULL(parsed_schema);
285 AddSchema(parsed_schema);
286 }
287 return Status::OK();
288 }
289
ParseShardAddress(const json & address)290 void ShardHeader::ParseShardAddress(const json &address) {
291 std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_));
292 }
293
SerializeHeader()294 std::vector<std::string> ShardHeader::SerializeHeader() {
295 std::vector<std::string> header;
296 auto index = SerializeIndexFields();
297 auto stats = SerializeStatistics();
298 auto schema = SerializeSchema();
299 auto pages = SerializePage();
300 auto address = SerializeShardAddress();
301 if (shard_count_ > static_cast<int>(pages.size())) {
302 return std::vector<string>{};
303 }
304 if (shard_count_ <= kMaxShardCount) {
305 for (int shardId = 0; shardId < shard_count_; shardId++) {
306 string s;
307 s += "{\"header_size\":" + std::to_string(header_size_) + ",";
308 s += "\"index_fields\":" + index + ",";
309 s += "\"page\":" + pages[shardId] + ",";
310 s += "\"page_size\":" + std::to_string(page_size_) + ",";
311 s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
312 s += "\"schema\":" + schema + ",";
313 s += "\"shard_addresses\":" + address + ",";
314 s += "\"shard_id\":" + std::to_string(shardId) + ",";
315 s += "\"statistics\":" + stats + ",";
316 s += "\"version\":\"" + std::string(kVersion) + "\"";
317 s += "}";
318 header.emplace_back(s);
319 }
320 }
321 return header;
322 }
323
SerializeIndexFields()324 std::string ShardHeader::SerializeIndexFields() {
325 json j;
326 auto fields = index_->GetFields();
327 (void)std::transform(fields.begin(), fields.end(), std::back_inserter(j),
328 [](const std::pair<uint64_t, std::string> &field) -> json {
329 return {{"schema_id", field.first}, {"index_field", field.second}};
330 });
331 return j.dump();
332 }
333
SerializePage()334 std::vector<std::string> ShardHeader::SerializePage() {
335 std::vector<string> pages;
336 for (auto &shard_pages : pages_) {
337 json j;
338 (void)std::transform(shard_pages.begin(), shard_pages.end(), std::back_inserter(j),
339 [](const std::shared_ptr<Page> &p) { return p->GetPage(); });
340 pages.emplace_back(j.dump());
341 }
342 return pages;
343 }
344
SerializeStatistics()345 std::string ShardHeader::SerializeStatistics() {
346 json j;
347 (void)std::transform(statistics_.begin(), statistics_.end(), std::back_inserter(j),
348 [](const std::shared_ptr<Statistics> &stats) { return stats->GetStatistics(); });
349 return j.dump();
350 }
351
SerializeSchema()352 std::string ShardHeader::SerializeSchema() {
353 json j;
354 (void)std::transform(schema_.begin(), schema_.end(), std::back_inserter(j),
355 [](const std::shared_ptr<Schema> &schema) { return schema->GetSchema(); });
356 return j.dump();
357 }
358
SerializeShardAddress()359 std::string ShardHeader::SerializeShardAddress() {
360 json j;
361 std::shared_ptr<std::string> fn_ptr;
362 for (const auto &addr : shard_addresses_) {
363 (void)GetFileName(addr, &fn_ptr);
364 j.emplace_back(*fn_ptr);
365 }
366 return j.dump();
367 }
368
GetPage(const int & shard_id,const int & page_id,std::shared_ptr<Page> * page_ptr)369 Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) {
370 RETURN_UNEXPECTED_IF_NULL(page_ptr);
371 if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
372 *page_ptr = pages_[shard_id][page_id];
373 return Status::OK();
374 }
375 page_ptr = nullptr;
376 RETURN_STATUS_UNEXPECTED("Failed to get Page, 'page_id': " + std::to_string(page_id));
377 }
378
SetPage(const std::shared_ptr<Page> & new_page)379 Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
380 int shard_id = new_page->GetShardID();
381 int page_id = new_page->GetPageID();
382 if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
383 pages_[shard_id][page_id] = new_page;
384 return Status::OK();
385 }
386 RETURN_STATUS_UNEXPECTED("Failed to set Page, 'page_id': " + std::to_string(page_id));
387 }
388
AddPage(const std::shared_ptr<Page> & new_page)389 Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
390 int shard_id = new_page->GetShardID();
391 int page_id = new_page->GetPageID();
392 if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
393 pages_[shard_id].push_back(new_page);
394 return Status::OK();
395 }
396 RETURN_STATUS_UNEXPECTED("Failed to add Page, 'page_id': " + std::to_string(page_id));
397 }
398
GetLastPageId(const int & shard_id)399 int64_t ShardHeader::GetLastPageId(const int &shard_id) {
400 if (shard_id >= static_cast<int>(pages_.size())) {
401 return 0;
402 }
403 return pages_[shard_id].size() - 1;
404 }
405
GetLastPageIdByType(const int & shard_id,const std::string & page_type)406 int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) {
407 if (shard_id >= static_cast<int>(pages_.size())) {
408 return 0;
409 }
410 int last_page_id = -1;
411 for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
412 if (pages_[shard_id][i - 1]->GetPageType() == page_type) {
413 last_page_id = pages_[shard_id][i - 1]->GetPageID();
414 return last_page_id;
415 }
416 }
417 return last_page_id;
418 }
419
GetPageByGroupId(const int & group_id,const int & shard_id,std::shared_ptr<Page> * page_ptr)420 Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) {
421 RETURN_UNEXPECTED_IF_NULL(page_ptr);
422 CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast<int>(pages_.size()), "Shard id is more than sum of shards.");
423 for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
424 auto page = pages_[shard_id][i - 1];
425 if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
426 *page_ptr = std::make_shared<Page>(*page);
427 return Status::OK();
428 }
429 }
430 page_ptr = nullptr;
431 RETURN_STATUS_UNEXPECTED("Failed to get Page, 'group_id': " + std::to_string(group_id));
432 }
433
AddSchema(std::shared_ptr<Schema> schema)434 int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
435 if (schema == nullptr) {
436 MS_LOG(ERROR) << "The pointer of schema is null.";
437 return -1;
438 }
439
440 if (!schema_.empty()) {
441 MS_LOG(ERROR) << "The schema can not be added twice.";
442 return -1;
443 }
444
445 int64_t schema_id = schema->GetSchemaID();
446 if (schema_id == -1) {
447 schema_id = schema_.size();
448 schema->SetSchemaID(schema_id);
449 }
450 schema_.push_back(schema);
451 return schema_id;
452 }
453
AddStatistic(std::shared_ptr<Statistics> statistic)454 void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) {
455 if (statistic) {
456 int64_t statistics_id = statistic->GetStatisticsID();
457 if (statistics_id == -1) {
458 statistics_id = statistics_.size();
459 statistic->SetStatisticsID(statistics_id);
460 }
461 statistics_.push_back(statistic);
462 }
463 }
464
InitIndexPtr()465 std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
466 std::shared_ptr<Index> index = index_;
467 if (!index_) {
468 index = std::make_shared<Index>();
469 index_ = index;
470 }
471 return index;
472 }
473
CheckIndexField(const std::string & field,const json & schema)474 Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
475 // check field name is or is not valid
476 CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
477 "Invalid input, field [" + field + "] can not found in schema.");
478 CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes",
479 "Invalid input, byte type field [" + field + "] can not set as an index field.");
480 CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
481 "Invalid input, array type field [" + field + "] can not set as an index field.");
482 return Status::OK();
483 }
484
AddIndexFields(const std::vector<std::string> & fields)485 Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
486 if (fields.empty()) {
487 return Status::OK();
488 }
489 CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Invalid data, schema is empty.");
490 // create index Object
491 std::shared_ptr<Index> index = InitIndexPtr();
492 for (const auto &schemaPtr : schema_) {
493 std::shared_ptr<Schema> schema_ptr;
494 RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr));
495 json schema = schema_ptr->GetSchema().at("schema");
496 // checkout and add fields for each schema
497 std::set<std::string> field_set;
498 for (const auto &item : index->GetFields()) {
499 field_set.insert(item.second);
500 }
501 for (const auto &field : fields) {
502 CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(),
503 "Invalid data, the same index field [" + field + "] can not added twice.");
504 // check field name is or is not valid
505 RETURN_IF_NOT_OK(CheckIndexField(field, schema));
506 field_set.insert(field);
507 // add field into index
508 index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
509 }
510 }
511 index_ = index;
512 return Status::OK();
513 }
514
GetAllSchemaID(std::set<uint64_t> & bucket_count)515 Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
516 // get all schema id
517 for (const auto &schema : schema_) {
518 auto schema_id = schema->GetSchemaID();
519 CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) == bucket_count.end(),
520 "Invalid data, duplicate schema exist, schema id: " + std::to_string(schema_id));
521 bucket_count.insert(schema_id);
522 }
523 return Status::OK();
524 }
525
AddIndexFields(std::vector<std::pair<uint64_t,std::string>> fields)526 Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
527 if (fields.empty()) {
528 return Status::OK();
529 }
530 // create index Object
531 std::shared_ptr<Index> index = InitIndexPtr();
532 // get all schema id
533 std::set<uint64_t> bucket_count;
534 RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count));
535 // check and add fields for each schema
536 std::set<std::pair<uint64_t, std::string>> field_set;
537 for (const auto &item : index->GetFields()) {
538 field_set.insert(item);
539 }
540 for (const auto &field : fields) {
541 CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(),
542 "Invalid data, the same index field [" + field.second + "] can not added twice.");
543 uint64_t schema_id = field.first;
544 std::string field_name = field.second;
545
546 // check schemaId is or is not valid
547 CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(),
548 "Invalid data, schema id [" + std::to_string(schema_id) + "] is invalid.");
549 // check field name is or is not valid
550 std::shared_ptr<Schema> schema_ptr;
551 RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr));
552 json schema = schema_ptr->GetSchema().at("schema");
553 CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(),
554 "Invalid data, field [" + field_name + "] is not found in schema.");
555 RETURN_IF_NOT_OK(CheckIndexField(field_name, schema));
556 field_set.insert(field);
557 // add field into index
558 index->AddIndexField(schema_id, field_name);
559 }
560 index_ = index;
561 return Status::OK();
562 }
563
GetShardAddressByID(int64_t shard_id)564 std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
565 if (shard_id >= shard_addresses_.size()) {
566 return "";
567 }
568 return shard_addresses_.at(shard_id);
569 }
570
GetSchemas()571 std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; }
572
GetStatistics()573 std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; }
574
GetFields()575 std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); }
576
GetIndex()577 std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
578
GetSchemaByID(int64_t schema_id,std::shared_ptr<Schema> * schema_ptr)579 Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
580 RETURN_UNEXPECTED_IF_NULL(schema_ptr);
581 int64_t schema_size = schema_.size();
582 CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schema_size,
583 "Invalid data, schema id [" + std::to_string(schema_id) + "] is not in range [0, " +
584 std::to_string(schema_size) + ").");
585 *schema_ptr = schema_.at(schema_id);
586 return Status::OK();
587 }
588
GetStatisticByID(int64_t statistic_id,std::shared_ptr<Statistics> * statistics_ptr)589 Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
590 RETURN_UNEXPECTED_IF_NULL(statistics_ptr);
591 int64_t statistics_size = statistics_.size();
592 CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size,
593 "Invalid data, statistic id [" + std::to_string(statistic_id) +
594 "] is not in range [0, " + std::to_string(statistics_size) + ").");
595 *statistics_ptr = statistics_.at(statistic_id);
596 return Status::OK();
597 }
598
PagesToFile(const std::string dump_file_name)599 Status ShardHeader::PagesToFile(const std::string dump_file_name) {
600 auto realpath = FileUtils::GetRealPath(dump_file_name.data());
601 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + dump_file_name);
602 // write header content to file, dump whatever is in the file before
603 std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
604 CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file, path: " + dump_file_name);
605 auto pages = SerializePage();
606 for (const auto &shard_pages : pages) {
607 page_out_handle << shard_pages << "\n";
608 }
609 page_out_handle.close();
610 return Status::OK();
611 }
612
FileToPages(const std::string dump_file_name)613 Status ShardHeader::FileToPages(const std::string dump_file_name) {
614 for (auto &v : pages_) { // clean pages
615 v.clear();
616 }
617 auto realpath = FileUtils::GetRealPath(dump_file_name.data());
618 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + dump_file_name);
619 // attempt to open the file contains the page in json
620 std::ifstream page_in_handle(realpath.value());
621 CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(),
622 "Invalid file, page file does not exist, path: " + dump_file_name);
623 std::string line;
624 while (std::getline(page_in_handle, line)) {
625 RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true));
626 }
627 page_in_handle.close();
628 return Status::OK();
629 }
630
Initialize(const std::shared_ptr<ShardHeader> * header_ptr,const json & schema,const std::vector<std::string> & index_fields,std::vector<std::string> & blob_fields,uint64_t & schema_id)631 Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
632 const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
633 uint64_t &schema_id) {
634 RETURN_UNEXPECTED_IF_NULL(header_ptr);
635 auto schema_ptr = Schema::Build("mindrecord", schema);
636 CHECK_FAIL_RETURN_UNEXPECTED(schema_ptr != nullptr, "Failed to build schema: " + schema.dump() + ".");
637 schema_id = (*header_ptr)->AddSchema(schema_ptr);
638 // create index
639 std::vector<std::pair<uint64_t, std::string>> id_index_fields;
640 if (!index_fields.empty()) {
641 (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
642 [schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
643 RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields));
644 }
645
646 auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
647 blob_fields = build_schema_ptr->GetBlobFields();
648 return Status::OK();
649 }
650 } // namespace mindrecord
651 } // namespace mindspore
652