1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_ 19 20 #include <libgen.h> 21 #include <sys/file.h> 22 #include <unistd.h> 23 #include <algorithm> 24 #include <array> 25 #include <chrono> 26 #include <exception> 27 #include <fstream> 28 #include <functional> 29 #include <map> 30 #include <memory> 31 #include <mutex> 32 #include <random> 33 #include <string> 34 #include <thread> 35 #include <tuple> 36 #include <utility> 37 #include <vector> 38 #include "minddata/mindrecord/include/common/shard_utils.h" 39 #include "minddata/mindrecord/include/shard_column.h" 40 #include "minddata/mindrecord/include/shard_error.h" 41 #include "minddata/mindrecord/include/shard_header.h" 42 #include "minddata/mindrecord/include/shard_index.h" 43 #include "pybind11/pybind11.h" 44 #include "pybind11/stl.h" 45 #include "utils/log_adapter.h" 46 47 namespace mindspore { 48 namespace mindrecord { 49 class __attribute__((visibility("default"))) ShardWriter { 50 public: 51 ShardWriter(); 52 53 ~ShardWriter(); 54 55 /// \brief Open file at the beginning 56 /// \param[in] paths the file names list 57 /// \param[in] append new data at the end of file if true, otherwise overwrite file 58 /// \return Status 59 Status Open(const std::vector<std::string> &paths, bool append = false); 60 61 /// \brief Open file at the ending 62 /// \param[in] paths the file names list 63 /// \return MSRStatus the status of MSRStatus 64 Status OpenForAppend(const std::string &path); 65 66 /// \brief Write header to disk 67 /// \return MSRStatus the status of MSRStatus 68 Status Commit(); 69 70 /// \brief Set file size 71 /// \param[in] header_size the size of header, only (1<<N) is accepted 72 /// \return MSRStatus the status of MSRStatus 73 Status SetHeaderSize(const uint64_t &header_size); 74 75 /// \brief Set page size 76 /// \param[in] page_size the size of page, only (1<<N) is accepted 77 /// \return MSRStatus the status of MSRStatus 78 Status SetPageSize(const uint64_t &page_size); 79 80 /// \brief Set shard header 81 /// \param[in] header_data the info of header 82 /// WARNING, only called when file is empty 83 /// \return MSRStatus the status of MSRStatus 84 Status SetShardHeader(std::shared_ptr<ShardHeader> header_data); 85 86 /// \brief write raw data by group size 87 /// \param[in] raw_data the vector of raw json data, vector format 88 /// \param[in] blob_data the vector of image data 89 /// \param[in] sign validate data or not 90 /// \return MSRStatus the status of MSRStatus to judge if write successfully 91 Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, 92 bool sign = true, bool parallel_writer = false); 93 94 /// \brief write raw data by group size for call from python 95 /// \param[in] raw_data the vector of raw json data, python-handle format 96 /// \param[in] blob_data the vector of blob json data, python-handle format 97 /// \param[in] sign validate data or not 98 /// \return MSRStatus the status of MSRStatus to judge if write successfully 99 Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, 100 std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, 101 bool parallel_writer = false); 102 103 Status MergeBlobData(const std::vector<string> &blob_fields, 104 const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, 105 std::shared_ptr<std::vector<uint8_t>> *output); 106 107 static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names); 108 109 private: 110 /// \brief write shard header data to disk 111 Status WriteShardHeader(); 112 113 /// \brief erase error data 114 void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data); 115 116 /// \brief populate error data 117 void PopulateMutexErrorData(const int &row, const std::string &message, std::map<int, std::string> &err_raw_data); 118 119 /// \brief check data 120 void CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data, 121 std::map<int, std::string> &err_raw_data); 122 123 /// \brief write shard header data to disk 124 Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data, 125 bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr); 126 127 /// \brief fill data array in multiple thread run 128 void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data, 129 std::vector<std::vector<uint8_t>> &bin_data); 130 131 /// \brief serialized raw data 132 Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &bin_data, 133 uint32_t row_count); 134 135 /// \brief write all data parallel 136 Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data, 137 const std::vector<std::vector<uint8_t>> &bin_raw_data); 138 139 /// \brief write data shard by shard 140 Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, 141 const std::vector<std::vector<uint8_t>> &bin_raw_data); 142 143 /// \brief break image data up into multiple row groups 144 Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, 145 std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page, 146 const std::shared_ptr<Page> &last_blob_page); 147 148 /// \brief append partial blob data to previous page 149 Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, 150 const std::vector<std::pair<int, int>> &rows_in_group, 151 const std::shared_ptr<Page> &last_blob_page); 152 153 /// \brief write new blob data page to disk 154 Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, 155 const std::vector<std::pair<int, int>> &rows_in_group, 156 const std::shared_ptr<Page> &last_blob_page); 157 158 /// \brief shift last row group to next raw page for new appending 159 Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, 160 std::shared_ptr<Page> &last_raw_page); 161 162 /// \brief write raw data page to disk 163 Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, 164 std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data); 165 166 /// \brief generate empty raw data page 167 Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); 168 169 /// \brief append a row group at the end of raw page 170 Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id, 171 int &last_row_groupId, std::shared_ptr<Page> last_raw_page, 172 const std::vector<std::vector<uint8_t>> &bin_raw_data); 173 174 /// \brief write blob chunk to disk 175 Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data, 176 const std::pair<int, int> &blob_row); 177 178 /// \brief write raw chunk to disk 179 Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group, 180 const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data); 181 182 /// \brief break up into tasks by shard 183 std::vector<std::pair<int, int>> BreakIntoShards(); 184 185 /// \brief calculate raw data size row by row 186 Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data); 187 188 /// \brief calculate blob data size row by row 189 Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data); 190 191 /// \brief populate last raw page pointer 192 Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); 193 194 /// \brief populate last blob page pointer 195 Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page); 196 197 /// \brief check the data by schema 198 Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data); 199 200 /// \brief check the data and type 201 Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, 202 std::map<int, std::string> &err_raw_data); 203 204 /// \brief Lock writer and save pages info 205 Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr); 206 207 /// \brief Unlock writer and save pages info 208 Status UnlockWriter(int fd, bool parallel_writer = false); 209 210 /// \brief Check raw data before writing 211 Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, 212 bool sign, int *schema_count, int *row_count); 213 214 /// \brief Get full path from file name 215 Status GetFullPathFromFileName(const std::vector<std::string> &paths); 216 217 /// \brief Open files 218 Status OpenDataFiles(bool append); 219 220 /// \brief Remove lock file 221 Status RemoveLockFile(); 222 223 /// \brief Remove lock file 224 Status InitLockFile(); 225 226 private: 227 const std::string kLockFileSuffix = "_Locker"; 228 const std::string kPageFileSuffix = "_Pages"; 229 std::string lock_file_; // lock file for parallel run 230 std::string pages_file_; // temporary file of pages info for parallel run 231 232 int shard_count_; // number of files 233 uint64_t header_size_; // header size 234 uint64_t page_size_; // page size 235 uint32_t row_count_; // count of rows 236 uint32_t schema_count_; // count of schemas 237 238 std::vector<uint64_t> raw_data_size_; // Raw data size 239 std::vector<uint64_t> blob_data_size_; // Blob data size 240 241 std::vector<std::string> file_paths_; // file paths 242 std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles 243 std::shared_ptr<ShardHeader> shard_header_; // shard header 244 std::shared_ptr<ShardColumn> shard_column_; // shard columns 245 246 std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info 247 248 std::mutex check_mutex_; // mutex for data check 249 std::atomic<bool> flag_{false}; 250 std::atomic<int64_t> compression_size_; 251 }; 252 } // namespace mindrecord 253 } // namespace mindspore 254 255 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_ 256