• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
18 
19 #include <algorithm>
20 #include <deque>
21 #include <map>
22 #include <memory>
23 #include <mutex>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "minddata/dataset/include/dataset/constants.h"
28 #include "minddata/dataset/engine/datasetops/dataset_op.h"
29 #include "minddata/dataset/engine/execution_tree.h"
30 #include "minddata/dataset/engine/datasetops/source/io_block.h"
31 #include "minddata/dataset/util/status.h"
32 
33 namespace mindspore {
34 namespace dataset {
35 constexpr int64_t kCachedRowsSize = 16;
36 
37 class ExecutionTree;
38 
39 // A ParallelOp provides a multi-threaded DatasetOp
40 template <typename T, typename S>
41 class ParallelOp : public DatasetOp {
42  public:
43   /// Constructor
44   /// \param num_workers
45   /// \param op_connector_size - size of the output connector for this operator
46   /// \param sampler - The sampler for the op
47   ParallelOp(int32_t num_workers, int32_t op_connector_size, const std::shared_ptr<SamplerRT> sampler = nullptr)
DatasetOp(op_connector_size,sampler)48       : DatasetOp(op_connector_size, sampler),
49         num_workers_paused_(0),
50         epoch_sync_flag_(false),
51         num_workers_(num_workers),
52         next_worker_id_(0),
53         worker_connector_size_(op_connector_size),
54         strategy_{nullptr} {
55     // reduce excessive memory usage with high parallelism
56     constexpr int32_t worker_limit = 4;
57     if (num_workers_ > worker_limit) {
58       worker_connector_size_ = std::max(1, op_connector_size * worker_limit / num_workers_);
59     }
60   }
61   // Destructor
62   ~ParallelOp() override = default;
63 
64   /// A print method typically used for debugging
65   /// \param out - The output stream to write output to
66   /// \param show_all - A bool to control if you want to show all info or just a summary
Print(std::ostream & out,bool show_all)67   void Print(std::ostream &out, bool show_all) const override {
68     DatasetOp::Print(out, show_all);
69     out << " [workers: " << num_workers_ << "]";
70   }
71 
Name()72   std::string Name() const override { return kParallelOp; }
73 
74   // << Stream output operator overload
75   // @notes This allows you to write the debug print info using stream operators
76   // @param out - reference to the output stream being overloaded
77   // @param pO - reference to the ParallelOp to display
78   // @return - the output stream must be returned
79   friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) {
80     po.Print(out, false);
81     return out;
82   }
83 
NumWorkers()84   int32_t NumWorkers() const override {
85     int32_t num_workers = 1;
86     {
87       std::unique_lock<std::mutex> _lock(mux_);
88       num_workers = num_workers_;
89     }
90     return num_workers;
91   }
92 
93   // pause all the worker thread and collector thread
WaitForWorkers()94   Status WaitForWorkers() override {
95     // reset num_paused workers to 0
96     num_workers_paused_ = 0;
97     uint32_t num_workers = NumWorkers();
98     for (int32_t wkr_id = 0; wkr_id < num_workers; wkr_id++) {
99       RETURN_IF_NOT_OK(SendWaitFlagToWorker(NextWorkerID()));
100     }
101     // wait until all workers are done processing their work in local_queue_
102     RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
103     next_worker_id_ = 0;
104     // clear the WaitPost for the next Wait()
105     wait_for_workers_post_.Clear();
106     return Status::OK();
107   }
108 
109   // wakeup all the worker threads and collector thread
PostForWorkers()110   Status PostForWorkers() override {
111     // wakeup old workers
112     for (auto &item : worker_tasks_) {
113       item->Post();
114     }
115 
116     // wakeup the collector thread
117     wait_for_collector_.Set();
118 
119     return Status::OK();
120   }
121 
122   /// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows.
123   /// Then it adds a new thread to the list.
124   /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to
125   /// push rows to workers_in_queue
126   /// \return Status The status code returned
127   Status AddNewWorkers(int32_t num_new_workers = 1) override {
128     // wait for workers to process the current rows
129     RETURN_IF_NOT_OK(WaitForWorkers());
130     for (int32_t i = 0; i < num_new_workers; i++) {
131       RETURN_IF_NOT_OK(worker_in_queues_.AddQueue(tree_->AllTasks()));
132       RETURN_IF_NOT_OK(worker_out_queues_.AddQueue(tree_->AllTasks()));
133     }
134 
135     for (int32_t i = 0; i < num_new_workers; i++) {
136       Task *new_task;
137       RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
138         Name() + "::WorkerEntry", std::bind(&ParallelOp::WorkerEntry, this, num_workers_), &new_task, id()));
139       CHECK_FAIL_RETURN_UNEXPECTED(new_task != nullptr, "Cannot create a new worker.");
140       worker_tasks_.push_back(new_task);
141       {
142         std::unique_lock<std::mutex> _lock(mux_);
143         num_workers_++;
144       }
145       MS_LOG(INFO) << "A new worker has been added to op: " << Name() << "::" << id()
146                    << " num_workers=" << num_workers_;
147     }
148 
149     // wakeup all the workers threads and collector thread
150     RETURN_IF_NOT_OK(PostForWorkers());
151 
152     return Status::OK();
153   }
154 
155   /// Add a new worker to the parallelOp. The function will have to wait for all workers to process current rows.
156   /// Then it adds a new thread to the list.
157   /// \note The caller of this function has to be the main thread of the Op, since it's the only entity responsible to
158   /// push rows to workers_in_queue
159   /// \return Status The status code returned
160   Status RemoveWorkers(int32_t num_workers = 1) override {
161     // wait for workers to process the current rows
162     RETURN_IF_NOT_OK(WaitForWorkers());
163     for (size_t i = 0; i < num_workers; i++) {
164       RETURN_IF_NOT_OK(SendQuitFlagToWorker(static_cast<size_t>(num_workers_) - 1));
165       worker_tasks_[num_workers_ - 1]->Post();  // wakeup the worker
166       RETURN_IF_NOT_OK(worker_tasks_[static_cast<size_t>(num_workers_) - 1]->Join());
167       RETURN_IF_NOT_OK(worker_in_queues_.RemoveLastQueue());
168       worker_tasks_.pop_back();
169       {
170         std::unique_lock<std::mutex> _lock(mux_);
171         num_workers_--;
172       }
173       MS_LOG(INFO) << "Worker ID " << num_workers_ << " is requested to be removed in operator: " << NameWithID()
174                    << " num_workers=" << num_workers_;
175     }
176 
177     // wakeup all the workers threads and collector thread
178     RETURN_IF_NOT_OK(PostForWorkers());
179 
180     return Status::OK();
181   }
182 
183  protected:
184   /// Interface for derived classes to implement. All derived classes must provide the entry
185   /// function with the main execution loop for worker threads.
186   /// \return Status The status code returned
187   virtual Status WorkerEntry(int32_t workerId) = 0;
188 
189   /// Called first when function is called
190   /// \return Status The status code returned
RegisterAndLaunchThreads()191   virtual Status RegisterAndLaunchThreads() {
192     RETURN_UNEXPECTED_IF_NULL(tree_);
193     worker_in_queues_.Init(num_workers_, worker_connector_size_);
194     worker_out_queues_.Init(num_workers_, worker_connector_size_);
195 
196     // Registers QueueList and individual Queues for interrupt services
197     RETURN_IF_NOT_OK(worker_in_queues_.Register(tree_->AllTasks()));
198     RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks()));
199     RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
200 
201     RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
202                                           std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1),
203                                           &worker_tasks_, Name() + "::WorkerEntry", id()));
204     RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ParallelOp::Collector, this), Name() + "::Collector", id()));
205 
206     return Status::OK();
207   }
208 
209   class RowHandlingStrategy {
210    public:
RowHandlingStrategy(ParallelOp * op)211     explicit RowHandlingStrategy(ParallelOp *op) : op_(op) {}
212     virtual ~RowHandlingStrategy() = default;
213 
HandleHealthyRow(TensorRow * row)214     virtual Status HandleHealthyRow([[maybe_unused]] TensorRow *row) {
215       ++this->op_->ep_step_;
216       ++this->op_->total_step_;
217       RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam(
218         static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
219       return this->op_->out_connector_->Add(std::move(*row));
220     }
221     virtual Status HandleErrorRow([[maybe_unused]] TensorRow *row) = 0;
222 
HandleEOE(TensorRow * row)223     virtual Status HandleEOE([[maybe_unused]] TensorRow *row) {
224       this->op_->current_repeats_++;
225       // check whether this is the end of a real epoch (not all eoe signals end of epoch)
226       if (this->op_->current_repeats_ % this->op_->GetOpNumRepeatsPerEpoch() == 0) {
227         this->op_->current_epochs_++;
228         RETURN_IF_NOT_OK(this->op_->callback_manager_.EpochEnd(
229           CallbackParam(this->op_->current_epochs_, this->op_->ep_step_, this->op_->total_step_)));
230         this->op_->ep_step_ = 0;
231       }
232       return op_->out_connector_->Add(std::move(*row));
233     }
HandleEOF(TensorRow * row)234     virtual Status HandleEOF([[maybe_unused]] TensorRow *row) {
235       RETURN_IF_NOT_OK(this->op_->callback_manager_.End(CallbackParam(
236         static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
237       return op_->out_connector_->Add(std::move(*row));
238     }
239 
240    protected:
241     ParallelOp *op_;
242   };
243 
244   class ErrorStrategy : public RowHandlingStrategy {
245    public:
246     using RowHandlingStrategy::RowHandlingStrategy;
HandleErrorRow(TensorRow * row)247     Status HandleErrorRow([[maybe_unused]] TensorRow *row) override {
248       return Status(StatusCode::kMDUnexpectedError,
249                     "[Internal Error] Error row is detected in collector while Error strategy is set to error out!");
250     }
251   };
252 
253   class SkipStrategy : public RowHandlingStrategy {
254    public:
255     using RowHandlingStrategy::RowHandlingStrategy;
HandleErrorRow(TensorRow * row)256     Status HandleErrorRow([[maybe_unused]] TensorRow *row) override { return Status::OK(); }
257   };
258 
259   class ReplaceStrategy : public RowHandlingStrategy {
260    public:
261     using RowHandlingStrategy::RowHandlingStrategy;
262 
HandleHealthyRow(TensorRow * row)263     Status HandleHealthyRow([[maybe_unused]] TensorRow *row) override {
264       CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
265                                    "[Internal Error] Number of cached rows is beyond the number set.");
266       if (backup_index_ < kCachedRowsSize - 1) {  // cache has used row(s) or is not full
267         if (IsCacheFull()) {
268           // remove the last element from cache (a used row)
269           PopFromCache();
270         }
271         RETURN_IF_NOT_OK(AddToCache(*row));
272       } else {  // cache is full of unused rows
273         if (missing_errors_ > 0) {
274           // send a cached row to next op and cache the current row
275           RETURN_IF_NOT_OK(AddFromCache());
276           PopFromCache();
277           missing_errors_--;
278           RETURN_IF_NOT_OK(AddToCache(*row));
279         }
280       }
281       // send the healthy row to next op
282       ++this->op_->ep_step_;
283       ++this->op_->total_step_;
284       RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam(
285         static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
286       return this->op_->out_connector_->Add(std::move(*row));
287     }
288 
HandleErrorRow(TensorRow * row)289     Status HandleErrorRow([[maybe_unused]] TensorRow *row) override {
290       CHECK_FAIL_RETURN_UNEXPECTED(backup_index_ < kCachedRowsSize,
291                                    "[Internal Error] Number of cached rows is beyond the number set.");
292       // cache is not full of unused rows
293       if (backup_index_ != kCachedRowsSize - 1) {
294         missing_errors_++;
295         return Status::OK();
296       }
297       // cache is full of unused rows and we have an error row
298       return AddFromCache();
299     }
300 
HandleEOE(TensorRow * row)301     Status HandleEOE([[maybe_unused]] TensorRow *row) override {
302       CHECK_FAIL_RETURN_UNEXPECTED(missing_errors_ == 0 || !IsCacheEmpty(),
303                                    "All data is garbage and cannot be replaced.");
304       // send outstanding rows first and then send eoe
305       while (missing_errors_ > 0) {
306         RETURN_IF_NOT_OK(AddFromCache());
307         missing_errors_--;
308       }
309       return RowHandlingStrategy::HandleEOE(row);
310     }
311 
HandleEOF(TensorRow * row)312     Status HandleEOF([[maybe_unused]] TensorRow *row) override {
313       // release memory
314       std::deque<TensorRow>().swap(backup_rows);
315       return RowHandlingStrategy::HandleEOF(row);
316     }
317 
318    private:
AddFromCache()319     Status AddFromCache() {
320       CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() > 0, "Cannot add a row from cache since cache is empty!");
321       const TensorRow &cached_row = backup_rows[static_cast<size_t>(backup_index_) % backup_rows.size()];
322       TensorRow copy_row;
323       RETURN_IF_NOT_OK(cached_row.Clone(&copy_row));
324       backup_index_--;
325       ++this->op_->ep_step_;
326       ++this->op_->total_step_;
327       RETURN_IF_NOT_OK(this->op_->callback_manager_.StepEnd(CallbackParam(
328         static_cast<int64_t>(this->op_->current_epochs_) + 1, this->op_->ep_step_, this->op_->total_step_)));
329       return this->op_->out_connector_->Add(std::move(copy_row));
330     }
331 
AddToCache(const TensorRow & row)332     Status AddToCache(const TensorRow &row) {
333       CHECK_FAIL_RETURN_UNEXPECTED(backup_rows.size() < kCachedRowsSize,
334                                    "[Internal Error] Inserting another row to cache while cache is already full.");
335       CHECK_FAIL_RETURN_UNEXPECTED(
336         backup_index_ < kCachedRowsSize - 1,
337         "[Internal Error] Inserting another row to cache while cache is already full of unused rows.");
338       TensorRow copy_row;
339       RETURN_IF_NOT_OK(row.Clone(&copy_row));
340       (void)backup_rows.emplace_front(std::move(copy_row));
341       backup_index_++;
342       return Status::OK();
343     }
344 
PopFromCache()345     void PopFromCache() { backup_rows.pop_back(); }
IsCacheFull()346     bool IsCacheFull() const { return backup_rows.size() == kCachedRowsSize; }
IsCacheEmpty()347     bool IsCacheEmpty() const { return backup_rows.size() == 0; }
348     std::deque<TensorRow> backup_rows{};  // will hold a copy of some healthy rows collected (NOT error, skip, eoe, eof)
349     int32_t backup_index_{-1};  // index of the backup we should pick next time (can be negative if we run out of
350     // unused cached rows)
351     int32_t missing_errors_{0};  // the number of unaddressed error rows (that we need to send a replacement to output)
352   };
353 
Collector()354   virtual Status Collector() {
355     TaskManager::FindMe()->Post();
356     // num_rows received, including eoe,
357     int64_t num_rows = 0;
358     current_repeats_ = 0;
359     current_epochs_ = 0;
360     SetStrategy();
361     // num_step of current epoch and the total
362     ep_step_ = 0, total_step_ = 0;
363     do {
364       TensorRow row;
365       RETURN_IF_NOT_OK(worker_out_queues_[static_cast<const int>(num_rows++ % NumWorkers())]->PopFront(&row));
366       if (row.wait()) {
367         // When collector receives the signal from worker thread, it increments an atomic int
368         // If num_worker signals are received, wakes up the main thread
369         if (++num_workers_paused_ == num_workers_) {
370           wait_for_workers_post_.Set();
371           RETURN_IF_NOT_OK(wait_for_collector_.Wait());
372           wait_for_collector_.Clear();
373           num_rows = 0;
374         }
375         continue;
376       } else if (row.eoe()) {
377         RETURN_IF_NOT_OK(strategy_->HandleEOE(&row));
378       } else if (row.eof()) {
379         RETURN_IF_NOT_OK(strategy_->HandleEOF(&row));
380         break;
381       } else if (row.skip()) {
382         continue;
383       } else if (row.error()) {
384         RETURN_IF_NOT_OK(strategy_->HandleErrorRow(&row));
385       } else if (row.Flags() == TensorRow::TensorRowFlags::kFlagNone) {
386         RETURN_IF_NOT_OK(strategy_->HandleHealthyRow(&row));
387       }
388     } while (true);
389     return Status::OK();
390   }
391 
392   // Wait post used to perform the pausing logic
393   WaitPost wait_for_workers_post_;
394 
395   // Wait post used to perform the collector thread
396   WaitPost wait_for_collector_;
397 
398   // Count number of workers that have signaled master
399   std::atomic_int num_workers_paused_;
400 
401   /// Whether or not to sync worker threads at the end of each epoch
402   bool epoch_sync_flag_;
403 
404   /// The number of worker threads
405   int32_t num_workers_;
406 
407   std::vector<Task *> worker_tasks_;
408 
NextWorkerID()409   int32_t NextWorkerID() {
410     int32_t next_worker = next_worker_id_;
411     next_worker_id_ = (next_worker_id_ + 1) % num_workers_;
412     return next_worker;
413   }
414 
415  public:
NumWorkers()416   int32_t NumWorkers() override {
417     int32_t num_workers = 1;
418     {
419       std::unique_lock<std::mutex> _lock(mux_);
420       num_workers = num_workers_;
421     }
422     return num_workers;
423   }
424 
425  private:
SetStrategy()426   void SetStrategy() {
427     if (Name() != kMapOp) {
428       strategy_ = std::make_unique<ErrorStrategy>(this);
429       return;
430     }
431     if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kSkip) {
432       strategy_ = std::make_unique<SkipStrategy>(this);
433     } else if (GlobalContext::config_manager()->error_samples_mode() == ErrorSamplesMode::kReplace) {
434       strategy_ = std::make_unique<ReplaceStrategy>(this);
435     } else {
436       strategy_ = std::make_unique<ErrorStrategy>(this);
437     }
438   }
439 
440  protected:
441   std::atomic_int next_worker_id_;
442 
443   std::map<int32_t, std::atomic_bool> quit_ack_;
444 
445   /// The size of input/output worker queeus
446   int32_t worker_connector_size_;
447   /// queues to hold the input rows to workers
448   QueueList<T> worker_in_queues_;
449   /// queues to hold the output from workers
450   QueueList<S> worker_out_queues_;
451 
452   // lock for num_workers_ read and write
453   mutable std::mutex mux_;
454 
455  private:
456   std::unique_ptr<RowHandlingStrategy> strategy_;
457   int32_t ep_step_{0};
458   int32_t total_step_{0};
459   int32_t current_epochs_{0};
460   int32_t current_repeats_{0};
461 };
462 }  // namespace dataset
463 }  // namespace mindspore
464 
465 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
466