1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_PARALLEL_THREADPOOL_H_ 18 #define MINDSPORE_CORE_MINDRT_RUNTIME_PARALLEL_THREADPOOL_H_ 19 20 #include <queue> 21 #include <vector> 22 #include <mutex> 23 #include <atomic> 24 #include <string> 25 #include <condition_variable> 26 #include "thread/actor_threadpool.h" 27 28 namespace mindspore { 29 typedef struct Distributor { 30 int started = 0; 31 int task_num = 0; 32 } Distributor; 33 typedef struct ParallelTask : public Task { ParallelTaskParallelTask34 ParallelTask() : Task(nullptr, nullptr) {} 35 std::atomic<Distributor> distributor; 36 std::atomic_bool valid{false}; 37 std::atomic_bool occupied{false}; 38 } ParallelTask; 39 class ParallelThreadPool; 40 class ParallelWorker : public Worker { 41 public: ParallelWorker(ThreadPool * pool,size_t index)42 explicit ParallelWorker(ThreadPool *pool, size_t index) : Worker(pool, index) { 43 parallel_pool_ = reinterpret_cast<ParallelThreadPool *>(pool_); 44 } 45 void CreateThread() override; 46 bool RunLocalKernelTask() override; 47 ~ParallelWorker() override; 48 void RunOtherPoolTask(ParallelTask *p_task); 49 void WaitOtherPoolTask(); 50 void DirectRunOtherPoolTask(); 51 void ActivateByOtherPoolTask(ParallelTask *task = nullptr); 52 53 protected: 54 void WaitUntilActive() override; 55 56 private: 57 void ParallelRun(); 58 bool RunQueueActorTask(); 59 ParallelThreadPool *parallel_pool_{nullptr}; 60 bool wait_do_other_task_{true}; 61 ParallelTask *other_task_ = nullptr; 62 std::mutex other_task_mutex_; 63 std::condition_variable cv_other_task_; 64 bool enable_shared_thread_pool_ = false; 65 }; 66 67 class MS_CORE_API ParallelThreadPool : public ActorThreadPool { 68 public: 69 static ParallelThreadPool *CreateThreadPool(size_t actor_thread_num, size_t all_thread_num, 70 const std::vector<int> &core_list, BindMode bind_mode, 71 std::string runner_id = ""); ~ParallelThreadPool()72 ~ParallelThreadPool() override { 73 MS_LOG(INFO) << "free parallel thread pool."; 74 // wait until actor queue is empty 75 bool terminate = false; 76 int count = 0; 77 do { 78 { 79 #ifdef USE_HQUEUE 80 terminate = actor_queue_.Empty(); 81 #else 82 std::lock_guard<std::mutex> _l(actor_mutex_); 83 terminate = actor_queue_.empty(); 84 #endif 85 } 86 if (!terminate) { 87 ActiveWorkers(); 88 std::this_thread::yield(); 89 } 90 } while (!terminate && count++ < kMaxCount); 91 MS_LOG(INFO) << "Wait for all worker to delete."; 92 for (auto &worker : workers_) { 93 delete worker; 94 worker = nullptr; 95 } 96 MS_LOG(INFO) << "delete workers."; 97 workers_.clear(); 98 tasks_size_ = 0; 99 if (tasks_) { 100 delete[] tasks_; 101 } 102 } 103 104 int ParallelLaunch(const Func &func, Content content, int task_num) override; 105 PushActorToQueue(ActorBase * actor)106 void PushActorToQueue(ActorBase *actor) override { 107 if (!actor) { 108 return; 109 } 110 { 111 #ifdef USE_HQUEUE 112 while (!actor_queue_.Enqueue(actor)) { 113 } 114 #else 115 std::lock_guard<std::mutex> _l(actor_mutex_); 116 actor_queue_.push(actor); 117 #endif 118 } 119 THREAD_DEBUG("actor[%s] enqueue success", actor->GetAID().Name().c_str()); 120 size_t size = workers_.size() > tasks_size_ ? tasks_size_ : workers_.size(); 121 for (size_t i = 0; i < size; i++) { 122 if (!enable_shared_) { 123 workers_[i]->Active(); 124 } else { 125 static_cast<ParallelWorker *>(workers_[i])->ActivateByOtherPoolTask(nullptr); 126 } 127 } 128 } 129 130 inline bool RunTaskOnce(int start, int end); 131 132 bool RunParallel(); 133 tasks_size()134 size_t tasks_size() const { return tasks_size_; } 135 136 bool IsIdlePool(); 137 138 int GetPoolRef(); 139 140 bool SetRunnerID(const std::string &runner_id) override; 141 142 std::vector<ParallelWorker *> GetParallelPoolWorkers(); 143 144 void UseThreadPool(int num); 145 GetPoolBindRunnerID()146 inline std::string GetPoolBindRunnerID() { return bind_runner_id_; } 147 GetEnableShared()148 inline bool GetEnableShared() { return enable_shared_; } 149 150 private: ParallelThreadPool()151 ParallelThreadPool() {} 152 int CreateParallelThreads(size_t actor_thread_num, size_t all_thread_num, const std::vector<int> &core_list); 153 154 std::atomic_int tasks_start_{0}; 155 std::atomic_int tasks_end_{0}; 156 ParallelTask *tasks_; 157 size_t tasks_size_ = 0; 158 bool enable_shared_ = false; 159 std::string bind_runner_id_; 160 std::mutex mutex_pool_ref_count_; 161 std::atomic_int pool_ref_count_{0}; 162 int thread_num_; 163 }; 164 } // namespace mindspore 165 #endif // MINDSPORE_CORE_MINDRT_RUNTIME_PARALLEL_THREADPOOL_H_ 166