• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_
18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_
19 
20 #include <signal.h>
21 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
22 #include <sys/prctl.h>
23 #endif
24 #include <sys/stat.h>
25 #include <sys/types.h>
26 #include <algorithm>
27 #include <chrono>
28 #include <cstdint>
29 #include <fstream>
30 #include <iostream>
31 #include <map>
32 #include <memory>
33 #include <mutex>
34 #include <set>
35 #include <stack>
36 #include <string>
37 #include <thread>
38 #include <tuple>
39 #include <unordered_map>
40 #include <unordered_set>
41 #include <utility>
42 #include <vector>
43 #include "minddata/mindrecord/include/common/log_adapter.h"
44 #include "minddata/mindrecord/include/common/shard_utils.h"
45 #include "minddata/mindrecord/include/shard_category.h"
46 #include "minddata/mindrecord/include/shard_column.h"
47 #include "minddata/mindrecord/include/shard_distributed_sample.h"
48 #include "minddata/mindrecord/include/shard_error.h"
49 #include "minddata/mindrecord/include/shard_index_generator.h"
50 #include "minddata/mindrecord/include/shard_operator.h"
51 #include "minddata/mindrecord/include/shard_pk_sample.h"
52 #include "minddata/mindrecord/include/shard_reader.h"
53 #include "minddata/mindrecord/include/shard_sample.h"
54 #include "minddata/mindrecord/include/shard_shuffle.h"
55 
56 namespace mindspore {
57 namespace mindrecord {
58 using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>;
59 using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>;
60 using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>;
61 const int kNumBatchInMap = 1000;  // iterator buffer size in row-reader mode
62 
63 class MINDRECORD_API ShardReader {
64  public:
65   ShardReader();
66 
67   virtual ~ShardReader();
68 
69   /// \brief open files and initialize reader, c++ API
70   /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
71   /// \param[in] load_dataset load dataset from single file or not
72   /// \param[in] n_consumer number of threads when reading
73   /// \param[in] selected_columns column list to be populated
74   /// \param[in] operators operators applied to data, operator type is shuffle, sample or category
75   /// \param[in] num_padded the number of padded samples
76   /// \param[in] load_mode
77   ///            LoadMode::kNormal: cache whole meta data for dataset
78   ///            LoadMode::kLazy: cache part meta data for dataset
79   ///            LoadMode::kLowest: don't cache meta data
80   /// \return MSRStatus the status of MSRStatus
81   Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
82               const std::vector<std::string> &selected_columns = {},
83               const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int64_t num_padded = 0,
84               LoadMode load_mode = LoadMode::kFast);
85 
86   /// \brief close reader
87   /// \return null
88   void Close();
89 
90   /// \brief read the file, get schema meta,statistics and index, multiple-thread mode
91   /// \return MSRStatus the status of MSRStatus
92   Status Open(int n_consumer);
93 
94   /// \brief increase number of random file stream for parallel read
95   /// \param[in] n_new_consumers number of new file streams to be added
96   /// \return MSRStatus the status of MSRStatus
97   Status ExtendRandomFileStreams(const int n_new_consumers);
98 
99   /// \brief decrease number of random file streams for parallel read
100   /// \param[in] n_remove_consumers number of file streams to be removed
101   /// \return MSRStatus the status of MSRStatus
102   Status ShrinkRandomFileStreams(const int n_remove_consumers);
103 
104   /// \brief launch threads to get batches
105   /// \param[in] is_simple_reader trigger threads if false; do nothing if true
106   /// \return MSRStatus the status of MSRStatus
107   Status Launch(bool is_simple_reader = false);
108 
109   /// \brief aim to get the meta data
110   /// \return the metadata
111   std::shared_ptr<ShardHeader> GetShardHeader() const;
112 
113   /// \brief aim to get columns context
114   /// \return the columns
115   std::shared_ptr<ShardColumn> GetShardColumn() const;
116 
117   /// \brief get the number of shards
118   /// \return # of shards
119   int GetShardCount() const;
120 
121   /// \brief get the number of rows in database
122   /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
123   /// \param[in] load_dataset load dataset from single file or not
124   /// \param[in] op smart pointer refer to ShardCategory or ShardSample object
125   /// \param[out] count # of rows
126   /// \return MSRStatus the status of MSRStatus
127   Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
128                         const std::shared_ptr<ShardOperator> &op, int64_t *count, const int64_t num_padded);
129 
130   /// \brief shuffle task with incremental seed
131   /// \return void
132   void ShuffleTask();
133 
134   /// \brief get the number of rows in database
135   /// \return # of rows
136   int64_t GetNumRows() const;
137 
138   /// \brief get the number of rows after sampling
139   /// \return # of rows
140   int64_t GetNumRowsAfterSampling() const;
141 
142   /// \brief Read the summary of row groups
143   /// \return the tuple of 4 elements
144   ///         1. Sharding ID
145   ///         2. Row group ID
146   ///         3. The row ID started in row group
147   ///         4. # of rows in row group
148   std::vector<std::tuple<int, int, int, uint64_t>> ReadRowGroupSummary();
149 
150   /// \brief Read 1 row group data, excluding images
151   /// \param[in] groupID row group ID
152   /// \param[in] shard_id sharding ID
153   /// \param[in] columns multi-columns retrieved
154   /// \return the tuple of 5 elements
155   ///         1. file name where row group is located
156   ///         2. Actual row group size
157   ///         3. Offset address of row group in file
158   ///         4. The list of image offset in page [startOffset, endOffset)
159   ///         5. The list of columns data
160   Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
161                            std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);
162 
163   /// \brief Read 1 row group data, excluding images, following an index field criteria
164   /// \param[in] groupID row group ID
165   /// \param[in] shard_id sharding ID
166   /// \param[in] column-value pair of criteria to fulfill
167   /// \param[in] columns multi-columns retrieved
168   /// \return the tuple of 5 elements
169   ///         1. file name where row group is located
170   ///         2. Actual row group size
171   ///         3. Offset address of row group in file
172   ///         4. The list of image offset in page [startOffset, endOffset)
173   ///         5. The list of columns data
174   Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
175                               const std::vector<std::string> &columns,
176                               std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);
177 
178   /// \brief return a batch, given that one is ready
179   /// \return a batch of images and image data
180   std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> GetNext();
181 
182   /// \brief return a row by id
183   /// \return a batch of images and image data
184   Status GetNextById(const int64_t &task_id, const int32_t &consumer_id,
185                      std::shared_ptr<TASK_CONTENT> *task_content_ptr);
186 
187   /// \brief  get blob filed list
188   /// \return blob field list
189   std::pair<ShardType, std::vector<std::string>> GetBlobFields();
190 
191   /// \brief reset reader
192   /// \return null
193   void Reset();
194 
195   /// \brief set flag of all-in-index
196   /// \return null
SetAllInIndex(bool all_in_index)197   void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
198 
199   /// \brief get all classes
200   Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);
201 
202   /// \brief get a read-only ptr to the sampled ids for this epoch
203   const std::vector<int64_t> *GetSampleIds();
204 
205   /// \brief get the size of blob data
206   Status GetTotalBlobSize(int64_t *total_blob_size);
207 
208   /// \brief extract uncompressed data based on column list
209   Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
210                         std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr);
211 
212   /// \brief get load mode
213   LoadMode GetLoadMode() const;
214 
215   /// \brief get next sample ids in slow load mode
216   std::vector<int64_t> GetNextSampleIds();
217 
218  protected:
219   /// \brief sqlite call back function
220   static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
221 
222  private:
223   /// \brief wrap up labels to json format
224   Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
225                             std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id,
226                             const std::vector<std::string> &columns,
227                             std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
228 
229   /// \brief convert json format to expected type
230   Status ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns,
231                           const json &schema, json *value);
232 
233   /// \brief read all rows for specified columns
234   Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr);
235 
236   /// \brief read row meta by shard_id and sample_id
237   Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
238                                           const int32_t &consumer_id, const uint32_t &sample_id,
239                                           std::shared_ptr<ROW_GROUPS> *row_group_ptr);
240 
241   /// \brief read all rows in one shard
242   Status ReadAllRowsInShard(int shard_id, const int32_t &consumer_id, const std::string &sql,
243                             const std::vector<std::string> &columns,
244                             std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
245                             std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
246 
247   /// \brief initialize reader
248   Status Init(const std::vector<std::string> &file_paths, bool load_dataset);
249 
250   /// \brief validate column list
251   Status CheckColumnList(const std::vector<std::string> &selected_columns);
252 
253   /// \brief populate one row by task list in row-reader mode
254   void ConsumerByRow(int consumer_id);
255 
256   /// \brief get offset address of images within page
257   std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
258                                                     const std::pair<std::string, std::string> &criteria = {"", ""});
259 
260   /// \brief get page id by category
261   Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
262                             std::shared_ptr<std::vector<uint64_t>> *pages_ptr);
263   /// \brief execute sqlite query with prepare statement
264   Status QueryWithPageIdBlobAndCriteria(sqlite3 *db, const string &sql, const int &page_id, const string &criteria,
265                                         std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
266   /// \brief verify the validity of dataset
267   Status VerifyDataset(sqlite3 **db, const string &file);
268 
269   /// \brief get column values
270   Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
271                    const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr);
272 
273   /// \brief get column values from raw data page
274   Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
275                            const std::pair<std::string, std::string> &criteria,
276                            std::shared_ptr<std::vector<json>> *labels_ptr);
277 
278   /// \brief create category-applied task list
279   Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
280 
281   /// \brief create task list in row-reader mode
282   Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
283                           const std::vector<std::shared_ptr<ShardOperator>> &operators);
284 
285   /// \brief create task list in row-reader mode and lazy mode
286   Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
287                               const std::vector<std::shared_ptr<ShardOperator>> &operators);
288 
289   /// \brief create task in slow load mode
290   Status CreateSlowTasksByRow();
291 
292   /// \brief crate task list
293   Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
294                      const std::vector<std::shared_ptr<ShardOperator>> &operators);
295 
296   /// \brief check if all specified columns are in index table
297   void CheckIfColumnInIndex(const std::vector<std::string> &columns);
298 
299   /// \brief open multiple file handle
300   void FileStreamsOperator();
301 
302   /// \brief read one row by one task
303   Status ConsumerOneTask(int64_t task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt);
304 
305   /// \brief get labels from binary file
306   Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
307                                  const std::vector<std::vector<std::string>> &label_offsets,
308                                  std::shared_ptr<std::vector<json>> *labels_ptr);
309 
310   /// \brief get classes in one shard
311   void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
312                          std::shared_ptr<std::set<std::string>> category_ptr);
313 
314   /// \brief get number of classes
315   int64_t GetNumClasses(const std::string &category_field);
316 
317   /// \brief get meta of header
318   Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
319                  std::shared_ptr<std::vector<std::string>> *addresses_ptr);
320 
321  protected:
322   uint64_t header_size_;                       // header size
323   uint64_t page_size_;                         // page size
324   int shard_count_;                            // number of shards
325   std::shared_ptr<ShardHeader> shard_header_;  // shard header
326   std::shared_ptr<ShardColumn> shard_column_;  // shard column
327 
328   std::vector<sqlite3 *> database_paths_;                                        // sqlite handle list
329   std::vector<string> file_paths_;                                               // file paths
330   std::vector<std::shared_ptr<std::fstream>> file_streams_;                      // single-file handle list
331   std::vector<std::vector<std::shared_ptr<std::fstream>>> file_streams_random_;  // multiple-file handle list
332 
333  private:
334   int n_consumer_;                                         // number of workers (threads)
335   std::vector<std::string> selected_columns_;              // columns which will be read
336   std::map<string, uint64_t> column_schema_id_;            // column-schema map
337   std::vector<std::shared_ptr<ShardOperator>> operators_;  // data operators, including shuffle, sample and category
338   ShardTaskList tasks_;                                    // shard task list
339   std::mutex shard_locker_;                                // locker of shard
340 
341   // flags
342   bool all_in_index_ = true;  // if all columns are stored in index-table
343   bool interrupt_ = false;    // reader interrupted
344 
345   int64_t num_padded_;  // number of padding samples
346 
347   // Delivery/Iterator mode begin
348   std::vector<std::thread> thread_set_;  // thread list
349   int64_t num_rows_;                     // number of rows
350   int64_t total_blob_size_;              // total size of blob data
351   std::mutex mtx_delivery_;              // locker for delivery
352   std::condition_variable cv_delivery_;  // conditional variable for delivery
353   std::condition_variable cv_iterator_;  // conditional variable for iterator
354   std::atomic<int> sample_id_position_;  // index into the sample ids vector for the current sample id
355   std::atomic<int> deliver_id_;          // delivery ID which is picked up by iterator
356   // map of delivery
357   std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;
358   // Delivery/Iterator mode end
359 
360   // all metadata in the index is not loaded during initialization
361   LoadMode load_mode_;
362 
363   // indicate shard_id : inc_count
364   // 0 : 15  -  shard0 has 15 samples
365   // 1 : 41  -  shard1 has 26 samples
366   // 2 : 58  -  shard2 has 17 samples
367   std::vector<int64_t> shard_sample_count_;
368 };
369 }  // namespace mindrecord
370 }  // namespace mindspore
371 
372 #endif  // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_
373