1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "common/thread_pool.h" 18 #include <algorithm> 19 #include <exception> 20 #include "utils/log_adapter.h" 21 #include "utils/convert_utils_base.h" 22 #include "utils/ms_exception.h" 23 24 namespace mindspore { 25 namespace common { 26 #if ENABLE_D || ENABLE_GPU 27 const size_t kDeviceNum = 8; 28 #endif 29 const size_t kMaxThreadNum = 23; 30 ThreadPool()31ThreadPool::ThreadPool() { 32 size_t process_core_num = std::thread::hardware_concurrency() - 1; 33 if (process_core_num < 1) { 34 process_core_num = 1; 35 } 36 #if ENABLE_D || ENABLE_GPU 37 max_thread_num_ = process_core_num / kDeviceNum; 38 #else 39 max_thread_num_ = process_core_num; 40 #endif 41 if (max_thread_num_ < 1) { 42 max_thread_num_ = 1; 43 } 44 if (max_thread_num_ > kMaxThreadNum) { 45 max_thread_num_ = kMaxThreadNum; 46 } 47 } 48 SyncRunLoop()49void ThreadPool::SyncRunLoop() { 50 while (true) { 51 Task task; 52 { 53 std::unique_lock<std::mutex> lock(task_mutex_); 54 task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; }); 55 if (exit_run_) { 56 return; 57 } 58 task = task_queue_.front(); 59 task_queue_.pop(); 60 } 61 try { 62 task(); 63 } catch (std::exception &e) { 64 MsException::Instance().SetException(); 65 } 66 { 67 std::unique_lock<std::mutex> task_lock(task_mutex_); 68 task_finished_count_ = task_finished_count_ + 1; 69 } 70 finished_cond_var_.notify_one(); 71 } 72 } 73 SyncRun(const std::vector<Task> & tasks)74bool ThreadPool::SyncRun(const std::vector<Task> &tasks) { 75 if (tasks.size() == 1) { 76 auto ret = tasks[0](); 77 return ret == SUCCESS; 78 } 79 std::unique_lock<std::mutex> lock(pool_mtx_); 80 exit_run_ = false; 81 size_t task_num = tasks.size(); 82 size_t thread_num = sync_run_threads_.size(); 83 if (thread_num < max_thread_num_ && thread_num < task_num) { 84 auto new_thread_num = max_thread_num_; 85 if (task_num < max_thread_num_) { 86 new_thread_num = task_num; 87 } 88 for (size_t i = thread_num; i < new_thread_num; ++i) { 89 sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); 90 } 91 } 92 93 for (auto &task : tasks) { 94 std::lock_guard<std::mutex> task_lock(task_mutex_); 95 task_queue_.push(task); 96 task_cond_var_.notify_one(); 97 } 98 { 99 std::unique_lock<std::mutex> task_lock(task_mutex_); 100 finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; }); 101 task_finished_count_ = 0; 102 } 103 return true; 104 } 105 GetInstance()106ThreadPool &ThreadPool::GetInstance() { 107 static ThreadPool instance{}; 108 return instance; 109 } 110 ClearThreadPool()111void ThreadPool::ClearThreadPool() { 112 std::lock_guard<std::mutex> sync_run_lock(pool_mtx_); 113 if (exit_run_) { 114 return; 115 } 116 exit_run_ = true; 117 task_cond_var_.notify_all(); 118 for (auto &it : sync_run_threads_) { 119 if (it.joinable()) { 120 it.join(); 121 } 122 } 123 sync_run_threads_.clear(); 124 } 125 ~ThreadPool()126ThreadPool::~ThreadPool() { 127 try { 128 ClearThreadPool(); 129 } catch (...) { 130 // exit 131 } 132 } 133 } // namespace common 134 } // namespace mindspore 135