1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/mindrecord/include/shard_writer.h"
18 #include "utils/file_utils.h"
19 #include "utils/ms_utils.h"
20 #include "minddata/mindrecord/include/common/shard_utils.h"
21 #include "securec.h"
22
23 namespace mindspore {
24 namespace mindrecord {
ShardWriter()25 ShardWriter::ShardWriter()
26 : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
27 compression_size_ = 0;
28 }
29
~ShardWriter()30 ShardWriter::~ShardWriter() {
31 for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
32 file_streams_[i]->close();
33 }
34 }
35
GetFullPathFromFileName(const std::vector<std::string> & paths)36 Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
37 // Get full path from file name
38 for (const auto &path : paths) {
39 CHECK_FAIL_RETURN_UNEXPECTED_MR(CheckIsValidUtf8(path),
40 "Invalid file, mindrecord file name: " + path +
41 " contains invalid uft-8 character. Please rename mindrecord file name.");
42 // get realpath
43 std::optional<std::string> dir = "";
44 std::optional<std::string> local_file_name = "";
45 FileUtils::SplitDirAndFileName(path, &dir, &local_file_name);
46 if (!dir.has_value()) {
47 dir = ".";
48 }
49
50 auto realpath = FileUtils::GetRealPath(dir.value().c_str());
51 CHECK_FAIL_RETURN_UNEXPECTED_MR(
52 realpath.has_value(),
53 "Invalid dir, failed to get the realpath of mindrecord file dir. Please check path: " + dir.value());
54
55 std::optional<std::string> whole_path = "";
56 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
57
58 (void)file_paths_.emplace_back(whole_path.value());
59 }
60 return Status::OK();
61 }
62
OpenDataFiles(bool append,bool overwrite)63 Status ShardWriter::OpenDataFiles(bool append, bool overwrite) {
64 // Open files
65 for (const auto &file : file_paths_) {
66 std::optional<std::string> dir = "";
67 std::optional<std::string> local_file_name = "";
68 FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
69 if (!dir.has_value()) {
70 dir = ".";
71 }
72
73 auto realpath = FileUtils::GetRealPath(dir.value().c_str());
74 CHECK_FAIL_RETURN_UNEXPECTED_MR(
75 realpath.has_value(), "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
76
77 std::optional<std::string> whole_path = "";
78 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
79
80 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
81 if (!append) {
82 // if not append && mindrecord or db file exist
83 fs->open(whole_path.value(), std::ios::in | std::ios::binary);
84 std::ifstream fs_db(whole_path.value() + ".db");
85 if (fs->good() || fs_db.good()) {
86 fs->close();
87 fs_db.close();
88 if (overwrite) {
89 auto res1 = std::remove(whole_path.value().c_str());
90 CHECK_FAIL_RETURN_UNEXPECTED_MR(!std::ifstream(whole_path.value()) == true,
91 "Invalid file, failed to remove the old files when trying to overwrite "
92 "mindrecord files. Please check file path and permission: " +
93 file);
94 if (res1 == 0) {
95 MS_LOG(WARNING) << "Succeed to remove the old mindrecord files, path: " << file;
96 }
97 auto db_file = whole_path.value() + ".db";
98 auto res2 = std::remove(db_file.c_str());
99 CHECK_FAIL_RETURN_UNEXPECTED_MR(!std::ifstream(whole_path.value() + ".db") == true,
100 "Invalid file, failed to remove the old mindrecord meta files when trying to "
101 "overwrite mindrecord files. Please check file path and permission: " +
102 file + ".db");
103 if (res2 == 0) {
104 MS_LOG(WARNING) << "Succeed to remove the old mindrecord metadata files, path: " << file + ".db";
105 }
106 } else {
107 RETURN_STATUS_UNEXPECTED_MR(
108 "Invalid file, mindrecord files already exist. Please check file path: " + file +
109 +".\nIf you do not want to keep the files, set the 'overwrite' parameter to True and try again.");
110 }
111 } else {
112 fs->close();
113 fs_db.close();
114 }
115 // open the mindrecord file to write
116 fs->open(whole_path.value().data(), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
117 if (!fs->good()) {
118 fs->close();
119 RETURN_STATUS_UNEXPECTED_MR(
120 "Invalid file, failed to open files for writing mindrecord files. Please check file path, permission and "
121 "open file limit: " +
122 file);
123 }
124 } else {
125 // open the mindrecord file to append
126 fs->open(whole_path.value().data(), std::ios::out | std::ios::in | std::ios::binary);
127 if (!fs->good()) {
128 fs->close();
129 RETURN_STATUS_UNEXPECTED_MR(
130 "Invalid file, failed to open files for appending mindrecord files. Please check file path, permission and "
131 "open file limit: " +
132 file);
133 }
134 }
135 MS_LOG(INFO) << "Succeed to open mindrecord shard file, path: " << file;
136 file_streams_.push_back(fs);
137 }
138 return Status::OK();
139 }
140
RemoveLockFile()141 Status ShardWriter::RemoveLockFile() {
142 // Remove temporary file
143 int ret = std::remove(pages_file_.c_str());
144 if (ret == 0) {
145 MS_LOG(DEBUG) << "Succeed to remove page file, path: " << pages_file_;
146 }
147
148 ret = std::remove(lock_file_.c_str());
149 if (ret == 0) {
150 MS_LOG(DEBUG) << "Succeed to remove lock file, path: " << lock_file_;
151 }
152 return Status::OK();
153 }
154
InitLockFile()155 Status ShardWriter::InitLockFile() {
156 CHECK_FAIL_RETURN_UNEXPECTED_MR(file_paths_.size() != 0, "[Internal ERROR] 'file_paths_' is not initialized.");
157
158 lock_file_ = file_paths_[0] + kLockFileSuffix;
159 pages_file_ = file_paths_[0] + kPageFileSuffix;
160 RETURN_IF_NOT_OK_MR(RemoveLockFile());
161 return Status::OK();
162 }
163
Open(const std::vector<std::string> & paths,bool append,bool overwrite)164 Status ShardWriter::Open(const std::vector<std::string> &paths, bool append, bool overwrite) {
165 shard_count_ = paths.size();
166 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_count_ <= kMaxSchemaCount,
167 "[Internal ERROR] 'schema_count_' should be less than or equal to " +
168 std::to_string(kMaxSchemaCount) + ", but got: " + std::to_string(schema_count_));
169
170 // Get full path from file name
171 RETURN_IF_NOT_OK_MR(GetFullPathFromFileName(paths));
172 // Open files
173 RETURN_IF_NOT_OK_MR(OpenDataFiles(append, overwrite));
174 // Init lock file
175 RETURN_IF_NOT_OK_MR(InitLockFile());
176 return Status::OK();
177 }
178
OpenForAppend(const std::string & path)179 Status ShardWriter::OpenForAppend(const std::string &path) {
180 RETURN_IF_NOT_OK_MR(CheckFile(path));
181 std::shared_ptr<json> header_ptr;
182 RETURN_IF_NOT_OK_MR(ShardHeader::BuildSingleHeader(path, &header_ptr));
183 auto ds = std::make_shared<std::vector<std::string>>();
184 RETURN_IF_NOT_OK_MR(GetDatasetFiles(path, (*header_ptr)["shard_addresses"], &ds));
185 ShardHeader header = ShardHeader();
186 RETURN_IF_NOT_OK_MR(header.BuildDataset(*ds));
187 shard_header_ = std::make_shared<ShardHeader>(header);
188 RETURN_IF_NOT_OK_MR(SetHeaderSize(shard_header_->GetHeaderSize()));
189 RETURN_IF_NOT_OK_MR(SetPageSize(shard_header_->GetPageSize()));
190 compression_size_ = shard_header_->GetCompressionSize();
191 RETURN_IF_NOT_OK_MR(Open(*ds, true));
192 shard_column_ = std::make_shared<ShardColumn>(shard_header_);
193 return Status::OK();
194 }
195
Commit()196 Status ShardWriter::Commit() {
197 // Read pages file
198 std::ifstream page_file(pages_file_.c_str(), std::ios::in);
199 if (page_file.good()) {
200 page_file.close();
201 RETURN_IF_NOT_OK_MR(shard_header_->FileToPages(pages_file_));
202 }
203 RETURN_IF_NOT_OK_MR(WriteShardHeader());
204 MS_LOG(INFO) << "Succeed to write meta data.";
205 // Remove lock file
206 RETURN_IF_NOT_OK_MR(RemoveLockFile());
207
208 return Status::OK();
209 }
210
SetShardHeader(std::shared_ptr<ShardHeader> header_data)211 Status ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) {
212 CHECK_FAIL_RETURN_UNEXPECTED_MR(
213 header_data->GetSchemaCount() > 0,
214 "Invalid data, schema is not found in header, please use 'add_schema' to add a schema for new mindrecord files.");
215 RETURN_IF_NOT_OK_MR(header_data->InitByFiles(file_paths_));
216 // set fields in mindrecord when empty
217 std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
218 if (fields.empty()) {
219 MS_LOG(DEBUG) << "Index field is not set, it will be generated automatically.";
220 std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
221 for (const auto &schema : schemas) {
222 json jsonSchema = schema->GetSchema()["schema"];
223 for (const auto &el : jsonSchema.items()) {
224 if (el.value()["type"] == "string" ||
225 (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) ||
226 (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) ||
227 (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) ||
228 (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) {
229 fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key()));
230 }
231 }
232 }
233 // only blob data
234 if (!fields.empty()) {
235 RETURN_IF_NOT_OK_MR(header_data->AddIndexFields(fields));
236 }
237 }
238
239 shard_header_ = header_data;
240 shard_header_->SetHeaderSize(header_size_);
241 shard_header_->SetPageSize(page_size_);
242 shard_column_ = std::make_shared<ShardColumn>(shard_header_);
243 return Status::OK();
244 }
245
SetHeaderSize(const uint64_t & header_size)246 Status ShardWriter::SetHeaderSize(const uint64_t &header_size) {
247 // header_size [16KB, 128MB]
248 CHECK_FAIL_RETURN_UNEXPECTED_MR(header_size >= kMinHeaderSize && header_size <= kMaxHeaderSize,
249 "Invalid data, header size: " + std::to_string(header_size) +
250 " should be in range [" + std::to_string(kMinHeaderSize) + " bytes, " +
251 std::to_string(kMaxHeaderSize) + " bytes].");
252 CHECK_FAIL_RETURN_UNEXPECTED_MR(
253 header_size % 4 == 0, "Invalid data, header size " + std::to_string(header_size) + " should be divided by four.");
254 header_size_ = header_size;
255 return Status::OK();
256 }
257
SetPageSize(const uint64_t & page_size)258 Status ShardWriter::SetPageSize(const uint64_t &page_size) {
259 // PageSize [32KB, 256MB]
260 CHECK_FAIL_RETURN_UNEXPECTED_MR(page_size >= kMinPageSize && page_size <= kMaxPageSize,
261 "Invalid data, page size: " + std::to_string(page_size) + " should be in range [" +
262 std::to_string(kMinPageSize) + " bytes, " + std::to_string(kMaxPageSize) +
263 " bytes].");
264 CHECK_FAIL_RETURN_UNEXPECTED_MR(
265 page_size % 4 == 0, "Invalid data, page size " + std::to_string(page_size) + " should be divided by four.");
266 page_size_ = page_size;
267 return Status::OK();
268 }
269
DeleteErrorData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data)270 void ShardWriter::DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data,
271 std::vector<std::vector<uint8_t>> &blob_data) {
272 // get wrong data location
273 std::set<int, std::greater<int>> delete_set;
274 for (auto &err_mg : err_mg_) {
275 uint64_t id = err_mg.first;
276 auto sub_err_mg = err_mg.second;
277 for (auto &subMg : sub_err_mg) {
278 int loc = subMg.first;
279 std::string message = subMg.second;
280 MS_LOG(ERROR) << "Invalid input, the " << loc + 1
281 << " th data provided by user is invalid while writing mindrecord files. Please fix the error: "
282 << message;
283 (void)delete_set.insert(loc);
284 }
285 }
286
287 auto it = raw_data.begin();
288 if (delete_set.size() == it->second.size()) {
289 raw_data.clear();
290 blob_data.clear();
291 return;
292 }
293
294 // delete wrong raw data
295 for (auto &loc : delete_set) {
296 // delete row data
297 for (auto &raw : raw_data) {
298 (void)raw.second.erase(raw.second.begin() + loc);
299 }
300
301 // delete blob data
302 (void)blob_data.erase(blob_data.begin() + loc);
303 }
304 }
305
PopulateMutexErrorData(const int & row,const std::string & message,std::map<int,std::string> & err_raw_data)306 void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message,
307 std::map<int, std::string> &err_raw_data) {
308 std::lock_guard<std::mutex> lock(check_mutex_);
309 (void)err_raw_data.insert(std::make_pair(row, message));
310 }
311
CheckDataTypeAndValue(const std::string & key,const json & value,const json & data,const int & i,std::map<int,std::string> & err_raw_data)312 Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
313 std::map<int, std::string> &err_raw_data) {
314 auto data_type = std::string(value["type"].get<std::string>());
315 if ((data_type == "int32" && !data[key].is_number_integer()) ||
316 (data_type == "int64" && !data[key].is_number_integer()) ||
317 (data_type == "float32" && !data[key].is_number_float()) ||
318 (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) {
319 std::string message = "Invalid input, for field: " + key + ", type: " + data_type +
320 " and value: " + data[key].dump() + " do not match while writing mindrecord files.";
321 PopulateMutexErrorData(i, message, err_raw_data);
322 RETURN_STATUS_UNEXPECTED_MR(message);
323 }
324
325 if (data_type == "int32" && data[key].is_number_integer()) {
326 int64_t temp_value = data[key];
327 if (static_cast<int64_t>(temp_value) < static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
328 static_cast<int64_t>(temp_value) > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
329 std::string message = "Invalid input, for field: " + key + "and its type: " + data_type +
330 ", value: " + data[key].dump() + " is out of range while writing mindrecord files.";
331 PopulateMutexErrorData(i, message, err_raw_data);
332 RETURN_STATUS_UNEXPECTED_MR(message);
333 }
334 }
335 return Status::OK();
336 }
337
CheckSliceData(int start_row,int end_row,json schema,const std::vector<json> & sub_raw_data,std::map<int,std::string> & err_raw_data)338 void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data,
339 std::map<int, std::string> &err_raw_data) {
340 if (start_row < 0 || start_row > end_row || end_row > static_cast<int>(sub_raw_data.size())) {
341 return;
342 }
343 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
344 pthread_setname_np(pthread_self(),
345 std::string(__func__ + std::to_string(start_row) + ":" + std::to_string(end_row)).c_str());
346 #endif
347 for (int i = start_row; i < end_row; i++) {
348 json data = sub_raw_data[i];
349
350 for (auto iter = schema.begin(); iter != schema.end(); iter++) {
351 std::string key = iter.key();
352 json value = iter.value();
353 if (data.find(key) == data.end()) {
354 std::string message = "'" + key + "' object can not found in data: " + value.dump();
355 PopulateMutexErrorData(i, message, err_raw_data);
356 break;
357 }
358
359 if (value.size() == kInt2) {
360 // Skip check since all shaped data will store as blob
361 continue;
362 }
363
364 if (CheckDataTypeAndValue(key, value, data, i, err_raw_data).IsError()) {
365 break;
366 }
367 }
368 }
369 }
370
CheckData(const std::map<uint64_t,std::vector<json>> & raw_data)371 Status ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &raw_data) {
372 auto rawdata_iter = raw_data.begin();
373
374 // make sure rawdata match schema
375 for (; rawdata_iter != raw_data.end(); ++rawdata_iter) {
376 // used for storing error
377 std::map<int, std::string> sub_err_mg;
378 int schema_id = rawdata_iter->first;
379 std::shared_ptr<Schema> schema_ptr;
380 RETURN_IF_NOT_OK_MR(shard_header_->GetSchemaByID(schema_id, &schema_ptr));
381 json schema = schema_ptr->GetSchema()["schema"];
382 for (const auto &field : schema_ptr->GetBlobFields()) {
383 (void)schema.erase(field);
384 }
385 std::vector<json> sub_raw_data = rawdata_iter->second;
386
387 // calculate start position and end position for each thread
388 int batch_size = rawdata_iter->second.size() / shard_count_;
389 int thread_num = shard_count_;
390 CHECK_FAIL_RETURN_UNEXPECTED_MR(thread_num > 0, "[Internal ERROR] 'thread_num' should be positive.");
391 if (thread_num > kMaxThreadCount) {
392 thread_num = kMaxThreadCount;
393 }
394 std::vector<std::thread> thread_set(thread_num);
395
396 // start multiple thread
397 int start_row = 0, end_row = 0;
398 for (int x = 0; x < thread_num; ++x) {
399 if (x != thread_num - 1) {
400 start_row = batch_size * x;
401 end_row = batch_size * (x + 1);
402 } else {
403 start_row = batch_size * x;
404 end_row = rawdata_iter->second.size();
405 }
406 thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema,
407 std::ref(sub_raw_data), std::ref(sub_err_mg));
408 }
409 CHECK_FAIL_RETURN_UNEXPECTED_MR(
410 thread_num <= kMaxThreadCount,
411 "[Internal ERROR] 'thread_num' should be less than or equal to " + std::to_string(kMaxThreadCount));
412 // Wait for threads done
413 for (int x = 0; x < thread_num; ++x) {
414 thread_set[x].join();
415 }
416
417 (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg));
418 }
419 return Status::OK();
420 }
421
ValidateRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,std::shared_ptr<std::pair<int,int>> * count_ptr)422 Status ShardWriter::ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,
423 std::vector<std::vector<uint8_t>> &blob_data, bool sign,
424 std::shared_ptr<std::pair<int, int>> *count_ptr) {
425 RETURN_UNEXPECTED_IF_NULL_MR(count_ptr);
426 auto rawdata_iter = raw_data.begin();
427 schema_count_ = raw_data.size();
428 CHECK_FAIL_RETURN_UNEXPECTED_MR(schema_count_ > 0, "Invalid data, the number of schema should be positive but got: " +
429 std::to_string(schema_count_) +
430 ". Please check the input schema.");
431
432 // keep schema_id
433 std::set<int64_t> schema_ids;
434 row_count_ = (rawdata_iter->second).size();
435
436 // Determine if the number of schemas is the same
437 CHECK_FAIL_RETURN_UNEXPECTED_MR(shard_header_->GetSchemas().size() == schema_count_,
438 "[Internal ERROR] 'schema_count_' and the schema count in schema: " +
439 std::to_string(schema_count_) + " do not match.");
440 // Determine raw_data size == blob_data size
441 CHECK_FAIL_RETURN_UNEXPECTED_MR(raw_data[0].size() == blob_data.size(),
442 "[Internal ERROR] raw data size: " + std::to_string(raw_data[0].size()) +
443 " is not equal to blob data size: " + std::to_string(blob_data.size()) + ".");
444
445 // Determine whether the number of samples corresponding to each schema is the same
446 for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
447 CHECK_FAIL_RETURN_UNEXPECTED_MR(row_count_ == rawdata_iter->second.size(),
448 "[Internal ERROR] 'row_count_': " + std::to_string(rawdata_iter->second.size()) +
449 " for each schema is not the same.");
450 (void)schema_ids.insert(rawdata_iter->first);
451 }
452 const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
453 // There is not enough data which is not matching the number of schema
454 CHECK_FAIL_RETURN_UNEXPECTED_MR(!std::any_of(schemas.begin(), schemas.end(),
455 [schema_ids](const std::shared_ptr<Schema> &schema) {
456 return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
457 }),
458 "[Internal ERROR] schema id in 'schemas' can not found in 'schema_ids'.");
459 if (!sign) {
460 *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
461 return Status::OK();
462 }
463
464 // check the data according the schema
465 RETURN_IF_NOT_OK_MR(CheckData(raw_data));
466
467 // delete wrong data from raw data
468 DeleteErrorData(raw_data, blob_data);
469
470 // update raw count
471 row_count_ = row_count_ - err_mg_.begin()->second.size();
472 *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
473 return Status::OK();
474 }
475
FillArray(int start,int end,std::map<uint64_t,vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data)476 void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
477 std::vector<std::vector<uint8_t>> &bin_data) {
478 // Prevent excessive thread opening and cause cross-border
479 if (start >= end) {
480 flag_ = true;
481 return;
482 }
483 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
484 pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(start) + ":" + std::to_string(end)).c_str());
485 #endif
486 int schema_count = static_cast<int>(raw_data.size());
487 std::map<uint64_t, vector<json>>::const_iterator rawdata_iter;
488 for (int x = start; x < end; ++x) {
489 int cnt = 0;
490 for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
491 const json &line = raw_data.at(rawdata_iter->first)[x];
492 std::vector<std::uint8_t> bline = json::to_msgpack(line);
493
494 // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2]
495 bin_data[x * schema_count + cnt] = bline;
496 cnt++;
497 }
498 }
499 }
500
LockWriter(bool parallel_writer,std::unique_ptr<int> * fd_ptr)501 Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr) {
502 if (!parallel_writer) {
503 *fd_ptr = std::make_unique<int>(0);
504 return Status::OK();
505 }
506
507 #if defined(_WIN32) || defined(_WIN64)
508 const int fd = 0;
509 MS_LOG(DEBUG) << "Lock file done by Python.";
510
511 #else
512 const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
513 if (fd >= 0) {
514 flock(fd, LOCK_EX);
515 } else {
516 close(fd);
517 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to lock file, path: " + lock_file_);
518 }
519 #endif
520
521 // Open files
522 file_streams_.clear();
523 for (const auto &file : file_paths_) {
524 auto realpath = FileUtils::GetRealPath(file.c_str());
525 if (!realpath.has_value()) {
526 close(fd);
527 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to get real path, path: " + file);
528 }
529 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
530 fs->open(realpath.value(), std::ios::in | std::ios::out | std::ios::binary);
531 if (fs->fail()) {
532 close(fd);
533 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to open file, path: " + file);
534 }
535 file_streams_.push_back(fs);
536 }
537 auto status = shard_header_->FileToPages(pages_file_);
538 if (status.IsError()) {
539 close(fd);
540 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Error raised in FileToPages function.");
541 }
542 *fd_ptr = std::make_unique<int>(fd);
543 return Status::OK();
544 }
545
UnlockWriter(int fd,bool parallel_writer)546 Status ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
547 if (!parallel_writer) {
548 return Status::OK();
549 }
550 RETURN_IF_NOT_OK_MR(shard_header_->PagesToFile(pages_file_));
551 for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
552 file_streams_[i]->close();
553 }
554 #if defined(_WIN32) || defined(_WIN64)
555 MS_LOG(DEBUG) << "Unlock file done by Python.";
556
557 #else
558 flock(fd, LOCK_UN);
559 close(fd);
560 #endif
561 return Status::OK();
562 }
563
WriteRawDataPreCheck(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,int * schema_count,int * row_count)564 Status ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
565 std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
566 int *row_count) {
567 // check the free disk size
568 std::string env_free_disk_check = common::GetEnv("MS_FREE_DISK_CHECK");
569 transform(env_free_disk_check.begin(), env_free_disk_check.end(), env_free_disk_check.begin(), ::tolower);
570 bool free_disk_check = true;
571 if (env_free_disk_check == "false") {
572 free_disk_check = false;
573 MS_LOG(INFO) << "environment MS_FREE_DISK_CHECK is false, free disk checking will be turned off.";
574 } else if (env_free_disk_check == "true" || env_free_disk_check == "") {
575 free_disk_check = true;
576 MS_LOG(INFO) << "environment MS_FREE_DISK_CHECK is true, free disk checking will be turned on.";
577 } else {
578 MS_LOG(WARNING) << "environment MS_FREE_DISK_CHECK: " << env_free_disk_check
579 << " is configured wrong, free disk checking will be turned on.";
580 }
581 if (free_disk_check) {
582 std::shared_ptr<uint64_t> size_ptr;
583 RETURN_IF_NOT_OK_MR(GetDiskSize(file_paths_[0], kFreeSize, &size_ptr));
584 CHECK_FAIL_RETURN_UNEXPECTED_MR(
585 *size_ptr >= kMinFreeDiskSize,
586 "No free disk to be used while writing mindrecord files, available free disk size: " + std::to_string(*size_ptr));
587 }
588 // compress blob
589 if (shard_column_->CheckCompressBlob()) {
590 for (auto &blob : blob_data) {
591 int64_t compression_bytes = 0;
592 blob = shard_column_->CompressBlob(blob, &compression_bytes);
593 compression_size_ += compression_bytes;
594 }
595 }
596
597 // Add 4-bytes dummy blob data if no any blob fields
598 if (blob_data.size() == 0 && raw_data.size() > 0) {
599 blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
600 }
601
602 // Add dummy id if all are blob fields
603 if (blob_data.size() > 0 && raw_data.size() == 0) {
604 raw_data.insert(std::pair<uint64_t, std::vector<json>>(0, std::vector<json>(blob_data.size(), kDummyId)));
605 }
606 std::shared_ptr<std::pair<int, int>> count_ptr;
607 RETURN_IF_NOT_OK_MR(ValidateRawData(raw_data, blob_data, sign, &count_ptr));
608 *schema_count = (*count_ptr).first;
609 *row_count = (*count_ptr).second;
610 return Status::OK();
611 }
MergeBlobData(const std::vector<string> & blob_fields,const std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> & row_bin_data,std::shared_ptr<std::vector<uint8_t>> * output)612 Status ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
613 const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
614 std::shared_ptr<std::vector<uint8_t>> *output) {
615 if (blob_fields.empty()) {
616 return Status::OK();
617 }
618 if (blob_fields.size() == 1) {
619 auto &blob = row_bin_data.at(blob_fields[0]);
620 auto blob_size = blob->size();
621 *output = std::make_shared<std::vector<uint8_t>>(blob_size);
622 (void)std::copy(blob->begin(), blob->end(), (*output)->begin());
623 } else {
624 size_t output_size = 0;
625 for (auto &field : blob_fields) {
626 output_size += row_bin_data.at(field)->size();
627 }
628 output_size += blob_fields.size() * sizeof(uint64_t);
629 *output = std::make_shared<std::vector<uint8_t>>(output_size);
630 std::vector<uint8_t> buf(sizeof(uint64_t), 0);
631 size_t idx = 0;
632 for (auto &field : blob_fields) {
633 auto &b = row_bin_data.at(field);
634 uint64_t blob_size = b->size();
635 // big edian
636 for (size_t i = 0; i < buf.size(); ++i) {
637 buf[buf.size() - 1 - i] = (std::numeric_limits<uint8_t>::max()) & blob_size;
638 blob_size >>= 8u;
639 }
640 (void)std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
641 idx += buf.size();
642 (void)std::copy(b->begin(), b->end(), (*output)->begin() + idx);
643 idx += b->size();
644 }
645 }
646 return Status::OK();
647 }
648
WriteRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,bool parallel_writer)649 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
650 std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
651 // Lock Writer if loading data parallel
652 std::unique_ptr<int> fd_ptr;
653 RETURN_IF_NOT_OK_MR(LockWriter(parallel_writer, &fd_ptr));
654
655 // Get the count of schemas and rows
656 int schema_count = 0;
657 int row_count = 0;
658
659 // Serialize raw data
660 RETURN_IF_NOT_OK_MR(WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count));
661 CHECK_FAIL_RETURN_UNEXPECTED_MR(row_count >= kInt0, "[Internal ERROR] the size of raw data should be positive.");
662 if (row_count == kInt0) {
663 return Status::OK();
664 }
665 std::vector<std::vector<uint8_t>> bin_raw_data(row_count * schema_count);
666 // Serialize raw data
667 RETURN_IF_NOT_OK_MR(SerializeRawData(raw_data, bin_raw_data, row_count));
668 // Set row size of raw data
669 RETURN_IF_NOT_OK_MR(SetRawDataSize(bin_raw_data));
670 // Set row size of blob data
671 RETURN_IF_NOT_OK_MR(SetBlobDataSize(blob_data));
672 // Write data to disk with multi threads
673 RETURN_IF_NOT_OK_MR(ParallelWriteData(blob_data, bin_raw_data));
674 MS_LOG(INFO) << "Succeed to write " << bin_raw_data.size() << " records.";
675
676 RETURN_IF_NOT_OK_MR(UnlockWriter(*fd_ptr, parallel_writer));
677
678 return Status::OK();
679 }
680
WriteRawData(std::map<uint64_t,std::vector<py::handle>> & raw_data,std::map<uint64_t,std::vector<py::handle>> & blob_data,bool sign,bool parallel_writer)681 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
682 std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
683 bool parallel_writer) {
684 std::map<uint64_t, std::vector<json>> raw_data_json;
685 std::map<uint64_t, std::vector<json>> blob_data_json;
686
687 (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
688 [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
689 auto &py_raw_data = pair.second;
690 std::vector<json> json_raw_data;
691 (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
692 [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
693 return std::make_pair(pair.first, std::move(json_raw_data));
694 });
695
696 (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()),
697 [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
698 auto &py_blob_data = pair.second;
699 std::vector<json> jsonBlobData;
700 (void)std::transform(py_blob_data.begin(), py_blob_data.end(),
701 std::back_inserter(jsonBlobData),
702 [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
703 return std::make_pair(pair.first, std::move(jsonBlobData));
704 });
705
706 // Serialize blob page
707 auto blob_data_iter = blob_data.begin();
708 auto schema_count = blob_data.size();
709 auto row_count = blob_data_iter->second.size();
710
711 std::vector<std::vector<uint8_t>> bin_blob_data(row_count * schema_count);
712 // Serialize blob data
713 RETURN_IF_NOT_OK_MR(SerializeRawData(blob_data_json, bin_blob_data, row_count));
714 return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
715 }
716
ParallelWriteData(const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)717 Status ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
718 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
719 auto shards = BreakIntoShards();
720 // define the number of thread
721 int thread_num = static_cast<int>(shard_count_);
722 CHECK_FAIL_RETURN_UNEXPECTED_MR(thread_num > 0, "[Internal ERROR] 'thread_num' should be positive.");
723 if (thread_num > kMaxThreadCount) {
724 thread_num = kMaxThreadCount;
725 }
726 int left_thread = shard_count_;
727 int current_thread = 0;
728 while (left_thread) {
729 if (left_thread < thread_num) {
730 thread_num = left_thread;
731 }
732 // Start one thread for one shard
733 std::vector<std::thread> thread_set(thread_num);
734 if (thread_num <= kMaxThreadCount) {
735 for (int x = 0; x < thread_num; ++x) {
736 int start_row = shards[current_thread + x].first;
737 int end_row = shards[current_thread + x].second;
738 thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row,
739 std::ref(blob_data), std::ref(bin_raw_data));
740 }
741 // Wait for threads done
742 for (int x = 0; x < thread_num; ++x) {
743 thread_set[x].join();
744 }
745 left_thread -= thread_num;
746 current_thread += thread_num;
747 }
748 }
749 return Status::OK();
750 }
751
WriteByShard(int shard_id,int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)752 Status ShardWriter::WriteByShard(int shard_id, int start_row, int end_row,
753 const std::vector<std::vector<uint8_t>> &blob_data,
754 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
755 MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row
756 << ", schema size: " << schema_count_;
757 if (start_row == end_row) {
758 return Status::OK();
759 }
760 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
761 pthread_setname_np(pthread_self(), std::string(__func__ + std::to_string(shard_id)).c_str());
762 #endif
763 vector<std::pair<int, int>> rows_in_group;
764 std::shared_ptr<Page> last_raw_page = nullptr;
765 std::shared_ptr<Page> last_blob_page = nullptr;
766 RETURN_IF_NOT_OK_MR(SetLastRawPage(shard_id, last_raw_page));
767 RETURN_IF_NOT_OK_MR(SetLastBlobPage(shard_id, last_blob_page));
768
769 RETURN_IF_NOT_OK_MR(CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page));
770 RETURN_IF_NOT_OK_MR(AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
771 RETURN_IF_NOT_OK_MR(NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
772 RETURN_IF_NOT_OK_MR(ShiftRawPage(shard_id, rows_in_group, last_raw_page));
773 RETURN_IF_NOT_OK_MR(WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data));
774
775 return Status::OK();
776 }
777
CutRowGroup(int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_raw_page,const std::shared_ptr<Page> & last_blob_page)778 Status ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
779 std::vector<std::pair<int, int>> &rows_in_group,
780 const std::shared_ptr<Page> &last_raw_page,
781 const std::shared_ptr<Page> &last_blob_page) {
782 auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0;
783
784 auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
785 auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0;
786 auto n_byte_raw = last_raw_page_size - last_raw_offset;
787
788 int page_start_row = start_row;
789 CHECK_FAIL_RETURN_UNEXPECTED_MR(start_row <= end_row,
790 "[Internal ERROR] 'start_row': " + std::to_string(start_row) +
791 " should be less than or equal to 'end_row': " + std::to_string(end_row));
792
793 CHECK_FAIL_RETURN_UNEXPECTED_MR(
794 end_row <= static_cast<int>(blob_data_size_.size()) && end_row <= static_cast<int>(raw_data_size_.size()),
795 "[Internal ERROR] 'end_row': " + std::to_string(end_row) + " should be less than 'blob_data_size': " +
796 std::to_string(blob_data_size_.size()) + " and 'raw_data_size': " + std::to_string(raw_data_size_.size()) + ".");
797 for (int i = start_row; i < end_row; ++i) {
798 // n_byte_blob(0) indicate appendBlobPage
799 if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ ||
800 n_byte_raw + raw_data_size_[i] > page_size_) {
801 rows_in_group.emplace_back(page_start_row, i);
802 page_start_row = i;
803 n_byte_blob = blob_data_size_[i];
804 n_byte_raw = raw_data_size_[i];
805 } else {
806 n_byte_blob += blob_data_size_[i];
807 n_byte_raw += raw_data_size_[i];
808 }
809 }
810
811 // Not forget last one
812 rows_in_group.emplace_back(page_start_row, end_row);
813 return Status::OK();
814 }
815
AppendBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)816 Status ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
817 const std::vector<std::pair<int, int>> &rows_in_group,
818 const std::shared_ptr<Page> &last_blob_page) {
819 auto blob_row = rows_in_group[0];
820 if (blob_row.first == blob_row.second) {
821 return Status::OK();
822 }
823 // Write disk
824 auto page_id = last_blob_page->GetPageID();
825 auto bytes_page = last_blob_page->GetPageSize();
826 auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg);
827 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
828 file_streams_[shard_id]->close();
829 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
830 }
831
832 (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
833
834 // Update last blob page
835 bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
836 last_blob_page->SetPageSize(bytes_page);
837 uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first;
838 last_blob_page->SetEndRowID(end_row);
839 (void)shard_header_->SetPage(last_blob_page);
840 return Status::OK();
841 }
842
NewBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)843 Status ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
844 const std::vector<std::pair<int, int>> &rows_in_group,
845 const std::shared_ptr<Page> &last_blob_page) {
846 auto page_id = shard_header_->GetLastPageId(shard_id);
847 auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1;
848 auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0;
849 // index(0) indicate appendBlobPage
850 for (uint32_t i = 1; i < rows_in_group.size(); ++i) {
851 auto blob_row = rows_in_group[i];
852
853 // Write 1 blob page to disk
854 auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
855 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
856 file_streams_[shard_id]->close();
857 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
858 }
859
860 (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
861 // Create new page info for header
862 auto page_size =
863 std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
864 std::vector<std::pair<int, uint64_t>> row_group_ids;
865 auto start_row = current_row;
866 auto end_row = start_row + blob_row.second - blob_row.first;
867 auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size);
868 (void)shard_header_->AddPage(std::make_shared<Page>(page));
869 current_row = end_row;
870 }
871 return Status::OK();
872 }
873
ShiftRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page)874 Status ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
875 std::shared_ptr<Page> &last_raw_page) {
876 auto blob_row = rows_in_group[0];
877 if (blob_row.first == blob_row.second) {
878 return Status::OK();
879 }
880 auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
881 if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) +
882 last_raw_page_size <=
883 page_size_) {
884 return Status::OK();
885 }
886 auto page_id = shard_header_->GetLastPageId(shard_id);
887 auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second;
888 auto last_raw_page_id = last_raw_page->GetPageID();
889 auto shift_size = last_raw_page_size - last_row_group_id_offset;
890
891 std::vector<uint8_t> buf(shift_size);
892
893 // Read last row group from previous raw data page
894 CHECK_FAIL_RETURN_UNEXPECTED_MR(
895 shard_id >= 0 && shard_id < file_streams_.size(),
896 "[Internal ERROR] 'shard_id' should be in range [0, " + std::to_string(file_streams_.size()) + ").");
897
898 auto &io_seekg = file_streams_[shard_id]->seekg(
899 page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg);
900 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
901 file_streams_[shard_id]->close();
902 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
903 }
904
905 auto &io_read = file_streams_[shard_id]->read(reinterpret_cast<char *>(&buf[0]), buf.size());
906 if (!io_read.good() || io_read.fail() || io_read.bad()) {
907 file_streams_[shard_id]->close();
908 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to read file.");
909 }
910
911 // Merge into new row group at new raw data page
912 auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
913 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
914 file_streams_[shard_id]->close();
915 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
916 }
917
918 auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&buf[0]), buf.size());
919 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
920 file_streams_[shard_id]->close();
921 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
922 }
923 last_raw_page->DeleteLastGroupId();
924 (void)shard_header_->SetPage(last_raw_page);
925
926 // Refresh page info in header
927 int row_group_id = last_raw_page->GetLastRowGroupID().first + 1;
928 std::vector<std::pair<int, uint64_t>> row_group_ids;
929 row_group_ids.emplace_back(row_group_id, 0);
930 int page_type_id = last_raw_page->GetPageID();
931 auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size);
932 (void)shard_header_->AddPage(std::make_shared<Page>(page));
933
934 // Reset: last raw page
935 RETURN_IF_NOT_OK_MR(SetLastRawPage(shard_id, last_raw_page));
936 return Status::OK();
937 }
938
WriteRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)939 Status ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
940 std::shared_ptr<Page> &last_raw_page,
941 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
942 int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1;
943 for (uint32_t i = 0; i < rows_in_group.size(); ++i) {
944 const auto &blob_row = rows_in_group[i];
945 if (blob_row.first == blob_row.second) {
946 continue;
947 }
948 auto raw_size =
949 std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0);
950 if (!last_raw_page) {
951 RETURN_IF_NOT_OK_MR(EmptyRawPage(shard_id, last_raw_page));
952 } else if (last_raw_page->GetPageSize() + raw_size > page_size_) {
953 RETURN_IF_NOT_OK_MR(shard_header_->SetPage(last_raw_page));
954 RETURN_IF_NOT_OK_MR(EmptyRawPage(shard_id, last_raw_page));
955 }
956 RETURN_IF_NOT_OK_MR(AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data));
957 }
958 RETURN_IF_NOT_OK_MR(shard_header_->SetPage(last_raw_page));
959 return Status::OK();
960 }
961
EmptyRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)962 Status ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
963 auto row_group_ids = std::vector<std::pair<int, uint64_t>>();
964 auto page_id = shard_header_->GetLastPageId(shard_id);
965 auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1;
966 auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0);
967 RETURN_IF_NOT_OK_MR(shard_header_->AddPage(std::make_shared<Page>(page)));
968 RETURN_IF_NOT_OK_MR(SetLastRawPage(shard_id, last_raw_page));
969 return Status::OK();
970 }
971
AppendRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,int & last_row_group_id,std::shared_ptr<Page> last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)972 Status ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
973 const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page,
974 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
975 std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds();
976 auto last_raw_page_id = last_raw_page->GetPageID();
977 auto n_bytes = last_raw_page->GetPageSize();
978
979 // previous raw data page
980 auto &io_seekp =
981 file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg);
982 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
983 file_streams_[shard_id]->close();
984 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekg file.");
985 }
986
987 if (chunk_id > 0) {
988 row_group_ids.emplace_back(++last_row_group_id, n_bytes);
989 }
990 n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first,
991 raw_data_size_.begin() + rows_in_group[chunk_id].second, 0);
992 RETURN_IF_NOT_OK_MR(FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data));
993
994 // Update previous raw data page
995 last_raw_page->SetPageSize(n_bytes);
996 last_raw_page->SetRowGroupIds(row_group_ids);
997 RETURN_IF_NOT_OK_MR(shard_header_->SetPage(last_raw_page));
998
999 return Status::OK();
1000 }
1001
FlushBlobChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::vector<uint8_t>> & blob_data,const std::pair<int,int> & blob_row)1002 Status ShardWriter::FlushBlobChunk(const std::shared_ptr<std::fstream> &out,
1003 const std::vector<std::vector<uint8_t>> &blob_data,
1004 const std::pair<int, int> &blob_row) {
1005 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1006 blob_row.first <= blob_row.second && blob_row.second <= static_cast<int>(blob_data.size()) && blob_row.first >= 0,
1007 "[Internal ERROR] 'blob_row': " + std::to_string(blob_row.first) + ", " + std::to_string(blob_row.second) +
1008 " is invalid.");
1009 for (int j = blob_row.first; j < blob_row.second; ++j) {
1010 // Write the size of blob
1011 uint64_t line_len = blob_data[j].size();
1012 auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1013 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1014 out->close();
1015 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1016 }
1017
1018 // Write the data of blob
1019 auto line = blob_data[j];
1020 auto &io_handle_data = out->write(reinterpret_cast<char *>(&line[0]), line_len);
1021 if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) {
1022 out->close();
1023 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1024 }
1025 }
1026 return Status::OK();
1027 }
1028
FlushRawChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,const std::vector<std::vector<uint8_t>> & bin_raw_data)1029 Status ShardWriter::FlushRawChunk(const std::shared_ptr<std::fstream> &out,
1030 const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
1031 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
1032 for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) {
1033 // Write the size of multi schemas
1034 for (uint32_t j = 0; j < schema_count_; ++j) {
1035 uint64_t line_len = bin_raw_data[i * schema_count_ + j].size();
1036 auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1037 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1038 out->close();
1039 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1040 }
1041 }
1042 // Write the data of multi schemas
1043 for (uint32_t j = 0; j < schema_count_; ++j) {
1044 auto line = bin_raw_data[i * schema_count_ + j];
1045 auto &io_handle = out->write(reinterpret_cast<char *>(&line[0]), line.size());
1046 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1047 out->close();
1048 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1049 }
1050 }
1051 }
1052 return Status::OK();
1053 }
1054
1055 // Allocate data to shards evenly
BreakIntoShards()1056 std::vector<std::pair<int, int>> ShardWriter::BreakIntoShards() {
1057 std::vector<std::pair<int, int>> shards;
1058 int row_in_shard = row_count_ / shard_count_;
1059 int remains = row_count_ % shard_count_;
1060
1061 std::vector<int> v_list(shard_count_);
1062 std::iota(v_list.begin(), v_list.end(), 0);
1063
1064 std::mt19937 g = GetRandomDevice();
1065 std::shuffle(v_list.begin(), v_list.end(), g);
1066 std::unordered_set<int> set(v_list.begin(), v_list.begin() + remains);
1067
1068 if (shard_count_ <= kMaxShardCount) {
1069 int start_row = 0;
1070 for (int i = 0; i < shard_count_; ++i) {
1071 int end_row = start_row + row_in_shard;
1072 if (set.count(i) == 1) {
1073 end_row++;
1074 }
1075 shards.emplace_back(start_row, end_row);
1076 start_row = end_row;
1077 }
1078 }
1079 return shards;
1080 }
1081
WriteShardHeader()1082 Status ShardWriter::WriteShardHeader() {
1083 RETURN_UNEXPECTED_IF_NULL_MR(shard_header_);
1084 int64_t compression_temp = compression_size_;
1085 uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
1086 shard_header_->SetCompressionSize(compression_size);
1087
1088 auto shard_header = shard_header_->SerializeHeader();
1089 // Write header data to multi files
1090 CHECK_FAIL_RETURN_UNEXPECTED_MR(
1091 shard_count_ <= static_cast<int>(file_streams_.size()) && shard_count_ <= static_cast<int>(shard_header.size()),
1092 "[Internal ERROR] 'shard_count_' should be less than or equal to 'file_stream_' size: " +
1093 std::to_string(file_streams_.size()) + ", and 'shard_header' size: " + std::to_string(shard_header.size()) + ".");
1094 if (shard_count_ <= kMaxShardCount) {
1095 for (int shard_id = 0; shard_id < shard_count_; ++shard_id) {
1096 auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg);
1097 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
1098 file_streams_[shard_id]->close();
1099 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to seekp file.");
1100 }
1101
1102 std::vector<uint8_t> bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end());
1103 uint64_t line_len = bin_header.size();
1104 if (line_len + kInt64Len > header_size_) {
1105 file_streams_[shard_id]->close();
1106 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] shard header is too big.");
1107 }
1108 auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1109 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1110 file_streams_[shard_id]->close();
1111 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1112 }
1113
1114 auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast<char *>(&bin_header[0]), line_len);
1115 if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) {
1116 file_streams_[shard_id]->close();
1117 RETURN_STATUS_UNEXPECTED_MR("[Internal ERROR] Failed to write file.");
1118 }
1119 file_streams_[shard_id]->close();
1120 }
1121 }
1122 return Status::OK();
1123 }
1124
SerializeRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data,uint32_t row_count)1125 Status ShardWriter::SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,
1126 std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count) {
1127 // define the number of thread
1128 uint32_t thread_num = std::thread::hardware_concurrency();
1129 if (thread_num == 0) {
1130 thread_num = kThreadNumber;
1131 }
1132 // Set the number of samples processed by each thread
1133 int group_num = static_cast<int>(ceil(row_count * 1.0 / thread_num));
1134 std::vector<std::thread> thread_set(thread_num);
1135 int work_thread_num = 0;
1136 for (uint32_t x = 0; x < thread_num; ++x) {
1137 int start_num = x * group_num;
1138 int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num;
1139 if (start_num >= end_num) {
1140 continue;
1141 }
1142 // Define the run boundary and start the child thread
1143 thread_set[x] =
1144 std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data));
1145 work_thread_num++;
1146 }
1147 for (uint32_t x = 0; x < work_thread_num; ++x) {
1148 // Set obstacles to prevent the main thread from running
1149 thread_set[x].join();
1150 }
1151 CHECK_FAIL_RETURN_SYNTAX_ERROR_MR(flag_ != true, "[Internal ERROR] Error raised in FillArray function.");
1152 return Status::OK();
1153 }
1154
SetRawDataSize(const std::vector<std::vector<uint8_t>> & bin_raw_data)1155 Status ShardWriter::SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data) {
1156 raw_data_size_ = std::vector<uint64_t>(row_count_, 0);
1157 for (uint32_t i = 0; i < row_count_; ++i) {
1158 raw_data_size_[i] = std::accumulate(
1159 bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0,
1160 [](uint64_t accumulator, const std::vector<uint8_t> &row) { return accumulator + kInt64Len + row.size(); });
1161 }
1162 CHECK_FAIL_RETURN_SYNTAX_ERROR_MR(*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_,
1163 "Invalid data, Page size: " + std::to_string(page_size_) +
1164 " is too small to save a raw row. Please try to use the mindrecord api "
1165 "'set_page_size(value)' to enable larger page size, and the value range is in [" +
1166 std::to_string(kMinPageSize) + " bytes, " + std::to_string(kMaxPageSize) +
1167 " bytes].");
1168 return Status::OK();
1169 }
1170
SetBlobDataSize(const std::vector<std::vector<uint8_t>> & blob_data)1171 Status ShardWriter::SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data) {
1172 blob_data_size_ = std::vector<uint64_t>(row_count_);
1173 (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(),
1174 [](const std::vector<uint8_t> &row) { return kInt64Len + row.size(); });
1175 CHECK_FAIL_RETURN_SYNTAX_ERROR_MR(*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_,
1176 "Invalid data, Page size: " + std::to_string(page_size_) +
1177 " is too small to save a blob row. Please try to use the mindrecord api "
1178 "'set_page_size(value)' to enable larger page size, and the value range is in [" +
1179 std::to_string(kMinPageSize) + " bytes, " + std::to_string(kMaxPageSize) +
1180 " bytes].");
1181 return Status::OK();
1182 }
1183
SetLastRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)1184 Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
1185 // Get last raw page
1186 auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw);
1187 if (last_raw_page_id == -1) {
1188 return Status::OK();
1189 }
1190 RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, last_raw_page_id, &last_raw_page));
1191 return Status::OK();
1192 }
1193
SetLastBlobPage(const int & shard_id,std::shared_ptr<Page> & last_blob_page)1194 Status ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page) {
1195 // Get last blob page
1196 auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob);
1197 if (last_blob_page_id == -1) {
1198 return Status::OK();
1199 }
1200 RETURN_IF_NOT_OK_MR(shard_header_->GetPage(shard_id, last_blob_page_id, &last_blob_page));
1201 return Status::OK();
1202 }
1203
Initialize(const std::unique_ptr<ShardWriter> * writer_ptr,const std::vector<std::string> & file_names)1204 Status ShardWriter::Initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
1205 const std::vector<std::string> &file_names) {
1206 RETURN_UNEXPECTED_IF_NULL_MR(writer_ptr);
1207 RETURN_IF_NOT_OK_MR((*writer_ptr)->Open(file_names, false));
1208 RETURN_IF_NOT_OK_MR((*writer_ptr)->SetHeaderSize(kDefaultHeaderSize));
1209 RETURN_IF_NOT_OK_MR((*writer_ptr)->SetPageSize(kDefaultPageSize));
1210 return Status::OK();
1211 }
1212 } // namespace mindrecord
1213 } // namespace mindspore
1214