• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_
19 
20 #include <libgen.h>
21 #include <sys/file.h>
22 #include <unistd.h>
23 #include <algorithm>
24 #include <array>
25 #include <chrono>
26 #include <exception>
27 #include <fstream>
28 #include <functional>
29 #include <map>
30 #include <memory>
31 #include <mutex>
32 #include <random>
33 #include <string>
34 #include <thread>
35 #include <tuple>
36 #include <utility>
37 #include <vector>
38 #include "minddata/mindrecord/include/common/shard_utils.h"
39 #include "minddata/mindrecord/include/shard_column.h"
40 #include "minddata/mindrecord/include/shard_error.h"
41 #include "minddata/mindrecord/include/shard_header.h"
42 #include "minddata/mindrecord/include/shard_index.h"
43 #include "pybind11/pybind11.h"
44 #include "pybind11/stl.h"
45 #include "utils/log_adapter.h"
46 
47 namespace mindspore {
48 namespace mindrecord {
49 class __attribute__((visibility("default"))) ShardWriter {
50  public:
51   ShardWriter();
52 
53   ~ShardWriter();
54 
55   /// \brief Open file at the beginning
56   /// \param[in] paths the file names list
57   /// \param[in] append new data at the end of file if true, otherwise overwrite file
58   /// \return Status
59   Status Open(const std::vector<std::string> &paths, bool append = false);
60 
61   /// \brief Open file at the ending
62   /// \param[in] paths the file names list
63   /// \return MSRStatus the status of MSRStatus
64   Status OpenForAppend(const std::string &path);
65 
66   /// \brief Write header to disk
67   /// \return MSRStatus the status of MSRStatus
68   Status Commit();
69 
70   /// \brief Set file size
71   /// \param[in] header_size the size of header, only (1<<N) is accepted
72   /// \return MSRStatus the status of MSRStatus
73   Status SetHeaderSize(const uint64_t &header_size);
74 
75   /// \brief Set page size
76   /// \param[in] page_size the size of page, only (1<<N) is accepted
77   /// \return MSRStatus the status of MSRStatus
78   Status SetPageSize(const uint64_t &page_size);
79 
80   /// \brief Set shard header
81   /// \param[in] header_data the info of header
82   ///        WARNING, only called when file is empty
83   /// \return MSRStatus the status of MSRStatus
84   Status SetShardHeader(std::shared_ptr<ShardHeader> header_data);
85 
86   /// \brief write raw data by group size
87   /// \param[in] raw_data the vector of raw json data, vector format
88   /// \param[in] blob_data the vector of image data
89   /// \param[in] sign validate data or not
90   /// \return MSRStatus the status of MSRStatus to judge if write successfully
91   Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
92                       bool sign = true, bool parallel_writer = false);
93 
94   /// \brief write raw data by group size for call from python
95   /// \param[in] raw_data the vector of raw json data, python-handle format
96   /// \param[in] blob_data the vector of blob json data, python-handle format
97   /// \param[in] sign validate data or not
98   /// \return MSRStatus the status of MSRStatus to judge if write successfully
99   Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
100                       std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
101                       bool parallel_writer = false);
102 
103   Status MergeBlobData(const std::vector<string> &blob_fields,
104                        const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
105                        std::shared_ptr<std::vector<uint8_t>> *output);
106 
107   static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names);
108 
109  private:
110   /// \brief write shard header data to disk
111   Status WriteShardHeader();
112 
113   /// \brief erase error data
114   void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data);
115 
116   /// \brief populate error data
117   void PopulateMutexErrorData(const int &row, const std::string &message, std::map<int, std::string> &err_raw_data);
118 
119   /// \brief check data
120   void CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data,
121                       std::map<int, std::string> &err_raw_data);
122 
123   /// \brief write shard header data to disk
124   Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data,
125                          bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr);
126 
127   /// \brief fill data array in multiple thread run
128   void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
129                  std::vector<std::vector<uint8_t>> &bin_data);
130 
131   /// \brief serialized raw data
132   Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &bin_data,
133                           uint32_t row_count);
134 
135   /// \brief write all data parallel
136   Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
137                            const std::vector<std::vector<uint8_t>> &bin_raw_data);
138 
139   /// \brief write data shard by shard
140   Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
141                       const std::vector<std::vector<uint8_t>> &bin_raw_data);
142 
143   /// \brief break image data up into multiple row groups
144   Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
145                      std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page,
146                      const std::shared_ptr<Page> &last_blob_page);
147 
148   /// \brief append partial blob data to previous page
149   Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
150                         const std::vector<std::pair<int, int>> &rows_in_group,
151                         const std::shared_ptr<Page> &last_blob_page);
152 
153   /// \brief write new blob data page to disk
154   Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
155                      const std::vector<std::pair<int, int>> &rows_in_group,
156                      const std::shared_ptr<Page> &last_blob_page);
157 
158   /// \brief shift last row group to next raw page for new appending
159   Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
160                       std::shared_ptr<Page> &last_raw_page);
161 
162   /// \brief write raw data page to disk
163   Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
164                       std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data);
165 
166   /// \brief generate empty raw data page
167   Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);
168 
169   /// \brief append a row group at the end of raw page
170   Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
171                        int &last_row_groupId, std::shared_ptr<Page> last_raw_page,
172                        const std::vector<std::vector<uint8_t>> &bin_raw_data);
173 
174   /// \brief write blob chunk to disk
175   Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data,
176                         const std::pair<int, int> &blob_row);
177 
178   /// \brief write raw chunk to disk
179   Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group,
180                        const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data);
181 
182   /// \brief break up into tasks by shard
183   std::vector<std::pair<int, int>> BreakIntoShards();
184 
185   /// \brief calculate raw data size row by row
186   Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data);
187 
188   /// \brief calculate blob data size row by row
189   Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data);
190 
191   /// \brief populate last raw page pointer
192   Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);
193 
194   /// \brief populate last blob page pointer
195   Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page);
196 
197   /// \brief check the data by schema
198   Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data);
199 
200   /// \brief check the data and type
201   Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
202                                std::map<int, std::string> &err_raw_data);
203 
204   /// \brief Lock writer and save pages info
205   Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr);
206 
207   /// \brief Unlock writer and save pages info
208   Status UnlockWriter(int fd, bool parallel_writer = false);
209 
210   /// \brief Check raw data before writing
211   Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
212                               bool sign, int *schema_count, int *row_count);
213 
214   /// \brief Get full path from file name
215   Status GetFullPathFromFileName(const std::vector<std::string> &paths);
216 
217   /// \brief Open files
218   Status OpenDataFiles(bool append);
219 
220   /// \brief Remove lock file
221   Status RemoveLockFile();
222 
223   /// \brief Remove lock file
224   Status InitLockFile();
225 
226  private:
227   const std::string kLockFileSuffix = "_Locker";
228   const std::string kPageFileSuffix = "_Pages";
229   std::string lock_file_;   // lock file for parallel run
230   std::string pages_file_;  // temporary file of pages info for parallel run
231 
232   int shard_count_;        // number of files
233   uint64_t header_size_;   // header size
234   uint64_t page_size_;     // page size
235   uint32_t row_count_;     // count of rows
236   uint32_t schema_count_;  // count of schemas
237 
238   std::vector<uint64_t> raw_data_size_;   // Raw data size
239   std::vector<uint64_t> blob_data_size_;  // Blob data size
240 
241   std::vector<std::string> file_paths_;                      // file paths
242   std::vector<std::shared_ptr<std::fstream>> file_streams_;  // file handles
243   std::shared_ptr<ShardHeader> shard_header_;                // shard header
244   std::shared_ptr<ShardColumn> shard_column_;                // shard columns
245 
246   std::map<uint64_t, std::map<int, std::string>> err_mg_;  // used for storing error raw_data info
247 
248   std::mutex check_mutex_;  // mutex for data check
249   std::atomic<bool> flag_{false};
250   std::atomic<int64_t> compression_size_;
251 };
252 }  // namespace mindrecord
253 }  // namespace mindspore
254 
255 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_
256