1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ 19 20 #include <signal.h> 21 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) 22 #include <sys/prctl.h> 23 #endif 24 #include <sys/stat.h> 25 #include <sys/types.h> 26 #include <algorithm> 27 #include <chrono> 28 #include <cstdint> 29 #include <fstream> 30 #include <iostream> 31 #include <map> 32 #include <memory> 33 #include <mutex> 34 #include <set> 35 #include <stack> 36 #include <string> 37 #include <thread> 38 #include <tuple> 39 #include <unordered_map> 40 #include <unordered_set> 41 #include <utility> 42 #include <vector> 43 #include "minddata/mindrecord/include/common/log_adapter.h" 44 #include "minddata/mindrecord/include/common/shard_utils.h" 45 #include "minddata/mindrecord/include/shard_category.h" 46 #include "minddata/mindrecord/include/shard_column.h" 47 #include "minddata/mindrecord/include/shard_distributed_sample.h" 48 #include "minddata/mindrecord/include/shard_error.h" 49 #include "minddata/mindrecord/include/shard_index_generator.h" 50 #include "minddata/mindrecord/include/shard_operator.h" 51 #include "minddata/mindrecord/include/shard_pk_sample.h" 52 #include "minddata/mindrecord/include/shard_reader.h" 53 #include "minddata/mindrecord/include/shard_sample.h" 54 #include "minddata/mindrecord/include/shard_shuffle.h" 55 56 namespace mindspore { 57 namespace mindrecord { 58 using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>; 59 using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>; 60 using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>; 61 const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode 62 63 class MINDRECORD_API ShardReader { 64 public: 65 ShardReader(); 66 67 virtual ~ShardReader(); 68 69 /// \brief open files and initialize reader, c++ API 70 /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list 71 /// \param[in] load_dataset load dataset from single file or not 72 /// \param[in] n_consumer number of threads when reading 73 /// \param[in] selected_columns column list to be populated 74 /// \param[in] operators operators applied to data, operator type is shuffle, sample or category 75 /// \param[in] num_padded the number of padded samples 76 /// \param[in] load_mode 77 /// LoadMode::kNormal: cache whole meta data for dataset 78 /// LoadMode::kLazy: cache part meta data for dataset 79 /// LoadMode::kLowest: don't cache meta data 80 /// \return MSRStatus the status of MSRStatus 81 Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, 82 const std::vector<std::string> &selected_columns = {}, 83 const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int64_t num_padded = 0, 84 LoadMode load_mode = LoadMode::kFast); 85 86 /// \brief close reader 87 /// \return null 88 void Close(); 89 90 /// \brief read the file, get schema meta,statistics and index, multiple-thread mode 91 /// \return MSRStatus the status of MSRStatus 92 Status Open(int n_consumer); 93 94 /// \brief increase number of random file stream for parallel read 95 /// \param[in] n_new_consumers number of new file streams to be added 96 /// \return MSRStatus the status of MSRStatus 97 Status ExtendRandomFileStreams(const int n_new_consumers); 98 99 /// \brief decrease number of random file streams for parallel read 100 /// \param[in] n_remove_consumers number of file streams to be removed 101 /// \return MSRStatus the status of MSRStatus 102 Status ShrinkRandomFileStreams(const int n_remove_consumers); 103 104 /// \brief launch threads to get batches 105 /// \param[in] is_simple_reader trigger threads if false; do nothing if true 106 /// \return MSRStatus the status of MSRStatus 107 Status Launch(bool is_simple_reader = false); 108 109 /// \brief aim to get the meta data 110 /// \return the metadata 111 std::shared_ptr<ShardHeader> GetShardHeader() const; 112 113 /// \brief aim to get columns context 114 /// \return the columns 115 std::shared_ptr<ShardColumn> GetShardColumn() const; 116 117 /// \brief get the number of shards 118 /// \return # of shards 119 int GetShardCount() const; 120 121 /// \brief get the number of rows in database 122 /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list 123 /// \param[in] load_dataset load dataset from single file or not 124 /// \param[in] op smart pointer refer to ShardCategory or ShardSample object 125 /// \param[out] count # of rows 126 /// \return MSRStatus the status of MSRStatus 127 Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, 128 const std::shared_ptr<ShardOperator> &op, int64_t *count, const int64_t num_padded); 129 130 /// \brief shuffle task with incremental seed 131 /// \return void 132 void ShuffleTask(); 133 134 /// \brief get the number of rows in database 135 /// \return # of rows 136 int64_t GetNumRows() const; 137 138 /// \brief get the number of rows after sampling 139 /// \return # of rows 140 int64_t GetNumRowsAfterSampling() const; 141 142 /// \brief Read the summary of row groups 143 /// \return the tuple of 4 elements 144 /// 1. Sharding ID 145 /// 2. Row group ID 146 /// 3. The row ID started in row group 147 /// 4. # of rows in row group 148 std::vector<std::tuple<int, int, int, uint64_t>> ReadRowGroupSummary(); 149 150 /// \brief Read 1 row group data, excluding images 151 /// \param[in] groupID row group ID 152 /// \param[in] shard_id sharding ID 153 /// \param[in] columns multi-columns retrieved 154 /// \return the tuple of 5 elements 155 /// 1. file name where row group is located 156 /// 2. Actual row group size 157 /// 3. Offset address of row group in file 158 /// 4. The list of image offset in page [startOffset, endOffset) 159 /// 5. The list of columns data 160 Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns, 161 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); 162 163 /// \brief Read 1 row group data, excluding images, following an index field criteria 164 /// \param[in] groupID row group ID 165 /// \param[in] shard_id sharding ID 166 /// \param[in] column-value pair of criteria to fulfill 167 /// \param[in] columns multi-columns retrieved 168 /// \return the tuple of 5 elements 169 /// 1. file name where row group is located 170 /// 2. Actual row group size 171 /// 3. Offset address of row group in file 172 /// 4. The list of image offset in page [startOffset, endOffset) 173 /// 5. The list of columns data 174 Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, 175 const std::vector<std::string> &columns, 176 std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); 177 178 /// \brief return a batch, given that one is ready 179 /// \return a batch of images and image data 180 std::vector<std::tuple<std::map<std::string, std::vector<uint8_t>>, json>> GetNext(); 181 182 /// \brief return a row by id 183 /// \return a batch of images and image data 184 Status GetNextById(const int64_t &task_id, const int32_t &consumer_id, 185 std::shared_ptr<TASK_CONTENT> *task_content_ptr); 186 187 /// \brief get blob filed list 188 /// \return blob field list 189 std::pair<ShardType, std::vector<std::string>> GetBlobFields(); 190 191 /// \brief reset reader 192 /// \return null 193 void Reset(); 194 195 /// \brief set flag of all-in-index 196 /// \return null SetAllInIndex(bool all_in_index)197 void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } 198 199 /// \brief get all classes 200 Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr); 201 202 /// \brief get a read-only ptr to the sampled ids for this epoch 203 const std::vector<int64_t> *GetSampleIds(); 204 205 /// \brief get the size of blob data 206 Status GetTotalBlobSize(int64_t *total_blob_size); 207 208 /// \brief extract uncompressed data based on column list 209 Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data, 210 std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr); 211 212 /// \brief get load mode 213 LoadMode GetLoadMode() const; 214 215 /// \brief get next sample ids in slow load mode 216 std::vector<int64_t> GetNextSampleIds(); 217 218 protected: 219 /// \brief sqlite call back function 220 static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); 221 222 private: 223 /// \brief wrap up labels to json format 224 Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs, 225 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id, 226 const std::vector<std::string> &columns, 227 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); 228 229 /// \brief convert json format to expected type 230 Status ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns, 231 const json &schema, json *value); 232 233 /// \brief read all rows for specified columns 234 Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr); 235 236 /// \brief read row meta by shard_id and sample_id 237 Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id, 238 const int32_t &consumer_id, const uint32_t &sample_id, 239 std::shared_ptr<ROW_GROUPS> *row_group_ptr); 240 241 /// \brief read all rows in one shard 242 Status ReadAllRowsInShard(int shard_id, const int32_t &consumer_id, const std::string &sql, 243 const std::vector<std::string> &columns, 244 std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, 245 std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); 246 247 /// \brief initialize reader 248 Status Init(const std::vector<std::string> &file_paths, bool load_dataset); 249 250 /// \brief validate column list 251 Status CheckColumnList(const std::vector<std::string> &selected_columns); 252 253 /// \brief populate one row by task list in row-reader mode 254 void ConsumerByRow(int consumer_id); 255 256 /// \brief get offset address of images within page 257 std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id, 258 const std::pair<std::string, std::string> &criteria = {"", ""}); 259 260 /// \brief get page id by category 261 Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria, 262 std::shared_ptr<std::vector<uint64_t>> *pages_ptr); 263 /// \brief execute sqlite query with prepare statement 264 Status QueryWithPageIdBlobAndCriteria(sqlite3 *db, const string &sql, const int &page_id, const string &criteria, 265 std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr); 266 /// \brief verify the validity of dataset 267 Status VerifyDataset(sqlite3 **db, const string &file); 268 269 /// \brief get column values 270 Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns, 271 const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr); 272 273 /// \brief get column values from raw data page 274 Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns, 275 const std::pair<std::string, std::string> &criteria, 276 std::shared_ptr<std::vector<json>> *labels_ptr); 277 278 /// \brief create category-applied task list 279 Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op); 280 281 /// \brief create task list in row-reader mode 282 Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, 283 const std::vector<std::shared_ptr<ShardOperator>> &operators); 284 285 /// \brief create task list in row-reader mode and lazy mode 286 Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, 287 const std::vector<std::shared_ptr<ShardOperator>> &operators); 288 289 /// \brief create task in slow load mode 290 Status CreateSlowTasksByRow(); 291 292 /// \brief crate task list 293 Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, 294 const std::vector<std::shared_ptr<ShardOperator>> &operators); 295 296 /// \brief check if all specified columns are in index table 297 void CheckIfColumnInIndex(const std::vector<std::string> &columns); 298 299 /// \brief open multiple file handle 300 void FileStreamsOperator(); 301 302 /// \brief read one row by one task 303 Status ConsumerOneTask(int64_t task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt); 304 305 /// \brief get labels from binary file 306 Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns, 307 const std::vector<std::vector<std::string>> &label_offsets, 308 std::shared_ptr<std::vector<json>> *labels_ptr); 309 310 /// \brief get classes in one shard 311 void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, 312 std::shared_ptr<std::set<std::string>> category_ptr); 313 314 /// \brief get number of classes 315 int64_t GetNumClasses(const std::string &category_field); 316 317 /// \brief get meta of header 318 Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr, 319 std::shared_ptr<std::vector<std::string>> *addresses_ptr); 320 321 protected: 322 uint64_t header_size_; // header size 323 uint64_t page_size_; // page size 324 int shard_count_; // number of shards 325 std::shared_ptr<ShardHeader> shard_header_; // shard header 326 std::shared_ptr<ShardColumn> shard_column_; // shard column 327 328 std::vector<sqlite3 *> database_paths_; // sqlite handle list 329 std::vector<string> file_paths_; // file paths 330 std::vector<std::shared_ptr<std::fstream>> file_streams_; // single-file handle list 331 std::vector<std::vector<std::shared_ptr<std::fstream>>> file_streams_random_; // multiple-file handle list 332 333 private: 334 int n_consumer_; // number of workers (threads) 335 std::vector<std::string> selected_columns_; // columns which will be read 336 std::map<string, uint64_t> column_schema_id_; // column-schema map 337 std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category 338 ShardTaskList tasks_; // shard task list 339 std::mutex shard_locker_; // locker of shard 340 341 // flags 342 bool all_in_index_ = true; // if all columns are stored in index-table 343 bool interrupt_ = false; // reader interrupted 344 345 int64_t num_padded_; // number of padding samples 346 347 // Delivery/Iterator mode begin 348 std::vector<std::thread> thread_set_; // thread list 349 int64_t num_rows_; // number of rows 350 int64_t total_blob_size_; // total size of blob data 351 std::mutex mtx_delivery_; // locker for delivery 352 std::condition_variable cv_delivery_; // conditional variable for delivery 353 std::condition_variable cv_iterator_; // conditional variable for iterator 354 std::atomic<int> sample_id_position_; // index into the sample ids vector for the current sample id 355 std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator 356 // map of delivery 357 std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_; 358 // Delivery/Iterator mode end 359 360 // all metadata in the index is not loaded during initialization 361 LoadMode load_mode_; 362 363 // indicate shard_id : inc_count 364 // 0 : 15 - shard0 has 15 samples 365 // 1 : 41 - shard1 has 26 samples 366 // 2 : 58 - shard2 has 17 samples 367 std::vector<int64_t> shard_sample_count_; 368 }; 369 } // namespace mindrecord 370 } // namespace mindspore 371 372 #endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_READER_H_ 373