1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ 19 20 #include <dirent.h> 21 #include <signal.h> 22 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) 23 #include <sys/prctl.h> 24 #endif 25 #include <sys/stat.h> 26 #include <sys/types.h> 27 #include <unistd.h> 28 #include <algorithm> 29 #include <chrono> 30 #include <cstdint> 31 #include <fstream> 32 #include <iostream> 33 #include <map> 34 #include <memory> 35 #include <mutex> 36 #include <set> 37 #include <stack> 38 #include <string> 39 #include <thread> 40 #include <tuple> 41 #include <unordered_map> 42 #include <unordered_set> 43 #include <utility> 44 #include <vector> 45 #include "minddata/mindrecord/include/common/shard_utils.h" 46 #include "minddata/mindrecord/include/shard_category.h" 47 #include "minddata/mindrecord/include/shard_column.h" 48 #include "minddata/mindrecord/include/shard_distributed_sample.h" 49 #include "minddata/mindrecord/include/shard_error.h" 50 #include "minddata/mindrecord/include/shard_index_generator.h" 51 #include "minddata/mindrecord/include/shard_operator.h" 52 #include "minddata/mindrecord/include/shard_pk_sample.h" 53 #include "minddata/mindrecord/include/shard_reader.h" 54 #include "minddata/mindrecord/include/shard_sample.h" 55 #include "minddata/mindrecord/include/shard_shuffle.h" 56 #include "utils/log_adapter.h" 57 58 #define API_PUBLIC __attribute__((visibility("default"))) 59 60 namespace mindspore { 61 namespace mindrecord { 62 using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>; 63 using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>; 64 using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>; 65 const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode 66 67 class API_PUBLIC ShardReader { 68 public: 69 ShardReader(); 70 71 virtual ~ShardReader(); 72 73 /// \brief open files and initialize reader, c++ API 74 /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list 75 /// \param[in] load_dataset load dataset from single file or not 76 /// \param[in] n_consumer number of threads when reading 77 /// \param[in] selected_columns column list to be populated 78 /// \param[in] operators operators applied to data, operator type is shuffle, sample or category 79 /// \param[in] num_padded the number of padded samples 80 /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization 81 /// \return MSRStatus the status of MSRStatus 82 Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, 83 const std::vector<std::string> &selected_columns = {}, 84 const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0, 85 bool lazy_load = false); 86 87 /// \brief close reader 88 /// \return null 89 void Close(); 90 91 /// \brief read the file, get schema meta,statistics and index, single-thread mode 92 /// \return MSRStatus the status of MSRStatus 93 Status Open(); 94 95 /// \brief read the file, get schema meta,statistics and index, multiple-thread mode 96 /// \return MSRStatus the status of MSRStatus 97 Status Open(int n_consumer); 98 99 /// \brief launch threads to get batches 100 /// \param[in] is_simple_reader trigger threads if false; do nothing if true 101 /// \return MSRStatus the status of MSRStatus 102 Status Launch(bool is_simple_reader = false); 103 104 /// \brief aim to get the meta data 105 /// \return the metadata 106 std::shared_ptr<ShardHeader> GetShardHeader() const; 107 108 /// \brief aim to get columns context 109 /// \return the columns 110 std::shared_ptr<ShardColumn> GetShardColumn() const; 111 112 /// \brief get the number of shards 113 /// \return # of shards 114 int GetShardCount() const; 115 116 /// \brief get the number of rows in database 117 /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list 118 /// \param[in] load_dataset load dataset from single file or not 119 /// \param[in] op smart pointer refer to ShardCategory or ShardSample object 120 /// \param[out] count # of rows 121 /// \return MSRStatus the status of MSRStatus 122 Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, 123 const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded); 124 125 /// \brief shuffle task with incremental seed 126 /// \return void 127 void ShuffleTask(); 128 129 /// \brief get the number of rows in database 130 /// \return # of rows 131 int GetNumRows() const; 132 133 /// \brief Read the summary of row groups 134 /// \return the tuple of 4 elements 135 /// 1. Sharding ID 136 /// 2. Row group ID 137 /// 3. The row ID started in row group 138 /// 4. # of rows in row group 139 std::vector<std::tuple<int, int, int, uint64_t>> ReadRowGroupSummary(); 140 141 /// \brief Read 1 row group data, excluding images 142 /// \param[in] groupID row group ID 143 /// \param[in] shard_id sharding ID 144 /// \param[in] columns multi-columns retrieved 145 /// \return the tuple of 5 elements 146 /// 1. file name where row group is located 147 /// 2. Actual row group size 148 /// 3. Offset address of row group in file 149 /// 4. The list of image offset in page [startOffset, endOffset) 150 /// 5. The list of columns data 151 Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns, 152 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); 153 154 /// \brief Read 1 row group data, excluding images, following an index field criteria 155 /// \param[in] groupID row group ID 156 /// \param[in] shard_id sharding ID 157 /// \param[in] column-value pair of criteria to fulfill 158 /// \param[in] columns multi-columns retrieved 159 /// \return the tuple of 5 elements 160 /// 1. file name where row group is located 161 /// 2. Actual row group size 162 /// 3. Offset address of row group in file 163 /// 4. The list of image offset in page [startOffset, endOffset) 164 /// 5. The list of columns data 165 Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, 166 const std::vector<std::string> &columns, 167 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); 168 169 /// \brief return a batch, given that one is ready 170 /// \return a batch of images and image data 171 std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext(); 172 173 /// \brief return a row by id 174 /// \return a batch of images and image data 175 TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id); 176 /// \brief get blob filed list 177 /// \return blob field list 178 std::pair<ShardType, std::vector<std::string>> GetBlobFields(); 179 180 /// \brief reset reader 181 /// \return null 182 void Reset(); 183 184 /// \brief set flag of all-in-index 185 /// \return null SetAllInIndex(bool all_in_index)186 void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } 187 188 /// \brief get all classes 189 Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr); 190 191 /// \brief get a read-only ptr to the sampled ids for this epoch 192 const std::vector<int> *GetSampleIds(); 193 194 /// \brief get the size of blob data 195 Status GetTotalBlobSize(int64_t *total_blob_size); 196 197 /// \brief extract uncompressed data based on column list 198 Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data, 199 std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr); 200 201 protected: 202 /// \brief sqlite call back function 203 static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); 204 205 private: 206 /// \brief wrap up labels to json format 207 Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs, 208 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id, 209 const std::vector<std::string> &columns, 210 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); 211 212 /// \brief read all rows for specified columns 213 Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr); 214 215 /// \brief read row meta by shard_id and sample_id 216 Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id, 217 const uint32_t &sample_id, std::shared_ptr<ROW_GROUPS> *row_group_ptr); 218 219 /// \brief read all rows in one shard 220 Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns, 221 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, 222 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); 223 224 /// \brief initialize reader 225 Status Init(const std::vector<std::string> &file_paths, bool load_dataset); 226 227 /// \brief validate column list 228 Status CheckColumnList(const std::vector<std::string> &selected_columns); 229 230 /// \brief populate one row by task list in row-reader mode 231 void ConsumerByRow(int consumer_id); 232 233 /// \brief get offset address of images within page 234 std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id, 235 const std::pair<std::string, std::string> &criteria = {"", ""}); 236 237 /// \brief get page id by category 238 Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria, 239 std::shared_ptr<std::vector<uint64_t>> *pages_ptr); 240 /// \brief execute sqlite query with prepare statement 241 Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, 242 std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr); 243 /// \brief verify the validity of dataset 244 Status VerifyDataset(sqlite3 **db, const string &file); 245 246 /// \brief get column values 247 Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns, 248 const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr); 249 250 /// \brief get column values from raw data page 251 Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns, 252 const std::pair<std::string, std::string> &criteria, 253 std::shared_ptr<std::vector<json>> *labels_ptr); 254 255 /// \brief create category-applied task list 256 Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op); 257 258 /// \brief create task list in row-reader mode 259 Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, 260 const std::vector<std::shared_ptr<ShardOperator>> &operators); 261 262 /// \brief create task list in row-reader mode and lazy mode 263 Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, 264 const std::vector<std::shared_ptr<ShardOperator>> &operators); 265 266 /// \brief crate task list 267 Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, 268 const std::vector<std::shared_ptr<ShardOperator>> &operators); 269 270 /// \brief check if all specified columns are in index table 271 void CheckIfColumnInIndex(const std::vector<std::string> &columns); 272 273 /// \brief open multiple file handle 274 void FileStreamsOperator(); 275 276 /// \brief read one row by one task 277 Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt); 278 279 /// \brief get labels from binary file 280 Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns, 281 const std::vector<std::vector<std::string>> &label_offsets, 282 std::shared_ptr<std::vector<json>> *labels_ptr); 283 284 /// \brief get classes in one shard 285 void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, 286 std::shared_ptr<std::set<std::string>> category_ptr); 287 288 /// \brief get number of classes 289 int64_t GetNumClasses(const std::string &category_field); 290 291 /// \brief get meta of header 292 Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr, 293 std::shared_ptr<std::vector<std::string>> *addresses_ptr); 294 295 protected: 296 uint64_t header_size_; // header size 297 uint64_t page_size_; // page size 298 int shard_count_; // number of shards 299 std::shared_ptr<ShardHeader> shard_header_; // shard header 300 std::shared_ptr<ShardColumn> shard_column_; // shard column 301 302 std::vector<sqlite3 *> database_paths_; // sqlite handle list 303 std::vector<string> file_paths_; // file paths 304 std::vector<std::shared_ptr<std::fstream>> file_streams_; // single-file handle list 305 std::vector<std::vector<std::shared_ptr<std::fstream>>> file_streams_random_; // multiple-file handle list 306 307 private: 308 int n_consumer_; // number of workers (threads) 309 std::vector<std::string> selected_columns_; // columns which will be read 310 std::map<string, uint64_t> column_schema_id_; // column-schema map 311 std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category 312 ShardTaskList tasks_; // shard task list 313 std::mutex shard_locker_; // locker of shard 314 315 // flags 316 bool all_in_index_ = true; // if all columns are stored in index-table 317 bool interrupt_ = false; // reader interrupted 318 319 int num_padded_; // number of padding samples 320 321 // Delivery/Iterator mode begin 322 const std::string kThreadName = "THRD_ITER_"; // prefix of thread name 323 std::vector<std::thread> thread_set_; // thread list 324 int num_rows_; // number of rows 325 int64_t total_blob_size_; // total size of blob data 326 std::mutex mtx_delivery_; // locker for delivery 327 std::condition_variable cv_delivery_; // conditional variable for delivery 328 std::condition_variable cv_iterator_; // conditional variable for iterator 329 std::atomic<int> sample_id_position_; // index into the sample ids vector for the current sample id 330 std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator 331 // map of delivery 332 std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_; 333 // Delivery/Iterator mode end 334 335 // all metadata in the index is not loaded during initialization 336 bool lazy_load_; 337 338 // indicate shard_id : inc_count 339 // 0 : 15 - shard0 has 15 samples 340 // 1 : 41 - shard1 has 26 samples 341 // 2 : 58 - shard2 has 17 samples 342 std::vector<uint32_t> shard_sample_count_; 343 }; 344 } // namespace mindrecord 345 } // namespace mindspore 346 347 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ 348