• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_PARALLEL_THREADPOOL_H_
18 #define MINDSPORE_CORE_MINDRT_RUNTIME_PARALLEL_THREADPOOL_H_
19 
20 #include <queue>
21 #include <vector>
22 #include <mutex>
23 #include <atomic>
24 #include <string>
25 #include <condition_variable>
26 #include "thread/actor_threadpool.h"
27 
28 namespace mindspore {
29 typedef struct Distributor {
30   int started = 0;
31   int task_num = 0;
32 } Distributor;
33 typedef struct ParallelTask : public Task {
ParallelTaskParallelTask34   ParallelTask() : Task(nullptr, nullptr) {}
35   std::atomic<Distributor> distributor;
36   std::atomic_bool valid{false};
37   std::atomic_bool occupied{false};
38 } ParallelTask;
39 class ParallelThreadPool;
40 class ParallelWorker : public Worker {
41  public:
ParallelWorker(ThreadPool * pool,size_t index)42   explicit ParallelWorker(ThreadPool *pool, size_t index) : Worker(pool, index) {
43     parallel_pool_ = reinterpret_cast<ParallelThreadPool *>(pool_);
44   }
45   void CreateThread() override;
46   bool RunLocalKernelTask() override;
47   ~ParallelWorker() override;
48   void RunOtherPoolTask(ParallelTask *p_task);
49   void WaitOtherPoolTask();
50   void DirectRunOtherPoolTask();
51   void ActivateByOtherPoolTask(ParallelTask *task = nullptr);
52 
53  protected:
54   void WaitUntilActive() override;
55 
56  private:
57   void ParallelRun();
58   bool RunQueueActorTask();
59   ParallelThreadPool *parallel_pool_{nullptr};
60   bool wait_do_other_task_{true};
61   ParallelTask *other_task_ = nullptr;
62   std::mutex other_task_mutex_;
63   std::condition_variable cv_other_task_;
64   bool enable_shared_thread_pool_ = false;
65 };
66 
67 class MS_CORE_API ParallelThreadPool : public ActorThreadPool {
68  public:
69   static ParallelThreadPool *CreateThreadPool(size_t actor_thread_num, size_t all_thread_num,
70                                               const std::vector<int> &core_list, BindMode bind_mode,
71                                               std::string runner_id = "");
~ParallelThreadPool()72   ~ParallelThreadPool() override {
73     MS_LOG(INFO) << "free parallel thread pool.";
74     // wait until actor queue is empty
75     bool terminate = false;
76     int count = 0;
77     do {
78       {
79 #ifdef USE_HQUEUE
80         terminate = actor_queue_.Empty();
81 #else
82         std::lock_guard<std::mutex> _l(actor_mutex_);
83         terminate = actor_queue_.empty();
84 #endif
85       }
86       if (!terminate) {
87         ActiveWorkers();
88         std::this_thread::yield();
89       }
90     } while (!terminate && count++ < kMaxCount);
91     MS_LOG(INFO) << "Wait for all worker to delete.";
92     for (auto &worker : workers_) {
93       delete worker;
94       worker = nullptr;
95     }
96     MS_LOG(INFO) << "delete workers.";
97     workers_.clear();
98     tasks_size_ = 0;
99     if (tasks_) {
100       delete[] tasks_;
101     }
102   }
103 
104   int ParallelLaunch(const Func &func, Content content, int task_num) override;
105 
PushActorToQueue(ActorBase * actor)106   void PushActorToQueue(ActorBase *actor) override {
107     if (!actor) {
108       return;
109     }
110     {
111 #ifdef USE_HQUEUE
112       while (!actor_queue_.Enqueue(actor)) {
113       }
114 #else
115       std::lock_guard<std::mutex> _l(actor_mutex_);
116       actor_queue_.push(actor);
117 #endif
118     }
119     THREAD_DEBUG("actor[%s] enqueue success", actor->GetAID().Name().c_str());
120     size_t size = workers_.size() > tasks_size_ ? tasks_size_ : workers_.size();
121     for (size_t i = 0; i < size; i++) {
122       if (!enable_shared_) {
123         workers_[i]->Active();
124       } else {
125         static_cast<ParallelWorker *>(workers_[i])->ActivateByOtherPoolTask(nullptr);
126       }
127     }
128   }
129 
130   inline bool RunTaskOnce(int start, int end);
131 
132   bool RunParallel();
133 
tasks_size()134   size_t tasks_size() const { return tasks_size_; }
135 
136   bool IsIdlePool();
137 
138   int GetPoolRef();
139 
140   bool SetRunnerID(const std::string &runner_id) override;
141 
142   std::vector<ParallelWorker *> GetParallelPoolWorkers();
143 
144   void UseThreadPool(int num);
145 
GetPoolBindRunnerID()146   inline std::string GetPoolBindRunnerID() { return bind_runner_id_; }
147 
GetEnableShared()148   inline bool GetEnableShared() { return enable_shared_; }
149 
150  private:
ParallelThreadPool()151   ParallelThreadPool() {}
152   int CreateParallelThreads(size_t actor_thread_num, size_t all_thread_num, const std::vector<int> &core_list);
153 
154   std::atomic_int tasks_start_{0};
155   std::atomic_int tasks_end_{0};
156   ParallelTask *tasks_;
157   size_t tasks_size_ = 0;
158   bool enable_shared_ = false;
159   std::string bind_runner_id_;
160   std::mutex mutex_pool_ref_count_;
161   std::atomic_int pool_ref_count_{0};
162   int thread_num_;
163 };
164 }  // namespace mindspore
165 #endif  // MINDSPORE_CORE_MINDRT_RUNTIME_PARALLEL_THREADPOOL_H_
166