• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_
19 
20 #include <dirent.h>
21 #include <signal.h>
22 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
23 #include <sys/prctl.h>
24 #endif
25 #include <sys/stat.h>
26 #include <sys/types.h>
27 #include <unistd.h>
28 #include <algorithm>
29 #include <chrono>
30 #include <cstdint>
31 #include <fstream>
32 #include <iostream>
33 #include <map>
34 #include <memory>
35 #include <mutex>
36 #include <set>
37 #include <stack>
38 #include <string>
39 #include <thread>
40 #include <tuple>
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <utility>
44 #include <vector>
45 #include "minddata/mindrecord/include/common/shard_utils.h"
46 #include "minddata/mindrecord/include/shard_category.h"
47 #include "minddata/mindrecord/include/shard_column.h"
48 #include "minddata/mindrecord/include/shard_distributed_sample.h"
49 #include "minddata/mindrecord/include/shard_error.h"
50 #include "minddata/mindrecord/include/shard_index_generator.h"
51 #include "minddata/mindrecord/include/shard_operator.h"
52 #include "minddata/mindrecord/include/shard_pk_sample.h"
53 #include "minddata/mindrecord/include/shard_reader.h"
54 #include "minddata/mindrecord/include/shard_sample.h"
55 #include "minddata/mindrecord/include/shard_shuffle.h"
56 #include "utils/log_adapter.h"
57 
58 #define API_PUBLIC __attribute__((visibility("default")))
59 
60 namespace mindspore {
61 namespace mindrecord {
62 using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>;
63 using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>;
64 using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>;
65 const int kNumBatchInMap = 1000;  // iterator buffer size in row-reader mode
66 
67 class API_PUBLIC ShardReader {
68  public:
69   ShardReader();
70 
71   virtual ~ShardReader();
72 
73   /// \brief open files and initialize reader, c++ API
74   /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
75   /// \param[in] load_dataset load dataset from single file or not
76   /// \param[in] n_consumer number of threads when reading
77   /// \param[in] selected_columns column list to be populated
78   /// \param[in] operators operators applied to data, operator type is shuffle, sample or category
79   /// \param[in] num_padded the number of padded samples
80   /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization
81   /// \return MSRStatus the status of MSRStatus
82   Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
83               const std::vector<std::string> &selected_columns = {},
84               const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
85               bool lazy_load = false);
86 
87   /// \brief close reader
88   /// \return null
89   void Close();
90 
91   /// \brief read the file, get schema meta,statistics and index, single-thread mode
92   /// \return MSRStatus the status of MSRStatus
93   Status Open();
94 
95   /// \brief read the file, get schema meta,statistics and index, multiple-thread mode
96   /// \return MSRStatus the status of MSRStatus
97   Status Open(int n_consumer);
98 
99   /// \brief launch threads to get batches
100   /// \param[in] is_simple_reader trigger threads if false; do nothing if true
101   /// \return MSRStatus the status of MSRStatus
102   Status Launch(bool is_simple_reader = false);
103 
104   /// \brief aim to get the meta data
105   /// \return the metadata
106   std::shared_ptr<ShardHeader> GetShardHeader() const;
107 
108   /// \brief aim to get columns context
109   /// \return the columns
110   std::shared_ptr<ShardColumn> GetShardColumn() const;
111 
112   /// \brief get the number of shards
113   /// \return # of shards
114   int GetShardCount() const;
115 
116   /// \brief get the number of rows in database
117   /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
118   /// \param[in] load_dataset load dataset from single file or not
119   /// \param[in] op smart pointer refer to ShardCategory or ShardSample object
120   /// \param[out] count # of rows
121   /// \return MSRStatus the status of MSRStatus
122   Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
123                         const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded);
124 
125   /// \brief shuffle task with incremental seed
126   /// \return void
127   void ShuffleTask();
128 
129   /// \brief get the number of rows in database
130   /// \return # of rows
131   int GetNumRows() const;
132 
133   /// \brief Read the summary of row groups
134   /// \return the tuple of 4 elements
135   ///         1. Sharding ID
136   ///         2. Row group ID
137   ///         3. The row ID started in row group
138   ///         4. # of rows in row group
139   std::vector<std::tuple<int, int, int, uint64_t>> ReadRowGroupSummary();
140 
141   /// \brief Read 1 row group data, excluding images
142   /// \param[in] groupID row group ID
143   /// \param[in] shard_id sharding ID
144   /// \param[in] columns multi-columns retrieved
145   /// \return the tuple of 5 elements
146   ///         1. file name where row group is located
147   ///         2. Actual row group size
148   ///         3. Offset address of row group in file
149   ///         4. The list of image offset in page [startOffset, endOffset)
150   ///         5. The list of columns data
151   Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
152                            std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);
153 
154   /// \brief Read 1 row group data, excluding images, following an index field criteria
155   /// \param[in] groupID row group ID
156   /// \param[in] shard_id sharding ID
157   /// \param[in] column-value pair of criteria to fulfill
158   /// \param[in] columns multi-columns retrieved
159   /// \return the tuple of 5 elements
160   ///         1. file name where row group is located
161   ///         2. Actual row group size
162   ///         3. Offset address of row group in file
163   ///         4. The list of image offset in page [startOffset, endOffset)
164   ///         5. The list of columns data
165   Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
166                               const std::vector<std::string> &columns,
167                               std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);
168 
169   /// \brief return a batch, given that one is ready
170   /// \return a batch of images and image data
171   std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext();
172 
173   /// \brief return a row by id
174   /// \return a batch of images and image data
175   TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id);
176   /// \brief  get blob filed list
177   /// \return blob field list
178   std::pair<ShardType, std::vector<std::string>> GetBlobFields();
179 
180   /// \brief reset reader
181   /// \return null
182   void Reset();
183 
184   /// \brief set flag of all-in-index
185   /// \return null
SetAllInIndex(bool all_in_index)186   void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
187 
188   /// \brief get all classes
189   Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);
190 
191   /// \brief get a read-only ptr to the sampled ids for this epoch
192   const std::vector<int> *GetSampleIds();
193 
194   /// \brief get the size of blob data
195   Status GetTotalBlobSize(int64_t *total_blob_size);
196 
197   /// \brief extract uncompressed data based on column list
198   Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
199                         std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr);
200 
201  protected:
202   /// \brief sqlite call back function
203   static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
204 
205  private:
206   /// \brief wrap up labels to json format
207   Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
208                             std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id,
209                             const std::vector<std::string> &columns,
210                             std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
211 
212   /// \brief read all rows for specified columns
213   Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr);
214 
215   /// \brief read row meta by shard_id and sample_id
216   Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
217                                           const uint32_t &sample_id, std::shared_ptr<ROW_GROUPS> *row_group_ptr);
218 
219   /// \brief read all rows in one shard
220   Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
221                             std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
222                             std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
223 
224   /// \brief initialize reader
225   Status Init(const std::vector<std::string> &file_paths, bool load_dataset);
226 
227   /// \brief validate column list
228   Status CheckColumnList(const std::vector<std::string> &selected_columns);
229 
230   /// \brief populate one row by task list in row-reader mode
231   void ConsumerByRow(int consumer_id);
232 
233   /// \brief get offset address of images within page
234   std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
235                                                     const std::pair<std::string, std::string> &criteria = {"", ""});
236 
237   /// \brief get page id by category
238   Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
239                             std::shared_ptr<std::vector<uint64_t>> *pages_ptr);
240   /// \brief execute sqlite query with prepare statement
241   Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
242                            std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
243   /// \brief verify the validity of dataset
244   Status VerifyDataset(sqlite3 **db, const string &file);
245 
246   /// \brief get column values
247   Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
248                    const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr);
249 
250   /// \brief get column values from raw data page
251   Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
252                            const std::pair<std::string, std::string> &criteria,
253                            std::shared_ptr<std::vector<json>> *labels_ptr);
254 
255   /// \brief create category-applied task list
256   Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
257 
258   /// \brief create task list in row-reader mode
259   Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
260                           const std::vector<std::shared_ptr<ShardOperator>> &operators);
261 
262   /// \brief create task list in row-reader mode and lazy mode
263   Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
264                               const std::vector<std::shared_ptr<ShardOperator>> &operators);
265 
266   /// \brief crate task list
267   Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
268                      const std::vector<std::shared_ptr<ShardOperator>> &operators);
269 
270   /// \brief check if all specified columns are in index table
271   void CheckIfColumnInIndex(const std::vector<std::string> &columns);
272 
273   /// \brief open multiple file handle
274   void FileStreamsOperator();
275 
276   /// \brief read one row by one task
277   Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt);
278 
279   /// \brief get labels from binary file
280   Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
281                                  const std::vector<std::vector<std::string>> &label_offsets,
282                                  std::shared_ptr<std::vector<json>> *labels_ptr);
283 
284   /// \brief get classes in one shard
285   void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
286                          std::shared_ptr<std::set<std::string>> category_ptr);
287 
288   /// \brief get number of classes
289   int64_t GetNumClasses(const std::string &category_field);
290 
291   /// \brief get meta of header
292   Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
293                  std::shared_ptr<std::vector<std::string>> *addresses_ptr);
294 
295  protected:
296   uint64_t header_size_;                       // header size
297   uint64_t page_size_;                         // page size
298   int shard_count_;                            // number of shards
299   std::shared_ptr<ShardHeader> shard_header_;  // shard header
300   std::shared_ptr<ShardColumn> shard_column_;  // shard column
301 
302   std::vector<sqlite3 *> database_paths_;                                        // sqlite handle list
303   std::vector<string> file_paths_;                                               // file paths
304   std::vector<std::shared_ptr<std::fstream>> file_streams_;                      // single-file handle list
305   std::vector<std::vector<std::shared_ptr<std::fstream>>> file_streams_random_;  // multiple-file handle list
306 
307  private:
308   int n_consumer_;                                         // number of workers (threads)
309   std::vector<std::string> selected_columns_;              // columns which will be read
310   std::map<string, uint64_t> column_schema_id_;            // column-schema map
311   std::vector<std::shared_ptr<ShardOperator>> operators_;  // data operators, including shuffle, sample and category
312   ShardTaskList tasks_;                                    // shard task list
313   std::mutex shard_locker_;                                // locker of shard
314 
315   // flags
316   bool all_in_index_ = true;  // if all columns are stored in index-table
317   bool interrupt_ = false;    // reader interrupted
318 
319   int num_padded_;  // number of padding samples
320 
321   // Delivery/Iterator mode begin
322   const std::string kThreadName = "THRD_ITER_";  // prefix of thread name
323   std::vector<std::thread> thread_set_;          // thread list
324   int num_rows_;                                 // number of rows
325   int64_t total_blob_size_;                      // total size of blob data
326   std::mutex mtx_delivery_;                      // locker for delivery
327   std::condition_variable cv_delivery_;          // conditional variable for delivery
328   std::condition_variable cv_iterator_;          // conditional variable for iterator
329   std::atomic<int> sample_id_position_;          // index into the sample ids vector for the current sample id
330   std::atomic<int> deliver_id_;                  // delivery ID which is picked up by iterator
331   // map of delivery
332   std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;
333   // Delivery/Iterator mode end
334 
335   // all metadata in the index is not loaded during initialization
336   bool lazy_load_;
337 
338   // indicate shard_id : inc_count
339   // 0 : 15  -  shard0 has 15 samples
340   // 1 : 41  -  shard1 has 26 samples
341   // 2 : 58  -  shard2 has 17 samples
342   std::vector<uint32_t> shard_sample_count_;
343 };
344 }  // namespace mindrecord
345 }  // namespace mindspore
346 
347 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_
348