1 /**
2 * Copyright 2023-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/pipeline/async_rqueue.h"
18
19 #include <utility>
20 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
21 #include "include/common/utils/signal_util.h"
22 #endif
23 #include "utils/log_adapter.h"
24 #include "utils/ms_exception.h"
25 #include "mindrt/include/fork_utils.h"
26 #include "include/common/profiler.h"
27
28 #include "utils/profile.h"
29
30 namespace mindspore {
31 namespace runtime {
32 constexpr size_t kThreadNameThreshold = 15;
33 thread_local kThreadWaitLevel current_level_{kThreadWaitLevel::kLevelUnknown};
34
~AsyncRQueue()35 AsyncRQueue::~AsyncRQueue() {
36 try {
37 WorkerJoin();
38 } catch (const std::exception &e) {
39 MS_LOG(INFO) << "WorkerJoin failed, error msg:" << e.what();
40 }
41 }
42
SetThreadName() const43 void AsyncRQueue::SetThreadName() const {
44 // Set thread name for gdb debug
45 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
46 (void)pthread_setname_np(pthread_self(), name_.substr(0, kThreadNameThreshold).c_str());
47 #endif
48 }
49
WorkerLoop()50 void AsyncRQueue::WorkerLoop() {
51 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
52 // cppcheck-suppress unreadVariable
53 SignalGuard sig([](int, siginfo_t *, void *) {
54 int this_pid = getpid();
55 MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
56 (void)kill(this_pid, SIGTERM);
57 });
58 #endif
59
60 // Thread init.
61 SetThreadName();
62 runtime::ProfilerAnalyzer::GetInstance().SetThreadIdToName(std::this_thread::get_id(), name_);
63 {
64 // cppcheck-suppress unreadVariable
65 std::unique_lock<std::mutex> lock(level_mutex_);
66 thread_id_to_wait_level_[std::this_thread::get_id()] = wait_level_;
67 }
68
69 while (true) {
70 std::shared_ptr<AsyncTask> task = tasks_queue_.Head();
71
72 MS_LOG(DEBUG) << "Get task";
73 MS_EXCEPTION_IF_NULL(task);
74 if (task->task_type() == kExitTask) {
75 tasks_queue_.Dequeue();
76 MS_LOG(DEBUG) << "Thread exit";
77 return;
78 }
79
80 try {
81 task->Run();
82 tasks_queue_.Dequeue();
83 } catch (const std::exception &e) {
84 MS_LOG(INFO) << "Run task failed, error msg:" << e.what();
85 {
86 MsException::Instance().SetException();
87 // MsException is unreliable because it gets modified everywhere.
88 auto e_ptr = std::current_exception();
89 while (!tasks_queue_.IsEmpty()) {
90 auto &t = tasks_queue_.Head();
91 if (t->task_type() == kExitTask) {
92 break;
93 }
94 t->SetException(e_ptr);
95 tasks_queue_.Dequeue();
96 }
97 }
98 }
99 }
100 }
101
Push(const AsyncTaskPtr & task)102 void AsyncRQueue::Push(const AsyncTaskPtr &task) {
103 if (worker_ == nullptr) {
104 worker_ = std::make_unique<std::thread>(&AsyncRQueue::WorkerLoop, this);
105 }
106
107 if (current_level_ == kThreadWaitLevel::kLevelUnknown) {
108 // cppcheck-suppress unreadVariable
109 std::unique_lock<std::mutex> lock(level_mutex_);
110 current_level_ = thread_id_to_wait_level_[std::this_thread::get_id()];
111 }
112
113 if (current_level_ >= wait_level_) {
114 MS_LOG(EXCEPTION) << "Cannot push task from thread " << current_level_ << " to queue " << wait_level_;
115 }
116 tasks_queue_.Enqueue(task);
117 }
118
Wait()119 void AsyncRQueue::Wait() {
120 if (worker_ == nullptr) {
121 return;
122 }
123 if (current_level_ == kThreadWaitLevel::kLevelUnknown) {
124 // cppcheck-suppress unreadVariable
125 std::unique_lock<std::mutex> lock(level_mutex_);
126 current_level_ = thread_id_to_wait_level_[std::this_thread::get_id()];
127 }
128
129 if (current_level_ >= wait_level_) {
130 MS_LOG(DEBUG) << "No need to wait, current level " << current_level_ << " AsyncQueue name " << name_;
131 // Only need to wait the low level thread.
132 return;
133 }
134
135 MS_LOG(DEBUG) << "Start to wait thread " << name_;
136 while (!tasks_queue_.IsEmpty()) {
137 }
138 MsException::Instance().CheckException();
139 MS_LOG(DEBUG) << "End to wait thread " << name_;
140 }
141
Empty()142 bool AsyncRQueue::Empty() { return tasks_queue_.IsEmpty(); }
143
Clear()144 void AsyncRQueue::Clear() {
145 {
146 if (tasks_queue_.IsEmpty()) {
147 return;
148 }
149
150 ClearTaskWithException();
151
152 // Avoid to push task after WorkerJoin.
153 if (worker_ != nullptr && worker_->joinable()) {
154 auto task = std::make_shared<WaitTask>();
155 tasks_queue_.Enqueue(task);
156 }
157 }
158 // There is still one task in progress
159 Wait();
160 }
161
Reset()162 void AsyncRQueue::Reset() {
163 {
164 if (tasks_queue_.IsEmpty()) {
165 return;
166 }
167
168 ClearTaskWithException();
169 MS_LOG(DEBUG) << "Reset AsyncQueue";
170 }
171 }
172
ClearTaskWithException()173 void AsyncRQueue::ClearTaskWithException() {
174 while (!tasks_queue_.IsEmpty()) {
175 auto &t = tasks_queue_.Head();
176 t->SetException(std::make_exception_ptr(std::runtime_error("Clean up tasks that are not yet running")));
177 tasks_queue_.Dequeue();
178 }
179 }
180
WorkerJoin()181 void AsyncRQueue::WorkerJoin() {
182 try {
183 if (worker_ == nullptr) {
184 return;
185 }
186 // Avoid worker thread join itself which will cause deadlock
187 if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
188 {
189 auto task = std::make_shared<ExitTask>();
190 tasks_queue_.Enqueue(task);
191 MS_LOG(DEBUG) << "Push exit task and notify all";
192 }
193 worker_->join();
194 MS_LOG(DEBUG) << "Worker join finish";
195 MsException::Instance().CheckException();
196 }
197 } catch (const std::exception &e) {
198 MS_LOG(ERROR) << "WorkerJoin failed: " << e.what();
199 } catch (...) {
200 MS_LOG(ERROR) << "WorkerJoin failed";
201 }
202 }
203
ChildAfterFork()204 void AsyncRQueue::ChildAfterFork() {
205 MS_LOG(DEBUG) << "AsyncQueue reinitialize after fork";
206 if (worker_ != nullptr) {
207 MS_LOG(DEBUG) << "Release and recreate worker_.";
208 (void)worker_.release();
209 worker_ = std::make_unique<std::thread>(&AsyncRQueue::WorkerLoop, this);
210 }
211 MS_LOG(DEBUG) << "AsyncQueue reinitialize after fork done.";
212 }
213 } // namespace runtime
214 } // namespace mindspore
215