• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_
19 
20 #ifndef _MSC_VER
21 #include <libgen.h>
22 #include <sys/file.h>
23 #endif
24 #include <algorithm>
25 #include <array>
26 #include <chrono>
27 #include <exception>
28 #include <fstream>
29 #include <functional>
30 #include <map>
31 #include <memory>
32 #include <mutex>
33 #include <random>
34 #include <string>
35 #include <thread>
36 #include <tuple>
37 #include <utility>
38 #include <vector>
39 
40 #include "minddata/mindrecord/include/common/log_adapter.h"
41 #include "minddata/mindrecord/include/common/shard_utils.h"
42 #include "minddata/mindrecord/include/shard_column.h"
43 #include "minddata/mindrecord/include/shard_error.h"
44 #include "minddata/mindrecord/include/shard_header.h"
45 #include "minddata/mindrecord/include/shard_index.h"
46 #include "pybind11/pybind11.h"
47 #include "pybind11/stl.h"
48 
49 namespace mindspore {
50 namespace mindrecord {
51 class MINDRECORD_API ShardWriter {
52  public:
53   ShardWriter();
54 
55   ~ShardWriter();
56 
57   /// \brief Open file at the beginning
58   /// \param[in] paths the file names list
59   /// \param[in] append new data at the end of file if true, otherwise try to overwrite file
60   /// \param[in] overwrite a file with the same name if true
61   /// \return Status
62   Status Open(const std::vector<std::string> &paths, bool append = false, bool overwrite = false);
63 
64   /// \brief Open file at the ending
65   /// \param[in] paths the file names list
66   /// \return MSRStatus the status of MSRStatus
67   Status OpenForAppend(const std::string &path);
68 
69   /// \brief Write header to disk
70   /// \return MSRStatus the status of MSRStatus
71   Status Commit();
72 
73   /// \brief Set file size
74   /// \param[in] header_size the size of header, only (1<<N) is accepted
75   /// \return MSRStatus the status of MSRStatus
76   Status SetHeaderSize(const uint64_t &header_size);
77 
78   /// \brief Set page size
79   /// \param[in] page_size the size of page, only (1<<N) is accepted
80   /// \return MSRStatus the status of MSRStatus
81   Status SetPageSize(const uint64_t &page_size);
82 
83   /// \brief Set shard header
84   /// \param[in] header_data the info of header
85   ///        WARNING, only called when file is empty
86   /// \return MSRStatus the status of MSRStatus
87   Status SetShardHeader(std::shared_ptr<ShardHeader> header_data);
88 
89   /// \brief write raw data by group size
90   /// \param[in] raw_data the vector of raw json data, vector format
91   /// \param[in] blob_data the vector of image data
92   /// \param[in] sign validate data or not
93   /// \return MSRStatus the status of MSRStatus to judge if write successfully
94   Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,  // NOLINT
95                       bool sign = true, bool parallel_writer = false);
96 
97   /// \brief write raw data by group size for call from python
98   /// \param[in] raw_data the vector of raw json data, python-handle format
99   /// \param[in] blob_data the vector of blob json data, python-handle format
100   /// \param[in] sign validate data or not
101   /// \return MSRStatus the status of MSRStatus to judge if write successfully
102   Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,                     // NOLINT
103                       std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,  // NOLINT
104                       bool parallel_writer = false);
105 
106   Status MergeBlobData(const std::vector<string> &blob_fields,
107                        const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
108                        std::shared_ptr<std::vector<uint8_t>> *output);
109 
110   static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names);
111 
112  private:
113   /// \brief write shard header data to disk
114   Status WriteShardHeader();
115 
116   /// \brief erase error data
117   void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data,  // NOLINT
118                        std::vector<std::vector<uint8_t>> &blob_data);    // NOLINT
119 
120   /// \brief populate error data
121   void PopulateMutexErrorData(const int &row, const std::string &message,
122                               std::map<int, std::string> &err_raw_data);  // NOLINT
123 
124   /// \brief check data
125   void CheckSliceData(int start_row, int end_row, json schema, const std::vector<json> &sub_raw_data,
126                       std::map<int, std::string> &err_raw_data);  // NOLINT
127 
128   /// \brief write shard header data to disk
129   Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,  // NOLINT
130                          std::vector<std::vector<uint8_t>> &blob_data,     // NOLINT
131                          bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr);
132 
133   /// \brief fill data array in multiple thread run
134   void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,  // NOLINT
135                  std::vector<std::vector<uint8_t>> &bin_data);                    // NOLINT
136 
137   /// \brief serialized raw data
138   Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,  // NOLINT
139                           std::vector<std::vector<uint8_t>> &bin_data,      // NOLINT
140                           uint32_t row_count);
141 
142   /// \brief write all data parallel
143   Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
144                            const std::vector<std::vector<uint8_t>> &bin_raw_data);
145 
146   /// \brief write data shard by shard
147   Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
148                       const std::vector<std::vector<uint8_t>> &bin_raw_data);
149 
150   /// \brief break image data up into multiple row groups
151   Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
152                      std::vector<std::pair<int, int>> &rows_in_group,  // NOLINT
153                      const std::shared_ptr<Page> &last_raw_page, const std::shared_ptr<Page> &last_blob_page);
154 
155   /// \brief append partial blob data to previous page
156   Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
157                         const std::vector<std::pair<int, int>> &rows_in_group,
158                         const std::shared_ptr<Page> &last_blob_page);
159 
160   /// \brief write new blob data page to disk
161   Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
162                      const std::vector<std::pair<int, int>> &rows_in_group,
163                      const std::shared_ptr<Page> &last_blob_page);
164 
165   /// \brief shift last row group to next raw page for new appending
166   Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
167                       std::shared_ptr<Page> &last_raw_page);  // NOLINT
168 
169   /// \brief write raw data page to disk
170   Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
171                       std::shared_ptr<Page> &last_raw_page,  // NOLINT
172                       const std::vector<std::vector<uint8_t>> &bin_raw_data);
173 
174   /// \brief generate empty raw data page
175   Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);  // NOLINT
176 
177   /// \brief append a row group at the end of raw page
178   Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
179                        int &last_row_groupId, std::shared_ptr<Page> last_raw_page,  // NOLINT
180                        const std::vector<std::vector<uint8_t>> &bin_raw_data);
181 
182   /// \brief write blob chunk to disk
183   Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data,
184                         const std::pair<int, int> &blob_row);
185 
186   /// \brief write raw chunk to disk
187   Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group,
188                        const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data);
189 
190   /// \brief break up into tasks by shard
191   std::vector<std::pair<int, int>> BreakIntoShards();
192 
193   /// \brief calculate raw data size row by row
194   Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data);
195 
196   /// \brief calculate blob data size row by row
197   Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data);
198 
199   /// \brief populate last raw page pointer
200   Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);  // NOLINT
201 
202   /// \brief populate last blob page pointer
203   Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page);  // NOLINT
204 
205   /// \brief check the data by schema
206   Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data);
207 
208   /// \brief check the data and type
209   Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
210                                std::map<int, std::string> &err_raw_data);  // NOLINT
211 
212   /// \brief Lock writer and save pages info
213   Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr);
214 
215   /// \brief Unlock writer and save pages info
216   Status UnlockWriter(int fd, bool parallel_writer = false);
217 
218   /// \brief Check raw data before writing
219   Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,  // NOLINT
220                               vector<vector<uint8_t>> &blob_data,               // NOLINT
221                               bool sign, int *schema_count, int *row_count);
222 
223   /// \brief Get full path from file name
224   Status GetFullPathFromFileName(const std::vector<std::string> &paths);
225 
226   /// \brief Open files
227   Status OpenDataFiles(bool append, bool overwrite);
228 
229   /// \brief Remove lock file
230   Status RemoveLockFile();
231 
232   /// \brief Remove lock file
233   Status InitLockFile();
234 
235  private:
236   const std::string kLockFileSuffix = "_Locker";
237   const std::string kPageFileSuffix = "_Pages";
238   std::string lock_file_;   // lock file for parallel run
239   std::string pages_file_;  // temporary file of pages info for parallel run
240 
241   int shard_count_;        // number of files
242   uint64_t header_size_;   // header size
243   uint64_t page_size_;     // page size
244   uint32_t row_count_;     // count of rows
245   uint32_t schema_count_;  // count of schemas
246 
247   std::vector<uint64_t> raw_data_size_;   // Raw data size
248   std::vector<uint64_t> blob_data_size_;  // Blob data size
249 
250   std::vector<std::string> file_paths_;                      // file paths
251   std::vector<std::shared_ptr<std::fstream>> file_streams_;  // file handles
252   std::shared_ptr<ShardHeader> shard_header_;                // shard header
253   std::shared_ptr<ShardColumn> shard_column_;                // shard columns
254 
255   std::map<uint64_t, std::map<int, std::string>> err_mg_;  // used for storing error raw_data info
256 
257   std::mutex check_mutex_;  // mutex for data check
258   std::atomic<bool> flag_{false};
259   std::atomic<int64_t> compression_size_;
260 };
261 }  // namespace mindrecord
262 }  // namespace mindspore
263 
264 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_WRITER_H_
265