• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/mindrecord/include/shard_writer.h"
18 #include "utils/file_utils.h"
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21 #include "securec.h"
22 
23 namespace mindspore {
24 namespace mindrecord {
ShardWriter()25 ShardWriter::ShardWriter()
26     : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
27   compression_size_ = 0;
28 }
29 
~ShardWriter()30 ShardWriter::~ShardWriter() {
31   for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
32     file_streams_[i]->close();
33   }
34 }
35 
GetFullPathFromFileName(const std::vector<std::string> & paths)36 Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
37   // Get full path from file name
38   for (const auto &path : paths) {
39     CHECK_FAIL_RETURN_UNEXPECTED_MR(CheckIsValidUtf8(path),
40                                     "Invalid file, mindrecord file name: " + path +
41                                       " contains invalid uft-8 character. Please rename mindrecord file name.");
42     // get realpath
43     std::optional<std::string> dir = "";
44     std::optional<std::string> local_file_name = "";
45     FileUtils::SplitDirAndFileName(path, &dir, &local_file_name);
46     if (!dir.has_value()) {
47       dir = ".";
48     }
49 
50     auto realpath = FileUtils::GetRealPath(dir.value().c_str());
51     CHECK_FAIL_RETURN_UNEXPECTED_MR(
52       realpath.has_value(),
53       "Invalid dir, failed to get the realpath of mindrecord file dir. Please check path: " + dir.value());
54 
55     std::optional<std::string> whole_path = "";
56     FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
57 
58     (void)file_paths_.emplace_back(whole_path.value());
59   }
60   return Status::OK();
61 }
62 
OpenDataFiles(bool append,bool overwrite)63 Status ShardWriter::OpenDataFiles(bool append, bool overwrite) {
64   // Open files
65   for (const auto &file : file_paths_) {
66     std::optional<std::string> dir = "";
67     std::optional<std::string> local_file_name = "";
68     FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
69     if (!dir.has_value()) {
70       dir = ".";
71     }
72 
73     auto realpath = FileUtils::GetRealPath(dir.value().c_str());
74     CHECK_FAIL_RETURN_UNEXPECTED_MR(
75       realpath.has_value(), "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
76 
77     std::optional<std::string> whole_path = "";
78     FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
79 
80     std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
81     if (!append) {
82       // if not append && mindrecord or db file exist
83       fs->open(whole_path.value(), std::ios::in | std::ios::binary);
84       std::ifstream fs_db(whole_path.value() + ".db");
85       if (fs->good() || fs_db.good()) {
86         fs->close();
87         fs_db.close();
88         if (overwrite) {
89           auto res1 = std::remove(whole_path.value().c_str());
90           CHECK_FAIL_RETURN_UNEXPECTED_MR(!std::ifstream(whole_path.value()) == true,
91                                           "Invalid file, failed to remove the old files when trying to overwrite "
92                                           "mindrecord files. Please check file path and permission: " +
93                                             file);
94           if (res1 == 0) {
95             MS_LOG(WARNING) << "Succeed to remove the old mindrecord files, path: " << file;
96           }
97           auto db_file = whole_path.value() + ".db";
98           auto res2 = std::remove(db_file.c_str());
99           CHECK_FAIL_RETURN_UNEXPECTED_MR(!std::ifstream(whole_path.value() + ".db") == true,
100                                           "Invalid file, failed to remove the old mindrecord meta files when trying to "
101                                           "overwrite mindrecord files. Please check file path and permission: " +
102                                             file + ".db");
103           if (res2 == 0) {
104             MS_LOG(WARNING) << "Succeed to remove the old mindrecord metadata files, path: " << file + ".db";
105           }
106         } else {
107           RETURN_STATUS_UNEXPECTED_MR(
108             "Invalid file, mindrecord files already exist. Please check file path: " + file +
109             +".\nIf you do not want to keep the files, set the 'overwrite' parameter to True and try again.");
110         }
111       } else {
112         fs->close();
113         fs_db.close();
114       }
115       // open the mindrecord file to write
116       fs->open(whole_path.value().data(), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
117       if (!fs->good()) {
118         fs->close();
119         RETURN_STATUS_UNEXPECTED_MR(
120           "Invalid file, failed to open files for writing mindrecord files. Please check file path, permission and "
121           "open file limit: " +
122           file);
123       }
124     } else {
125       // open the mindrecord file to append
126       fs->open(whole_path.value().data(), std::ios::out | std::ios::in | std::ios::binary);
127       if (!fs->good()) {
128         fs->close();
129         RETURN_STATUS_UNEXPECTED_MR(
130           "Invalid file, failed to open files for appending mindrecord files. Please check file path, permission and "
131           "open file limit: " +
132           file);
133       }
134     }
135     MS_LOG(INFO) << "Succeed to open mindrecord shard file, path: " << file;
136     file_streams_.push_back(fs);
137   }
138   return Status::OK();
139 }
140 
RemoveLockFile()141 Status ShardWriter::RemoveLockFile() {
142   // Remove temporary file
143   int ret = std::remove(pages_file_.c_str());
144   if (ret == 0) {
145     MS_LOG(DEBUG) << "Succeed to remove page file, path: " << pages_file_;
146   }
147 
148   ret = std::remove(lock_file_.c_str());
149   if (ret == 0) {
150     MS_LOG(DEBUG) << "Succeed to remove lock file, path: " << lock_file_;
151   }
152   return Status::OK();
153 }
154 
InitLockFile()155 Status ShardWriter::InitLockFile() {
156   CHECK_FAIL_RETURN_UNEXPECTED_MR(file_paths_.size() != 0, "[Internal ERROR] 'file_paths_' is not initialized.");
157 
158   lock_file_ = file_paths_[0] + kLockFileSuffix;
159   pages_file_ = file_paths_[0] + kPageFileSuffix;
160   RETURN_IF_NOT_OK_MR(RemoveLockFile());
161   return Status::OK();
162 }
163 
Open(const std::vector<std::string> & paths,bool append,bool overwrite)164 Status ShardWriter::Open(const std::vector<std::string> &paths, bool append, bool overwrite) {
165   shard_count_ = paths.size();
166   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_count_ <= kMaxSchemaCount,
167                                   "[Internal ERROR] 'schema_count_' should be less than or equal to " +
168                                     std::to_string(kMaxSchemaCount) + ", but got: " + std::to_string(schema_count_));
169 
170   // Get full path from file name
171   RETURN_IF_NOT_OK_MR(GetFullPathFromFileName(paths));
172   // Open files
173   RETURN_IF_NOT_OK_MR(OpenDataFiles(append, overwrite));
174   // Init lock file
175   RETURN_IF_NOT_OK_MR(InitLockFile());
176   return Status::OK();
177 }
178 
OpenForAppend(const std::string & path)179 Status ShardWriter::OpenForAppend(const std::string &path) {
180   RETURN_IF_NOT_OK_MR(CheckFile(path));
181   std::shared_ptr<json> header_ptr;
182   RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(path, &header_ptr));
183   auto ds = std::make_shared<std::vector<std::string>>();
184   RETURN_IF_NOT_OK_MR(GetDatasetFiles(path, (*header_ptr)["shard_addresses"], &ds));
185   ShardHeader header = ShardHeader();
186   RETURN_IF_NOT_OK_MR(header.BuildDataset(*ds));
187   shard_header_ = std::make_shared<ShardHeader>(header);
188   RETURN_IF_NOT_OK_MR(SetHeaderSize(shard_header_->GetHeaderSize()));
189   RETURN_IF_NOT_OK_MR(SetPageSize(shard_header_->GetPageSize()));
190   compression_size_ = shard_header_->GetCompressionSize();
191   RETURN_IF_NOT_OK_MR(Open(*ds, true));
192   shard_column_ = std::make_shared<ShardColumn>(shard_header_);
193   return Status::OK();
194 }
195 
Commit()196 Status ShardWriter::Commit() {
197   // Read pages file
198   std::ifstream page_file(pages_file_.c_str(), std::ios::in);
199   if (page_file.good()) {
200     page_file.close();
201     RETURN_IF_NOT_OK_MR(shard_header_->FileToPages(pages_file_));
202   }
203   RETURN_IF_NOT_OK_MR(WriteShardHeader());
204   MS_LOG(INFO) << "Succeed to write meta data.";
205   // Remove lock file
206   RETURN_IF_NOT_OK_MR(RemoveLockFile());
207 
208   return Status::OK();
209 }
210 
SetShardHeader(std::shared_ptr<ShardHeader> header_data)211 Status ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) {
212   CHECK_FAIL_RETURN_UNEXPECTED_MR(
213     header_data->GetSchemaCount() > 0,
214     "Invalid data, schema is not found in header, please use 'add_schema' to add a schema for new mindrecord files.");
215   RETURN_IF_NOT_OK_MR(header_data->InitByFiles(file_paths_));
216   // set fields in mindrecord when empty
217   std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
218   if (fields.empty()) {
219     MS_LOG(DEBUG) << "Index field is not set, it will be generated automatically.";
220     std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
221     for (const auto &schema : schemas) {
222       json jsonSchema = schema->GetSchema()["schema"];
223       for (const auto &el : jsonSchema.items()) {
224         if (el.value()["type"] == "string" ||
225             (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) ||
226             (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) ||
227             (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) ||
228             (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) {
229           fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key()));
230         }
231       }
232     }
233     // only blob data
234     if (!fields.empty()) {
235       RETURN_IF_NOT_OK_MR(header_data->AddIndexFields(fields));
236     }
237   }
238 
239   shard_header_ = header_data;
240   shard_header_->SetHeaderSize(header_size_);
241   shard_header_->SetPageSize(page_size_);
242   shard_column_ = std::make_shared<ShardColumn>(shard_header_);
243   return Status::OK();
244 }
245 
SetHeaderSize(const uint64_t & header_size)246 Status ShardWriter::SetHeaderSize(const uint64_t &header_size) {
247   // header_size [16KB, 128MB]
248   CHECK_FAIL_RETURN_UNEXPECTED_MR(header_size >= kMinHeaderSize && header_size <= kMaxHeaderSize,
249                                   "Invalid data, header size: " + std::to_string(header_size) +
250                                     " should be in range [" + std::to_string(kMinHeaderSize) + " bytes, " +
251                                     std::to_string(kMaxHeaderSize) + " bytes].");
252   CHECK_FAIL_RETURN_UNEXPECTED_MR(
253     header_size % 4 == 0, "Invalid data, header size " + std::to_string(header_size) + " should be divided by four.");
254   header_size_ = header_size;
255   return Status::OK();
256 }
257 
SetPageSize(const uint64_t & page_size)258 Status ShardWriter::SetPageSize(const uint64_t &page_size) {
259   // PageSize [32KB, 256MB]
260   CHECK_FAIL_RETURN_UNEXPECTED_MR(page_size >= kMinPageSize && page_size <= kMaxPageSize,
261                                   "Invalid data, page size: " + std::to_string(page_size) + " should be in range [" +
262                                     std::to_string(kMinPageSize) + " bytes, " + std::to_string(kMaxPageSize) +
263                                     " bytes].");
264   CHECK_FAIL_RETURN_UNEXPECTED_MR(
265     page_size % 4 == 0, "Invalid data, page size " + std::to_string(page_size) + " should be divided by four.");
266   page_size_ = page_size;
267   return Status::OK();
268 }
269 
DeleteErrorData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data)270 void ShardWriter::DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data,
271                                   std::vector<std::vector<uint8_t>> &blob_data) {
272   // get wrong data location
273   std::set<int, std::greater<int>> delete_set;
274   for (auto &err_mg : err_mg_) {
275     uint64_t id = err_mg.first;
276     auto sub_err_mg = err_mg.second;
277     for (auto &subMg : sub_err_mg) {
278       int loc = subMg.first;
279       std::string message = subMg.second;
280       MS_LOG(ERROR) << "Invalid input, the " << loc + 1
281                     << " th data provided by user is invalid while writing mindrecord files. Please fix the error: "
282                     << message;
283       (void)delete_set.insert(loc);
284     }
285   }
286 
287   auto it = raw_data.begin();
288   if (delete_set.size() == it->second.size()) {
289     raw_data.clear();
290     blob_data.clear();
291     return;
292   }
293 
294   // delete wrong raw data
295   for (auto &loc : delete_set) {
296     // delete row data
297     for (auto &raw : raw_data) {
298       (void)raw.second.erase(raw.second.begin() + loc);
299     }
300 
301     // delete blob data
302     (void)blob_data.erase(blob_data.begin() + loc);
303   }
304 }
305 
PopulateMutexErrorData(const int & row,const std::string & message,std::map<int,std::string> & err_raw_data)306 void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message,
307                                          std::map<int, std::string> &err_raw_data) {
308   std::lock_guard<std::mutex> lock(check_mutex_);
309   (void)err_raw_data.insert(std::make_pair(row, message));
310 }
311 
CheckDataTypeAndValue(const std::string & key,const json & value,const json & data,const int & i,std::map<int,std::string> & err_raw_data)312 Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
313                                           std::map<int, std::string> &err_raw_data) {
314   auto data_type = std::string(value["type"].get<std::string>());
315   if ((data_type == "int32" && !data[key].is_number_integer()) ||
316       (data_type == "int64" && !data[key].is_number_integer()) ||
317       (data_type == "float32" && !data[key].is_number_float()) ||
318       (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) {
319     std::string message = "Invalid input, for field: " + key + ", type: " + data_type +
320                           " and value: " + data[key].dump() + " do not match while writing mindrecord files.";
321     PopulateMutexErrorData(i, message, err_raw_data);
322     RETURN_STATUS_UNEXPECTED_MR(message);
323   }
324 
325   if (data_type == "int32" && data[key].is_number_integer()) {
326     int64_t temp_value = data[key];
327     if (static_cast<int64_t>(temp_value) < static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
328         static_cast<int64_t>(temp_value) > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
329       std::string message = "Invalid input, for field: " + key + "and its type: " + data_type +
330                             ", value: " + data[key].dump() + " is out of range while writing mindrecord files.";
331       PopulateMutexErrorData(i, message, err_raw_data);
332       RETURN_STATUS_UNEXPECTED_MR(message);
333     }
334   }
335   return Status::OK();
336 }
337 
CheckSliceData(int start_row,int end_row,json schema,const std::vector<json> & sub_raw_data,std::map<int,std::string> & err_raw_data)338 void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data,
339                                  std::map<int, std::string> &err_raw_data) {
340   if (start_row < 0 || start_row > end_row || end_row > static_cast<int>(sub_raw_data.size())) {
341     return;
342   }
343 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
344   pthread_setname_np(pthread_self(),
345                      std::string(__func__ + std::to_string(start_row) + ":" + std::to_string(end_row)).c_str());
346 #endif
347   for (int i = start_row; i < end_row; i++) {
348     json data = sub_raw_data[i];
349 
350     for (auto iter = schema.begin(); iter != schema.end(); iter++) {
351       std::string key = iter.key();
352       json value = iter.value();
353       if (data.find(key) == data.end()) {
354         std::string message = "'" + key + "' object can not found in data: " + value.dump();
355         PopulateMutexErrorData(i, message, err_raw_data);
356         break;
357       }
358 
359       if (value.size() == kInt2) {
360         // Skip check since all shaped data will store as blob
361         continue;
362       }
363 
364       if (CheckDataTypeAndValue(key, value, data, i, err_raw_data).IsError()) {
365         break;
366       }
367     }
368   }
369 }
370 
CheckData(const std::map<uint64_t,std::vector<json>> & raw_data)371 Status ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &raw_data) {
372   auto rawdata_iter = raw_data.begin();
373 
374   // make sure rawdata match schema
375   for (; rawdata_iter != raw_data.end(); ++rawdata_iter) {
376     // used for storing error
377     std::map<int, std::string> sub_err_mg;
378     int schema_id = rawdata_iter->first;
379     std::shared_ptr<Schema> schema_ptr;
380     RETURN_IF_NOT_OK_MR(shard_header_->GetSchemaByID(schema_id, &schema_ptr));
381     json schema = schema_ptr->GetSchema()["schema"];
382     for (const auto &field : schema_ptr->GetBlobFields()) {
383       (void)schema.erase(field);
384     }
385     std::vector<json> sub_raw_data = rawdata_iter->second;
386 
387     // calculate start position and end position for each thread
388     int batch_size = rawdata_iter->second.size() / shard_count_;
389     int thread_num = shard_count_;
390     CHECK_FAIL_RETURN_UNEXPECTED_MR(thread_num > 0, "[Internal ERROR] 'thread_num' should be positive.");
391     if (thread_num > kMaxThreadCount) {
392       thread_num = kMaxThreadCount;
393     }
394     std::vector<std::thread> thread_set(thread_num);
395 
396     // start multiple thread
397     int start_row = 0, end_row = 0;
398     for (int x = 0; x < thread_num; ++x) {
399       if (x != thread_num - 1) {
400         start_row = batch_size * x;
401         end_row = batch_size * (x + 1);
402       } else {
403         start_row = batch_size * x;
404         end_row = rawdata_iter->second.size();
405       }
406       thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema,
407                                   std::ref(sub_raw_data), std::ref(sub_err_mg));
408     }
409     CHECK_FAIL_RETURN_UNEXPECTED_MR(
410       thread_num <= kMaxThreadCount,
411       "[Internal ERROR] 'thread_num' should be less than or equal to " + std::to_string(kMaxThreadCount));
412     // Wait for threads done
413     for (int x = 0; x < thread_num; ++x) {
414       thread_set[x].join();
415     }
416 
417     (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg));
418   }
419   return Status::OK();
420 }
421 
ValidateRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,std::shared_ptr<std::pair<int,int>> * count_ptr)422 Status ShardWriter::ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,
423                                     std::vector<std::vector<uint8_t>> &blob_data, bool sign,
424                                     std::shared_ptr<std::pair<int, int>> *count_ptr) {
425   RETURN_UNEXPECTED_IF_NULL_MR(count_ptr);
426   auto rawdata_iter = raw_data.begin();
427   schema_count_ = raw_data.size();
428   CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_count_ > 0, "Invalid data, the number of schema should be positive but got: " +
429                                                        std::to_string(schema_count_) +
430                                                        ". Please check the input schema.");
431 
432   // keep schema_id
433   std::set<int64_t> schema_ids;
434   row_count_ = (rawdata_iter->second).size();
435 
436   // Determine if the number of schemas is the same
437   CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_header_->GetSchemas().size() == schema_count_,
438                                   "[Internal ERROR] 'schema_count_' and the schema count in schema: " +
439                                     std::to_string(schema_count_) + " do not match.");
440   // Determine raw_data size == blob_data size
441   CHECK_FAIL_RETURN_UNEXPECTED_MR(raw_data[0].size() == blob_data.size(),
442                                   "[Internal ERROR] raw data size: " + std::to_string(raw_data[0].size()) +
443                                     " is not equal to blob data size: " + std::to_string(blob_data.size()) + ".");
444 
445   // Determine whether the number of samples corresponding to each schema is the same
446   for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
447     CHECK_FAIL_RETURN_UNEXPECTED_MR(row_count_ == rawdata_iter->second.size(),
448                                     "[Internal ERROR] 'row_count_': " + std::to_string(rawdata_iter->second.size()) +
449                                       " for each schema is not the same.");
450     (void)schema_ids.insert(rawdata_iter->first);
451   }
452   const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
453   // There is not enough data which is not matching the number of schema
454   CHECK_FAIL_RETURN_UNEXPECTED_MR(!std::any_of(schemas.begin(), schemas.end(),
455                                                [schema_ids](const std::shared_ptr<Schema> &schema) {
456                                                  return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
457                                                }),
458                                   "[Internal ERROR] schema id in 'schemas' can not found in 'schema_ids'.");
459   if (!sign) {
460     *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
461     return Status::OK();
462   }
463 
464   // check the data according the schema
465   RETURN_IF_NOT_OK_MR(CheckData(raw_data));
466 
467   // delete wrong data from raw data
468   DeleteErrorData(raw_data, blob_data);
469 
470   // update raw count
471   row_count_ = row_count_ - err_mg_.begin()->second.size();
472   *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
473   return Status::OK();
474 }
475 
FillArray(int start,int end,std::map<uint64_t,vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data)476 void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
477                             std::vector<std::vector<uint8_t>> &bin_data) {
478   // Prevent excessive thread opening and cause cross-border
479   if (start >= end) {
480     flag_ = true;
481     return;
482   }
483 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
484   pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(start) + ":" + std::to_string(end)).c_str());
485 #endif
486   int schema_count = static_cast<int>(raw_data.size());
487   std::map<uint64_t, vector<json>>::const_iterator rawdata_iter;
488   for (int x = start; x < end; ++x) {
489     int cnt = 0;
490     for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
491       const json &line = raw_data.at(rawdata_iter->first)[x];
492       std::vector<std::uint8_t> bline = json::to_msgpack(line);
493 
494       // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2]
495       bin_data[x * schema_count + cnt] = bline;
496       cnt++;
497     }
498   }
499 }
500 
LockWriter(bool parallel_writer,std::unique_ptr<int> * fd_ptr)501 Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr) {
502   if (!parallel_writer) {
503     *fd_ptr = std::make_unique<int>(0);
504     return Status::OK();
505   }
506 
507 #if defined(_WIN32) || defined(_WIN64)
508   const int fd = 0;
509   MS_LOG(DEBUG) << "Lock file done by Python.";
510 
511 #else
512   const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
513   if (fd >= 0) {
514     flock(fd, LOCK_EX);
515   } else {
516     close(fd);
517     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to lock file, path: " + lock_file_);
518   }
519 #endif
520 
521   // Open files
522   file_streams_.clear();
523   for (const auto &file : file_paths_) {
524     auto realpath = FileUtils::GetRealPath(file.c_str());
525     if (!realpath.has_value()) {
526       close(fd);
527       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to get real path, path: " + file);
528     }
529     std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
530     fs->open(realpath.value(), std::ios::in | std::ios::out | std::ios::binary);
531     if (fs->fail()) {
532       close(fd);
533       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to open file, path: " + file);
534     }
535     file_streams_.push_back(fs);
536   }
537   auto status = shard_header_->FileToPages(pages_file_);
538   if (status.IsError()) {
539     close(fd);
540     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Error raised in FileToPages function.");
541   }
542   *fd_ptr = std::make_unique<int>(fd);
543   return Status::OK();
544 }
545 
UnlockWriter(int fd,bool parallel_writer)546 Status ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
547   if (!parallel_writer) {
548     return Status::OK();
549   }
550   RETURN_IF_NOT_OK_MR(shard_header_->PagesToFile(pages_file_));
551   for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
552     file_streams_[i]->close();
553   }
554 #if defined(_WIN32) || defined(_WIN64)
555   MS_LOG(DEBUG) << "Unlock file done by Python.";
556 
557 #else
558   flock(fd, LOCK_UN);
559   close(fd);
560 #endif
561   return Status::OK();
562 }
563 
WriteRawDataPreCheck(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,int * schema_count,int * row_count)564 Status ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
565                                          std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
566                                          int *row_count) {
567   // check the free disk size
568   std::string env_free_disk_check = common::GetEnv("MS_FREE_DISK_CHECK");
569   transform(env_free_disk_check.begin(), env_free_disk_check.end(), env_free_disk_check.begin(), ::tolower);
570   bool free_disk_check = true;
571   if (env_free_disk_check == "false") {
572     free_disk_check = false;
573     MS_LOG(INFO) << "environment MS_FREE_DISK_CHECK is false, free disk checking will be turned off.";
574   } else if (env_free_disk_check == "true" || env_free_disk_check == "") {
575     free_disk_check = true;
576     MS_LOG(INFO) << "environment MS_FREE_DISK_CHECK is true, free disk checking will be turned on.";
577   } else {
578     MS_LOG(WARNING) << "environment MS_FREE_DISK_CHECK: " << env_free_disk_check
579                     << " is configured wrong, free disk checking will be turned on.";
580   }
581   if (free_disk_check) {
582     std::shared_ptr<uint64_t> size_ptr;
583     RETURN_IF_NOT_OK_MR(GetDiskSize(file_paths_[0], kFreeSize, &size_ptr));
584     CHECK_FAIL_RETURN_UNEXPECTED_MR(
585       *size_ptr >= kMinFreeDiskSize,
586       "No free disk to be used while writing mindrecord files, available free disk size: " + std::to_string(*size_ptr));
587   }
588   // compress blob
589   if (shard_column_->CheckCompressBlob()) {
590     for (auto &blob : blob_data) {
591       int64_t compression_bytes = 0;
592       blob = shard_column_->CompressBlob(blob, &compression_bytes);
593       compression_size_ += compression_bytes;
594     }
595   }
596 
597   // Add 4-bytes dummy blob data if no any blob fields
598   if (blob_data.size() == 0 && raw_data.size() > 0) {
599     blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
600   }
601 
602   // Add dummy id if all are blob fields
603   if (blob_data.size() > 0 && raw_data.size() == 0) {
604     raw_data.insert(std::pair<uint64_t, std::vector<json>>(0, std::vector<json>(blob_data.size(), kDummyId)));
605   }
606   std::shared_ptr<std::pair<int, int>> count_ptr;
607   RETURN_IF_NOT_OK_MR(ValidateRawData(raw_data, blob_data, sign, &count_ptr));
608   *schema_count = (*count_ptr).first;
609   *row_count = (*count_ptr).second;
610   return Status::OK();
611 }
MergeBlobData(const std::vector<string> & blob_fields,const std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> & row_bin_data,std::shared_ptr<std::vector<uint8_t>> * output)612 Status ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
613                                   const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
614                                   std::shared_ptr<std::vector<uint8_t>> *output) {
615   if (blob_fields.empty()) {
616     return Status::OK();
617   }
618   if (blob_fields.size() == 1) {
619     auto &blob = row_bin_data.at(blob_fields[0]);
620     auto blob_size = blob->size();
621     *output = std::make_shared<std::vector<uint8_t>>(blob_size);
622     (void)std::copy(blob->begin(), blob->end(), (*output)->begin());
623   } else {
624     size_t output_size = 0;
625     for (auto &field : blob_fields) {
626       output_size += row_bin_data.at(field)->size();
627     }
628     output_size += blob_fields.size() * sizeof(uint64_t);
629     *output = std::make_shared<std::vector<uint8_t>>(output_size);
630     std::vector<uint8_t> buf(sizeof(uint64_t), 0);
631     size_t idx = 0;
632     for (auto &field : blob_fields) {
633       auto &b = row_bin_data.at(field);
634       uint64_t blob_size = b->size();
635       // big edian
636       for (size_t i = 0; i < buf.size(); ++i) {
637         buf[buf.size() - 1 - i] = (std::numeric_limits<uint8_t>::max()) & blob_size;
638         blob_size >>= 8u;
639       }
640       (void)std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
641       idx += buf.size();
642       (void)std::copy(b->begin(), b->end(), (*output)->begin() + idx);
643       idx += b->size();
644     }
645   }
646   return Status::OK();
647 }
648 
WriteRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,bool parallel_writer)649 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
650                                  std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
651   // Lock Writer if loading data parallel
652   std::unique_ptr<int> fd_ptr;
653   RETURN_IF_NOT_OK_MR(LockWriter(parallel_writer, &fd_ptr));
654 
655   // Get the count of schemas and rows
656   int schema_count = 0;
657   int row_count = 0;
658 
659   // Serialize raw data
660   RETURN_IF_NOT_OK_MR(WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count));
661   CHECK_FAIL_RETURN_UNEXPECTED_MR(row_count >= kInt0, "[Internal ERROR] the size of raw data should be positive.");
662   if (row_count == kInt0) {
663     return Status::OK();
664   }
665   std::vector<std::vector<uint8_t>> bin_raw_data(row_count * schema_count);
666   // Serialize raw data
667   RETURN_IF_NOT_OK_MR(SerializeRawData(raw_data, bin_raw_data, row_count));
668   // Set row size of raw data
669   RETURN_IF_NOT_OK_MR(SetRawDataSize(bin_raw_data));
670   // Set row size of blob data
671   RETURN_IF_NOT_OK_MR(SetBlobDataSize(blob_data));
672   // Write data to disk with multi threads
673   RETURN_IF_NOT_OK_MR(ParallelWriteData(blob_data, bin_raw_data));
674   MS_LOG(INFO) << "Succeed to write " << bin_raw_data.size() << " records.";
675 
676   RETURN_IF_NOT_OK_MR(UnlockWriter(*fd_ptr, parallel_writer));
677 
678   return Status::OK();
679 }
680 
WriteRawData(std::map<uint64_t,std::vector<py::handle>> & raw_data,std::map<uint64_t,std::vector<py::handle>> & blob_data,bool sign,bool parallel_writer)681 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
682                                  std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
683                                  bool parallel_writer) {
684   std::map<uint64_t, std::vector<json>> raw_data_json;
685   std::map<uint64_t, std::vector<json>> blob_data_json;
686 
687   (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
688                        [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
689                          auto &py_raw_data = pair.second;
690                          std::vector<json> json_raw_data;
691                          (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
692                                               [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
693                          return std::make_pair(pair.first, std::move(json_raw_data));
694                        });
695 
696   (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()),
697                        [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
698                          auto &py_blob_data = pair.second;
699                          std::vector<json> jsonBlobData;
700                          (void)std::transform(py_blob_data.begin(), py_blob_data.end(),
701                                               std::back_inserter(jsonBlobData),
702                                               [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
703                          return std::make_pair(pair.first, std::move(jsonBlobData));
704                        });
705 
706   // Serialize blob page
707   auto blob_data_iter = blob_data.begin();
708   auto schema_count = blob_data.size();
709   auto row_count = blob_data_iter->second.size();
710 
711   std::vector<std::vector<uint8_t>> bin_blob_data(row_count * schema_count);
712   // Serialize blob data
713   RETURN_IF_NOT_OK_MR(SerializeRawData(blob_data_json, bin_blob_data, row_count));
714   return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
715 }
716 
ParallelWriteData(const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)717 Status ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
718                                       const std::vector<std::vector<uint8_t>> &bin_raw_data) {
719   auto shards = BreakIntoShards();
720   // define the number of thread
721   int thread_num = static_cast<int>(shard_count_);
722   CHECK_FAIL_RETURN_UNEXPECTED_MR(thread_num > 0, "[Internal ERROR] 'thread_num' should be positive.");
723   if (thread_num > kMaxThreadCount) {
724     thread_num = kMaxThreadCount;
725   }
726   int left_thread = shard_count_;
727   int current_thread = 0;
728   while (left_thread) {
729     if (left_thread < thread_num) {
730       thread_num = left_thread;
731     }
732     // Start one thread for one shard
733     std::vector<std::thread> thread_set(thread_num);
734     if (thread_num <= kMaxThreadCount) {
735       for (int x = 0; x < thread_num; ++x) {
736         int start_row = shards[current_thread + x].first;
737         int end_row = shards[current_thread + x].second;
738         thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row,
739                                     std::ref(blob_data), std::ref(bin_raw_data));
740       }
741       // Wait for threads done
742       for (int x = 0; x < thread_num; ++x) {
743         thread_set[x].join();
744       }
745       left_thread -= thread_num;
746       current_thread += thread_num;
747     }
748   }
749   return Status::OK();
750 }
751 
WriteByShard(int shard_id,int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)752 Status ShardWriter::WriteByShard(int shard_id, int start_row, int end_row,
753                                  const std::vector<std::vector<uint8_t>> &blob_data,
754                                  const std::vector<std::vector<uint8_t>> &bin_raw_data) {
755   MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row
756                 << ", schema size: " << schema_count_;
757   if (start_row == end_row) {
758     return Status::OK();
759   }
760 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
761   pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_id)).c_str());
762 #endif
763   vector<std::pair<int, int>> rows_in_group;
764   std::shared_ptr<Page> last_raw_page = nullptr;
765   std::shared_ptr<Page> last_blob_page = nullptr;
766   RETURN_IF_NOT_OK_MR(SetLastRawPage(shard_id, last_raw_page));
767   RETURN_IF_NOT_OK_MR(SetLastBlobPage(shard_id, last_blob_page));
768 
769   RETURN_IF_NOT_OK_MR(CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page));
770   RETURN_IF_NOT_OK_MR(AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
771   RETURN_IF_NOT_OK_MR(NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
772   RETURN_IF_NOT_OK_MR(ShiftRawPage(shard_id, rows_in_group, last_raw_page));
773   RETURN_IF_NOT_OK_MR(WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data));
774 
775   return Status::OK();
776 }
777 
CutRowGroup(int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_raw_page,const std::shared_ptr<Page> & last_blob_page)778 Status ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
779                                 std::vector<std::pair<int, int>> &rows_in_group,
780                                 const std::shared_ptr<Page> &last_raw_page,
781                                 const std::shared_ptr<Page> &last_blob_page) {
782   auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0;
783 
784   auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
785   auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0;
786   auto n_byte_raw = last_raw_page_size - last_raw_offset;
787 
788   int page_start_row = start_row;
789   CHECK_FAIL_RETURN_UNEXPECTED_MR(start_row <= end_row,
790                                   "[Internal ERROR] 'start_row': " + std::to_string(start_row) +
791                                     " should be less than or equal to 'end_row': " + std::to_string(end_row));
792 
793   CHECK_FAIL_RETURN_UNEXPECTED_MR(
794     end_row <= static_cast<int>(blob_data_size_.size()) && end_row <= static_cast<int>(raw_data_size_.size()),
795     "[Internal ERROR] 'end_row': " + std::to_string(end_row) + " should be less than 'blob_data_size': " +
796       std::to_string(blob_data_size_.size()) + " and 'raw_data_size': " + std::to_string(raw_data_size_.size()) + ".");
797   for (int i = start_row; i < end_row; ++i) {
798     // n_byte_blob(0) indicate appendBlobPage
799     if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ ||
800         n_byte_raw + raw_data_size_[i] > page_size_) {
801       rows_in_group.emplace_back(page_start_row, i);
802       page_start_row = i;
803       n_byte_blob = blob_data_size_[i];
804       n_byte_raw = raw_data_size_[i];
805     } else {
806       n_byte_blob += blob_data_size_[i];
807       n_byte_raw += raw_data_size_[i];
808     }
809   }
810 
811   // Not forget last one
812   rows_in_group.emplace_back(page_start_row, end_row);
813   return Status::OK();
814 }
815 
AppendBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)816 Status ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
817                                    const std::vector<std::pair<int, int>> &rows_in_group,
818                                    const std::shared_ptr<Page> &last_blob_page) {
819   auto blob_row = rows_in_group[0];
820   if (blob_row.first == blob_row.second) {
821     return Status::OK();
822   }
823   // Write disk
824   auto page_id = last_blob_page->GetPageID();
825   auto bytes_page = last_blob_page->GetPageSize();
826   auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg);
827   if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
828     file_streams_[shard_id]->close();
829     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
830   }
831 
832   (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
833 
834   // Update last blob page
835   bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
836   last_blob_page->SetPageSize(bytes_page);
837   uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first;
838   last_blob_page->SetEndRowID(end_row);
839   (void)shard_header_->SetPage(last_blob_page);
840   return Status::OK();
841 }
842 
NewBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)843 Status ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
844                                 const std::vector<std::pair<int, int>> &rows_in_group,
845                                 const std::shared_ptr<Page> &last_blob_page) {
846   auto page_id = shard_header_->GetLastPageId(shard_id);
847   auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1;
848   auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0;
849   // index(0) indicate appendBlobPage
850   for (uint32_t i = 1; i < rows_in_group.size(); ++i) {
851     auto blob_row = rows_in_group[i];
852 
853     // Write 1 blob page to disk
854     auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
855     if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
856       file_streams_[shard_id]->close();
857       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
858     }
859 
860     (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
861     // Create new page info for header
862     auto page_size =
863       std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
864     std::vector<std::pair<int, uint64_t>> row_group_ids;
865     auto start_row = current_row;
866     auto end_row = start_row + blob_row.second - blob_row.first;
867     auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size);
868     (void)shard_header_->AddPage(std::make_shared<Page>(page));
869     current_row = end_row;
870   }
871   return Status::OK();
872 }
873 
ShiftRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page)874 Status ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
875                                  std::shared_ptr<Page> &last_raw_page) {
876   auto blob_row = rows_in_group[0];
877   if (blob_row.first == blob_row.second) {
878     return Status::OK();
879   }
880   auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
881   if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) +
882         last_raw_page_size <=
883       page_size_) {
884     return Status::OK();
885   }
886   auto page_id = shard_header_->GetLastPageId(shard_id);
887   auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second;
888   auto last_raw_page_id = last_raw_page->GetPageID();
889   auto shift_size = last_raw_page_size - last_row_group_id_offset;
890 
891   std::vector<uint8_t> buf(shift_size);
892 
893   // Read last row group from previous raw data page
894   CHECK_FAIL_RETURN_UNEXPECTED_MR(
895     shard_id >= 0 && shard_id < file_streams_.size(),
896     "[Internal ERROR] 'shard_id' should be in range [0, " + std::to_string(file_streams_.size()) + ").");
897 
898   auto &io_seekg = file_streams_[shard_id]->seekg(
899     page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg);
900   if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
901     file_streams_[shard_id]->close();
902     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
903   }
904 
905   auto &io_read = file_streams_[shard_id]->read(reinterpret_cast<char *>(&buf[0]), buf.size());
906   if (!io_read.good() || io_read.fail() || io_read.bad()) {
907     file_streams_[shard_id]->close();
908     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
909   }
910 
911   // Merge into new row group at new raw data page
912   auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
913   if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
914     file_streams_[shard_id]->close();
915     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
916   }
917 
918   auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&buf[0]), buf.size());
919   if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
920     file_streams_[shard_id]->close();
921     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
922   }
923   last_raw_page->DeleteLastGroupId();
924   (void)shard_header_->SetPage(last_raw_page);
925 
926   // Refresh page info in header
927   int row_group_id = last_raw_page->GetLastRowGroupID().first + 1;
928   std::vector<std::pair<int, uint64_t>> row_group_ids;
929   row_group_ids.emplace_back(row_group_id, 0);
930   int page_type_id = last_raw_page->GetPageID();
931   auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size);
932   (void)shard_header_->AddPage(std::make_shared<Page>(page));
933 
934   // Reset: last raw page
935   RETURN_IF_NOT_OK_MR(SetLastRawPage(shard_id, last_raw_page));
936   return Status::OK();
937 }
938 
WriteRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)939 Status ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
940                                  std::shared_ptr<Page> &last_raw_page,
941                                  const std::vector<std::vector<uint8_t>> &bin_raw_data) {
942   int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1;
943   for (uint32_t i = 0; i < rows_in_group.size(); ++i) {
944     const auto &blob_row = rows_in_group[i];
945     if (blob_row.first == blob_row.second) {
946       continue;
947     }
948     auto raw_size =
949       std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0);
950     if (!last_raw_page) {
951       RETURN_IF_NOT_OK_MR(EmptyRawPage(shard_id, last_raw_page));
952     } else if (last_raw_page->GetPageSize() + raw_size > page_size_) {
953       RETURN_IF_NOT_OK_MR(shard_header_->SetPage(last_raw_page));
954       RETURN_IF_NOT_OK_MR(EmptyRawPage(shard_id, last_raw_page));
955     }
956     RETURN_IF_NOT_OK_MR(AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data));
957   }
958   RETURN_IF_NOT_OK_MR(shard_header_->SetPage(last_raw_page));
959   return Status::OK();
960 }
961 
EmptyRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)962 Status ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
963   auto row_group_ids = std::vector<std::pair<int, uint64_t>>();
964   auto page_id = shard_header_->GetLastPageId(shard_id);
965   auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1;
966   auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0);
967   RETURN_IF_NOT_OK_MR(shard_header_->AddPage(std::make_shared<Page>(page)));
968   RETURN_IF_NOT_OK_MR(SetLastRawPage(shard_id, last_raw_page));
969   return Status::OK();
970 }
971 
AppendRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,int & last_row_group_id,std::shared_ptr<Page> last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)972 Status ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
973                                   const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page,
974                                   const std::vector<std::vector<uint8_t>> &bin_raw_data) {
975   std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds();
976   auto last_raw_page_id = last_raw_page->GetPageID();
977   auto n_bytes = last_raw_page->GetPageSize();
978 
979   //  previous raw data page
980   auto &io_seekp =
981     file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg);
982   if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
983     file_streams_[shard_id]->close();
984     RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
985   }
986 
987   if (chunk_id > 0) {
988     row_group_ids.emplace_back(++last_row_group_id, n_bytes);
989   }
990   n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first,
991                              raw_data_size_.begin() + rows_in_group[chunk_id].second, 0);
992   RETURN_IF_NOT_OK_MR(FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data));
993 
994   // Update previous raw data page
995   last_raw_page->SetPageSize(n_bytes);
996   last_raw_page->SetRowGroupIds(row_group_ids);
997   RETURN_IF_NOT_OK_MR(shard_header_->SetPage(last_raw_page));
998 
999   return Status::OK();
1000 }
1001 
FlushBlobChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::vector<uint8_t>> & blob_data,const std::pair<int,int> & blob_row)1002 Status ShardWriter::FlushBlobChunk(const std::shared_ptr<std::fstream> &out,
1003                                    const std::vector<std::vector<uint8_t>> &blob_data,
1004                                    const std::pair<int, int> &blob_row) {
1005   CHECK_FAIL_RETURN_UNEXPECTED_MR(
1006     blob_row.first <= blob_row.second && blob_row.second <= static_cast<int>(blob_data.size()) && blob_row.first >= 0,
1007     "[Internal ERROR] 'blob_row': " + std::to_string(blob_row.first) + ", " + std::to_string(blob_row.second) +
1008       " is invalid.");
1009   for (int j = blob_row.first; j < blob_row.second; ++j) {
1010     // Write the size of blob
1011     uint64_t line_len = blob_data[j].size();
1012     auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1013     if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1014       out->close();
1015       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1016     }
1017 
1018     // Write the data of blob
1019     auto line = blob_data[j];
1020     auto &io_handle_data = out->write(reinterpret_cast<char *>(&line[0]), line_len);
1021     if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) {
1022       out->close();
1023       RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1024     }
1025   }
1026   return Status::OK();
1027 }
1028 
FlushRawChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,const std::vector<std::vector<uint8_t>> & bin_raw_data)1029 Status ShardWriter::FlushRawChunk(const std::shared_ptr<std::fstream> &out,
1030                                   const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
1031                                   const std::vector<std::vector<uint8_t>> &bin_raw_data) {
1032   for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) {
1033     // Write the size of multi schemas
1034     for (uint32_t j = 0; j < schema_count_; ++j) {
1035       uint64_t line_len = bin_raw_data[i * schema_count_ + j].size();
1036       auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1037       if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1038         out->close();
1039         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1040       }
1041     }
1042     // Write the data of multi schemas
1043     for (uint32_t j = 0; j < schema_count_; ++j) {
1044       auto line = bin_raw_data[i * schema_count_ + j];
1045       auto &io_handle = out->write(reinterpret_cast<char *>(&line[0]), line.size());
1046       if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1047         out->close();
1048         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1049       }
1050     }
1051   }
1052   return Status::OK();
1053 }
1054 
1055 // Allocate data to shards evenly
BreakIntoShards()1056 std::vector<std::pair<int, int>> ShardWriter::BreakIntoShards() {
1057   std::vector<std::pair<int, int>> shards;
1058   int row_in_shard = row_count_ / shard_count_;
1059   int remains = row_count_ % shard_count_;
1060 
1061   std::vector<int> v_list(shard_count_);
1062   std::iota(v_list.begin(), v_list.end(), 0);
1063 
1064   std::mt19937 g = GetRandomDevice();
1065   std::shuffle(v_list.begin(), v_list.end(), g);
1066   std::unordered_set<int> set(v_list.begin(), v_list.begin() + remains);
1067 
1068   if (shard_count_ <= kMaxShardCount) {
1069     int start_row = 0;
1070     for (int i = 0; i < shard_count_; ++i) {
1071       int end_row = start_row + row_in_shard;
1072       if (set.count(i) == 1) {
1073         end_row++;
1074       }
1075       shards.emplace_back(start_row, end_row);
1076       start_row = end_row;
1077     }
1078   }
1079   return shards;
1080 }
1081 
WriteShardHeader()1082 Status ShardWriter::WriteShardHeader() {
1083   RETURN_UNEXPECTED_IF_NULL_MR(shard_header_);
1084   int64_t compression_temp = compression_size_;
1085   uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
1086   shard_header_->SetCompressionSize(compression_size);
1087 
1088   auto shard_header = shard_header_->SerializeHeader();
1089   // Write header data to multi files
1090   CHECK_FAIL_RETURN_UNEXPECTED_MR(
1091     shard_count_ <= static_cast<int>(file_streams_.size()) && shard_count_ <= static_cast<int>(shard_header.size()),
1092     "[Internal ERROR] 'shard_count_' should be less than or equal to 'file_stream_' size: " +
1093       std::to_string(file_streams_.size()) + ", and 'shard_header' size: " + std::to_string(shard_header.size()) + ".");
1094   if (shard_count_ <= kMaxShardCount) {
1095     for (int shard_id = 0; shard_id < shard_count_; ++shard_id) {
1096       auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg);
1097       if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
1098         file_streams_[shard_id]->close();
1099         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekp file.");
1100       }
1101 
1102       std::vector<uint8_t> bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end());
1103       uint64_t line_len = bin_header.size();
1104       if (line_len + kInt64Len > header_size_) {
1105         file_streams_[shard_id]->close();
1106         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] shard header is too big.");
1107       }
1108       auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1109       if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1110         file_streams_[shard_id]->close();
1111         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1112       }
1113 
1114       auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast<char *>(&bin_header[0]), line_len);
1115       if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) {
1116         file_streams_[shard_id]->close();
1117         RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1118       }
1119       file_streams_[shard_id]->close();
1120     }
1121   }
1122   return Status::OK();
1123 }
1124 
SerializeRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data,uint32_t row_count)1125 Status ShardWriter::SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,
1126                                      std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count) {
1127   // define the number of thread
1128   uint32_t thread_num = std::thread::hardware_concurrency();
1129   if (thread_num == 0) {
1130     thread_num = kThreadNumber;
1131   }
1132   // Set the number of samples processed by each thread
1133   int group_num = static_cast<int>(ceil(row_count * 1.0 / thread_num));
1134   std::vector<std::thread> thread_set(thread_num);
1135   int work_thread_num = 0;
1136   for (uint32_t x = 0; x < thread_num; ++x) {
1137     int start_num = x * group_num;
1138     int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num;
1139     if (start_num >= end_num) {
1140       continue;
1141     }
1142     // Define the run boundary and start the child thread
1143     thread_set[x] =
1144       std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data));
1145     work_thread_num++;
1146   }
1147   for (uint32_t x = 0; x < work_thread_num; ++x) {
1148     // Set obstacles to prevent the main thread from running
1149     thread_set[x].join();
1150   }
1151   CHECK_FAIL_RETURN_SYNTAX_ERROR_MR(flag_ != true, "[Internal ERROR] Error raised in FillArray function.");
1152   return Status::OK();
1153 }
1154 
SetRawDataSize(const std::vector<std::vector<uint8_t>> & bin_raw_data)1155 Status ShardWriter::SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data) {
1156   raw_data_size_ = std::vector<uint64_t>(row_count_, 0);
1157   for (uint32_t i = 0; i < row_count_; ++i) {
1158     raw_data_size_[i] = std::accumulate(
1159       bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0,
1160       [](uint64_t accumulator, const std::vector<uint8_t> &row) { return accumulator + kInt64Len + row.size(); });
1161   }
1162   CHECK_FAIL_RETURN_SYNTAX_ERROR_MR(*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_,
1163                                     "Invalid data, Page size: " + std::to_string(page_size_) +
1164                                       " is too small to save a raw row. Please try to use the mindrecord api "
1165                                       "'set_page_size(value)' to enable larger page size, and the value range is in [" +
1166                                       std::to_string(kMinPageSize) + " bytes, " + std::to_string(kMaxPageSize) +
1167                                       " bytes].");
1168   return Status::OK();
1169 }
1170 
SetBlobDataSize(const std::vector<std::vector<uint8_t>> & blob_data)1171 Status ShardWriter::SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data) {
1172   blob_data_size_ = std::vector<uint64_t>(row_count_);
1173   (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(),
1174                        [](const std::vector<uint8_t> &row) { return kInt64Len + row.size(); });
1175   CHECK_FAIL_RETURN_SYNTAX_ERROR_MR(*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_,
1176                                     "Invalid data, Page size: " + std::to_string(page_size_) +
1177                                       " is too small to save a blob row. Please try to use the mindrecord api "
1178                                       "'set_page_size(value)' to enable larger page size, and the value range is in [" +
1179                                       std::to_string(kMinPageSize) + " bytes, " + std::to_string(kMaxPageSize) +
1180                                       " bytes].");
1181   return Status::OK();
1182 }
1183 
SetLastRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)1184 Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
1185   // Get last raw page
1186   auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw);
1187   if (last_raw_page_id == -1) {
1188     return Status::OK();
1189   }
1190   RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, last_raw_page_id, &last_raw_page));
1191   return Status::OK();
1192 }
1193 
SetLastBlobPage(const int & shard_id,std::shared_ptr<Page> & last_blob_page)1194 Status ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page) {
1195   // Get last blob page
1196   auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob);
1197   if (last_blob_page_id == -1) {
1198     return Status::OK();
1199   }
1200   RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, last_blob_page_id, &last_blob_page));
1201   return Status::OK();
1202 }
1203 
Initialize(const std::unique_ptr<ShardWriter> * writer_ptr,const std::vector<std::string> & file_names)1204 Status ShardWriter::Initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
1205                                const std::vector<std::string> &file_names) {
1206   RETURN_UNEXPECTED_IF_NULL_MR(writer_ptr);
1207   RETURN_IF_NOT_OK_MR((*writer_ptr)->Open(file_names, false));
1208   RETURN_IF_NOT_OK_MR((*writer_ptr)->SetHeaderSize(kDefaultHeaderSize));
1209   RETURN_IF_NOT_OK_MR((*writer_ptr)->SetPageSize(kDefaultPageSize));
1210   return Status::OK();
1211 }
1212 }  // namespace mindrecord
1213 }  // namespace mindspore
1214