• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_header.h"
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "utils/file_utils.h"
26 #include "utils/ms_utils.h"
27 #include "minddata/mindrecord/include/shard_error.h"
28 #include "minddata/mindrecord/include/shard_page.h"
29 
30 namespace mindspore {
31 namespace mindrecord {
32 std::atomic<bool> thread_status(false);
ShardHeader()33 ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
34   index_ = std::make_shared<Index>();
35 }
36 
InitializeHeader(const std::vector<json> & headers,bool load_dataset)37 Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
38   shard_count_ = headers.size();
39   int shard_index = 0;
40   bool first = true;
41   for (const auto &header : headers) {
42     if (first) {
43       first = false;
44       RETURN_IF_NOT_OK_MR(ParseSchema(header["schema"]));
45       RETURN_IF_NOT_OK_MR(ParseIndexFields(header["index_fields"]));
46       RETURN_IF_NOT_OK_MR(ParseStatistics(header["statistics"]));
47       ParseShardAddress(header["shard_addresses"]);
48       header_size_ = header["header_size"].get<uint64_t>();
49       page_size_ = header["page_size"].get<uint64_t>();
50       compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
51     }
52     RETURN_IF_NOT_OK_MR(ParsePage(header["page"], shard_index, load_dataset));
53     shard_index++;
54   }
55   return Status::OK();
56 }
57 
CheckFileStatus(const std::string & path)58 Status ShardHeader::CheckFileStatus(const std::string &path) {
59   auto realpath = FileUtils::GetRealPath(path.c_str());
60   CHECK_FAIL_RETURN_UNEXPECTED_MR(
61     realpath.has_value(),
62     "Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + path);
63   std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
64   CHECK_FAIL_RETURN_UNEXPECTED_MR(fin.is_open(),
65                                   "Invalid file, failed to open files for loading mindrecord files. Please check file "
66                                   "path, permission and open file limit: " +
67                                     path);
68   // fetch file size
69   auto &io_seekg = fin.seekg(0, std::ios::end);
70   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
71     fin.close();
72     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] failed to seekg file, file path: " + path);
73   }
74 
75   size_t file_size = fin.tellg();
76   if (file_size < kMinFileSize) {
77     fin.close();
78     RETURN_STATUS_UNEXPECTED_MR("Invalid file, the size of mindrecord file: " + std::to_string(file_size) +
79                                 " is smaller than the lower limit: " + std::to_string(kMinFileSize) +
80                                 ".\n Please check file path: " + path +
81                                 " and use 'FileWriter' to generate valid mindrecord files.");
82   }
83   fin.close();
84   return Status::OK();
85 }
86 
ValidateHeader(const std::string & path,std::shared_ptr<json> * header_ptr)87 Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) {
88   RETURN_UNEXPECTED_IF_NULL_MR(header_ptr);
89   RETURN_IF_NOT_OK_MR(CheckFileStatus(path));
90 
91   auto realpath = FileUtils::GetRealPath(path.c_str());
92   CHECK_FAIL_RETURN_UNEXPECTED_MR(
93     realpath.has_value(),
94     "Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + path);
95 
96   // read header size
97   json json_header;
98   std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
99   CHECK_FAIL_RETURN_UNEXPECTED_MR(fin.is_open(),
100                                   "Invalid file, failed to open files for loading mindrecord files. Please check file "
101                                   "path, permission and open file limit: " +
102                                     path);
103 
104   uint64_t header_size = 0;
105   auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
106   if (!io_read.good() || io_read.fail() || io_read.bad()) {
107     fin.close();
108     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] failed to read file, file path: " + path);
109   }
110 
111   if (header_size > kMaxHeaderSize) {
112     fin.close();
113     RETURN_STATUS_UNEXPECTED_MR(
114       "Invalid file, the size of mindrecord file header is larger than the upper limit. \n"
115       "The invalid mindrecord file is [" +
116       path + "]. \nPlease use 'FileWriter' to generate valid mindrecord files.");
117   }
118 
119   // read header content
120   std::vector<uint8_t> header_content(header_size);
121   auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
122   if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
123     fin.close();
124     RETURN_STATUS_UNEXPECTED_MR("Invalid file, failed to read header content of file: " + path +
125                                 ", please check correction of MindRecord File");
126   }
127 
128   fin.close();
129   std::string raw_header_content = std::string(header_content.begin(), header_content.end());
130   // parse json content
131   try {
132     json_header = json::parse(raw_header_content);
133   } catch (json::parse_error &e) {
134     RETURN_STATUS_UNEXPECTED_MR("Invalid file, failed to parse header content of file: " + path +
135                                 ", please check correction of MindRecord File");
136   }
137   *header_ptr = std::make_shared<json>(json_header);
138   return Status::OK();
139 }
140 
BuildSingleHeader(const std::string & file_path,std::shared_ptr<json> * header_ptr)141 Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) {
142   RETURN_UNEXPECTED_IF_NULL_MR(header_ptr);
143   std::shared_ptr<json> raw_header;
144   RETURN_IF_NOT_OK_MR(ValidateHeader(file_path, &raw_header));
145   uint64_t compression_size =
146     raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0;
147   json header = {{"shard_addresses", (*raw_header)["shard_addresses"]},
148                  {"header_size", (*raw_header)["header_size"]},
149                  {"page_size", (*raw_header)["page_size"]},
150                  {"compression_size", compression_size},
151                  {"index_fields", (*raw_header)["index_fields"]},
152                  {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]},
153                  {"schema", (*raw_header)["schema"][0]["schema"]},
154                  {"version", (*raw_header)["version"]}};
155   *header_ptr = std::make_shared<json>(header);
156   return Status::OK();
157 }
158 
BuildDataset(const std::vector<std::string> & file_paths,bool load_dataset)159 Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
160   uint32_t thread_num = std::thread::hardware_concurrency();
161   if (thread_num == 0) {
162     thread_num = kThreadNumber;
163   }
164   uint32_t work_thread_num = 0;
165   uint32_t shard_count = file_paths.size();
166   int group_num = ceil(shard_count * 1.0 / thread_num);
167   std::vector<std::thread> thread_set(thread_num);
168   std::vector<json> headers(shard_count);
169   for (uint32_t x = 0; x < thread_num; ++x) {
170     int start_num = x * group_num;
171     int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num;
172     if (start_num >= end_num) {
173       continue;
174     }
175 
176     thread_set[x] =
177       std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths);
178     work_thread_num++;
179   }
180 
181   for (uint32_t x = 0; x < work_thread_num; ++x) {
182     thread_set[x].join();
183   }
184   if (thread_status) {
185     thread_status = false;
186     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Error raised in GetHeadersOneTask function.");
187   }
188   RETURN_IF_NOT_OK_MR(InitializeHeader(headers, load_dataset));
189   return Status::OK();
190 }
191 
GetHeadersOneTask(int start,int end,std::vector<json> & headers,const vector<string> & realAddresses)192 void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers,
193                                     const vector<string> &realAddresses) {
194   if (thread_status || end > realAddresses.size()) {
195     return;
196   }
197 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
198   pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(start) + ":" + std::to_string(end)).c_str());
199 #endif
200   for (int x = start; x < end; ++x) {
201     std::shared_ptr<json> header;
202     auto status = ValidateHeader(realAddresses[x], &header);
203     if (status.IsError()) {
204       thread_status = true;
205       return;
206     }
207     (*header)["shard_addresses"] = realAddresses;
208     if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
209         kSupportedVersion.end()) {
210       MS_LOG(ERROR) << "Invalid file, the version of mindrecord files" << (*header)["version"].dump()
211                     << " is not supported.\nPlease use 'FileWriter' to generate valid mindrecord files.";
212       thread_status = true;
213       return;
214     }
215     headers[x] = *header;
216   }
217 }
218 
InitByFiles(const std::vector<std::string> & file_paths)219 Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
220   std::vector<std::string> file_names(file_paths.size());
221   std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string {
222     std::shared_ptr<std::string> fn;
223     return GetFileName(fp, &fn).IsOk() ? *fn : "";
224   });
225 
226   shard_addresses_ = std::move(file_names);
227   shard_count_ = file_paths.size();
228   CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
229                                   "[Internal ERROR] 'shard_count_': " + std::to_string(shard_count_) +
230                                     "is not in range (0, " + std::to_string(kMaxShardCount) + "].");
231   pages_.resize(shard_count_);
232   return Status::OK();
233 }
234 
ParseIndexFields(const json & index_fields)235 Status ShardHeader::ParseIndexFields(const json &index_fields) {
236   std::vector<std::pair<uint64_t, std::string>> parsed_index_fields;
237   for (auto &index_field : index_fields) {
238     auto schema_id = index_field["schema_id"].get<uint64_t>();
239     std::string field_name = index_field["index_field"].get<std::string>();
240     std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name);
241     parsed_index_fields.push_back(parsed_index_field);
242   }
243   RETURN_IF_NOT_OK_MR(AddIndexFields(parsed_index_fields));
244   return Status::OK();
245 }
246 
ParsePage(const json & pages,int shard_index,bool load_dataset)247 Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
248   // set shard_index when load_dataset is false
249   if (pages_.empty()) {
250     pages_.resize(shard_count_);
251   }
252 
253   for (auto &page : pages) {
254     int page_id = page["page_id"];
255     int shard_id = page["shard_id"];
256     std::string page_type = page["page_type"];
257     int page_type_id = page["page_type_id"];
258     auto start_row_id = page["start_row_id"].get<uint64_t>();
259     auto end_row_id = page["end_row_id"].get<uint64_t>();
260 
261     std::vector<std::pair<int, uint64_t>> row_group_ids(page["row_group_ids"].size());
262     std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(),
263                    [](json rg) { return std::make_pair(rg["id"], rg["offset"].get<uint64_t>()); });
264 
265     auto page_size = page["page_size"].get<uint64_t>();
266 
267     std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id,
268                                                                end_row_id, row_group_ids, page_size);
269     if (load_dataset == true) {
270       pages_[shard_id].push_back(std::move(parsed_page));
271     } else {
272       pages_[shard_index].push_back(std::move(parsed_page));
273     }
274   }
275   return Status::OK();
276 }
277 
ParseStatistics(const json & statistics)278 Status ShardHeader::ParseStatistics(const json &statistics) {
279   for (auto &statistic : statistics) {
280     CHECK_FAIL_RETURN_UNEXPECTED_MR(
281       statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
282       "[Internal ERROR] Failed to deserialize statistics: " + statistics.dump());
283     std::string statistic_description = statistic["desc"].get<std::string>();
284     json statistic_body = statistic["statistics"];
285     std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
286     RETURN_UNEXPECTED_IF_NULL_MR(parsed_statistic);
287     AddStatistic(parsed_statistic);
288   }
289   return Status::OK();
290 }
291 
ParseSchema(const json & schemas)292 Status ShardHeader::ParseSchema(const json &schemas) {
293   for (auto &schema : schemas) {
294     // change how we get schemaBody once design is finalized
295     CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
296                                       schema.find("schema") != schema.end(),
297                                     "[Internal ERROR] Failed to deserialize schema: " + schema.dump());
298     std::string schema_description = schema["desc"].get<std::string>();
299     std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
300     json schema_body = schema["schema"];
301     std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body);
302     CHECK_FAIL_RETURN_UNEXPECTED_MR(
303       parsed_schema != nullptr,
304       "[Internal ERROR] Failed to build schema, Please check the [ERROR] logs before for more details.");
305     AddSchema(parsed_schema);
306   }
307   return Status::OK();
308 }
309 
ParseShardAddress(const json & address)310 void ShardHeader::ParseShardAddress(const json &address) {
311   (void)std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_));
312 }
313 
SerializeHeader()314 std::vector<std::string> ShardHeader::SerializeHeader() {
315   std::vector<std::string> header;
316   auto index = SerializeIndexFields();
317   auto stats = SerializeStatistics();
318   auto schema = SerializeSchema();
319   auto pages = SerializePage();
320   auto address = SerializeShardAddress();
321   if (shard_count_ > static_cast<int>(pages.size())) {
322     return std::vector<string>{};
323   }
324   if (shard_count_ <= kMaxShardCount) {
325     for (int shardId = 0; shardId < shard_count_; shardId++) {
326       string s;
327       s += "{\"header_size\":" + std::to_string(header_size_) + ",";
328       s += "\"index_fields\":" + index + ",";
329       s += "\"page\":" + pages[shardId] + ",";
330       s += "\"page_size\":" + std::to_string(page_size_) + ",";
331       s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
332       s += "\"schema\":" + schema + ",";
333       s += "\"shard_addresses\":" + address + ",";
334       s += "\"shard_id\":" + std::to_string(shardId) + ",";
335       s += "\"statistics\":" + stats + ",";
336       s += "\"version\":\"" + std::string(kVersion) + "\"";
337       s += "}";
338       header.emplace_back(s);
339     }
340   }
341   return header;
342 }
343 
SerializeIndexFields()344 std::string ShardHeader::SerializeIndexFields() {
345   json j;
346   auto fields = index_->GetFields();
347   (void)std::transform(fields.begin(), fields.end(), std::back_inserter(j),
348                        [](const std::pair<uint64_t, std::string> &field) -> json {
349                          return {{"schema_id", field.first}, {"index_field", field.second}};
350                        });
351   return j.dump();
352 }
353 
SerializePage()354 std::vector<std::string> ShardHeader::SerializePage() {
355   std::vector<string> pages;
356   for (auto &shard_pages : pages_) {
357     json j;
358     (void)std::transform(shard_pages.begin(), shard_pages.end(), std::back_inserter(j),
359                          [](const std::shared_ptr<Page> &p) { return p->GetPage(); });
360     pages.emplace_back(j.dump());
361   }
362   return pages;
363 }
364 
SerializeStatistics()365 std::string ShardHeader::SerializeStatistics() {
366   json j;
367   (void)std::transform(statistics_.begin(), statistics_.end(), std::back_inserter(j),
368                        [](const std::shared_ptr<Statistics> &stats) { return stats->GetStatistics(); });
369   return j.dump();
370 }
371 
SerializeSchema()372 std::string ShardHeader::SerializeSchema() {
373   json j;
374   (void)std::transform(schema_.begin(), schema_.end(), std::back_inserter(j),
375                        [](const std::shared_ptr<Schema> &schema) { return schema->GetSchema(); });
376   return j.dump();
377 }
378 
SerializeShardAddress()379 std::string ShardHeader::SerializeShardAddress() {
380   json j;
381   std::shared_ptr<std::string> fn_ptr;
382   for (const auto &addr : shard_addresses_) {
383     (void)GetFileName(addr, &fn_ptr);
384     (void)j.emplace_back(*fn_ptr);
385   }
386   return j.dump();
387 }
388 
GetPage(const int & shard_id,const int & page_id,std::shared_ptr<Page> * page_ptr)389 Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) {
390   RETURN_UNEXPECTED_IF_NULL_MR(page_ptr);
391   if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
392     *page_ptr = pages_[shard_id][page_id];
393     return Status::OK();
394   }
395   page_ptr = nullptr;
396   RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to get Page, 'page_id': " + std::to_string(page_id));
397 }
398 
SetPage(const std::shared_ptr<Page> & new_page)399 Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
400   int shard_id = new_page->GetShardID();
401   int page_id = new_page->GetPageID();
402   if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
403     pages_[shard_id][page_id] = new_page;
404     return Status::OK();
405   }
406   RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to set Page, 'page_id': " + std::to_string(page_id));
407 }
408 
AddPage(const std::shared_ptr<Page> & new_page)409 Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
410   int shard_id = new_page->GetShardID();
411   int page_id = new_page->GetPageID();
412   if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
413     pages_[shard_id].push_back(new_page);
414     return Status::OK();
415   }
416   RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to add Page, 'page_id': " + std::to_string(page_id));
417 }
418 
GetLastPageId(const int & shard_id)419 int64_t ShardHeader::GetLastPageId(const int &shard_id) {
420   if (shard_id >= static_cast<int>(pages_.size())) {
421     return 0;
422   }
423   return pages_[shard_id].size() - 1;
424 }
425 
GetLastPageIdByType(const int & shard_id,const std::string & page_type)426 int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) {
427   if (shard_id >= static_cast<int>(pages_.size())) {
428     return 0;
429   }
430   int last_page_id = -1;
431   for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
432     if (pages_[shard_id][i - 1]->GetPageType() == page_type) {
433       last_page_id = pages_[shard_id][i - 1]->GetPageID();
434       return last_page_id;
435     }
436   }
437   return last_page_id;
438 }
439 
GetPageByGroupId(const int & group_id,const int & shard_id,std::shared_ptr<Page> * page_ptr)440 Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) {
441   RETURN_UNEXPECTED_IF_NULL_MR(page_ptr);
442   CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_id < static_cast<int>(pages_.size()),
443                                   "[Internal ERROR] 'shard_id': " + std::to_string(shard_id) +
444                                     " should be smaller than the size of 'pages_': " + std::to_string(pages_.size()) +
445                                     ".");
446   for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
447     auto page = pages_[shard_id][i - 1];
448     if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
449       *page_ptr = std::make_shared<Page>(*page);
450       return Status::OK();
451     }
452   }
453   page_ptr = nullptr;
454   RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to get Page, 'group_id': " + std::to_string(group_id));
455 }
456 
AddSchema(std::shared_ptr<Schema> schema)457 int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
458   if (schema == nullptr) {
459     MS_LOG(ERROR) << "[Internal ERROR] The pointer of schema is NULL.";
460     return -1;
461   }
462 
463   if (!schema_.empty()) {
464     MS_LOG(ERROR) << "The schema is added repeatedly. Please remove the redundant 'add_schema' function.";
465     return -1;
466   }
467 
468   int64_t schema_id = schema->GetSchemaID();
469   if (schema_id == -1) {
470     schema_id = schema_.size();
471     schema->SetSchemaID(schema_id);
472   }
473   schema_.push_back(schema);
474   return schema_id;
475 }
476 
AddStatistic(std::shared_ptr<Statistics> statistic)477 void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) {
478   if (statistic) {
479     int64_t statistics_id = statistic->GetStatisticsID();
480     if (statistics_id == -1) {
481       statistics_id = statistics_.size();
482       statistic->SetStatisticsID(statistics_id);
483     }
484     statistics_.push_back(statistic);
485   }
486 }
487 
InitIndexPtr()488 std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
489   std::shared_ptr<Index> index = index_;
490   if (!index_) {
491     index = std::make_shared<Index>();
492     index_ = index;
493   }
494   return index;
495 }
496 
CheckIndexField(const std::string & field,const json & schema)497 Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
498   // check field name is or is not valid
499   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(field) != schema.end(),
500                                   "Invalid input, 'index_fields': " + field +
501                                     " can not found in schema: " + schema.dump() +
502                                     ".\n Please use 'add_index' function to add proper 'index_fields'.");
503   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema[field]["type"] != "Bytes",
504                                   "Invalid input, type of 'index_fields': " + field +
505                                     " is bytes and can not set as an 'index_fields'.\n Please use 'add_index' function "
506                                     "to add the other 'index_fields'.");
507   CHECK_FAIL_RETURN_UNEXPECTED_MR(
508     schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
509     "Invalid input, type of 'index_fields': " + field +
510       " is array and can not set as an 'index_fields'.\n Please use 'add_index' function "
511       "to add the other 'index_fields'.");
512   return Status::OK();
513 }
514 
AddIndexFields(const std::vector<std::string> & fields)515 Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
516   if (fields.empty()) {
517     return Status::OK();
518   }
519   CHECK_FAIL_RETURN_UNEXPECTED_MR(
520     !GetSchemas().empty(), "Invalid data, schema is empty. Please use 'add_schema' function to add schema first.");
521   // create index Object
522   std::shared_ptr<Index> index = InitIndexPtr();
523   for (const auto &schemaPtr : schema_) {
524     std::shared_ptr<Schema> schema_ptr;
525     RETURN_IF_NOT_OK_MR(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr));
526     json schema = schema_ptr->GetSchema().at("schema");
527     // checkout and add fields for each schema
528     std::set<std::string> field_set;
529     for (const auto &item : index->GetFields()) {
530       field_set.insert(item.second);
531     }
532     for (const auto &field : fields) {
533       CHECK_FAIL_RETURN_UNEXPECTED_MR(
534         field_set.find(field) == field_set.end(),
535         "The 'index_fields': " + field + " is added repeatedly. Please remove the redundant 'add_index' function.");
536       // check field name is or is not valid
537       RETURN_IF_NOT_OK_MR(CheckIndexField(field, schema));
538       field_set.insert(field);
539       // add field into index
540       index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
541     }
542   }
543   index_ = index;
544   return Status::OK();
545 }
546 
GetAllSchemaID(std::set<uint64_t> & bucket_count)547 Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
548   // get all schema id
549   for (const auto &schema : schema_) {
550     auto schema_id = schema->GetSchemaID();
551     CHECK_FAIL_RETURN_UNEXPECTED_MR(bucket_count.find(schema_id) == bucket_count.end(),
552                                     "[Internal ERROR] duplicate schema exist, schema id: " + std::to_string(schema_id));
553     bucket_count.insert(schema_id);
554   }
555   return Status::OK();
556 }
557 
AddIndexFields(std::vector<std::pair<uint64_t,std::string>> fields)558 Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
559   if (fields.empty()) {
560     return Status::OK();
561   }
562   // create index Object
563   std::shared_ptr<Index> index = InitIndexPtr();
564   // get all schema id
565   std::set<uint64_t> bucket_count;
566   RETURN_IF_NOT_OK_MR(GetAllSchemaID(bucket_count));
567   // check and add fields for each schema
568   std::set<std::pair<uint64_t, std::string>> field_set;
569   for (const auto &item : index->GetFields()) {
570     field_set.insert(item);
571   }
572   for (const auto &field : fields) {
573     CHECK_FAIL_RETURN_UNEXPECTED_MR(field_set.find(field) == field_set.end(),
574                                     "The 'index_fields': " + field.second +
575                                       " is added repeatedly. Please remove the redundant 'add_index' function.");
576     uint64_t schema_id = field.first;
577     std::string field_name = field.second;
578 
579     // check schemaId is or is not valid
580     CHECK_FAIL_RETURN_UNEXPECTED_MR(bucket_count.find(schema_id) != bucket_count.end(),
581                                     "[Internal ERROR] 'schema_id': " + std::to_string(schema_id) + " can not found.");
582     // check field name is or is not valid
583     std::shared_ptr<Schema> schema_ptr;
584     RETURN_IF_NOT_OK_MR(GetSchemaByID(schema_id, &schema_ptr));
585     json schema = schema_ptr->GetSchema().at("schema");
586     CHECK_FAIL_RETURN_UNEXPECTED_MR(schema.find(field_name) != schema.end(),
587                                     "Invalid input, 'index_fields': " + field_name +
588                                       " can not found in schema: " + schema.dump() +
589                                       ".\n Please use 'add_index' function to add proper 'index_fields'.");
590     RETURN_IF_NOT_OK_MR(CheckIndexField(field_name, schema));
591     field_set.insert(field);
592     // add field into index
593     index->AddIndexField(schema_id, field_name);
594   }
595   index_ = index;
596   return Status::OK();
597 }
598 
GetShardAddressByID(int64_t shard_id)599 std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
600   if (shard_id >= shard_addresses_.size()) {
601     return "";
602   }
603   return shard_addresses_.at(shard_id);
604 }
605 
GetSchemas()606 std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; }
607 
GetStatistics()608 std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; }
609 
GetFields()610 std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); }
611 
GetIndex()612 std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
613 
GetSchemaByID(int64_t schema_id,std::shared_ptr<Schema> * schema_ptr)614 Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
615   RETURN_UNEXPECTED_IF_NULL_MR(schema_ptr);
616   int64_t schema_size = schema_.size();
617   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_id >= 0 && schema_id < schema_size,
618                                   "[Internal ERROR] 'schema_id': " + std::to_string(schema_id) +
619                                     " is not in range [0, " + std::to_string(schema_size) + ").");
620   *schema_ptr = schema_.at(schema_id);
621   return Status::OK();
622 }
623 
GetStatisticByID(int64_t statistic_id,std::shared_ptr<Statistics> * statistics_ptr)624 Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
625   RETURN_UNEXPECTED_IF_NULL_MR(statistics_ptr);
626   int64_t statistics_size = statistics_.size();
627   CHECK_FAIL_RETURN_UNEXPECTED_MR(statistic_id >= 0 && statistic_id < statistics_size,
628                                   "[Internal ERROR] 'statistic_id': " + std::to_string(statistic_id) +
629                                     " is not in range [0, " + std::to_string(statistics_size) + ").");
630   *statistics_ptr = statistics_.at(statistic_id);
631   return Status::OK();
632 }
633 
PagesToFile(const std::string dump_file_name)634 Status ShardHeader::PagesToFile(const std::string dump_file_name) {
635   auto realpath = FileUtils::GetRealPath(dump_file_name.c_str());
636   CHECK_FAIL_RETURN_UNEXPECTED_MR(realpath.has_value(),
637                                   "[Internal ERROR] Failed to get the realpath of Pages file, path: " + dump_file_name);
638   // write header content to file, dump whatever is in the file before
639   std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
640   CHECK_FAIL_RETURN_UNEXPECTED_MR(page_out_handle.good(),
641                                   "[Internal ERROR] Failed to open Pages file, path: " + dump_file_name);
642   auto pages = SerializePage();
643   for (const auto &shard_pages : pages) {
644     page_out_handle << shard_pages << "\n";
645   }
646   page_out_handle.close();
647 
648   ChangeFileMode(realpath.value(), S_IRUSR | S_IWUSR);
649   return Status::OK();
650 }
651 
FileToPages(const std::string dump_file_name)652 Status ShardHeader::FileToPages(const std::string dump_file_name) {
653   for (auto &v : pages_) {  // clean pages
654     v.clear();
655   }
656   auto realpath = FileUtils::GetRealPath(dump_file_name.c_str());
657   CHECK_FAIL_RETURN_UNEXPECTED_MR(realpath.has_value(),
658                                   "[Internal ERROR] Failed to get the realpath of Pages file, path: " + dump_file_name);
659   // attempt to open the file contains the page in json
660   std::ifstream page_in_handle(realpath.value(), std::ios::in);
661   CHECK_FAIL_RETURN_UNEXPECTED_MR(page_in_handle.good(),
662                                   "[Internal ERROR] Pages file does not exist, path: " + dump_file_name);
663   std::string line;
664   while (std::getline(page_in_handle, line)) {
665     auto s = ParsePage(json::parse(line), -1, true);
666     if (s != Status::OK()) {
667       page_in_handle.close();
668       return s;
669     }
670   }
671   page_in_handle.close();
672   return Status::OK();
673 }
674 
Initialize(const std::shared_ptr<ShardHeader> * header_ptr,const json & schema,const std::vector<std::string> & index_fields,std::vector<std::string> & blob_fields,uint64_t & schema_id)675 Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
676                                const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
677                                uint64_t &schema_id) {
678   RETURN_UNEXPECTED_IF_NULL_MR(header_ptr);
679   auto schema_ptr = Schema::Build("mindrecord", schema);
680   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_ptr != nullptr, "[Internal ERROR] Failed to build schema: " + schema.dump() +
681                                                            "." + "Check the [ERROR] logs before for more details.");
682   schema_id = (*header_ptr)->AddSchema(schema_ptr);
683   // create index
684   std::vector<std::pair<uint64_t, std::string>> id_index_fields;
685   if (!index_fields.empty()) {
686     (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
687                     [schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
688     RETURN_IF_NOT_OK_MR((*header_ptr)->AddIndexFields(id_index_fields));
689   }
690 
691   auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
692   blob_fields = build_schema_ptr->GetBlobFields();
693   return Status::OK();
694 }
695 }  // namespace mindrecord
696 }  // namespace mindspore
697