1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ 18 19 #include <algorithm> 20 #include <deque> 21 #include <map> 22 #include <memory> 23 #include <mutex> 24 #include <string> 25 #include <utility> 26 #include <vector> 27 #include "minddata/dataset/include/dataset/constants.h" 28 #include "minddata/dataset/engine/datasetops/dataset_op.h" 29 #include "minddata/dataset/engine/execution_tree.h" 30 #include "minddata/dataset/engine/datasetops/source/io_block.h" 31 #include "minddata/dataset/util/status.h" 32 33 namespace mindspore { 34 namespace dataset { 35 constexpr int64_t kCachedRowsSize = 16; 36 37 class ExecutionTree; 38 39 // A ParallelOp provides a multi-threaded DatasetOp 40 template <typename T, typename S> 41 class ParallelOp : public DatasetOp { 42 public: 43 /// Constructor 44 /// \param num_workers 45 /// \param op_connector_size - size of the output connector for this operator 46 /// \param sampler - The sampler for the op 47 ParallelOp(int32_t num_workers, int32_t op_connector_size, const std::shared_ptr<SamplerRT> sampler = nullptr) DatasetOp(op_connector_size,sampler)48 : DatasetOp(op_connector_size, sampler), 49 num_workers_paused_(0), 50 epoch_sync_flag_(false), 51 num_workers_(num_workers), 52 next_worker_id_(0), 53 worker_connector_size_(op_connector_size), 54 strategy_{nullptr} { 55 // reduce excessive memory usage with high parallelism 56 constexpr int32_t worker_limit = 4; 57 if (num_workers_ > worker_limit) { 58 worker_connector_size_ = std::max(1, op_connector_size * worker_limit / num_workers_); 59 } 60 } 61 // Destructor 62 ~ParallelOp() override = default; 63 64 /// A print method typically used for debugging 65 /// \param out - The output stream to write output to 66 /// \param show_all - A bool to control if you want to show all info or just a summary Print(std::ostream & out,bool show_all)67 void Print(std::ostream &out, bool show_all) const override { 68 DatasetOp::Print(out, show_all); 69 out << " [workers: " << num_workers_ << "]"; 70 } 71 Name()72 std::string Name() const override { return kParallelOp; } 73 74 // << Stream output operator overload 75 // @notes This allows you to write the debug print info using stream operators 76 // @param out - reference to the output stream being overloaded 77 // @param pO - reference to the ParallelOp to display 78 // @return - the output stream must be returned 79 friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) { 80 po.Print(out, false); 81 return out; 82 } 83 NumWorkers()84 int32_t NumWorkers() const override { 85 int32_t num_workers = 1; 86 { 87 std::unique_lock<std::mutex> _lock(mux_); 88 num_workers = num_workers_; 89 } 90 return num_workers; 91 } 92 93 // pause all the worker thread and collector thread WaitForWorkers()94 Status WaitForWorkers() override { 95 // reset num_paused workers to 0 96 num_workers_paused_ = 0; 97 uint32_t num_workers = NumWorkers(); 98 for (int32_t wkr_id = 0; wkr_id < num_workers; wkr_id++) { 99 RETURN_IF_NOT_OK(SendWaitFlagToWorker(NextWorkerID())); 100 } 101 // wait until all workers are done processing their work in local_queue_ 102 RETURN_IF_NOT_OK(wait_for_workers_post_.Wait()); 103 next_worker_id_ = 0; 104 // clear the WaitPost for the next Wait() 105 wait_for_workers_post_.Clear(); 106 return Status::OK(); 107 } 108 109 // wakeup all the worker threads and collector thread PostForWorkers()110 Status PostForWorkers() override { 111 // wakeup old workers 112 for (auto &item : worker_tasks_) { 113 item->Post(); 114 } 115 116 // wakeup the collector thread 117 wait_for_collector_.Set(); 118 119 return Status::OK(); 120 } 121 122 /// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows. 123 /// Then it adds a new thread to the list. 124 /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to 125 /// push rows to workers_in_queue 126 /// \return Status The status code returned 127 Status AddNewWorkers(int32_t num_new_workers = 1) override { 128 // wait for workers to process the current rows 129 RETURN_IF_NOT_OK(WaitForWorkers()); 130 for (int32_t i = 0; i < num_new_workers; i++) { 131 RETURN_IF_NOT_OK(worker_in_queues_.AddQueue(tree_->AllTasks())); 132 RETURN_IF_NOT_OK(worker_out_queues_.AddQueue(tree_->AllTasks())); 133 } 134 135 for (int32_t i = 0; i < num_new_workers; i++) { 136 Task *new_task; 137 RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask( 138 Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), &new_task, id())); 139 CHECK_FAIL_RETURN_UNEXPECTED(new_task != nullptr, "Cannot create a new worker."); 140 worker_tasks_.push_back(new_task); 141 { 142 std::unique_lock<std::mutex> _lock(mux_); 143 num_workers_++; 144 } 145 MS_LOG(INFO) << "A new worker has been added to op: " << Name() << "::" << id() 146 << " num_workers=" << num_workers_; 147 } 148 149 // wakeup all the workers threads and collector thread 150 RETURN_IF_NOT_OK(PostForWorkers()); 151 152 return Status::OK(); 153 } 154 155 /// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows. 156 /// Then it adds a new thread to the list. 157 /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to 158 /// push rows to workers_in_queue 159 /// \return Status The status code returned 160 Status RemoveWorkers(int32_t num_workers = 1) override { 161 // wait for workers to process the current rows 162 RETURN_IF_NOT_OK(WaitForWorkers()); 163 for (size_t i = 0; i < num_workers; i++) { 164 RETURN_IF_NOT_OK(SendQuitFlagToWorker(static_cast<size_t>(num_workers_) - 1)); 165 worker_tasks_[num_workers_ - 1]->Post(); // wakeup the worker 166 RETURN_IF_NOT_OK(worker_tasks_[static_cast<size_t>(num_workers_) - 1]->Join()); 167 RETURN_IF_NOT_OK(worker_in_queues_.RemoveLastQueue()); 168 worker_tasks_.pop_back(); 169 { 170 std::unique_lock<std::mutex> _lock(mux_); 171 num_workers_--; 172 } 173 MS_LOG(INFO) << "Worker ID " << num_workers_ << " is requested to be removed in operator: " << NameWithID() 174 << " num_workers=" << num_workers_; 175 } 176 177 // wakeup all the workers threads and collector thread 178 RETURN_IF_NOT_OK(PostForWorkers()); 179 180 return Status::OK(); 181 } 182 183 protected: 184 /// Interface for derived classes to implement. All derived classes must provide the entry 185 /// function with the main execution loop for worker threads. 186 /// \return Status The status code returned 187 virtual Status WorkerEntry(int32_t workerId) = 0; 188 189 /// Called first when function is called 190 /// \return Status The status code returned RegisterAndLaunchThreads()191 virtual Status RegisterAndLaunchThreads() { 192 RETURN_UNEXPECTED_IF_NULL(tree_); 193 worker_in_queues_.Init(num_workers_, worker_connector_size_); 194 worker_out_queues_.Init(num_workers_, worker_connector_size_); 195 196 // Registers QueueList and individual Queues for interrupt services 197 RETURN_IF_NOT_OK(worker_in_queues_.Register(tree_->AllTasks())); 198 RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks())); 199 RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); 200 201 RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, 202 std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1), 203 &worker_tasks_, Name() + "::WorkerEntry", id())); 204 RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ParallelOp::Collector, this), Name() + "::Collector", id())); 205 206 return Status::OK(); 207 } 208 209 class RowHandlingStrategy { 210 public: RowHandlingStrategy(ParallelOp * op)211 explicit RowHandlingStrategy(ParallelOp *op) : op_(op) {} 212 virtual ~RowHandlingStrategy() = default; 213 HandleHealthyRow(TensorRow * row)214 virtual Status HandleHealthyRow([[maybe_unused]] TensorRow *row) { 215 ++this->op_->ep_step_; 216 ++this->op_->total_step_; 217 RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam( 218 static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_))); 219 return this->op_->out_connector_->Add(std::move(*row)); 220 } 221 virtual Status HandleErrorRow([[maybe_unused]] TensorRow *row) = 0; 222 HandleEOE(TensorRow * row)223 virtual Status HandleEOE([[maybe_unused]] TensorRow *row) { 224 this->op_->current_repeats_++; 225 // check whether this is the end of a real epoch (not all eoe signals end of epoch) 226 if (this->op_->current_repeats_ % this->op_->GetOpNumRepeatsPerEpoch() == 0) { 227 this->op_->current_epochs_++; 228 RETURN_IF_NOT_OK(this->op_->callback_manager_.EpochEnd( 229 CallbackParam(this->op_->current_epochs_, this->op_->ep_step_, this->op_->total_step_))); 230 this->op_->ep_step_ = 0; 231 } 232 return op_->out_connector_->Add(std::move(*row)); 233 } HandleEOF(TensorRow * row)234 virtual Status HandleEOF([[maybe_unused]] TensorRow *row) { 235 RETURN_IF_NOT_OK(this->op_->callback_manager_.End(CallbackParam( 236 static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_))); 237 return op_->out_connector_->Add(std::move(*row)); 238 } 239 240 protected: 241 ParallelOp *op_; 242 }; 243 244 class ErrorStrategy : public RowHandlingStrategy { 245 public: 246 using RowHandlingStrategy::RowHandlingStrategy; HandleErrorRow(TensorRow * row)247 Status HandleErrorRow([[maybe_unused]] TensorRow *row) override { 248 return Status(StatusCode::kMDUnexpectedError, 249 "[Internal Error] Error row is detected in collector while Error strategy is set to error out!"); 250 } 251 }; 252 253 class SkipStrategy : public RowHandlingStrategy { 254 public: 255 using RowHandlingStrategy::RowHandlingStrategy; HandleErrorRow(TensorRow * row)256 Status HandleErrorRow([[maybe_unused]] TensorRow *row) override { return Status::OK(); } 257 }; 258 259 class ReplaceStrategy : public RowHandlingStrategy { 260 public: 261 using RowHandlingStrategy::RowHandlingStrategy; 262 HandleHealthyRow(TensorRow * row)263 Status HandleHealthyRow([[maybe_unused]] TensorRow *row) override { 264 CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize, 265 "[Internal Error] Number of cached rows is beyond the number set."); 266 if (backup_index_ < kCachedRowsSize - 1) { // cache has used row(s) or is not full 267 if (IsCacheFull()) { 268 // remove the last element from cache (a used row) 269 PopFromCache(); 270 } 271 RETURN_IF_NOT_OK(AddToCache(*row)); 272 } else { // cache is full of unused rows 273 if (missing_errors_ > 0) { 274 // send a cached row to next op and cache the current row 275 RETURN_IF_NOT_OK(AddFromCache()); 276 PopFromCache(); 277 missing_errors_--; 278 RETURN_IF_NOT_OK(AddToCache(*row)); 279 } 280 } 281 // send the healthy row to next op 282 ++this->op_->ep_step_; 283 ++this->op_->total_step_; 284 RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam( 285 static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_))); 286 return this->op_->out_connector_->Add(std::move(*row)); 287 } 288 HandleErrorRow(TensorRow * row)289 Status HandleErrorRow([[maybe_unused]] TensorRow *row) override { 290 CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize, 291 "[Internal Error] Number of cached rows is beyond the number set."); 292 // cache is not full of unused rows 293 if (backup_index_ != kCachedRowsSize - 1) { 294 missing_errors_++; 295 return Status::OK(); 296 } 297 // cache is full of unused rows and we have an error row 298 return AddFromCache(); 299 } 300 HandleEOE(TensorRow * row)301 Status HandleEOE([[maybe_unused]] TensorRow *row) override { 302 CHECK_FAIL_RETURN_UNEXPECTED(missing_errors_ == 0 || !IsCacheEmpty(), 303 "All data is garbage and cannot be replaced."); 304 // send outstanding rows first and then send eoe 305 while (missing_errors_ > 0) { 306 RETURN_IF_NOT_OK(AddFromCache()); 307 missing_errors_--; 308 } 309 return RowHandlingStrategy::HandleEOE(row); 310 } 311 HandleEOF(TensorRow * row)312 Status HandleEOF([[maybe_unused]] TensorRow *row) override { 313 // release memory 314 std::deque<TensorRow>().swap(backup_rows); 315 return RowHandlingStrategy::HandleEOF(row); 316 } 317 318 private: AddFromCache()319 Status AddFromCache() { 320 CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() > 0, "Cannot add a row from cache since cache is empty!"); 321 const TensorRow &cached_row = backup_rows[static_cast<size_t>(backup_index_) % backup_rows.size()]; 322 TensorRow copy_row; 323 RETURN_IF_NOT_OK(cached_row.Clone(©_row)); 324 backup_index_--; 325 ++this->op_->ep_step_; 326 ++this->op_->total_step_; 327 RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam( 328 static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_))); 329 return this->op_->out_connector_->Add(std::move(copy_row)); 330 } 331 AddToCache(const TensorRow & row)332 Status AddToCache(const TensorRow &row) { 333 CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() < kCachedRowsSize, 334 "[Internal Error] Inserting another row to cache while cache is already full."); 335 CHECK_FAIL_RETURN_UNEXPECTED( 336 backup_index_ < kCachedRowsSize - 1, 337 "[Internal Error] Inserting another row to cache while cache is already full of unused rows."); 338 TensorRow copy_row; 339 RETURN_IF_NOT_OK(row.Clone(©_row)); 340 (void)backup_rows.emplace_front(std::move(copy_row)); 341 backup_index_++; 342 return Status::OK(); 343 } 344 PopFromCache()345 void PopFromCache() { backup_rows.pop_back(); } IsCacheFull()346 bool IsCacheFull() const { return backup_rows.size() == kCachedRowsSize; } IsCacheEmpty()347 bool IsCacheEmpty() const { return backup_rows.size() == 0; } 348 std::deque<TensorRow> backup_rows{}; // will hold a copy of some healthy rows collected (NOT error, skip, eoe, eof) 349 int32_t backup_index_{-1}; // index of the backup we should pick next time (can be negative if we run out of 350 // unused cached rows) 351 int32_t missing_errors_{0}; // the number of unaddressed error rows (that we need to send a replacement to output) 352 }; 353 Collector()354 virtual Status Collector() { 355 TaskManager::FindMe()->Post(); 356 // num_rows received, including eoe, 357 int64_t num_rows = 0; 358 current_repeats_ = 0; 359 current_epochs_ = 0; 360 SetStrategy(); 361 // num_step of current epoch and the total 362 ep_step_ = 0, total_step_ = 0; 363 do { 364 TensorRow row; 365 RETURN_IF_NOT_OK(worker_out_queues_[static_cast<const int>(num_rows++ % NumWorkers())]->PopFront(&row)); 366 if (row.wait()) { 367 // When collector receives the signal from worker thread, it increments an atomic int 368 // If num_worker signals are received, wakes up the main thread 369 if (++num_workers_paused_ == num_workers_) { 370 wait_for_workers_post_.Set(); 371 RETURN_IF_NOT_OK(wait_for_collector_.Wait()); 372 wait_for_collector_.Clear(); 373 num_rows = 0; 374 } 375 continue; 376 } else if (row.eoe()) { 377 RETURN_IF_NOT_OK(strategy_->HandleEOE(&row)); 378 } else if (row.eof()) { 379 RETURN_IF_NOT_OK(strategy_->HandleEOF(&row)); 380 break; 381 } else if (row.skip()) { 382 continue; 383 } else if (row.error()) { 384 RETURN_IF_NOT_OK(strategy_->HandleErrorRow(&row)); 385 } else if (row.Flags() == TensorRow::TensorRowFlags::kFlagNone) { 386 RETURN_IF_NOT_OK(strategy_->HandleHealthyRow(&row)); 387 } 388 } while (true); 389 return Status::OK(); 390 } 391 392 // Wait post used to perform the pausing logic 393 WaitPost wait_for_workers_post_; 394 395 // Wait post used to perform the collector thread 396 WaitPost wait_for_collector_; 397 398 // Count number of workers that have signaled master 399 std::atomic_int num_workers_paused_; 400 401 /// Whether or not to sync worker threads at the end of each epoch 402 bool epoch_sync_flag_; 403 404 /// The number of worker threads 405 int32_t num_workers_; 406 407 std::vector<Task *> worker_tasks_; 408 NextWorkerID()409 int32_t NextWorkerID() { 410 int32_t next_worker = next_worker_id_; 411 next_worker_id_ = (next_worker_id_ + 1) % num_workers_; 412 return next_worker; 413 } 414 415 public: NumWorkers()416 int32_t NumWorkers() override { 417 int32_t num_workers = 1; 418 { 419 std::unique_lock<std::mutex> _lock(mux_); 420 num_workers = num_workers_; 421 } 422 return num_workers; 423 } 424 425 private: SetStrategy()426 void SetStrategy() { 427 if (Name() != kMapOp) { 428 strategy_ = std::make_unique<ErrorStrategy>(this); 429 return; 430 } 431 if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kSkip) { 432 strategy_ = std::make_unique<SkipStrategy>(this); 433 } else if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kReplace) { 434 strategy_ = std::make_unique<ReplaceStrategy>(this); 435 } else { 436 strategy_ = std::make_unique<ErrorStrategy>(this); 437 } 438 } 439 440 protected: 441 std::atomic_int next_worker_id_; 442 443 std::map<int32_t, std::atomic_bool> quit_ack_; 444 445 /// The size of input/output worker queeus 446 int32_t worker_connector_size_; 447 /// queues to hold the input rows to workers 448 QueueList<T> worker_in_queues_; 449 /// queues to hold the output from workers 450 QueueList<S> worker_out_queues_; 451 452 // lock for num_workers_ read and write 453 mutable std::mutex mux_; 454 455 private: 456 std::unique_ptr<RowHandlingStrategy> strategy_; 457 int32_t ep_step_{0}; 458 int32_t total_step_{0}; 459 int32_t current_epochs_{0}; 460 int32_t current_repeats_{0}; 461 }; 462 } // namespace dataset 463 } // namespace mindspore 464 465 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ 466