• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/thread_pool.h"
18 #include <algorithm>
19 #include <exception>
20 #include "utils/log_adapter.h"
21 #include "utils/convert_utils_base.h"
22 #include "utils/ms_exception.h"
23 
24 namespace mindspore {
25 namespace common {
26 #if ENABLE_D || ENABLE_GPU
27 const size_t kDeviceNum = 8;
28 #endif
29 const size_t kMaxThreadNum = 23;
30 
ThreadPool()31 ThreadPool::ThreadPool() {
32   size_t process_core_num = std::thread::hardware_concurrency() - 1;
33   if (process_core_num < 1) {
34     process_core_num = 1;
35   }
36 #if ENABLE_D || ENABLE_GPU
37   max_thread_num_ = process_core_num / kDeviceNum;
38 #else
39   max_thread_num_ = process_core_num;
40 #endif
41   if (max_thread_num_ < 1) {
42     max_thread_num_ = 1;
43   }
44   if (max_thread_num_ > kMaxThreadNum) {
45     max_thread_num_ = kMaxThreadNum;
46   }
47 }
48 
SyncRunLoop()49 void ThreadPool::SyncRunLoop() {
50   while (true) {
51     Task task;
52     {
53       std::unique_lock<std::mutex> lock(task_mutex_);
54       task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
55       if (exit_run_) {
56         return;
57       }
58       task = task_queue_.front();
59       task_queue_.pop();
60     }
61     try {
62       task();
63     } catch (std::exception &e) {
64       MsException::Instance().SetException();
65     }
66     {
67       std::unique_lock<std::mutex> task_lock(task_mutex_);
68       task_finished_count_ = task_finished_count_ + 1;
69     }
70     finished_cond_var_.notify_one();
71   }
72 }
73 
SyncRun(const std::vector<Task> & tasks)74 bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
75   if (tasks.size() == 1) {
76     auto ret = tasks[0]();
77     return ret == SUCCESS;
78   }
79   std::unique_lock<std::mutex> lock(pool_mtx_);
80   exit_run_ = false;
81   size_t task_num = tasks.size();
82   size_t thread_num = sync_run_threads_.size();
83   if (thread_num < max_thread_num_ && thread_num < task_num) {
84     auto new_thread_num = max_thread_num_;
85     if (task_num < max_thread_num_) {
86       new_thread_num = task_num;
87     }
88     for (size_t i = thread_num; i < new_thread_num; ++i) {
89       sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
90     }
91   }
92 
93   for (auto &task : tasks) {
94     std::lock_guard<std::mutex> task_lock(task_mutex_);
95     task_queue_.push(task);
96     task_cond_var_.notify_one();
97   }
98   {
99     std::unique_lock<std::mutex> task_lock(task_mutex_);
100     finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
101     task_finished_count_ = 0;
102   }
103   return true;
104 }
105 
GetInstance()106 ThreadPool &ThreadPool::GetInstance() {
107   static ThreadPool instance{};
108   return instance;
109 }
110 
ClearThreadPool()111 void ThreadPool::ClearThreadPool() {
112   std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
113   if (exit_run_) {
114     return;
115   }
116   exit_run_ = true;
117   task_cond_var_.notify_all();
118   for (auto &it : sync_run_threads_) {
119     if (it.joinable()) {
120       it.join();
121     }
122   }
123   sync_run_threads_.clear();
124 }
125 
~ThreadPool()126 ThreadPool::~ThreadPool() {
127   try {
128     ClearThreadPool();
129   } catch (...) {
130     // exit
131   }
132 }
133 }  // namespace common
134 }  // namespace mindspore
135