1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ 18 19 #include <algorithm> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <set> 24 #include <string> 25 #include <unordered_map> 26 #include <utility> 27 #include <vector> 28 29 #include "minddata/dataset/core/config_manager.h" 30 #include "minddata/dataset/core/tensor.h" 31 #include "minddata/dataset/engine/dataset_iterator.h" 32 #include "minddata/dataset/engine/datasetops/parallel_op.h" 33 #include "minddata/dataset/util/status.h" 34 35 namespace mindspore { 36 namespace dataset { 37 38 using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>; 39 40 class BatchOp : public ParallelOp { 41 public: 42 class Builder { 43 public: 44 // Builder constructor for Batch, batch size needs to be specified 45 // @param int32_t batch_size 46 explicit Builder(int32_t batch_size); 47 48 // Default destructor 49 ~Builder() = default; 50 51 // set number of parallel Workers on batch 52 // @param int32_t num_workers 53 // @return Builder & reference to builder class object SetNumWorkers(int32_t num_workers)54 Builder &SetNumWorkers(int32_t num_workers) { 55 builder_num_workers_ = num_workers; 56 return *this; 57 } 58 59 // set drop for batch op,default false 60 // @param bool drop 61 // @return Builder & reference to builder class object SetDrop(bool drop)62 Builder &SetDrop(bool drop) { 63 builder_drop_ = drop; 64 return *this; 65 } 66 67 Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) { 68 builder_pad_ = pad; 69 builder_pad_map_ = pad_map; 70 return *this; 71 } 72 73 // set connector size for batch 74 // @param int32_t op_conn_size 75 // @return Builder & reference to builder class object SetOpConnectorSize(int32_t op_connector_size)76 Builder &SetOpConnectorSize(int32_t op_connector_size) { 77 builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size); 78 return *this; 79 } 80 81 /// \param in_col_name 82 /// \return Builder & reference to builder class object SetInColNames(const std::vector<std::string> & in_col_name)83 Builder &SetInColNames(const std::vector<std::string> &in_col_name) { 84 builder_in_names_ = in_col_name; 85 return *this; 86 } 87 88 /// \param out_col_name 89 /// \return Builder & reference to builder class object SetOutColNames(const std::vector<std::string> & out_col_name)90 Builder &SetOutColNames(const std::vector<std::string> &out_col_name) { 91 builder_out_names_ = out_col_name; 92 return *this; 93 } 94 95 #ifdef ENABLE_PYTHON 96 // set columns to perform map on 97 // @param const std::vector<std::string> & cols_to_map - name of columns to perform map on 98 // @return Builder & reference to builder class object SetBatchMapFunc(py::function batch_map_func)99 Builder &SetBatchMapFunc(py::function batch_map_func) { 100 builder_batch_map_func_ = batch_map_func; 101 return *this; 102 } 103 104 // SetBatchSizeFunc, a function that calls to python after every batch is made 105 // @param py::function batch_size_func - python function to call, GIL required before calling 106 // @return Builder & reference to builder class object SetBatchSizeFunc(py::function batch_size_func)107 Builder &SetBatchSizeFunc(py::function batch_size_func) { 108 builder_batch_size_func_ = batch_size_func; 109 return *this; 110 } 111 #endif 112 113 // @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg 114 // @return Status The status code returned 115 Status Build(std::shared_ptr<BatchOp> *); 116 117 private: 118 bool builder_drop_; 119 bool builder_pad_; 120 int32_t builder_batch_size_; 121 int32_t builder_num_workers_; 122 int32_t builder_op_connector_size_; 123 std::vector<std::string> builder_in_names_; 124 std::vector<std::string> builder_out_names_; 125 PadInfo builder_pad_map_; 126 #ifdef ENABLE_PYTHON 127 py::function builder_batch_size_func_; 128 py::function builder_batch_map_func_; 129 #endif 130 }; 131 132 enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 }; 133 134 // Parameters associate with one batch. 135 // This struct is used for both internal control and python callback. 136 // This struct is bound to python with read-only access. 137 struct CBatchInfo { CBatchInfoCBatchInfo138 CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl) 139 : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {} CBatchInfoCBatchInfo140 CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {} CBatchInfoCBatchInfo141 CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {} CBatchInfoCBatchInfo142 explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {} 143 int64_t epoch_num_; // i-th epoch. i starts from 0 144 int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0 145 int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0 146 batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3 get_batch_numCBatchInfo147 const int64_t get_batch_num() const { return batch_num_; } get_epoch_numCBatchInfo148 const int64_t get_epoch_num() const { return epoch_num_; } 149 }; 150 151 #ifdef ENABLE_PYTHON 152 153 BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, 154 const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, 155 py::function batch_size_func, py::function batch_map_func, PadInfo pad_map); 156 #else 157 BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, 158 const std::vector<std::string> &, PadInfo pad_map); 159 #endif 160 161 // BatchOp destructor ~BatchOp()162 ~BatchOp() {} 163 164 // @param int32_t workerId 165 // @return Status The status code returned 166 Status EofReceived(int32_t) override; 167 168 // @param int32_t workerId 169 // @return Status The status code returned 170 Status EoeReceived(int32_t) override; 171 172 // A print method typically used for debugging 173 // @param out - The output stream to write output to 174 // @param show_all - A bool to control if you want to show all info or just a summary 175 void Print(std::ostream &out, bool show_all) const override; 176 177 // << Stream output operator overload 178 // @notes This allows you to write the debug print info using stream operators 179 // @param out - reference to the output stream being overloaded 180 // @param sO - reference to the BatchOp to display 181 // @return - the output stream must be returned 182 friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) { 183 bo.Print(out, false); 184 return out; 185 } 186 187 // Main loop of batch 188 // @return Status The status code returned 189 Status operator()() override; 190 191 // Op name getter 192 // @return Name of the current Op Name()193 std::string Name() const override { return kBatchOp; } 194 195 // batch the rows in src table then put it to dest table 196 // @param const std::unique_ptr<TensorQTable> *src - table that has the rows for batching 197 // @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows 198 // @param int32_t size - batch_size 199 // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping 200 // @return Status The status code returned 201 static Status BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *dest, dsize_t batch_size); 202 203 // @param table 204 // @param const PadInfo &pad_info pad info 205 // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping 206 // @return Status The status code returned 207 static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info, 208 const std::unordered_map<std::string, int32_t> &column_name_id_map); 209 210 int64_t GetTreeBatchSize() override; 211 212 protected: 213 Status ComputeColMap() override; 214 215 private: 216 // Worker thread for doing the memcpy of batch 217 // @param int32_t param workerId 218 // @return Status The status code returned 219 Status WorkerEntry(int32_t worker_id) override; 220 221 // Generate row with batched tensors 222 // @return Status The status code returned 223 Status MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row); 224 225 #ifdef ENABLE_PYTHON 226 // Function that calls pyfunc to perform map on batch 227 // @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor 228 // @return Status The status code returned 229 Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair); 230 #endif 231 232 // @param const PadInfo &pad_info pad info to unpack 233 // @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping 234 // @param std::set<int32_t> *cols, col ids to perform pad on 235 // @param std::vector<float> *vals, default padding value for each column 236 // @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user 237 // @return Status The status code returned 238 static Status UnpackPadInfo(const PadInfo &pad_info, 239 const std::unordered_map<std::string, int32_t> &column_name_id_map, 240 std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals, 241 std::vector<std::vector<dsize_t>> *pad_shapes); 242 243 // the number of thread pulling from the mOutConnector of the Op below 244 // @return int32_t, 1 NumConsumers()245 int32_t NumConsumers() const override { return 1; } 246 247 // get the batch size for next batch 248 // @return Status The status code returned 249 Status GetBatchSize(int32_t *batch_size, CBatchInfo info); 250 251 // Do the initialization of all queues then start all worker threads 252 // @return Status The status code returned 253 Status LaunchThreadsAndInitOp(); 254 255 /// \brief Gets the next row 256 /// \param row[out] - Fetched TensorRow 257 /// \return Status The status code returned 258 Status GetNextRowPullMode(TensorRow *const row) override; 259 260 #ifdef ENABLE_PYTHON 261 // Invoke batch size function with current BatchInfo to generate batch size. 262 // @return Status The status code returned 263 Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); 264 265 // Invoke batch map function with current BatchInfo to generate tensors to batch. 266 // @return Status The status code returned 267 Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); 268 #endif 269 270 int32_t start_batch_size_; 271 const bool drop_; // bool for whether to drop remainder or not 272 const bool pad_; // bool for whether to perform padding on tensor 273 const std::vector<std::string> in_col_names_; // input column name for per_batch_map 274 std::vector<std::string> out_col_names_; // output column name for per_batch_map 275 PadInfo pad_info_; // column names to perform padding on 276 std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1 277 std::unordered_map<std::string, int32_t> child_map_; // col_name_id_map of the child node 278 QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_; // internal queue for syncing worker 279 int64_t batch_num_; 280 int64_t batch_cnt_; 281 #ifdef ENABLE_PYTHON 282 py::function batch_size_func_; // Function pointer of batch size function 283 py::function batch_map_func_; // Function pointer of per batch map function 284 #endif 285 }; 286 } // namespace dataset 287 } // namespace mindspore 288 289 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ 290