• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/util/random.h"
18 #include "minddata/mindrecord/include/shard_writer.h"
19 #include "utils/file_utils.h"
20 #include "utils/ms_utils.h"
21 #include "minddata/mindrecord/include/common/shard_utils.h"
22 #include "./securec.h"
23 
24 using mindspore::LogStream;
25 using mindspore::ExceptionType::NoExceptionType;
26 using mindspore::MsLogLevel::DEBUG;
27 using mindspore::MsLogLevel::ERROR;
28 using mindspore::MsLogLevel::INFO;
29 
30 namespace mindspore {
31 namespace mindrecord {
ShardWriter()32 ShardWriter::ShardWriter()
33     : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
34   compression_size_ = 0;
35 }
36 
~ShardWriter()37 ShardWriter::~ShardWriter() {
38   for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
39     file_streams_[i]->close();
40   }
41 }
42 
GetFullPathFromFileName(const std::vector<std::string> & paths)43 Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
44   // Get full path from file name
45   for (const auto &path : paths) {
46     CHECK_FAIL_RETURN_UNEXPECTED(CheckIsValidUtf8(path),
47                                  "Invalid data, file name: " + path + " contains invalid uft-8 character.");
48     char resolved_path[PATH_MAX] = {0};
49     char buf[PATH_MAX] = {0};
50     CHECK_FAIL_RETURN_UNEXPECTED(strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) == EOK,
51                                  "Failed to call securec func [strncpy_s], path: " + path);
52 #if defined(_WIN32) || defined(_WIN64)
53     RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX));
54     RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX));
55 #else
56     CHECK_FAIL_RETURN_UNEXPECTED(realpath(dirname(&(buf[0])), resolved_path) != nullptr,
57                                  "Invalid file, path: " + std::string(resolved_path));
58     if (realpath(common::SafeCStr(path), resolved_path) == nullptr) {
59       MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
60     }
61 #endif
62     file_paths_.emplace_back(string(resolved_path));
63   }
64   return Status::OK();
65 }
66 
OpenDataFiles(bool append)67 Status ShardWriter::OpenDataFiles(bool append) {
68   // Open files
69   for (const auto &file : file_paths_) {
70     std::optional<std::string> dir = "";
71     std::optional<std::string> local_file_name = "";
72     FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
73     if (!dir.has_value()) {
74       dir = ".";
75     }
76 
77     auto realpath = FileUtils::GetRealPath(dir.value().data());
78     CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
79 
80     std::optional<std::string> whole_path = "";
81     FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
82 
83     std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
84     if (!append) {
85       // if not append and mindrecord file exist, return FAILED
86       fs->open(whole_path.value(), std::ios::in | std::ios::binary);
87       if (fs->good()) {
88         fs->close();
89         RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
90       }
91       fs->close();
92       // open the mindrecord file to write
93       fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
94       if (!fs->good()) {
95         RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file);
96       }
97     } else {
98       // open the mindrecord file to append
99       fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary);
100       if (!fs->good()) {
101         fs->close();
102         RETURN_STATUS_UNEXPECTED("Failed to open file for append data, path: " + file);
103       }
104     }
105     MS_LOG(INFO) << "Succeed to open shard file, path: " << file;
106     file_streams_.push_back(fs);
107   }
108   return Status::OK();
109 }
110 
RemoveLockFile()111 Status ShardWriter::RemoveLockFile() {
112   // Remove temporary file
113   int ret = std::remove(pages_file_.c_str());
114   if (ret == 0) {
115     MS_LOG(DEBUG) << "Succeed to remove page file, path: " << pages_file_;
116   }
117 
118   ret = std::remove(lock_file_.c_str());
119   if (ret == 0) {
120     MS_LOG(DEBUG) << "Succeed to remove lock file, path: " << lock_file_;
121   }
122   return Status::OK();
123 }
124 
InitLockFile()125 Status ShardWriter::InitLockFile() {
126   CHECK_FAIL_RETURN_UNEXPECTED(file_paths_.size() != 0, "Invalid data, file_paths_ is not initialized.");
127 
128   lock_file_ = file_paths_[0] + kLockFileSuffix;
129   pages_file_ = file_paths_[0] + kPageFileSuffix;
130   RETURN_IF_NOT_OK(RemoveLockFile());
131   return Status::OK();
132 }
133 
Open(const std::vector<std::string> & paths,bool append)134 Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
135   shard_count_ = paths.size();
136   CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount,
137                                "Invalid data, schema_count_ must be less than or equal to " +
138                                  std::to_string(kMaxSchemaCount) + ", but got " + std::to_string(schema_count_));
139 
140   // Get full path from file name
141   RETURN_IF_NOT_OK(GetFullPathFromFileName(paths));
142   // Open files
143   RETURN_IF_NOT_OK(OpenDataFiles(append));
144   // Init lock file
145   RETURN_IF_NOT_OK(InitLockFile());
146   return Status::OK();
147 }
148 
OpenForAppend(const std::string & path)149 Status ShardWriter::OpenForAppend(const std::string &path) {
150   CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(path), "Invalid file, path: " + path);
151   std::shared_ptr<json> header_ptr;
152   RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(path, &header_ptr));
153   auto ds = std::make_shared<std::vector<std::string>>();
154   RETURN_IF_NOT_OK(GetDatasetFiles(path, (*header_ptr)["shard_addresses"], &ds));
155   ShardHeader header = ShardHeader();
156   RETURN_IF_NOT_OK(header.BuildDataset(*ds));
157   shard_header_ = std::make_shared<ShardHeader>(header);
158   RETURN_IF_NOT_OK(SetHeaderSize(shard_header_->GetHeaderSize()));
159   RETURN_IF_NOT_OK(SetPageSize(shard_header_->GetPageSize()));
160   compression_size_ = shard_header_->GetCompressionSize();
161   RETURN_IF_NOT_OK(Open(*ds, true));
162   shard_column_ = std::make_shared<ShardColumn>(shard_header_);
163   return Status::OK();
164 }
165 
Commit()166 Status ShardWriter::Commit() {
167   // Read pages file
168   std::ifstream page_file(pages_file_.c_str());
169   if (page_file.good()) {
170     page_file.close();
171     RETURN_IF_NOT_OK(shard_header_->FileToPages(pages_file_));
172   }
173   RETURN_IF_NOT_OK(WriteShardHeader());
174   MS_LOG(INFO) << "Succeed to write meta data.";
175   // Remove lock file
176   RETURN_IF_NOT_OK(RemoveLockFile());
177 
178   return Status::OK();
179 }
180 
SetShardHeader(std::shared_ptr<ShardHeader> header_data)181 Status ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) {
182   RETURN_IF_NOT_OK(header_data->InitByFiles(file_paths_));
183   // set fields in mindrecord when empty
184   std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
185   if (fields.empty()) {
186     MS_LOG(DEBUG) << "Index field is not set, it will be generated automatically.";
187     std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
188     for (const auto &schema : schemas) {
189       json jsonSchema = schema->GetSchema()["schema"];
190       for (const auto &el : jsonSchema.items()) {
191         if (el.value()["type"] == "string" ||
192             (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) ||
193             (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) ||
194             (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) ||
195             (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) {
196           fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key()));
197         }
198       }
199     }
200     // only blob data
201     if (!fields.empty()) {
202       RETURN_IF_NOT_OK(header_data->AddIndexFields(fields));
203     }
204   }
205 
206   shard_header_ = header_data;
207   shard_header_->SetHeaderSize(header_size_);
208   shard_header_->SetPageSize(page_size_);
209   shard_column_ = std::make_shared<ShardColumn>(shard_header_);
210   return Status::OK();
211 }
212 
SetHeaderSize(const uint64_t & header_size)213 Status ShardWriter::SetHeaderSize(const uint64_t &header_size) {
214   // header_size [16KB, 128MB]
215   CHECK_FAIL_RETURN_UNEXPECTED(header_size >= kMinHeaderSize && header_size <= kMaxHeaderSize,
216                                "Invalid data, header size: " + std::to_string(header_size) + " should be in range [" +
217                                  std::to_string(kMinHeaderSize) + "MB, " + std::to_string(kMaxHeaderSize) + "MB].");
218   CHECK_FAIL_RETURN_UNEXPECTED(
219     header_size % 4 == 0, "Invalid data, header size " + std::to_string(header_size) + " should be divided by four.");
220   header_size_ = header_size;
221   return Status::OK();
222 }
223 
SetPageSize(const uint64_t & page_size)224 Status ShardWriter::SetPageSize(const uint64_t &page_size) {
225   // PageSize [32KB, 256MB]
226   CHECK_FAIL_RETURN_UNEXPECTED(page_size >= kMinPageSize && page_size <= kMaxPageSize,
227                                "Invalid data, page size: " + std::to_string(page_size) + " should be in range [" +
228                                  std::to_string(kMinPageSize) + "MB, " + std::to_string(kMaxPageSize) + "MB].");
229   CHECK_FAIL_RETURN_UNEXPECTED(page_size % 4 == 0,
230                                "Invalid data, page size " + std::to_string(page_size) + " should be divided by four.");
231   page_size_ = page_size;
232   return Status::OK();
233 }
234 
DeleteErrorData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data)235 void ShardWriter::DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data,
236                                   std::vector<std::vector<uint8_t>> &blob_data) {
237   // get wrong data location
238   std::set<int, std::greater<int>> delete_set;
239   for (auto &err_mg : err_mg_) {
240     uint64_t id = err_mg.first;
241     auto sub_err_mg = err_mg.second;
242     for (auto &subMg : sub_err_mg) {
243       int loc = subMg.first;
244       std::string message = subMg.second;
245       MS_LOG(ERROR) << "Invalid input, the " << loc + 1 << " th data is invalid, " << message;
246       (void)delete_set.insert(loc);
247     }
248   }
249 
250   auto it = raw_data.begin();
251   if (delete_set.size() == it->second.size()) {
252     raw_data.clear();
253     blob_data.clear();
254     return;
255   }
256 
257   // delete wrong raw data
258   for (auto &loc : delete_set) {
259     // delete row data
260     for (auto &raw : raw_data) {
261       (void)raw.second.erase(raw.second.begin() + loc);
262     }
263 
264     // delete blob data
265     (void)blob_data.erase(blob_data.begin() + loc);
266   }
267 }
268 
PopulateMutexErrorData(const int & row,const std::string & message,std::map<int,std::string> & err_raw_data)269 void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message,
270                                          std::map<int, std::string> &err_raw_data) {
271   std::lock_guard<std::mutex> lock(check_mutex_);
272   (void)err_raw_data.insert(std::make_pair(row, message));
273 }
274 
CheckDataTypeAndValue(const std::string & key,const json & value,const json & data,const int & i,std::map<int,std::string> & err_raw_data)275 Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
276                                           std::map<int, std::string> &err_raw_data) {
277   auto data_type = std::string(value["type"].get<std::string>());
278   if ((data_type == "int32" && !data[key].is_number_integer()) ||
279       (data_type == "int64" && !data[key].is_number_integer()) ||
280       (data_type == "float32" && !data[key].is_number_float()) ||
281       (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) {
282     std::string message =
283       "field: " + key + " ,type : " + data_type + " ,value: " + data[key].dump() + " is not matched.";
284     PopulateMutexErrorData(i, message, err_raw_data);
285     RETURN_STATUS_UNEXPECTED(message);
286   }
287 
288   if (data_type == "int32" && data[key].is_number_integer()) {
289     int64_t temp_value = data[key];
290     if (static_cast<int64_t>(temp_value) < static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
291         static_cast<int64_t>(temp_value) > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
292       std::string message =
293         "field: " + key + " ,type : " + data_type + " ,value: " + data[key].dump() + " is out of range.";
294       PopulateMutexErrorData(i, message, err_raw_data);
295       RETURN_STATUS_UNEXPECTED(message);
296     }
297   }
298   return Status::OK();
299 }
300 
CheckSliceData(int start_row,int end_row,json schema,const std::vector<json> & sub_raw_data,std::map<int,std::string> & err_raw_data)301 void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data,
302                                  std::map<int, std::string> &err_raw_data) {
303   if (start_row < 0 || start_row > end_row || end_row > static_cast<int>(sub_raw_data.size())) {
304     return;
305   }
306   for (int i = start_row; i < end_row; i++) {
307     json data = sub_raw_data[i];
308 
309     for (auto iter = schema.begin(); iter != schema.end(); iter++) {
310       std::string key = iter.key();
311       json value = iter.value();
312       if (data.find(key) == data.end()) {
313         std::string message = "'" + key + "' object can not found in data: " + value.dump();
314         PopulateMutexErrorData(i, message, err_raw_data);
315         break;
316       }
317 
318       if (value.size() == kInt2) {
319         // Skip check since all shaped data will store as blob
320         continue;
321       }
322 
323       if (CheckDataTypeAndValue(key, value, data, i, err_raw_data).IsError()) {
324         break;
325       }
326     }
327   }
328 }
329 
CheckData(const std::map<uint64_t,std::vector<json>> & raw_data)330 Status ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &raw_data) {
331   auto rawdata_iter = raw_data.begin();
332 
333   // make sure rawdata match schema
334   for (; rawdata_iter != raw_data.end(); ++rawdata_iter) {
335     // used for storing error
336     std::map<int, std::string> sub_err_mg;
337     int schema_id = rawdata_iter->first;
338     std::shared_ptr<Schema> schema_ptr;
339     RETURN_IF_NOT_OK(shard_header_->GetSchemaByID(schema_id, &schema_ptr));
340     json schema = schema_ptr->GetSchema()["schema"];
341     for (const auto &field : schema_ptr->GetBlobFields()) {
342       (void)schema.erase(field);
343     }
344     std::vector<json> sub_raw_data = rawdata_iter->second;
345 
346     // calculate start position and end position for each thread
347     int batch_size = rawdata_iter->second.size() / shard_count_;
348     int thread_num = shard_count_;
349     CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid data, thread_num should be positive.");
350     if (thread_num > kMaxThreadCount) {
351       thread_num = kMaxThreadCount;
352     }
353     std::vector<std::thread> thread_set(thread_num);
354 
355     // start multiple thread
356     int start_row = 0, end_row = 0;
357     for (int x = 0; x < thread_num; ++x) {
358       if (x != thread_num - 1) {
359         start_row = batch_size * x;
360         end_row = batch_size * (x + 1);
361       } else {
362         start_row = batch_size * x;
363         end_row = rawdata_iter->second.size();
364       }
365       thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema,
366                                   std::ref(sub_raw_data), std::ref(sub_err_mg));
367     }
368     CHECK_FAIL_RETURN_UNEXPECTED(
369       thread_num <= kMaxThreadCount,
370       "Invalid data, thread_num should be less than or equal to " + std::to_string(kMaxThreadCount));
371     // Wait for threads done
372     for (int x = 0; x < thread_num; ++x) {
373       thread_set[x].join();
374     }
375 
376     (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg));
377   }
378   return Status::OK();
379 }
380 
ValidateRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,std::shared_ptr<std::pair<int,int>> * count_ptr)381 Status ShardWriter::ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,
382                                     std::vector<std::vector<uint8_t>> &blob_data, bool sign,
383                                     std::shared_ptr<std::pair<int, int>> *count_ptr) {
384   RETURN_UNEXPECTED_IF_NULL(count_ptr);
385   auto rawdata_iter = raw_data.begin();
386   schema_count_ = raw_data.size();
387   CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ > 0, "Invalid data, schema count should be positive.");
388 
389   // keep schema_id
390   std::set<int64_t> schema_ids;
391   row_count_ = (rawdata_iter->second).size();
392 
393   // Determine if the number of schemas is the same
394   CHECK_FAIL_RETURN_UNEXPECTED(shard_header_->GetSchemas().size() == schema_count_,
395                                "Invalid data, schema count: " + std::to_string(schema_count_) + " is not matched.");
396   // Determine raw_data size == blob_data size
397   CHECK_FAIL_RETURN_UNEXPECTED(raw_data[0].size() == blob_data.size(),
398                                "Invalid data, raw data size: " + std::to_string(raw_data[0].size()) +
399                                  " is not equal to blob data size: " + std::to_string(blob_data.size()) + ".");
400 
401   // Determine whether the number of samples corresponding to each schema is the same
402   for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
403     CHECK_FAIL_RETURN_UNEXPECTED(
404       row_count_ == rawdata_iter->second.size(),
405       "Invalid data, number of samples: " + std::to_string(rawdata_iter->second.size()) + " for schemais not matched.");
406     (void)schema_ids.insert(rawdata_iter->first);
407   }
408   const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
409   // There is not enough data which is not matching the number of schema
410   CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(schemas.begin(), schemas.end(),
411                                             [schema_ids](const std::shared_ptr<Schema> &schema) {
412                                               return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
413                                             }),
414                                "Invalid data, schema id of data is not matched.");
415   if (!sign) {
416     *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
417     return Status::OK();
418   }
419 
420   // check the data according the schema
421   RETURN_IF_NOT_OK(CheckData(raw_data));
422 
423   // delete wrong data from raw data
424   DeleteErrorData(raw_data, blob_data);
425 
426   // update raw count
427   row_count_ = row_count_ - err_mg_.begin()->second.size();
428   *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
429   return Status::OK();
430 }
431 
FillArray(int start,int end,std::map<uint64_t,vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data)432 void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
433                             std::vector<std::vector<uint8_t>> &bin_data) {
434   // Prevent excessive thread opening and cause cross-border
435   if (start >= end) {
436     flag_ = true;
437     return;
438   }
439   int schema_count = static_cast<int>(raw_data.size());
440   std::map<uint64_t, vector<json>>::const_iterator rawdata_iter;
441   for (int x = start; x < end; ++x) {
442     int cnt = 0;
443     for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
444       const json &line = raw_data.at(rawdata_iter->first)[x];
445       std::vector<std::uint8_t> bline = json::to_msgpack(line);
446 
447       // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2]
448       bin_data[x * schema_count + cnt] = bline;
449       cnt++;
450     }
451   }
452 }
453 
LockWriter(bool parallel_writer,std::unique_ptr<int> * fd_ptr)454 Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr) {
455   if (!parallel_writer) {
456     *fd_ptr = std::make_unique<int>(0);
457     return Status::OK();
458   }
459 
460 #if defined(_WIN32) || defined(_WIN64)
461   const int fd = 0;
462   MS_LOG(DEBUG) << "Lock file done by Python.";
463 
464 #else
465   const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
466   if (fd >= 0) {
467     flock(fd, LOCK_EX);
468   } else {
469     close(fd);
470     RETURN_STATUS_UNEXPECTED("Failed to lock file, path: " + lock_file_);
471   }
472 #endif
473 
474   // Open files
475   file_streams_.clear();
476   for (const auto &file : file_paths_) {
477     auto realpath = FileUtils::GetRealPath(file.data());
478     if (!realpath.has_value()) {
479       close(fd);
480       RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file);
481     }
482     std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
483     fs->open(realpath.value(), std::ios::in | std::ios::out | std::ios::binary);
484     if (fs->fail()) {
485       close(fd);
486       RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file);
487     }
488     file_streams_.push_back(fs);
489   }
490   auto status = shard_header_->FileToPages(pages_file_);
491   if (status.IsError()) {
492     close(fd);
493     RETURN_STATUS_UNEXPECTED("Error raised in FileToPages function.");
494   }
495   *fd_ptr = std::make_unique<int>(fd);
496   return Status::OK();
497 }
498 
UnlockWriter(int fd,bool parallel_writer)499 Status ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
500   if (!parallel_writer) {
501     return Status::OK();
502   }
503   RETURN_IF_NOT_OK(shard_header_->PagesToFile(pages_file_));
504   for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
505     file_streams_[i]->close();
506   }
507 #if defined(_WIN32) || defined(_WIN64)
508   MS_LOG(DEBUG) << "Unlock file done by Python.";
509 
510 #else
511   flock(fd, LOCK_UN);
512   close(fd);
513 #endif
514   return Status::OK();
515 }
516 
WriteRawDataPreCheck(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,int * schema_count,int * row_count)517 Status ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
518                                          std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
519                                          int *row_count) {
520   // check the free disk size
521   std::shared_ptr<uint64_t> size_ptr;
522   RETURN_IF_NOT_OK(GetDiskSize(file_paths_[0], kFreeSize, &size_ptr));
523   CHECK_FAIL_RETURN_UNEXPECTED(*size_ptr >= kMinFreeDiskSize,
524                                "No free disk to be used, free disk size: " + std::to_string(*size_ptr));
525   // compress blob
526   if (shard_column_->CheckCompressBlob()) {
527     for (auto &blob : blob_data) {
528       int64_t compression_bytes = 0;
529       blob = shard_column_->CompressBlob(blob, &compression_bytes);
530       compression_size_ += compression_bytes;
531     }
532   }
533 
534   // Add 4-bytes dummy blob data if no any blob fields
535   if (blob_data.size() == 0 && raw_data.size() > 0) {
536     blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
537   }
538 
539   // Add dummy id if all are blob fields
540   if (blob_data.size() > 0 && raw_data.size() == 0) {
541     raw_data.insert(std::pair<uint64_t, std::vector<json>>(0, std::vector<json>(blob_data.size(), kDummyId)));
542   }
543   std::shared_ptr<std::pair<int, int>> count_ptr;
544   RETURN_IF_NOT_OK(ValidateRawData(raw_data, blob_data, sign, &count_ptr));
545   *schema_count = (*count_ptr).first;
546   *row_count = (*count_ptr).second;
547   return Status::OK();
548 }
MergeBlobData(const std::vector<string> & blob_fields,const std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> & row_bin_data,std::shared_ptr<std::vector<uint8_t>> * output)549 Status ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
550                                   const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
551                                   std::shared_ptr<std::vector<uint8_t>> *output) {
552   if (blob_fields.empty()) {
553     return Status::OK();
554   }
555   if (blob_fields.size() == 1) {
556     auto &blob = row_bin_data.at(blob_fields[0]);
557     auto blob_size = blob->size();
558     *output = std::make_shared<std::vector<uint8_t>>(blob_size);
559     std::copy(blob->begin(), blob->end(), (*output)->begin());
560   } else {
561     size_t output_size = 0;
562     for (auto &field : blob_fields) {
563       output_size += row_bin_data.at(field)->size();
564     }
565     output_size += blob_fields.size() * sizeof(uint64_t);
566     *output = std::make_shared<std::vector<uint8_t>>(output_size);
567     std::vector<uint8_t> buf(sizeof(uint64_t), 0);
568     size_t idx = 0;
569     for (auto &field : blob_fields) {
570       auto &b = row_bin_data.at(field);
571       uint64_t blob_size = b->size();
572       // big edian
573       for (size_t i = 0; i < buf.size(); ++i) {
574         buf[buf.size() - 1 - i] = (std::numeric_limits<uint8_t>::max()) & blob_size;
575         blob_size >>= 8u;
576       }
577       std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
578       idx += buf.size();
579       std::copy(b->begin(), b->end(), (*output)->begin() + idx);
580       idx += b->size();
581     }
582   }
583   return Status::OK();
584 }
585 
WriteRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,bool parallel_writer)586 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
587                                  std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
588   // Lock Writer if loading data parallel
589   std::unique_ptr<int> fd_ptr;
590   RETURN_IF_NOT_OK(LockWriter(parallel_writer, &fd_ptr));
591 
592   // Get the count of schemas and rows
593   int schema_count = 0;
594   int row_count = 0;
595 
596   // Serialize raw data
597   RETURN_IF_NOT_OK(WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count));
598   CHECK_FAIL_RETURN_UNEXPECTED(row_count >= kInt0, "Invalid data, waw data size should be positive.");
599   if (row_count == kInt0) {
600     return Status::OK();
601   }
602   std::vector<std::vector<uint8_t>> bin_raw_data(row_count * schema_count);
603   // Serialize raw data
604   RETURN_IF_NOT_OK(SerializeRawData(raw_data, bin_raw_data, row_count));
605   // Set row size of raw data
606   RETURN_IF_NOT_OK(SetRawDataSize(bin_raw_data));
607   // Set row size of blob data
608   RETURN_IF_NOT_OK(SetBlobDataSize(blob_data));
609   // Write data to disk with multi threads
610   RETURN_IF_NOT_OK(ParallelWriteData(blob_data, bin_raw_data));
611   MS_LOG(INFO) << "Succeed to write " << bin_raw_data.size() << " records.";
612 
613   RETURN_IF_NOT_OK(UnlockWriter(*fd_ptr, parallel_writer));
614 
615   return Status::OK();
616 }
617 
WriteRawData(std::map<uint64_t,std::vector<py::handle>> & raw_data,std::map<uint64_t,std::vector<py::handle>> & blob_data,bool sign,bool parallel_writer)618 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
619                                  std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
620                                  bool parallel_writer) {
621   std::map<uint64_t, std::vector<json>> raw_data_json;
622   std::map<uint64_t, std::vector<json>> blob_data_json;
623 
624   (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
625                        [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
626                          auto &py_raw_data = pair.second;
627                          std::vector<json> json_raw_data;
628                          (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
629                                               [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
630                          return std::make_pair(pair.first, std::move(json_raw_data));
631                        });
632 
633   (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()),
634                        [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
635                          auto &py_blob_data = pair.second;
636                          std::vector<json> jsonBlobData;
637                          (void)std::transform(py_blob_data.begin(), py_blob_data.end(),
638                                               std::back_inserter(jsonBlobData),
639                                               [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
640                          return std::make_pair(pair.first, std::move(jsonBlobData));
641                        });
642 
643   // Serialize blob page
644   auto blob_data_iter = blob_data.begin();
645   auto schema_count = blob_data.size();
646   auto row_count = blob_data_iter->second.size();
647 
648   std::vector<std::vector<uint8_t>> bin_blob_data(row_count * schema_count);
649   // Serialize blob data
650   RETURN_IF_NOT_OK(SerializeRawData(blob_data_json, bin_blob_data, row_count));
651   return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
652 }
653 
ParallelWriteData(const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)654 Status ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
655                                       const std::vector<std::vector<uint8_t>> &bin_raw_data) {
656   auto shards = BreakIntoShards();
657   // define the number of thread
658   int thread_num = static_cast<int>(shard_count_);
659   CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid data, thread_num should be positive.");
660   if (thread_num > kMaxThreadCount) {
661     thread_num = kMaxThreadCount;
662   }
663   int left_thread = shard_count_;
664   int current_thread = 0;
665   while (left_thread) {
666     if (left_thread < thread_num) {
667       thread_num = left_thread;
668     }
669     // Start one thread for one shard
670     std::vector<std::thread> thread_set(thread_num);
671     if (thread_num <= kMaxThreadCount) {
672       for (int x = 0; x < thread_num; ++x) {
673         int start_row = shards[current_thread + x].first;
674         int end_row = shards[current_thread + x].second;
675         thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row,
676                                     std::ref(blob_data), std::ref(bin_raw_data));
677       }
678       // Wait for threads done
679       for (int x = 0; x < thread_num; ++x) {
680         thread_set[x].join();
681       }
682       left_thread -= thread_num;
683       current_thread += thread_num;
684     }
685   }
686   return Status::OK();
687 }
688 
WriteByShard(int shard_id,int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)689 Status ShardWriter::WriteByShard(int shard_id, int start_row, int end_row,
690                                  const std::vector<std::vector<uint8_t>> &blob_data,
691                                  const std::vector<std::vector<uint8_t>> &bin_raw_data) {
692   MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row
693                 << ", schema size: " << schema_count_;
694   if (start_row == end_row) {
695     return Status::OK();
696   }
697   vector<std::pair<int, int>> rows_in_group;
698   std::shared_ptr<Page> last_raw_page = nullptr;
699   std::shared_ptr<Page> last_blob_page = nullptr;
700   SetLastRawPage(shard_id, last_raw_page);
701   SetLastBlobPage(shard_id, last_blob_page);
702 
703   RETURN_IF_NOT_OK(CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page));
704   RETURN_IF_NOT_OK(AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
705   RETURN_IF_NOT_OK(NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
706   RETURN_IF_NOT_OK(ShiftRawPage(shard_id, rows_in_group, last_raw_page));
707   RETURN_IF_NOT_OK(WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data));
708 
709   return Status::OK();
710 }
711 
CutRowGroup(int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_raw_page,const std::shared_ptr<Page> & last_blob_page)712 Status ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
713                                 std::vector<std::pair<int, int>> &rows_in_group,
714                                 const std::shared_ptr<Page> &last_raw_page,
715                                 const std::shared_ptr<Page> &last_blob_page) {
716   auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0;
717 
718   auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
719   auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0;
720   auto n_byte_raw = last_raw_page_size - last_raw_offset;
721 
722   int page_start_row = start_row;
723   CHECK_FAIL_RETURN_UNEXPECTED(start_row <= end_row,
724                                "Invalid data, start row: " + std::to_string(start_row) +
725                                  " should be less than or equal to end row: " + std::to_string(end_row));
726 
727   CHECK_FAIL_RETURN_UNEXPECTED(
728     end_row <= static_cast<int>(blob_data_size_.size()) && end_row <= static_cast<int>(raw_data_size_.size()),
729     "Invalid data, end row: " + std::to_string(end_row) + " should be less than blob data size: " +
730       std::to_string(blob_data_size_.size()) + " and raw data size: " + std::to_string(raw_data_size_.size()) + ".");
731   for (int i = start_row; i < end_row; ++i) {
732     // n_byte_blob(0) indicate appendBlobPage
733     if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ ||
734         n_byte_raw + raw_data_size_[i] > page_size_) {
735       rows_in_group.emplace_back(page_start_row, i);
736       page_start_row = i;
737       n_byte_blob = blob_data_size_[i];
738       n_byte_raw = raw_data_size_[i];
739     } else {
740       n_byte_blob += blob_data_size_[i];
741       n_byte_raw += raw_data_size_[i];
742     }
743   }
744 
745   // Not forget last one
746   rows_in_group.emplace_back(page_start_row, end_row);
747   return Status::OK();
748 }
749 
AppendBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)750 Status ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
751                                    const std::vector<std::pair<int, int>> &rows_in_group,
752                                    const std::shared_ptr<Page> &last_blob_page) {
753   auto blob_row = rows_in_group[0];
754   if (blob_row.first == blob_row.second) {
755     return Status::OK();
756   }
757   // Write disk
758   auto page_id = last_blob_page->GetPageID();
759   auto bytes_page = last_blob_page->GetPageSize();
760   auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg);
761   if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
762     file_streams_[shard_id]->close();
763     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
764   }
765 
766   (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
767 
768   // Update last blob page
769   bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
770   last_blob_page->SetPageSize(bytes_page);
771   uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first;
772   last_blob_page->SetEndRowID(end_row);
773   (void)shard_header_->SetPage(last_blob_page);
774   return Status::OK();
775 }
776 
NewBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)777 Status ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
778                                 const std::vector<std::pair<int, int>> &rows_in_group,
779                                 const std::shared_ptr<Page> &last_blob_page) {
780   auto page_id = shard_header_->GetLastPageId(shard_id);
781   auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1;
782   auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0;
783   // index(0) indicate appendBlobPage
784   for (uint32_t i = 1; i < rows_in_group.size(); ++i) {
785     auto blob_row = rows_in_group[i];
786 
787     // Write 1 blob page to disk
788     auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
789     if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
790       file_streams_[shard_id]->close();
791       RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
792     }
793 
794     (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
795     // Create new page info for header
796     auto page_size =
797       std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
798     std::vector<std::pair<int, uint64_t>> row_group_ids;
799     auto start_row = current_row;
800     auto end_row = start_row + blob_row.second - blob_row.first;
801     auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size);
802     (void)shard_header_->AddPage(std::make_shared<Page>(page));
803     current_row = end_row;
804   }
805   return Status::OK();
806 }
807 
ShiftRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page)808 Status ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
809                                  std::shared_ptr<Page> &last_raw_page) {
810   auto blob_row = rows_in_group[0];
811   if (blob_row.first == blob_row.second) {
812     return Status::OK();
813   }
814   auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
815   if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) +
816         last_raw_page_size <=
817       page_size_) {
818     return Status::OK();
819   }
820   auto page_id = shard_header_->GetLastPageId(shard_id);
821   auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second;
822   auto last_raw_page_id = last_raw_page->GetPageID();
823   auto shift_size = last_raw_page_size - last_row_group_id_offset;
824 
825   std::vector<uint8_t> buf(shift_size);
826 
827   // Read last row group from previous raw data page
828   CHECK_FAIL_RETURN_UNEXPECTED(
829     shard_id >= 0 && shard_id < file_streams_.size(),
830     "Invalid data, shard_id should be in range [0, " + std::to_string(file_streams_.size()) + ").");
831 
832   auto &io_seekg = file_streams_[shard_id]->seekg(
833     page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg);
834   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
835     file_streams_[shard_id]->close();
836     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
837   }
838 
839   auto &io_read = file_streams_[shard_id]->read(reinterpret_cast<char *>(&buf[0]), buf.size());
840   if (!io_read.good() || io_read.fail() || io_read.bad()) {
841     file_streams_[shard_id]->close();
842     RETURN_STATUS_UNEXPECTED("Failed to read file.");
843   }
844 
845   // Merge into new row group at new raw data page
846   auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
847   if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
848     file_streams_[shard_id]->close();
849     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
850   }
851 
852   auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&buf[0]), buf.size());
853   if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
854     file_streams_[shard_id]->close();
855     RETURN_STATUS_UNEXPECTED("Failed to write file.");
856   }
857   last_raw_page->DeleteLastGroupId();
858   (void)shard_header_->SetPage(last_raw_page);
859 
860   // Refresh page info in header
861   int row_group_id = last_raw_page->GetLastRowGroupID().first + 1;
862   std::vector<std::pair<int, uint64_t>> row_group_ids;
863   row_group_ids.emplace_back(row_group_id, 0);
864   int page_type_id = last_raw_page->GetPageID();
865   auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size);
866   (void)shard_header_->AddPage(std::make_shared<Page>(page));
867 
868   // Reset: last raw page
869   SetLastRawPage(shard_id, last_raw_page);
870   return Status::OK();
871 }
872 
WriteRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)873 Status ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
874                                  std::shared_ptr<Page> &last_raw_page,
875                                  const std::vector<std::vector<uint8_t>> &bin_raw_data) {
876   int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1;
877   for (uint32_t i = 0; i < rows_in_group.size(); ++i) {
878     const auto &blob_row = rows_in_group[i];
879     if (blob_row.first == blob_row.second) {
880       continue;
881     }
882     auto raw_size =
883       std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0);
884     if (!last_raw_page) {
885       RETURN_IF_NOT_OK(EmptyRawPage(shard_id, last_raw_page));
886     } else if (last_raw_page->GetPageSize() + raw_size > page_size_) {
887       RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page));
888       RETURN_IF_NOT_OK(EmptyRawPage(shard_id, last_raw_page));
889     }
890     RETURN_IF_NOT_OK(AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data));
891   }
892   RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page));
893   return Status::OK();
894 }
895 
EmptyRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)896 Status ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
897   auto row_group_ids = std::vector<std::pair<int, uint64_t>>();
898   auto page_id = shard_header_->GetLastPageId(shard_id);
899   auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1;
900   auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0);
901   RETURN_IF_NOT_OK(shard_header_->AddPage(std::make_shared<Page>(page)));
902   SetLastRawPage(shard_id, last_raw_page);
903   return Status::OK();
904 }
905 
AppendRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,int & last_row_group_id,std::shared_ptr<Page> last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)906 Status ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
907                                   const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page,
908                                   const std::vector<std::vector<uint8_t>> &bin_raw_data) {
909   std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds();
910   auto last_raw_page_id = last_raw_page->GetPageID();
911   auto n_bytes = last_raw_page->GetPageSize();
912 
913   //  previous raw data page
914   auto &io_seekp =
915     file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg);
916   if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
917     file_streams_[shard_id]->close();
918     RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
919   }
920 
921   if (chunk_id > 0) {
922     row_group_ids.emplace_back(++last_row_group_id, n_bytes);
923   }
924   n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first,
925                              raw_data_size_.begin() + rows_in_group[chunk_id].second, 0);
926   RETURN_IF_NOT_OK(FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data));
927 
928   // Update previous raw data page
929   last_raw_page->SetPageSize(n_bytes);
930   last_raw_page->SetRowGroupIds(row_group_ids);
931   RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page));
932 
933   return Status::OK();
934 }
935 
FlushBlobChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::vector<uint8_t>> & blob_data,const std::pair<int,int> & blob_row)936 Status ShardWriter::FlushBlobChunk(const std::shared_ptr<std::fstream> &out,
937                                    const std::vector<std::vector<uint8_t>> &blob_data,
938                                    const std::pair<int, int> &blob_row) {
939   CHECK_FAIL_RETURN_UNEXPECTED(
940     blob_row.first <= blob_row.second && blob_row.second <= static_cast<int>(blob_data.size()) && blob_row.first >= 0,
941     "Invalid data, blob_row: " + std::to_string(blob_row.first) + ", " + std::to_string(blob_row.second) +
942       " is invalid.");
943   for (int j = blob_row.first; j < blob_row.second; ++j) {
944     // Write the size of blob
945     uint64_t line_len = blob_data[j].size();
946     auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
947     if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
948       out->close();
949       RETURN_STATUS_UNEXPECTED("Failed to write file.");
950     }
951 
952     // Write the data of blob
953     auto line = blob_data[j];
954     auto &io_handle_data = out->write(reinterpret_cast<char *>(&line[0]), line_len);
955     if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) {
956       out->close();
957       RETURN_STATUS_UNEXPECTED("Failed to write file.");
958     }
959   }
960   return Status::OK();
961 }
962 
FlushRawChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,const std::vector<std::vector<uint8_t>> & bin_raw_data)963 Status ShardWriter::FlushRawChunk(const std::shared_ptr<std::fstream> &out,
964                                   const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
965                                   const std::vector<std::vector<uint8_t>> &bin_raw_data) {
966   for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) {
967     // Write the size of multi schemas
968     for (uint32_t j = 0; j < schema_count_; ++j) {
969       uint64_t line_len = bin_raw_data[i * schema_count_ + j].size();
970       auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
971       if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
972         out->close();
973         RETURN_STATUS_UNEXPECTED("Failed to write file.");
974       }
975     }
976     // Write the data of multi schemas
977     for (uint32_t j = 0; j < schema_count_; ++j) {
978       auto line = bin_raw_data[i * schema_count_ + j];
979       auto &io_handle = out->write(reinterpret_cast<char *>(&line[0]), line.size());
980       if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
981         out->close();
982         RETURN_STATUS_UNEXPECTED("Failed to write file.");
983       }
984     }
985   }
986   return Status::OK();
987 }
988 
989 // Allocate data to shards evenly
BreakIntoShards()990 std::vector<std::pair<int, int>> ShardWriter::BreakIntoShards() {
991   std::vector<std::pair<int, int>> shards;
992   int row_in_shard = row_count_ / shard_count_;
993   int remains = row_count_ % shard_count_;
994 
995   std::vector<int> v_list(shard_count_);
996   std::iota(v_list.begin(), v_list.end(), 0);
997 
998   std::mt19937 g = mindspore::dataset::GetRandomDevice();
999   std::shuffle(v_list.begin(), v_list.end(), g);
1000   std::unordered_set<int> set(v_list.begin(), v_list.begin() + remains);
1001 
1002   if (shard_count_ <= kMaxShardCount) {
1003     int start_row = 0;
1004     for (int i = 0; i < shard_count_; ++i) {
1005       int end_row = start_row + row_in_shard;
1006       if (set.count(i)) end_row++;
1007       shards.emplace_back(start_row, end_row);
1008       start_row = end_row;
1009     }
1010   }
1011   return shards;
1012 }
1013 
WriteShardHeader()1014 Status ShardWriter::WriteShardHeader() {
1015   RETURN_UNEXPECTED_IF_NULL(shard_header_);
1016   int64_t compression_temp = compression_size_;
1017   uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
1018   shard_header_->SetCompressionSize(compression_size);
1019 
1020   auto shard_header = shard_header_->SerializeHeader();
1021   // Write header data to multi files
1022   CHECK_FAIL_RETURN_UNEXPECTED(
1023     shard_count_ <= static_cast<int>(file_streams_.size()) && shard_count_ <= static_cast<int>(shard_header.size()),
1024     "Invalid data, shard count should be less than or equal to file size: " + std::to_string(file_streams_.size()) +
1025       ", and header size: " + std::to_string(shard_header.size()) + ".");
1026   if (shard_count_ <= kMaxShardCount) {
1027     for (int shard_id = 0; shard_id < shard_count_; ++shard_id) {
1028       auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg);
1029       if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
1030         file_streams_[shard_id]->close();
1031         RETURN_STATUS_UNEXPECTED("Failed to seekp file.");
1032       }
1033 
1034       std::vector<uint8_t> bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end());
1035       uint64_t line_len = bin_header.size();
1036       if (line_len + kInt64Len > header_size_) {
1037         file_streams_[shard_id]->close();
1038         RETURN_STATUS_UNEXPECTED("shard header is too big.");
1039       }
1040       auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1041       if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1042         file_streams_[shard_id]->close();
1043         RETURN_STATUS_UNEXPECTED("Failed to write file.");
1044       }
1045 
1046       auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast<char *>(&bin_header[0]), line_len);
1047       if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) {
1048         file_streams_[shard_id]->close();
1049         RETURN_STATUS_UNEXPECTED("Failed to write file.");
1050       }
1051       file_streams_[shard_id]->close();
1052     }
1053   }
1054   return Status::OK();
1055 }
1056 
SerializeRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data,uint32_t row_count)1057 Status ShardWriter::SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,
1058                                      std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count) {
1059   // define the number of thread
1060   uint32_t thread_num = std::thread::hardware_concurrency();
1061   if (thread_num == 0) {
1062     thread_num = kThreadNumber;
1063   }
1064   // Set the number of samples processed by each thread
1065   int group_num = ceil(row_count * 1.0 / thread_num);
1066   std::vector<std::thread> thread_set(thread_num);
1067   int work_thread_num = 0;
1068   for (uint32_t x = 0; x < thread_num; ++x) {
1069     int start_num = x * group_num;
1070     int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num;
1071     if (start_num >= end_num) {
1072       continue;
1073     }
1074     // Define the run boundary and start the child thread
1075     thread_set[x] =
1076       std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data));
1077     work_thread_num++;
1078   }
1079   for (uint32_t x = 0; x < work_thread_num; ++x) {
1080     // Set obstacles to prevent the main thread from running
1081     thread_set[x].join();
1082   }
1083   CHECK_FAIL_RETURN_SYNTAX_ERROR(flag_ != true, "Error raised in FillArray function.");
1084   return Status::OK();
1085 }
1086 
SetRawDataSize(const std::vector<std::vector<uint8_t>> & bin_raw_data)1087 Status ShardWriter::SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data) {
1088   raw_data_size_ = std::vector<uint64_t>(row_count_, 0);
1089   for (uint32_t i = 0; i < row_count_; ++i) {
1090     raw_data_size_[i] = std::accumulate(
1091       bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0,
1092       [](uint64_t accumulator, const std::vector<uint8_t> &row) { return accumulator + kInt64Len + row.size(); });
1093   }
1094   CHECK_FAIL_RETURN_SYNTAX_ERROR(
1095     *std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_,
1096     "Invalid data, Page size: " + std::to_string(page_size_) + " is too small to save a raw row!");
1097   return Status::OK();
1098 }
1099 
SetBlobDataSize(const std::vector<std::vector<uint8_t>> & blob_data)1100 Status ShardWriter::SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data) {
1101   blob_data_size_ = std::vector<uint64_t>(row_count_);
1102   (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(),
1103                        [](const std::vector<uint8_t> &row) { return kInt64Len + row.size(); });
1104   CHECK_FAIL_RETURN_SYNTAX_ERROR(
1105     *std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_,
1106     "Invalid data, Page size: " + std::to_string(page_size_) + " is too small to save a blob row!");
1107   return Status::OK();
1108 }
1109 
SetLastRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)1110 Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
1111   // Get last raw page
1112   auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw);
1113   CHECK_FAIL_RETURN_SYNTAX_ERROR(last_raw_page_id >= 0, "Invalid data, last_raw_page_id: " +
1114                                                           std::to_string(last_raw_page_id) + " should be positive.");
1115   RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_raw_page_id, &last_raw_page));
1116   return Status::OK();
1117 }
1118 
SetLastBlobPage(const int & shard_id,std::shared_ptr<Page> & last_blob_page)1119 Status ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page) {
1120   // Get last blob page
1121   auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob);
1122   CHECK_FAIL_RETURN_SYNTAX_ERROR(last_blob_page_id >= 0, "Invalid data, last_blob_page_id: " +
1123                                                            std::to_string(last_blob_page_id) + " should be positive.");
1124   RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_blob_page_id, &last_blob_page));
1125   return Status::OK();
1126 }
1127 
Initialize(const std::unique_ptr<ShardWriter> * writer_ptr,const std::vector<std::string> & file_names)1128 Status ShardWriter::Initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
1129                                const std::vector<std::string> &file_names) {
1130   RETURN_UNEXPECTED_IF_NULL(writer_ptr);
1131   RETURN_IF_NOT_OK((*writer_ptr)->Open(file_names, false));
1132   RETURN_IF_NOT_OK((*writer_ptr)->SetHeaderSize(kDefaultHeaderSize));
1133   RETURN_IF_NOT_OK((*writer_ptr)->SetPageSize(kDefaultPageSize));
1134   return Status::OK();
1135 }
1136 }  // namespace mindrecord
1137 }  // namespace mindspore
1138