• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_header.h"
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "utils/file_utils.h"
26 #include "utils/ms_utils.h"
27 #include "minddata/mindrecord/include/shard_error.h"
28 #include "minddata/mindrecord/include/shard_page.h"
29 
30 using mindspore::LogStream;
31 using mindspore::ExceptionType::NoExceptionType;
32 using mindspore::MsLogLevel::ERROR;
33 
34 namespace mindspore {
35 namespace mindrecord {
36 std::atomic<bool> thread_status(false);
ShardHeader()37 ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
38   index_ = std::make_shared<Index>();
39 }
40 
InitializeHeader(const std::vector<json> & headers,bool load_dataset)41 Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
42   shard_count_ = headers.size();
43   int shard_index = 0;
44   bool first = true;
45   for (const auto &header : headers) {
46     if (first) {
47       first = false;
48       RETURN_IF_NOT_OK(ParseSchema(header["schema"]));
49       RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"]));
50       RETURN_IF_NOT_OK(ParseStatistics(header["statistics"]));
51       ParseShardAddress(header["shard_addresses"]);
52       header_size_ = header["header_size"].get<uint64_t>();
53       page_size_ = header["page_size"].get<uint64_t>();
54       compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
55     }
56     RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset));
57     shard_index++;
58   }
59   return Status::OK();
60 }
61 
CheckFileStatus(const std::string & path)62 Status ShardHeader::CheckFileStatus(const std::string &path) {
63   auto realpath = FileUtils::GetRealPath(path.data());
64   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + path);
65   std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
66   CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file, file path: " + path);
67   // fetch file size
68   auto &io_seekg = fin.seekg(0, std::ios::end);
69   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
70     fin.close();
71     RETURN_STATUS_UNEXPECTED("Failed to seekg file, file path: " + path);
72   }
73 
74   size_t file_size = fin.tellg();
75   if (file_size < kMinFileSize) {
76     fin.close();
77     RETURN_STATUS_UNEXPECTED("Invalid file content, file " + path + " size is smaller than the lower limit.");
78   }
79   fin.close();
80   return Status::OK();
81 }
82 
ValidateHeader(const std::string & path,std::shared_ptr<json> * header_ptr)83 Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) {
84   RETURN_UNEXPECTED_IF_NULL(header_ptr);
85   RETURN_IF_NOT_OK(CheckFileStatus(path));
86   // read header size
87   json json_header;
88   std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary);
89   CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file, file path: " + path);
90 
91   uint64_t header_size = 0;
92   auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
93   if (!io_read.good() || io_read.fail() || io_read.bad()) {
94     fin.close();
95     RETURN_STATUS_UNEXPECTED("Failed to read file, file path: " + path);
96   }
97 
98   if (header_size > kMaxHeaderSize) {
99     fin.close();
100     RETURN_STATUS_UNEXPECTED("Invalid file content, incorrect file or file header is exceeds the upper limit.");
101   }
102 
103   // read header content
104   std::vector<uint8_t> header_content(header_size);
105   auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
106   if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
107     fin.close();
108     RETURN_STATUS_UNEXPECTED("Failed to read file, file path: " + path);
109   }
110 
111   fin.close();
112   std::string raw_header_content = std::string(header_content.begin(), header_content.end());
113   // parse json content
114   try {
115     json_header = json::parse(raw_header_content);
116   } catch (json::parse_error &e) {
117     RETURN_STATUS_UNEXPECTED("Json parse failed: " + std::string(e.what()));
118   }
119   *header_ptr = std::make_shared<json>(json_header);
120   return Status::OK();
121 }
122 
BuildSingleHeader(const std::string & file_path,std::shared_ptr<json> * header_ptr)123 Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) {
124   RETURN_UNEXPECTED_IF_NULL(header_ptr);
125   std::shared_ptr<json> raw_header;
126   RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header));
127   uint64_t compression_size =
128     raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0;
129   json header = {{"shard_addresses", (*raw_header)["shard_addresses"]},
130                  {"header_size", (*raw_header)["header_size"]},
131                  {"page_size", (*raw_header)["page_size"]},
132                  {"compression_size", compression_size},
133                  {"index_fields", (*raw_header)["index_fields"]},
134                  {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]},
135                  {"schema", (*raw_header)["schema"][0]["schema"]},
136                  {"version", (*raw_header)["version"]}};
137   *header_ptr = std::make_shared<json>(header);
138   return Status::OK();
139 }
140 
BuildDataset(const std::vector<std::string> & file_paths,bool load_dataset)141 Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
142   uint32_t thread_num = std::thread::hardware_concurrency();
143   if (thread_num == 0) {
144     thread_num = kThreadNumber;
145   }
146   uint32_t work_thread_num = 0;
147   uint32_t shard_count = file_paths.size();
148   int group_num = ceil(shard_count * 1.0 / thread_num);
149   std::vector<std::thread> thread_set(thread_num);
150   std::vector<json> headers(shard_count);
151   for (uint32_t x = 0; x < thread_num; ++x) {
152     int start_num = x * group_num;
153     int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
154     if (start_num >= end_num) {
155       continue;
156     }
157 
158     thread_set[x] =
159       std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
160     work_thread_num++;
161   }
162 
163   for (uint32_t x = 0; x < work_thread_num; ++x) {
164     thread_set[x].join();
165   }
166   if (thread_status) {
167     thread_status = false;
168     RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread.");
169   }
170   RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset));
171   return Status::OK();
172 }
173 
GetHeadersOneTask(int start,int end,std::vector<json> & headers,const vector<string> & realAddresses)174 void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers,
175                                     const vector<string> &realAddresses) {
176   if (thread_status || end > realAddresses.size()) {
177     return;
178   }
179   for (int x = start; x < end; ++x) {
180     std::shared_ptr<json> header;
181     auto status = ValidateHeader(realAddresses[x], &header);
182     if (status.IsError()) {
183       thread_status = true;
184       return;
185     }
186     (*header)["shard_addresses"] = realAddresses;
187     if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
188         kSupportedVersion.end()) {
189       MS_LOG(ERROR) << "Invalid version, file version " << (*header)["version"].dump() << " can not match lib version "
190                     << kVersion << ".";
191       thread_status = true;
192       return;
193     }
194     headers[x] = *header;
195   }
196 }
197 
InitByFiles(const std::vector<std::string> & file_paths)198 Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
199   std::vector<std::string> file_names(file_paths.size());
200   std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string {
201     std::shared_ptr<std::string> fn;
202     return GetFileName(fp, &fn).IsOk() ? *fn : "";
203   });
204 
205   shard_addresses_ = std::move(file_names);
206   shard_count_ = file_paths.size();
207   CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
208                                "Invalid input, The number of MindRecord files " + std::to_string(shard_count_) +
209                                  "is not int range (0, " + std::to_string(kMaxShardCount) + "].");
210   pages_.resize(shard_count_);
211   return Status::OK();
212 }
213 
ParseIndexFields(const json & index_fields)214 Status ShardHeader::ParseIndexFields(const json &index_fields) {
215   std::vector<std::pair<uint64_t, std::string>> parsed_index_fields;
216   for (auto &index_field : index_fields) {
217     auto schema_id = index_field["schema_id"].get<uint64_t>();
218     std::string field_name = index_field["index_field"].get<std::string>();
219     std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name);
220     parsed_index_fields.push_back(parsed_index_field);
221   }
222   RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields));
223   return Status::OK();
224 }
225 
ParsePage(const json & pages,int shard_index,bool load_dataset)226 Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
227   // set shard_index when load_dataset is false
228   CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "Invalid input, The number of MindRecord files " +
229                                                                 std::to_string(shard_count_) + "is not int range (0, " +
230                                                                 std::to_string(kMaxFileCount) + "].");
231   if (pages_.empty()) {
232     pages_.resize(shard_count_);
233   }
234 
235   for (auto &page : pages) {
236     int page_id = page["page_id"];
237     int shard_id = page["shard_id"];
238     std::string page_type = page["page_type"];
239     int page_type_id = page["page_type_id"];
240     auto start_row_id = page["start_row_id"].get<uint64_t>();
241     auto end_row_id = page["end_row_id"].get<uint64_t>();
242 
243     std::vector<std::pair<int, uint64_t>> row_group_ids(page["row_group_ids"].size());
244     std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(),
245                    [](json rg) { return std::make_pair(rg["id"], rg["offset"].get<uint64_t>()); });
246 
247     auto page_size = page["page_size"].get<uint64_t>();
248 
249     std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
250                                                                end_row_id, row_group_ids, page_size);
251     if (load_dataset == true) {
252       pages_[shard_id].push_back(std::move(parsed_page));
253     } else {
254       pages_[shard_index].push_back(std::move(parsed_page));
255     }
256   }
257   return Status::OK();
258 }
259 
ParseStatistics(const json & statistics)260 Status ShardHeader::ParseStatistics(const json &statistics) {
261   for (auto &statistic : statistics) {
262     CHECK_FAIL_RETURN_UNEXPECTED(
263       statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
264       "Failed to deserialize statistics, statistic info: " + statistics.dump());
265     std::string statistic_description = statistic["desc"].get<std::string>();
266     json statistic_body = statistic["statistics"];
267     std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
268     RETURN_UNEXPECTED_IF_NULL(parsed_statistic);
269     AddStatistic(parsed_statistic);
270   }
271   return Status::OK();
272 }
273 
ParseSchema(const json & schemas)274 Status ShardHeader::ParseSchema(const json &schemas) {
275   for (auto &schema : schemas) {
276     // change how we get schemaBody once design is finalized
277     CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
278                                    schema.find("schema") != schema.end(),
279                                  "Failed to deserialize schema, schema info: " + schema.dump());
280     std::string schema_description = schema["desc"].get<std::string>();
281     std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
282     json schema_body = schema["schema"];
283     std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body);
284     RETURN_UNEXPECTED_IF_NULL(parsed_schema);
285     AddSchema(parsed_schema);
286   }
287   return Status::OK();
288 }
289 
ParseShardAddress(const json & address)290 void ShardHeader::ParseShardAddress(const json &address) {
291   std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_));
292 }
293 
SerializeHeader()294 std::vector<std::string> ShardHeader::SerializeHeader() {
295   std::vector<std::string> header;
296   auto index = SerializeIndexFields();
297   auto stats = SerializeStatistics();
298   auto schema = SerializeSchema();
299   auto pages = SerializePage();
300   auto address = SerializeShardAddress();
301   if (shard_count_ > static_cast<int>(pages.size())) {
302     return std::vector<string>{};
303   }
304   if (shard_count_ <= kMaxShardCount) {
305     for (int shardId = 0; shardId < shard_count_; shardId++) {
306       string s;
307       s += "{\"header_size\":" + std::to_string(header_size_) + ",";
308       s += "\"index_fields\":" + index + ",";
309       s += "\"page\":" + pages[shardId] + ",";
310       s += "\"page_size\":" + std::to_string(page_size_) + ",";
311       s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
312       s += "\"schema\":" + schema + ",";
313       s += "\"shard_addresses\":" + address + ",";
314       s += "\"shard_id\":" + std::to_string(shardId) + ",";
315       s += "\"statistics\":" + stats + ",";
316       s += "\"version\":\"" + std::string(kVersion) + "\"";
317       s += "}";
318       header.emplace_back(s);
319     }
320   }
321   return header;
322 }
323 
SerializeIndexFields()324 std::string ShardHeader::SerializeIndexFields() {
325   json j;
326   auto fields = index_->GetFields();
327   (void)std::transform(fields.begin(), fields.end(), std::back_inserter(j),
328                        [](const std::pair<uint64_t, std::string> &field) -> json {
329                          return {{"schema_id", field.first}, {"index_field", field.second}};
330                        });
331   return j.dump();
332 }
333 
SerializePage()334 std::vector<std::string> ShardHeader::SerializePage() {
335   std::vector<string> pages;
336   for (auto &shard_pages : pages_) {
337     json j;
338     (void)std::transform(shard_pages.begin(), shard_pages.end(), std::back_inserter(j),
339                          [](const std::shared_ptr<Page> &p) { return p->GetPage(); });
340     pages.emplace_back(j.dump());
341   }
342   return pages;
343 }
344 
SerializeStatistics()345 std::string ShardHeader::SerializeStatistics() {
346   json j;
347   (void)std::transform(statistics_.begin(), statistics_.end(), std::back_inserter(j),
348                        [](const std::shared_ptr<Statistics> &stats) { return stats->GetStatistics(); });
349   return j.dump();
350 }
351 
SerializeSchema()352 std::string ShardHeader::SerializeSchema() {
353   json j;
354   (void)std::transform(schema_.begin(), schema_.end(), std::back_inserter(j),
355                        [](const std::shared_ptr<Schema> &schema) { return schema->GetSchema(); });
356   return j.dump();
357 }
358 
SerializeShardAddress()359 std::string ShardHeader::SerializeShardAddress() {
360   json j;
361   std::shared_ptr<std::string> fn_ptr;
362   for (const auto &addr : shard_addresses_) {
363     (void)GetFileName(addr, &fn_ptr);
364     j.emplace_back(*fn_ptr);
365   }
366   return j.dump();
367 }
368 
GetPage(const int & shard_id,const int & page_id,std::shared_ptr<Page> * page_ptr)369 Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) {
370   RETURN_UNEXPECTED_IF_NULL(page_ptr);
371   if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
372     *page_ptr = pages_[shard_id][page_id];
373     return Status::OK();
374   }
375   page_ptr = nullptr;
376   RETURN_STATUS_UNEXPECTED("Failed to get Page, 'page_id': " + std::to_string(page_id));
377 }
378 
SetPage(const std::shared_ptr<Page> & new_page)379 Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
380   int shard_id = new_page->GetShardID();
381   int page_id = new_page->GetPageID();
382   if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
383     pages_[shard_id][page_id] = new_page;
384     return Status::OK();
385   }
386   RETURN_STATUS_UNEXPECTED("Failed to set Page, 'page_id': " + std::to_string(page_id));
387 }
388 
AddPage(const std::shared_ptr<Page> & new_page)389 Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
390   int shard_id = new_page->GetShardID();
391   int page_id = new_page->GetPageID();
392   if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
393     pages_[shard_id].push_back(new_page);
394     return Status::OK();
395   }
396   RETURN_STATUS_UNEXPECTED("Failed to add Page, 'page_id': " + std::to_string(page_id));
397 }
398 
GetLastPageId(const int & shard_id)399 int64_t ShardHeader::GetLastPageId(const int &shard_id) {
400   if (shard_id >= static_cast<int>(pages_.size())) {
401     return 0;
402   }
403   return pages_[shard_id].size() - 1;
404 }
405 
GetLastPageIdByType(const int & shard_id,const std::string & page_type)406 int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) {
407   if (shard_id >= static_cast<int>(pages_.size())) {
408     return 0;
409   }
410   int last_page_id = -1;
411   for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
412     if (pages_[shard_id][i - 1]->GetPageType() == page_type) {
413       last_page_id = pages_[shard_id][i - 1]->GetPageID();
414       return last_page_id;
415     }
416   }
417   return last_page_id;
418 }
419 
GetPageByGroupId(const int & group_id,const int & shard_id,std::shared_ptr<Page> * page_ptr)420 Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) {
421   RETURN_UNEXPECTED_IF_NULL(page_ptr);
422   CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast<int>(pages_.size()), "Shard id is more than sum of shards.");
423   for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
424     auto page = pages_[shard_id][i - 1];
425     if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
426       *page_ptr = std::make_shared<Page>(*page);
427       return Status::OK();
428     }
429   }
430   page_ptr = nullptr;
431   RETURN_STATUS_UNEXPECTED("Failed to get Page, 'group_id': " + std::to_string(group_id));
432 }
433 
AddSchema(std::shared_ptr<Schema> schema)434 int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
435   if (schema == nullptr) {
436     MS_LOG(ERROR) << "The pointer of schema is null.";
437     return -1;
438   }
439 
440   if (!schema_.empty()) {
441     MS_LOG(ERROR) << "The schema can not be added twice.";
442     return -1;
443   }
444 
445   int64_t schema_id = schema->GetSchemaID();
446   if (schema_id == -1) {
447     schema_id = schema_.size();
448     schema->SetSchemaID(schema_id);
449   }
450   schema_.push_back(schema);
451   return schema_id;
452 }
453 
AddStatistic(std::shared_ptr<Statistics> statistic)454 void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) {
455   if (statistic) {
456     int64_t statistics_id = statistic->GetStatisticsID();
457     if (statistics_id == -1) {
458       statistics_id = statistics_.size();
459       statistic->SetStatisticsID(statistics_id);
460     }
461     statistics_.push_back(statistic);
462   }
463 }
464 
InitIndexPtr()465 std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
466   std::shared_ptr<Index> index = index_;
467   if (!index_) {
468     index = std::make_shared<Index>();
469     index_ = index;
470   }
471   return index;
472 }
473 
CheckIndexField(const std::string & field,const json & schema)474 Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
475   // check field name is or is not valid
476   CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
477                                "Invalid input, field [" + field + "] can not found in schema.");
478   CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes",
479                                "Invalid input, byte type field [" + field + "] can not set as an index field.");
480   CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
481                                "Invalid input, array type field [" + field + "] can not set as an index field.");
482   return Status::OK();
483 }
484 
AddIndexFields(const std::vector<std::string> & fields)485 Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
486   if (fields.empty()) {
487     return Status::OK();
488   }
489   CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Invalid data, schema is empty.");
490   // create index Object
491   std::shared_ptr<Index> index = InitIndexPtr();
492   for (const auto &schemaPtr : schema_) {
493     std::shared_ptr<Schema> schema_ptr;
494     RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr));
495     json schema = schema_ptr->GetSchema().at("schema");
496     // checkout and add fields for each schema
497     std::set<std::string> field_set;
498     for (const auto &item : index->GetFields()) {
499       field_set.insert(item.second);
500     }
501     for (const auto &field : fields) {
502       CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(),
503                                    "Invalid data, the same index field [" + field + "] can not added twice.");
504       // check field name is or is not valid
505       RETURN_IF_NOT_OK(CheckIndexField(field, schema));
506       field_set.insert(field);
507       // add field into index
508       index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
509     }
510   }
511   index_ = index;
512   return Status::OK();
513 }
514 
GetAllSchemaID(std::set<uint64_t> & bucket_count)515 Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
516   // get all schema id
517   for (const auto &schema : schema_) {
518     auto schema_id = schema->GetSchemaID();
519     CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) == bucket_count.end(),
520                                  "Invalid data, duplicate schema exist, schema id: " + std::to_string(schema_id));
521     bucket_count.insert(schema_id);
522   }
523   return Status::OK();
524 }
525 
AddIndexFields(std::vector<std::pair<uint64_t,std::string>> fields)526 Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
527   if (fields.empty()) {
528     return Status::OK();
529   }
530   // create index Object
531   std::shared_ptr<Index> index = InitIndexPtr();
532   // get all schema id
533   std::set<uint64_t> bucket_count;
534   RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count));
535   // check and add fields for each schema
536   std::set<std::pair<uint64_t, std::string>> field_set;
537   for (const auto &item : index->GetFields()) {
538     field_set.insert(item);
539   }
540   for (const auto &field : fields) {
541     CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(),
542                                  "Invalid data, the same index field [" + field.second + "] can not added twice.");
543     uint64_t schema_id = field.first;
544     std::string field_name = field.second;
545 
546     // check schemaId is or is not valid
547     CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(),
548                                  "Invalid data, schema id [" + std::to_string(schema_id) + "] is invalid.");
549     // check field name is or is not valid
550     std::shared_ptr<Schema> schema_ptr;
551     RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr));
552     json schema = schema_ptr->GetSchema().at("schema");
553     CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(),
554                                  "Invalid data, field [" + field_name + "] is not found in schema.");
555     RETURN_IF_NOT_OK(CheckIndexField(field_name, schema));
556     field_set.insert(field);
557     // add field into index
558     index->AddIndexField(schema_id, field_name);
559   }
560   index_ = index;
561   return Status::OK();
562 }
563 
GetShardAddressByID(int64_t shard_id)564 std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
565   if (shard_id >= shard_addresses_.size()) {
566     return "";
567   }
568   return shard_addresses_.at(shard_id);
569 }
570 
GetSchemas()571 std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; }
572 
GetStatistics()573 std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; }
574 
GetFields()575 std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); }
576 
GetIndex()577 std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
578 
GetSchemaByID(int64_t schema_id,std::shared_ptr<Schema> * schema_ptr)579 Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
580   RETURN_UNEXPECTED_IF_NULL(schema_ptr);
581   int64_t schema_size = schema_.size();
582   CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schema_size,
583                                "Invalid data, schema id [" + std::to_string(schema_id) + "] is not in range [0, " +
584                                  std::to_string(schema_size) + ").");
585   *schema_ptr = schema_.at(schema_id);
586   return Status::OK();
587 }
588 
GetStatisticByID(int64_t statistic_id,std::shared_ptr<Statistics> * statistics_ptr)589 Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
590   RETURN_UNEXPECTED_IF_NULL(statistics_ptr);
591   int64_t statistics_size = statistics_.size();
592   CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size,
593                                "Invalid data, statistic id [" + std::to_string(statistic_id) +
594                                  "] is not in range [0, " + std::to_string(statistics_size) + ").");
595   *statistics_ptr = statistics_.at(statistic_id);
596   return Status::OK();
597 }
598 
PagesToFile(const std::string dump_file_name)599 Status ShardHeader::PagesToFile(const std::string dump_file_name) {
600   auto realpath = FileUtils::GetRealPath(dump_file_name.data());
601   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + dump_file_name);
602   // write header content to file, dump whatever is in the file before
603   std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
604   CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file, path: " + dump_file_name);
605   auto pages = SerializePage();
606   for (const auto &shard_pages : pages) {
607     page_out_handle << shard_pages << "\n";
608   }
609   page_out_handle.close();
610   return Status::OK();
611 }
612 
FileToPages(const std::string dump_file_name)613 Status ShardHeader::FileToPages(const std::string dump_file_name) {
614   for (auto &v : pages_) {  // clean pages
615     v.clear();
616   }
617   auto realpath = FileUtils::GetRealPath(dump_file_name.data());
618   CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + dump_file_name);
619   // attempt to open the file contains the page in json
620   std::ifstream page_in_handle(realpath.value());
621   CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(),
622                                "Invalid file, page file does not exist, path: " + dump_file_name);
623   std::string line;
624   while (std::getline(page_in_handle, line)) {
625     RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true));
626   }
627   page_in_handle.close();
628   return Status::OK();
629 }
630 
Initialize(const std::shared_ptr<ShardHeader> * header_ptr,const json & schema,const std::vector<std::string> & index_fields,std::vector<std::string> & blob_fields,uint64_t & schema_id)631 Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
632                                const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
633                                uint64_t &schema_id) {
634   RETURN_UNEXPECTED_IF_NULL(header_ptr);
635   auto schema_ptr = Schema::Build("mindrecord", schema);
636   CHECK_FAIL_RETURN_UNEXPECTED(schema_ptr != nullptr, "Failed to build schema: " + schema.dump() + ".");
637   schema_id = (*header_ptr)->AddSchema(schema_ptr);
638   // create index
639   std::vector<std::pair<uint64_t, std::string>> id_index_fields;
640   if (!index_fields.empty()) {
641     (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
642                     [schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
643     RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields));
644   }
645 
646   auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
647   blob_fields = build_schema_ptr->GetBlobFields();
648   return Status::OK();
649 }
650 }  // namespace mindrecord
651 }  // namespace mindspore
652