1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/util/random.h"
18 #include "minddata/mindrecord/include/shard_writer.h"
19 #include "utils/file_utils.h"
20 #include "utils/ms_utils.h"
21 #include "minddata/mindrecord/include/common/shard_utils.h"
22 #include "./securec.h"
23
24 using mindspore::LogStream;
25 using mindspore::ExceptionType::NoExceptionType;
26 using mindspore::MsLogLevel::DEBUG;
27 using mindspore::MsLogLevel::ERROR;
28 using mindspore::MsLogLevel::INFO;
29
30 namespace mindspore {
31 namespace mindrecord {
ShardWriter()32 ShardWriter::ShardWriter()
33 : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
34 compression_size_ = 0;
35 }
36
~ShardWriter()37 ShardWriter::~ShardWriter() {
38 for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
39 file_streams_[i]->close();
40 }
41 }
42
GetFullPathFromFileName(const std::vector<std::string> & paths)43 Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
44 // Get full path from file name
45 for (const auto &path : paths) {
46 CHECK_FAIL_RETURN_UNEXPECTED(CheckIsValidUtf8(path),
47 "Invalid data, file name: " + path + " contains invalid uft-8 character.");
48 char resolved_path[PATH_MAX] = {0};
49 char buf[PATH_MAX] = {0};
50 CHECK_FAIL_RETURN_UNEXPECTED(strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) == EOK,
51 "Failed to call securec func [strncpy_s], path: " + path);
52 #if defined(_WIN32) || defined(_WIN64)
53 RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX));
54 RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX));
55 #else
56 CHECK_FAIL_RETURN_UNEXPECTED(realpath(dirname(&(buf[0])), resolved_path) != nullptr,
57 "Invalid file, path: " + std::string(resolved_path));
58 if (realpath(common::SafeCStr(path), resolved_path) == nullptr) {
59 MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
60 }
61 #endif
62 file_paths_.emplace_back(string(resolved_path));
63 }
64 return Status::OK();
65 }
66
OpenDataFiles(bool append)67 Status ShardWriter::OpenDataFiles(bool append) {
68 // Open files
69 for (const auto &file : file_paths_) {
70 std::optional<std::string> dir = "";
71 std::optional<std::string> local_file_name = "";
72 FileUtils::SplitDirAndFileName(file, &dir, &local_file_name);
73 if (!dir.has_value()) {
74 dir = ".";
75 }
76
77 auto realpath = FileUtils::GetRealPath(dir.value().data());
78 CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
79
80 std::optional<std::string> whole_path = "";
81 FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
82
83 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
84 if (!append) {
85 // if not append and mindrecord file exist, return FAILED
86 fs->open(whole_path.value(), std::ios::in | std::ios::binary);
87 if (fs->good()) {
88 fs->close();
89 RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
90 }
91 fs->close();
92 // open the mindrecord file to write
93 fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
94 if (!fs->good()) {
95 RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file);
96 }
97 } else {
98 // open the mindrecord file to append
99 fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary);
100 if (!fs->good()) {
101 fs->close();
102 RETURN_STATUS_UNEXPECTED("Failed to open file for append data, path: " + file);
103 }
104 }
105 MS_LOG(INFO) << "Succeed to open shard file, path: " << file;
106 file_streams_.push_back(fs);
107 }
108 return Status::OK();
109 }
110
RemoveLockFile()111 Status ShardWriter::RemoveLockFile() {
112 // Remove temporary file
113 int ret = std::remove(pages_file_.c_str());
114 if (ret == 0) {
115 MS_LOG(DEBUG) << "Succeed to remove page file, path: " << pages_file_;
116 }
117
118 ret = std::remove(lock_file_.c_str());
119 if (ret == 0) {
120 MS_LOG(DEBUG) << "Succeed to remove lock file, path: " << lock_file_;
121 }
122 return Status::OK();
123 }
124
InitLockFile()125 Status ShardWriter::InitLockFile() {
126 CHECK_FAIL_RETURN_UNEXPECTED(file_paths_.size() != 0, "Invalid data, file_paths_ is not initialized.");
127
128 lock_file_ = file_paths_[0] + kLockFileSuffix;
129 pages_file_ = file_paths_[0] + kPageFileSuffix;
130 RETURN_IF_NOT_OK(RemoveLockFile());
131 return Status::OK();
132 }
133
Open(const std::vector<std::string> & paths,bool append)134 Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
135 shard_count_ = paths.size();
136 CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount,
137 "Invalid data, schema_count_ must be less than or equal to " +
138 std::to_string(kMaxSchemaCount) + ", but got " + std::to_string(schema_count_));
139
140 // Get full path from file name
141 RETURN_IF_NOT_OK(GetFullPathFromFileName(paths));
142 // Open files
143 RETURN_IF_NOT_OK(OpenDataFiles(append));
144 // Init lock file
145 RETURN_IF_NOT_OK(InitLockFile());
146 return Status::OK();
147 }
148
OpenForAppend(const std::string & path)149 Status ShardWriter::OpenForAppend(const std::string &path) {
150 CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(path), "Invalid file, path: " + path);
151 std::shared_ptr<json> header_ptr;
152 RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(path, &header_ptr));
153 auto ds = std::make_shared<std::vector<std::string>>();
154 RETURN_IF_NOT_OK(GetDatasetFiles(path, (*header_ptr)["shard_addresses"], &ds));
155 ShardHeader header = ShardHeader();
156 RETURN_IF_NOT_OK(header.BuildDataset(*ds));
157 shard_header_ = std::make_shared<ShardHeader>(header);
158 RETURN_IF_NOT_OK(SetHeaderSize(shard_header_->GetHeaderSize()));
159 RETURN_IF_NOT_OK(SetPageSize(shard_header_->GetPageSize()));
160 compression_size_ = shard_header_->GetCompressionSize();
161 RETURN_IF_NOT_OK(Open(*ds, true));
162 shard_column_ = std::make_shared<ShardColumn>(shard_header_);
163 return Status::OK();
164 }
165
Commit()166 Status ShardWriter::Commit() {
167 // Read pages file
168 std::ifstream page_file(pages_file_.c_str());
169 if (page_file.good()) {
170 page_file.close();
171 RETURN_IF_NOT_OK(shard_header_->FileToPages(pages_file_));
172 }
173 RETURN_IF_NOT_OK(WriteShardHeader());
174 MS_LOG(INFO) << "Succeed to write meta data.";
175 // Remove lock file
176 RETURN_IF_NOT_OK(RemoveLockFile());
177
178 return Status::OK();
179 }
180
SetShardHeader(std::shared_ptr<ShardHeader> header_data)181 Status ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) {
182 RETURN_IF_NOT_OK(header_data->InitByFiles(file_paths_));
183 // set fields in mindrecord when empty
184 std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
185 if (fields.empty()) {
186 MS_LOG(DEBUG) << "Index field is not set, it will be generated automatically.";
187 std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
188 for (const auto &schema : schemas) {
189 json jsonSchema = schema->GetSchema()["schema"];
190 for (const auto &el : jsonSchema.items()) {
191 if (el.value()["type"] == "string" ||
192 (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) ||
193 (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) ||
194 (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) ||
195 (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) {
196 fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key()));
197 }
198 }
199 }
200 // only blob data
201 if (!fields.empty()) {
202 RETURN_IF_NOT_OK(header_data->AddIndexFields(fields));
203 }
204 }
205
206 shard_header_ = header_data;
207 shard_header_->SetHeaderSize(header_size_);
208 shard_header_->SetPageSize(page_size_);
209 shard_column_ = std::make_shared<ShardColumn>(shard_header_);
210 return Status::OK();
211 }
212
SetHeaderSize(const uint64_t & header_size)213 Status ShardWriter::SetHeaderSize(const uint64_t &header_size) {
214 // header_size [16KB, 128MB]
215 CHECK_FAIL_RETURN_UNEXPECTED(header_size >= kMinHeaderSize && header_size <= kMaxHeaderSize,
216 "Invalid data, header size: " + std::to_string(header_size) + " should be in range [" +
217 std::to_string(kMinHeaderSize) + "MB, " + std::to_string(kMaxHeaderSize) + "MB].");
218 CHECK_FAIL_RETURN_UNEXPECTED(
219 header_size % 4 == 0, "Invalid data, header size " + std::to_string(header_size) + " should be divided by four.");
220 header_size_ = header_size;
221 return Status::OK();
222 }
223
SetPageSize(const uint64_t & page_size)224 Status ShardWriter::SetPageSize(const uint64_t &page_size) {
225 // PageSize [32KB, 256MB]
226 CHECK_FAIL_RETURN_UNEXPECTED(page_size >= kMinPageSize && page_size <= kMaxPageSize,
227 "Invalid data, page size: " + std::to_string(page_size) + " should be in range [" +
228 std::to_string(kMinPageSize) + "MB, " + std::to_string(kMaxPageSize) + "MB].");
229 CHECK_FAIL_RETURN_UNEXPECTED(page_size % 4 == 0,
230 "Invalid data, page size " + std::to_string(page_size) + " should be divided by four.");
231 page_size_ = page_size;
232 return Status::OK();
233 }
234
DeleteErrorData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data)235 void ShardWriter::DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data,
236 std::vector<std::vector<uint8_t>> &blob_data) {
237 // get wrong data location
238 std::set<int, std::greater<int>> delete_set;
239 for (auto &err_mg : err_mg_) {
240 uint64_t id = err_mg.first;
241 auto sub_err_mg = err_mg.second;
242 for (auto &subMg : sub_err_mg) {
243 int loc = subMg.first;
244 std::string message = subMg.second;
245 MS_LOG(ERROR) << "Invalid input, the " << loc + 1 << " th data is invalid, " << message;
246 (void)delete_set.insert(loc);
247 }
248 }
249
250 auto it = raw_data.begin();
251 if (delete_set.size() == it->second.size()) {
252 raw_data.clear();
253 blob_data.clear();
254 return;
255 }
256
257 // delete wrong raw data
258 for (auto &loc : delete_set) {
259 // delete row data
260 for (auto &raw : raw_data) {
261 (void)raw.second.erase(raw.second.begin() + loc);
262 }
263
264 // delete blob data
265 (void)blob_data.erase(blob_data.begin() + loc);
266 }
267 }
268
PopulateMutexErrorData(const int & row,const std::string & message,std::map<int,std::string> & err_raw_data)269 void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message,
270 std::map<int, std::string> &err_raw_data) {
271 std::lock_guard<std::mutex> lock(check_mutex_);
272 (void)err_raw_data.insert(std::make_pair(row, message));
273 }
274
CheckDataTypeAndValue(const std::string & key,const json & value,const json & data,const int & i,std::map<int,std::string> & err_raw_data)275 Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
276 std::map<int, std::string> &err_raw_data) {
277 auto data_type = std::string(value["type"].get<std::string>());
278 if ((data_type == "int32" && !data[key].is_number_integer()) ||
279 (data_type == "int64" && !data[key].is_number_integer()) ||
280 (data_type == "float32" && !data[key].is_number_float()) ||
281 (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) {
282 std::string message =
283 "field: " + key + " ,type : " + data_type + " ,value: " + data[key].dump() + " is not matched.";
284 PopulateMutexErrorData(i, message, err_raw_data);
285 RETURN_STATUS_UNEXPECTED(message);
286 }
287
288 if (data_type == "int32" && data[key].is_number_integer()) {
289 int64_t temp_value = data[key];
290 if (static_cast<int64_t>(temp_value) < static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
291 static_cast<int64_t>(temp_value) > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
292 std::string message =
293 "field: " + key + " ,type : " + data_type + " ,value: " + data[key].dump() + " is out of range.";
294 PopulateMutexErrorData(i, message, err_raw_data);
295 RETURN_STATUS_UNEXPECTED(message);
296 }
297 }
298 return Status::OK();
299 }
300
CheckSliceData(int start_row,int end_row,json schema,const std::vector<json> & sub_raw_data,std::map<int,std::string> & err_raw_data)301 void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data,
302 std::map<int, std::string> &err_raw_data) {
303 if (start_row < 0 || start_row > end_row || end_row > static_cast<int>(sub_raw_data.size())) {
304 return;
305 }
306 for (int i = start_row; i < end_row; i++) {
307 json data = sub_raw_data[i];
308
309 for (auto iter = schema.begin(); iter != schema.end(); iter++) {
310 std::string key = iter.key();
311 json value = iter.value();
312 if (data.find(key) == data.end()) {
313 std::string message = "'" + key + "' object can not found in data: " + value.dump();
314 PopulateMutexErrorData(i, message, err_raw_data);
315 break;
316 }
317
318 if (value.size() == kInt2) {
319 // Skip check since all shaped data will store as blob
320 continue;
321 }
322
323 if (CheckDataTypeAndValue(key, value, data, i, err_raw_data).IsError()) {
324 break;
325 }
326 }
327 }
328 }
329
CheckData(const std::map<uint64_t,std::vector<json>> & raw_data)330 Status ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &raw_data) {
331 auto rawdata_iter = raw_data.begin();
332
333 // make sure rawdata match schema
334 for (; rawdata_iter != raw_data.end(); ++rawdata_iter) {
335 // used for storing error
336 std::map<int, std::string> sub_err_mg;
337 int schema_id = rawdata_iter->first;
338 std::shared_ptr<Schema> schema_ptr;
339 RETURN_IF_NOT_OK(shard_header_->GetSchemaByID(schema_id, &schema_ptr));
340 json schema = schema_ptr->GetSchema()["schema"];
341 for (const auto &field : schema_ptr->GetBlobFields()) {
342 (void)schema.erase(field);
343 }
344 std::vector<json> sub_raw_data = rawdata_iter->second;
345
346 // calculate start position and end position for each thread
347 int batch_size = rawdata_iter->second.size() / shard_count_;
348 int thread_num = shard_count_;
349 CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid data, thread_num should be positive.");
350 if (thread_num > kMaxThreadCount) {
351 thread_num = kMaxThreadCount;
352 }
353 std::vector<std::thread> thread_set(thread_num);
354
355 // start multiple thread
356 int start_row = 0, end_row = 0;
357 for (int x = 0; x < thread_num; ++x) {
358 if (x != thread_num - 1) {
359 start_row = batch_size * x;
360 end_row = batch_size * (x + 1);
361 } else {
362 start_row = batch_size * x;
363 end_row = rawdata_iter->second.size();
364 }
365 thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema,
366 std::ref(sub_raw_data), std::ref(sub_err_mg));
367 }
368 CHECK_FAIL_RETURN_UNEXPECTED(
369 thread_num <= kMaxThreadCount,
370 "Invalid data, thread_num should be less than or equal to " + std::to_string(kMaxThreadCount));
371 // Wait for threads done
372 for (int x = 0; x < thread_num; ++x) {
373 thread_set[x].join();
374 }
375
376 (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg));
377 }
378 return Status::OK();
379 }
380
ValidateRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,std::shared_ptr<std::pair<int,int>> * count_ptr)381 Status ShardWriter::ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,
382 std::vector<std::vector<uint8_t>> &blob_data, bool sign,
383 std::shared_ptr<std::pair<int, int>> *count_ptr) {
384 RETURN_UNEXPECTED_IF_NULL(count_ptr);
385 auto rawdata_iter = raw_data.begin();
386 schema_count_ = raw_data.size();
387 CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ > 0, "Invalid data, schema count should be positive.");
388
389 // keep schema_id
390 std::set<int64_t> schema_ids;
391 row_count_ = (rawdata_iter->second).size();
392
393 // Determine if the number of schemas is the same
394 CHECK_FAIL_RETURN_UNEXPECTED(shard_header_->GetSchemas().size() == schema_count_,
395 "Invalid data, schema count: " + std::to_string(schema_count_) + " is not matched.");
396 // Determine raw_data size == blob_data size
397 CHECK_FAIL_RETURN_UNEXPECTED(raw_data[0].size() == blob_data.size(),
398 "Invalid data, raw data size: " + std::to_string(raw_data[0].size()) +
399 " is not equal to blob data size: " + std::to_string(blob_data.size()) + ".");
400
401 // Determine whether the number of samples corresponding to each schema is the same
402 for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
403 CHECK_FAIL_RETURN_UNEXPECTED(
404 row_count_ == rawdata_iter->second.size(),
405 "Invalid data, number of samples: " + std::to_string(rawdata_iter->second.size()) + " for schemais not matched.");
406 (void)schema_ids.insert(rawdata_iter->first);
407 }
408 const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
409 // There is not enough data which is not matching the number of schema
410 CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(schemas.begin(), schemas.end(),
411 [schema_ids](const std::shared_ptr<Schema> &schema) {
412 return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
413 }),
414 "Invalid data, schema id of data is not matched.");
415 if (!sign) {
416 *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
417 return Status::OK();
418 }
419
420 // check the data according the schema
421 RETURN_IF_NOT_OK(CheckData(raw_data));
422
423 // delete wrong data from raw data
424 DeleteErrorData(raw_data, blob_data);
425
426 // update raw count
427 row_count_ = row_count_ - err_mg_.begin()->second.size();
428 *count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
429 return Status::OK();
430 }
431
FillArray(int start,int end,std::map<uint64_t,vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data)432 void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
433 std::vector<std::vector<uint8_t>> &bin_data) {
434 // Prevent excessive thread opening and cause cross-border
435 if (start >= end) {
436 flag_ = true;
437 return;
438 }
439 int schema_count = static_cast<int>(raw_data.size());
440 std::map<uint64_t, vector<json>>::const_iterator rawdata_iter;
441 for (int x = start; x < end; ++x) {
442 int cnt = 0;
443 for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
444 const json &line = raw_data.at(rawdata_iter->first)[x];
445 std::vector<std::uint8_t> bline = json::to_msgpack(line);
446
447 // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2]
448 bin_data[x * schema_count + cnt] = bline;
449 cnt++;
450 }
451 }
452 }
453
LockWriter(bool parallel_writer,std::unique_ptr<int> * fd_ptr)454 Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr) {
455 if (!parallel_writer) {
456 *fd_ptr = std::make_unique<int>(0);
457 return Status::OK();
458 }
459
460 #if defined(_WIN32) || defined(_WIN64)
461 const int fd = 0;
462 MS_LOG(DEBUG) << "Lock file done by Python.";
463
464 #else
465 const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
466 if (fd >= 0) {
467 flock(fd, LOCK_EX);
468 } else {
469 close(fd);
470 RETURN_STATUS_UNEXPECTED("Failed to lock file, path: " + lock_file_);
471 }
472 #endif
473
474 // Open files
475 file_streams_.clear();
476 for (const auto &file : file_paths_) {
477 auto realpath = FileUtils::GetRealPath(file.data());
478 if (!realpath.has_value()) {
479 close(fd);
480 RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file);
481 }
482 std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
483 fs->open(realpath.value(), std::ios::in | std::ios::out | std::ios::binary);
484 if (fs->fail()) {
485 close(fd);
486 RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file);
487 }
488 file_streams_.push_back(fs);
489 }
490 auto status = shard_header_->FileToPages(pages_file_);
491 if (status.IsError()) {
492 close(fd);
493 RETURN_STATUS_UNEXPECTED("Error raised in FileToPages function.");
494 }
495 *fd_ptr = std::make_unique<int>(fd);
496 return Status::OK();
497 }
498
UnlockWriter(int fd,bool parallel_writer)499 Status ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
500 if (!parallel_writer) {
501 return Status::OK();
502 }
503 RETURN_IF_NOT_OK(shard_header_->PagesToFile(pages_file_));
504 for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
505 file_streams_[i]->close();
506 }
507 #if defined(_WIN32) || defined(_WIN64)
508 MS_LOG(DEBUG) << "Unlock file done by Python.";
509
510 #else
511 flock(fd, LOCK_UN);
512 close(fd);
513 #endif
514 return Status::OK();
515 }
516
WriteRawDataPreCheck(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,int * schema_count,int * row_count)517 Status ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
518 std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
519 int *row_count) {
520 // check the free disk size
521 std::shared_ptr<uint64_t> size_ptr;
522 RETURN_IF_NOT_OK(GetDiskSize(file_paths_[0], kFreeSize, &size_ptr));
523 CHECK_FAIL_RETURN_UNEXPECTED(*size_ptr >= kMinFreeDiskSize,
524 "No free disk to be used, free disk size: " + std::to_string(*size_ptr));
525 // compress blob
526 if (shard_column_->CheckCompressBlob()) {
527 for (auto &blob : blob_data) {
528 int64_t compression_bytes = 0;
529 blob = shard_column_->CompressBlob(blob, &compression_bytes);
530 compression_size_ += compression_bytes;
531 }
532 }
533
534 // Add 4-bytes dummy blob data if no any blob fields
535 if (blob_data.size() == 0 && raw_data.size() > 0) {
536 blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
537 }
538
539 // Add dummy id if all are blob fields
540 if (blob_data.size() > 0 && raw_data.size() == 0) {
541 raw_data.insert(std::pair<uint64_t, std::vector<json>>(0, std::vector<json>(blob_data.size(), kDummyId)));
542 }
543 std::shared_ptr<std::pair<int, int>> count_ptr;
544 RETURN_IF_NOT_OK(ValidateRawData(raw_data, blob_data, sign, &count_ptr));
545 *schema_count = (*count_ptr).first;
546 *row_count = (*count_ptr).second;
547 return Status::OK();
548 }
MergeBlobData(const std::vector<string> & blob_fields,const std::map<std::string,std::unique_ptr<std::vector<uint8_t>>> & row_bin_data,std::shared_ptr<std::vector<uint8_t>> * output)549 Status ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
550 const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
551 std::shared_ptr<std::vector<uint8_t>> *output) {
552 if (blob_fields.empty()) {
553 return Status::OK();
554 }
555 if (blob_fields.size() == 1) {
556 auto &blob = row_bin_data.at(blob_fields[0]);
557 auto blob_size = blob->size();
558 *output = std::make_shared<std::vector<uint8_t>>(blob_size);
559 std::copy(blob->begin(), blob->end(), (*output)->begin());
560 } else {
561 size_t output_size = 0;
562 for (auto &field : blob_fields) {
563 output_size += row_bin_data.at(field)->size();
564 }
565 output_size += blob_fields.size() * sizeof(uint64_t);
566 *output = std::make_shared<std::vector<uint8_t>>(output_size);
567 std::vector<uint8_t> buf(sizeof(uint64_t), 0);
568 size_t idx = 0;
569 for (auto &field : blob_fields) {
570 auto &b = row_bin_data.at(field);
571 uint64_t blob_size = b->size();
572 // big edian
573 for (size_t i = 0; i < buf.size(); ++i) {
574 buf[buf.size() - 1 - i] = (std::numeric_limits<uint8_t>::max()) & blob_size;
575 blob_size >>= 8u;
576 }
577 std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
578 idx += buf.size();
579 std::copy(b->begin(), b->end(), (*output)->begin() + idx);
580 idx += b->size();
581 }
582 }
583 return Status::OK();
584 }
585
WriteRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & blob_data,bool sign,bool parallel_writer)586 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
587 std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
588 // Lock Writer if loading data parallel
589 std::unique_ptr<int> fd_ptr;
590 RETURN_IF_NOT_OK(LockWriter(parallel_writer, &fd_ptr));
591
592 // Get the count of schemas and rows
593 int schema_count = 0;
594 int row_count = 0;
595
596 // Serialize raw data
597 RETURN_IF_NOT_OK(WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count));
598 CHECK_FAIL_RETURN_UNEXPECTED(row_count >= kInt0, "Invalid data, waw data size should be positive.");
599 if (row_count == kInt0) {
600 return Status::OK();
601 }
602 std::vector<std::vector<uint8_t>> bin_raw_data(row_count * schema_count);
603 // Serialize raw data
604 RETURN_IF_NOT_OK(SerializeRawData(raw_data, bin_raw_data, row_count));
605 // Set row size of raw data
606 RETURN_IF_NOT_OK(SetRawDataSize(bin_raw_data));
607 // Set row size of blob data
608 RETURN_IF_NOT_OK(SetBlobDataSize(blob_data));
609 // Write data to disk with multi threads
610 RETURN_IF_NOT_OK(ParallelWriteData(blob_data, bin_raw_data));
611 MS_LOG(INFO) << "Succeed to write " << bin_raw_data.size() << " records.";
612
613 RETURN_IF_NOT_OK(UnlockWriter(*fd_ptr, parallel_writer));
614
615 return Status::OK();
616 }
617
WriteRawData(std::map<uint64_t,std::vector<py::handle>> & raw_data,std::map<uint64_t,std::vector<py::handle>> & blob_data,bool sign,bool parallel_writer)618 Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
619 std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
620 bool parallel_writer) {
621 std::map<uint64_t, std::vector<json>> raw_data_json;
622 std::map<uint64_t, std::vector<json>> blob_data_json;
623
624 (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
625 [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
626 auto &py_raw_data = pair.second;
627 std::vector<json> json_raw_data;
628 (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
629 [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
630 return std::make_pair(pair.first, std::move(json_raw_data));
631 });
632
633 (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()),
634 [](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
635 auto &py_blob_data = pair.second;
636 std::vector<json> jsonBlobData;
637 (void)std::transform(py_blob_data.begin(), py_blob_data.end(),
638 std::back_inserter(jsonBlobData),
639 [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
640 return std::make_pair(pair.first, std::move(jsonBlobData));
641 });
642
643 // Serialize blob page
644 auto blob_data_iter = blob_data.begin();
645 auto schema_count = blob_data.size();
646 auto row_count = blob_data_iter->second.size();
647
648 std::vector<std::vector<uint8_t>> bin_blob_data(row_count * schema_count);
649 // Serialize blob data
650 RETURN_IF_NOT_OK(SerializeRawData(blob_data_json, bin_blob_data, row_count));
651 return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
652 }
653
ParallelWriteData(const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)654 Status ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
655 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
656 auto shards = BreakIntoShards();
657 // define the number of thread
658 int thread_num = static_cast<int>(shard_count_);
659 CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid data, thread_num should be positive.");
660 if (thread_num > kMaxThreadCount) {
661 thread_num = kMaxThreadCount;
662 }
663 int left_thread = shard_count_;
664 int current_thread = 0;
665 while (left_thread) {
666 if (left_thread < thread_num) {
667 thread_num = left_thread;
668 }
669 // Start one thread for one shard
670 std::vector<std::thread> thread_set(thread_num);
671 if (thread_num <= kMaxThreadCount) {
672 for (int x = 0; x < thread_num; ++x) {
673 int start_row = shards[current_thread + x].first;
674 int end_row = shards[current_thread + x].second;
675 thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row,
676 std::ref(blob_data), std::ref(bin_raw_data));
677 }
678 // Wait for threads done
679 for (int x = 0; x < thread_num; ++x) {
680 thread_set[x].join();
681 }
682 left_thread -= thread_num;
683 current_thread += thread_num;
684 }
685 }
686 return Status::OK();
687 }
688
WriteByShard(int shard_id,int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::vector<uint8_t>> & bin_raw_data)689 Status ShardWriter::WriteByShard(int shard_id, int start_row, int end_row,
690 const std::vector<std::vector<uint8_t>> &blob_data,
691 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
692 MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row
693 << ", schema size: " << schema_count_;
694 if (start_row == end_row) {
695 return Status::OK();
696 }
697 vector<std::pair<int, int>> rows_in_group;
698 std::shared_ptr<Page> last_raw_page = nullptr;
699 std::shared_ptr<Page> last_blob_page = nullptr;
700 SetLastRawPage(shard_id, last_raw_page);
701 SetLastBlobPage(shard_id, last_blob_page);
702
703 RETURN_IF_NOT_OK(CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page));
704 RETURN_IF_NOT_OK(AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
705 RETURN_IF_NOT_OK(NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page));
706 RETURN_IF_NOT_OK(ShiftRawPage(shard_id, rows_in_group, last_raw_page));
707 RETURN_IF_NOT_OK(WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data));
708
709 return Status::OK();
710 }
711
CutRowGroup(int start_row,int end_row,const std::vector<std::vector<uint8_t>> & blob_data,std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_raw_page,const std::shared_ptr<Page> & last_blob_page)712 Status ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
713 std::vector<std::pair<int, int>> &rows_in_group,
714 const std::shared_ptr<Page> &last_raw_page,
715 const std::shared_ptr<Page> &last_blob_page) {
716 auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0;
717
718 auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
719 auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0;
720 auto n_byte_raw = last_raw_page_size - last_raw_offset;
721
722 int page_start_row = start_row;
723 CHECK_FAIL_RETURN_UNEXPECTED(start_row <= end_row,
724 "Invalid data, start row: " + std::to_string(start_row) +
725 " should be less than or equal to end row: " + std::to_string(end_row));
726
727 CHECK_FAIL_RETURN_UNEXPECTED(
728 end_row <= static_cast<int>(blob_data_size_.size()) && end_row <= static_cast<int>(raw_data_size_.size()),
729 "Invalid data, end row: " + std::to_string(end_row) + " should be less than blob data size: " +
730 std::to_string(blob_data_size_.size()) + " and raw data size: " + std::to_string(raw_data_size_.size()) + ".");
731 for (int i = start_row; i < end_row; ++i) {
732 // n_byte_blob(0) indicate appendBlobPage
733 if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ ||
734 n_byte_raw + raw_data_size_[i] > page_size_) {
735 rows_in_group.emplace_back(page_start_row, i);
736 page_start_row = i;
737 n_byte_blob = blob_data_size_[i];
738 n_byte_raw = raw_data_size_[i];
739 } else {
740 n_byte_blob += blob_data_size_[i];
741 n_byte_raw += raw_data_size_[i];
742 }
743 }
744
745 // Not forget last one
746 rows_in_group.emplace_back(page_start_row, end_row);
747 return Status::OK();
748 }
749
AppendBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)750 Status ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
751 const std::vector<std::pair<int, int>> &rows_in_group,
752 const std::shared_ptr<Page> &last_blob_page) {
753 auto blob_row = rows_in_group[0];
754 if (blob_row.first == blob_row.second) {
755 return Status::OK();
756 }
757 // Write disk
758 auto page_id = last_blob_page->GetPageID();
759 auto bytes_page = last_blob_page->GetPageSize();
760 auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg);
761 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
762 file_streams_[shard_id]->close();
763 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
764 }
765
766 (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
767
768 // Update last blob page
769 bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
770 last_blob_page->SetPageSize(bytes_page);
771 uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first;
772 last_blob_page->SetEndRowID(end_row);
773 (void)shard_header_->SetPage(last_blob_page);
774 return Status::OK();
775 }
776
NewBlobPage(const int & shard_id,const std::vector<std::vector<uint8_t>> & blob_data,const std::vector<std::pair<int,int>> & rows_in_group,const std::shared_ptr<Page> & last_blob_page)777 Status ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
778 const std::vector<std::pair<int, int>> &rows_in_group,
779 const std::shared_ptr<Page> &last_blob_page) {
780 auto page_id = shard_header_->GetLastPageId(shard_id);
781 auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1;
782 auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0;
783 // index(0) indicate appendBlobPage
784 for (uint32_t i = 1; i < rows_in_group.size(); ++i) {
785 auto blob_row = rows_in_group[i];
786
787 // Write 1 blob page to disk
788 auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
789 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
790 file_streams_[shard_id]->close();
791 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
792 }
793
794 (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row);
795 // Create new page info for header
796 auto page_size =
797 std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
798 std::vector<std::pair<int, uint64_t>> row_group_ids;
799 auto start_row = current_row;
800 auto end_row = start_row + blob_row.second - blob_row.first;
801 auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size);
802 (void)shard_header_->AddPage(std::make_shared<Page>(page));
803 current_row = end_row;
804 }
805 return Status::OK();
806 }
807
ShiftRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page)808 Status ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
809 std::shared_ptr<Page> &last_raw_page) {
810 auto blob_row = rows_in_group[0];
811 if (blob_row.first == blob_row.second) {
812 return Status::OK();
813 }
814 auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
815 if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) +
816 last_raw_page_size <=
817 page_size_) {
818 return Status::OK();
819 }
820 auto page_id = shard_header_->GetLastPageId(shard_id);
821 auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second;
822 auto last_raw_page_id = last_raw_page->GetPageID();
823 auto shift_size = last_raw_page_size - last_row_group_id_offset;
824
825 std::vector<uint8_t> buf(shift_size);
826
827 // Read last row group from previous raw data page
828 CHECK_FAIL_RETURN_UNEXPECTED(
829 shard_id >= 0 && shard_id < file_streams_.size(),
830 "Invalid data, shard_id should be in range [0, " + std::to_string(file_streams_.size()) + ").");
831
832 auto &io_seekg = file_streams_[shard_id]->seekg(
833 page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg);
834 if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
835 file_streams_[shard_id]->close();
836 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
837 }
838
839 auto &io_read = file_streams_[shard_id]->read(reinterpret_cast<char *>(&buf[0]), buf.size());
840 if (!io_read.good() || io_read.fail() || io_read.bad()) {
841 file_streams_[shard_id]->close();
842 RETURN_STATUS_UNEXPECTED("Failed to read file.");
843 }
844
845 // Merge into new row group at new raw data page
846 auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg);
847 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
848 file_streams_[shard_id]->close();
849 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
850 }
851
852 auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&buf[0]), buf.size());
853 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
854 file_streams_[shard_id]->close();
855 RETURN_STATUS_UNEXPECTED("Failed to write file.");
856 }
857 last_raw_page->DeleteLastGroupId();
858 (void)shard_header_->SetPage(last_raw_page);
859
860 // Refresh page info in header
861 int row_group_id = last_raw_page->GetLastRowGroupID().first + 1;
862 std::vector<std::pair<int, uint64_t>> row_group_ids;
863 row_group_ids.emplace_back(row_group_id, 0);
864 int page_type_id = last_raw_page->GetPageID();
865 auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size);
866 (void)shard_header_->AddPage(std::make_shared<Page>(page));
867
868 // Reset: last raw page
869 SetLastRawPage(shard_id, last_raw_page);
870 return Status::OK();
871 }
872
WriteRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,std::shared_ptr<Page> & last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)873 Status ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
874 std::shared_ptr<Page> &last_raw_page,
875 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
876 int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1;
877 for (uint32_t i = 0; i < rows_in_group.size(); ++i) {
878 const auto &blob_row = rows_in_group[i];
879 if (blob_row.first == blob_row.second) {
880 continue;
881 }
882 auto raw_size =
883 std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0);
884 if (!last_raw_page) {
885 RETURN_IF_NOT_OK(EmptyRawPage(shard_id, last_raw_page));
886 } else if (last_raw_page->GetPageSize() + raw_size > page_size_) {
887 RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page));
888 RETURN_IF_NOT_OK(EmptyRawPage(shard_id, last_raw_page));
889 }
890 RETURN_IF_NOT_OK(AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data));
891 }
892 RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page));
893 return Status::OK();
894 }
895
EmptyRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)896 Status ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
897 auto row_group_ids = std::vector<std::pair<int, uint64_t>>();
898 auto page_id = shard_header_->GetLastPageId(shard_id);
899 auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1;
900 auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0);
901 RETURN_IF_NOT_OK(shard_header_->AddPage(std::make_shared<Page>(page)));
902 SetLastRawPage(shard_id, last_raw_page);
903 return Status::OK();
904 }
905
AppendRawPage(const int & shard_id,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,int & last_row_group_id,std::shared_ptr<Page> last_raw_page,const std::vector<std::vector<uint8_t>> & bin_raw_data)906 Status ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
907 const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page,
908 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
909 std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds();
910 auto last_raw_page_id = last_raw_page->GetPageID();
911 auto n_bytes = last_raw_page->GetPageSize();
912
913 // previous raw data page
914 auto &io_seekp =
915 file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg);
916 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
917 file_streams_[shard_id]->close();
918 RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
919 }
920
921 if (chunk_id > 0) {
922 row_group_ids.emplace_back(++last_row_group_id, n_bytes);
923 }
924 n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first,
925 raw_data_size_.begin() + rows_in_group[chunk_id].second, 0);
926 RETURN_IF_NOT_OK(FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data));
927
928 // Update previous raw data page
929 last_raw_page->SetPageSize(n_bytes);
930 last_raw_page->SetRowGroupIds(row_group_ids);
931 RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page));
932
933 return Status::OK();
934 }
935
FlushBlobChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::vector<uint8_t>> & blob_data,const std::pair<int,int> & blob_row)936 Status ShardWriter::FlushBlobChunk(const std::shared_ptr<std::fstream> &out,
937 const std::vector<std::vector<uint8_t>> &blob_data,
938 const std::pair<int, int> &blob_row) {
939 CHECK_FAIL_RETURN_UNEXPECTED(
940 blob_row.first <= blob_row.second && blob_row.second <= static_cast<int>(blob_data.size()) && blob_row.first >= 0,
941 "Invalid data, blob_row: " + std::to_string(blob_row.first) + ", " + std::to_string(blob_row.second) +
942 " is invalid.");
943 for (int j = blob_row.first; j < blob_row.second; ++j) {
944 // Write the size of blob
945 uint64_t line_len = blob_data[j].size();
946 auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
947 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
948 out->close();
949 RETURN_STATUS_UNEXPECTED("Failed to write file.");
950 }
951
952 // Write the data of blob
953 auto line = blob_data[j];
954 auto &io_handle_data = out->write(reinterpret_cast<char *>(&line[0]), line_len);
955 if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) {
956 out->close();
957 RETURN_STATUS_UNEXPECTED("Failed to write file.");
958 }
959 }
960 return Status::OK();
961 }
962
FlushRawChunk(const std::shared_ptr<std::fstream> & out,const std::vector<std::pair<int,int>> & rows_in_group,const int & chunk_id,const std::vector<std::vector<uint8_t>> & bin_raw_data)963 Status ShardWriter::FlushRawChunk(const std::shared_ptr<std::fstream> &out,
964 const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
965 const std::vector<std::vector<uint8_t>> &bin_raw_data) {
966 for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) {
967 // Write the size of multi schemas
968 for (uint32_t j = 0; j < schema_count_; ++j) {
969 uint64_t line_len = bin_raw_data[i * schema_count_ + j].size();
970 auto &io_handle = out->write(reinterpret_cast<char *>(&line_len), kInt64Len);
971 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
972 out->close();
973 RETURN_STATUS_UNEXPECTED("Failed to write file.");
974 }
975 }
976 // Write the data of multi schemas
977 for (uint32_t j = 0; j < schema_count_; ++j) {
978 auto line = bin_raw_data[i * schema_count_ + j];
979 auto &io_handle = out->write(reinterpret_cast<char *>(&line[0]), line.size());
980 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
981 out->close();
982 RETURN_STATUS_UNEXPECTED("Failed to write file.");
983 }
984 }
985 }
986 return Status::OK();
987 }
988
989 // Allocate data to shards evenly
BreakIntoShards()990 std::vector<std::pair<int, int>> ShardWriter::BreakIntoShards() {
991 std::vector<std::pair<int, int>> shards;
992 int row_in_shard = row_count_ / shard_count_;
993 int remains = row_count_ % shard_count_;
994
995 std::vector<int> v_list(shard_count_);
996 std::iota(v_list.begin(), v_list.end(), 0);
997
998 std::mt19937 g = mindspore::dataset::GetRandomDevice();
999 std::shuffle(v_list.begin(), v_list.end(), g);
1000 std::unordered_set<int> set(v_list.begin(), v_list.begin() + remains);
1001
1002 if (shard_count_ <= kMaxShardCount) {
1003 int start_row = 0;
1004 for (int i = 0; i < shard_count_; ++i) {
1005 int end_row = start_row + row_in_shard;
1006 if (set.count(i)) end_row++;
1007 shards.emplace_back(start_row, end_row);
1008 start_row = end_row;
1009 }
1010 }
1011 return shards;
1012 }
1013
WriteShardHeader()1014 Status ShardWriter::WriteShardHeader() {
1015 RETURN_UNEXPECTED_IF_NULL(shard_header_);
1016 int64_t compression_temp = compression_size_;
1017 uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
1018 shard_header_->SetCompressionSize(compression_size);
1019
1020 auto shard_header = shard_header_->SerializeHeader();
1021 // Write header data to multi files
1022 CHECK_FAIL_RETURN_UNEXPECTED(
1023 shard_count_ <= static_cast<int>(file_streams_.size()) && shard_count_ <= static_cast<int>(shard_header.size()),
1024 "Invalid data, shard count should be less than or equal to file size: " + std::to_string(file_streams_.size()) +
1025 ", and header size: " + std::to_string(shard_header.size()) + ".");
1026 if (shard_count_ <= kMaxShardCount) {
1027 for (int shard_id = 0; shard_id < shard_count_; ++shard_id) {
1028 auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg);
1029 if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
1030 file_streams_[shard_id]->close();
1031 RETURN_STATUS_UNEXPECTED("Failed to seekp file.");
1032 }
1033
1034 std::vector<uint8_t> bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end());
1035 uint64_t line_len = bin_header.size();
1036 if (line_len + kInt64Len > header_size_) {
1037 file_streams_[shard_id]->close();
1038 RETURN_STATUS_UNEXPECTED("shard header is too big.");
1039 }
1040 auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast<char *>(&line_len), kInt64Len);
1041 if (!io_handle.good() || io_handle.fail() || io_handle.bad()) {
1042 file_streams_[shard_id]->close();
1043 RETURN_STATUS_UNEXPECTED("Failed to write file.");
1044 }
1045
1046 auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast<char *>(&bin_header[0]), line_len);
1047 if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) {
1048 file_streams_[shard_id]->close();
1049 RETURN_STATUS_UNEXPECTED("Failed to write file.");
1050 }
1051 file_streams_[shard_id]->close();
1052 }
1053 }
1054 return Status::OK();
1055 }
1056
SerializeRawData(std::map<uint64_t,std::vector<json>> & raw_data,std::vector<std::vector<uint8_t>> & bin_data,uint32_t row_count)1057 Status ShardWriter::SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,
1058 std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count) {
1059 // define the number of thread
1060 uint32_t thread_num = std::thread::hardware_concurrency();
1061 if (thread_num == 0) {
1062 thread_num = kThreadNumber;
1063 }
1064 // Set the number of samples processed by each thread
1065 int group_num = ceil(row_count * 1.0 / thread_num);
1066 std::vector<std::thread> thread_set(thread_num);
1067 int work_thread_num = 0;
1068 for (uint32_t x = 0; x < thread_num; ++x) {
1069 int start_num = x * group_num;
1070 int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num;
1071 if (start_num >= end_num) {
1072 continue;
1073 }
1074 // Define the run boundary and start the child thread
1075 thread_set[x] =
1076 std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data));
1077 work_thread_num++;
1078 }
1079 for (uint32_t x = 0; x < work_thread_num; ++x) {
1080 // Set obstacles to prevent the main thread from running
1081 thread_set[x].join();
1082 }
1083 CHECK_FAIL_RETURN_SYNTAX_ERROR(flag_ != true, "Error raised in FillArray function.");
1084 return Status::OK();
1085 }
1086
SetRawDataSize(const std::vector<std::vector<uint8_t>> & bin_raw_data)1087 Status ShardWriter::SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data) {
1088 raw_data_size_ = std::vector<uint64_t>(row_count_, 0);
1089 for (uint32_t i = 0; i < row_count_; ++i) {
1090 raw_data_size_[i] = std::accumulate(
1091 bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0,
1092 [](uint64_t accumulator, const std::vector<uint8_t> &row) { return accumulator + kInt64Len + row.size(); });
1093 }
1094 CHECK_FAIL_RETURN_SYNTAX_ERROR(
1095 *std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_,
1096 "Invalid data, Page size: " + std::to_string(page_size_) + " is too small to save a raw row!");
1097 return Status::OK();
1098 }
1099
SetBlobDataSize(const std::vector<std::vector<uint8_t>> & blob_data)1100 Status ShardWriter::SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data) {
1101 blob_data_size_ = std::vector<uint64_t>(row_count_);
1102 (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(),
1103 [](const std::vector<uint8_t> &row) { return kInt64Len + row.size(); });
1104 CHECK_FAIL_RETURN_SYNTAX_ERROR(
1105 *std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_,
1106 "Invalid data, Page size: " + std::to_string(page_size_) + " is too small to save a blob row!");
1107 return Status::OK();
1108 }
1109
SetLastRawPage(const int & shard_id,std::shared_ptr<Page> & last_raw_page)1110 Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
1111 // Get last raw page
1112 auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw);
1113 CHECK_FAIL_RETURN_SYNTAX_ERROR(last_raw_page_id >= 0, "Invalid data, last_raw_page_id: " +
1114 std::to_string(last_raw_page_id) + " should be positive.");
1115 RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_raw_page_id, &last_raw_page));
1116 return Status::OK();
1117 }
1118
SetLastBlobPage(const int & shard_id,std::shared_ptr<Page> & last_blob_page)1119 Status ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page) {
1120 // Get last blob page
1121 auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob);
1122 CHECK_FAIL_RETURN_SYNTAX_ERROR(last_blob_page_id >= 0, "Invalid data, last_blob_page_id: " +
1123 std::to_string(last_blob_page_id) + " should be positive.");
1124 RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_blob_page_id, &last_blob_page));
1125 return Status::OK();
1126 }
1127
Initialize(const std::unique_ptr<ShardWriter> * writer_ptr,const std::vector<std::string> & file_names)1128 Status ShardWriter::Initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
1129 const std::vector<std::string> &file_names) {
1130 RETURN_UNEXPECTED_IF_NULL(writer_ptr);
1131 RETURN_IF_NOT_OK((*writer_ptr)->Open(file_names, false));
1132 RETURN_IF_NOT_OK((*writer_ptr)->SetHeaderSize(kDefaultHeaderSize));
1133 RETURN_IF_NOT_OK((*writer_ptr)->SetPageSize(kDefaultPageSize));
1134 return Status::OK();
1135 }
1136 } // namespace mindrecord
1137 } // namespace mindspore
1138