1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ 18 #pragma once 19 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <queue> 24 #include <string> 25 #include <tuple> 26 #include <unordered_map> 27 #include <unordered_set> 28 #include <utility> 29 #include <vector> 30 31 #include "minddata/dataset/engine/data_schema.h" 32 #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" 33 #include "minddata/dataset/util/queue.h" 34 #include "minddata/dataset/util/status.h" 35 #include "minddata/mindrecord/include/shard_column.h" 36 #include "minddata/mindrecord/include/shard_error.h" 37 #include "minddata/mindrecord/include/shard_reader.h" 38 #include "minddata/mindrecord/include/common/shard_utils.h" 39 #include "minddata/dataset/util/wait_post.h" 40 41 namespace mindspore { 42 namespace dataset { 43 // Forward declares 44 template <typename T> 45 class Queue; 46 47 using mindrecord::ShardOperator; 48 using mindrecord::ShardReader; 49 using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindrecord::json>>; /// Row of data from ShardReader 50 51 const int32_t LOG_INTERVAL = 19; 52 53 class MindRecordOp : public MappableLeafOp { 54 public: 55 // Constructor of the MindRecordOp. 56 // @note The builder class should be used to call it 57 // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) 58 // @param dataset_file - dataset files 59 // @param op_connector_queue_size - The output connector queue size 60 // @param columns_to_load - The list of columns to use (column name) 61 // @param operators - ShardOperators for Shuffle, Category, Sample 62 // @param sampler - sampler tells MindRecordOp what to read 63 MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset, 64 int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, 65 const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_, 66 const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_, 67 const ShuffleMode shuffle_mode_, std::unique_ptr<ShardReader> shard_reader, 68 std::shared_ptr<SamplerRT> sampler); 69 70 /// Destructor 71 ~MindRecordOp() override; 72 73 /// A print method typically used for debugging 74 /// @param out - The output stream to write output to 75 /// @param show_all - A bool to control if you want to show all info or just a summary 76 void Print(std::ostream &out, bool show_all) const override; 77 78 /// << Stream output operator overload 79 /// @notes This allows you to write the debug print info using stream operators 80 /// @param out - reference to the output stream being overloaded 81 /// @param op - reference to the MindRecordOp to display 82 /// @return - the output stream must be returned 83 friend std::ostream &operator<<(std::ostream &out, const MindRecordOp &op) { 84 op.Print(out, false); 85 return out; 86 } 87 88 // Worker thread pulls a number of IOBlock from IOBlock Queue, make a TensorRow and push it to Connector 89 // @param int32_t workerId - id of each worker 90 // @return Status The status code returned 91 Status WorkerEntry(int32_t worker_id) override; 92 93 // Called first when function is called 94 // @return 95 Status RegisterAndLaunchThreads() override; 96 97 /// Overrides base class reset method. When an operator does a reset, it cleans up any state 98 /// info from it's previous execution and then initializes itself so that it can be executed 99 /// again. 100 /// @return Status The status code returned 101 Status Reset() override; 102 103 static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, 104 const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded); 105 106 // Getter method dataset_file()107 std::vector<std::string> dataset_file() const { return dataset_file_; } 108 109 /// Getter method columns_to_load()110 std::vector<std::string> columns_to_load() const { return columns_to_load_; } 111 load_dataset()112 bool load_dataset() const { return load_dataset_; } 113 114 Status Init(); 115 116 /// Op name getter 117 /// @return Name of the current Op Name()118 std::string Name() const override { return "MindRecordOp"; } 119 120 private: 121 Status GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, int32_t worker_id); 122 123 /// Parses a single cell and puts the data into a tensor 124 /// @param tensor_row - the tensor row to put the parsed data in 125 /// @param columns_blob - the blob data received from the reader 126 /// @param columns_json - the data for fields received from the reader 127 Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob, 128 const mindrecord::json &columns_json, const mindrecord::TaskType task_type); 129 LoadTensorRow(row_id_type row_id,TensorRow * row)130 Status LoadTensorRow(row_id_type row_id, TensorRow *row) override { 131 return Status(StatusCode::kMDSyntaxError, "[Internal ERROR] Cannot call this method."); 132 } 133 // Private function for computing the assignment of the column name map. 134 // @return - Status 135 Status ComputeColMap() override; 136 137 protected: 138 Status PrepareData() override; 139 140 /// Add a new worker to the MindRecordOp. The function will have to wait for all workers to process current rows. 141 /// It will then update the shard reader. Finally, it adds a new thread to the list. 142 /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to 143 /// push rows to workers_in_queue 144 /// \return Status The status code returned 145 Status AddNewWorkers(int32_t num_new_workers = 1) override; 146 147 /// Remove a worker from MindRecordOp. The function will have to wait for all workers to process current rows. 148 /// It will then update the shard reader. Finally, it removes a thread from the list. 149 /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to 150 /// push rows to workers_in_queue 151 /// \return Status The status code returned 152 Status RemoveWorkers(int32_t num_workers = 1) override; 153 154 /// Initialize pull mode, calls PrepareData() within 155 /// @return Status The status code returned 156 Status InitPullMode() override; 157 158 /// Load a tensor row at location row_id for pull mode 159 /// \param row_id_type row_id - id for this tensor row 160 /// \param TensorRow row - loaded row 161 /// \return Status The status code returned 162 Status LoadTensorRowPullMode(row_id_type row_id, TensorRow *row) override; 163 164 private: 165 std::vector<std::string> dataset_file_; // dataset files 166 bool load_dataset_; // load dataset from single file or not 167 std::vector<std::string> columns_to_load_; // Columns to load from dataset 168 std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use 169 int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader 170 std::atomic<int32_t> ended_worker_; 171 172 int64_t num_padded_; 173 mindrecord::json sample_json_; 174 std::map<std::string, std::string> sample_bytes_; 175 176 std::unique_ptr<DataSchema> data_schema_; // Data schema for column typing 177 178 std::unique_ptr<ShardReader> shard_reader_; 179 180 std::mutex ended_worker_mutex_; 181 182 ShuffleMode shuffle_mode_; 183 }; 184 } // namespace dataset 185 } // namespace mindspore 186 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ 187