• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pipeline/async_rqueue.h"
18 
19 #include <utility>
20 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
21 #include "include/common/utils/signal_util.h"
22 #endif
23 #include "utils/log_adapter.h"
24 #include "utils/ms_exception.h"
25 #include "mindrt/include/fork_utils.h"
26 #include "include/common/profiler.h"
27 
28 #include "utils/profile.h"
29 
30 namespace mindspore {
31 namespace runtime {
32 constexpr size_t kThreadNameThreshold = 15;
33 thread_local kThreadWaitLevel current_level_{kThreadWaitLevel::kLevelUnknown};
34 
~AsyncRQueue()35 AsyncRQueue::~AsyncRQueue() {
36   try {
37     WorkerJoin();
38   } catch (const std::exception &e) {
39     MS_LOG(INFO) << "WorkerJoin failed, error msg:" << e.what();
40   }
41 }
42 
SetThreadName() const43 void AsyncRQueue::SetThreadName() const {
44 // Set thread name for gdb debug
45 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
46   (void)pthread_setname_np(pthread_self(), name_.substr(0, kThreadNameThreshold).c_str());
47 #endif
48 }
49 
WorkerLoop()50 void AsyncRQueue::WorkerLoop() {
51 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
52   // cppcheck-suppress unreadVariable
53   SignalGuard sig([](int, siginfo_t *, void *) {
54     int this_pid = getpid();
55     MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
56     (void)kill(this_pid, SIGTERM);
57   });
58 #endif
59 
60   // Thread init.
61   SetThreadName();
62   runtime::ProfilerAnalyzer::GetInstance().SetThreadIdToName(std::this_thread::get_id(), name_);
63   {
64     // cppcheck-suppress unreadVariable
65     std::unique_lock<std::mutex> lock(level_mutex_);
66     thread_id_to_wait_level_[std::this_thread::get_id()] = wait_level_;
67   }
68 
69   while (true) {
70     std::shared_ptr<AsyncTask> task = tasks_queue_.Head();
71 
72     MS_LOG(DEBUG) << "Get task";
73     MS_EXCEPTION_IF_NULL(task);
74     if (task->task_type() == kExitTask) {
75       tasks_queue_.Dequeue();
76       MS_LOG(DEBUG) << "Thread exit";
77       return;
78     }
79 
80     try {
81       task->Run();
82       tasks_queue_.Dequeue();
83     } catch (const std::exception &e) {
84       MS_LOG(INFO) << "Run task failed, error msg:" << e.what();
85       {
86         MsException::Instance().SetException();
87         // MsException is unreliable because it gets modified everywhere.
88         auto e_ptr = std::current_exception();
89         while (!tasks_queue_.IsEmpty()) {
90           auto &t = tasks_queue_.Head();
91           if (t->task_type() == kExitTask) {
92             break;
93           }
94           t->SetException(e_ptr);
95           tasks_queue_.Dequeue();
96         }
97       }
98     }
99   }
100 }
101 
Push(const AsyncTaskPtr & task)102 void AsyncRQueue::Push(const AsyncTaskPtr &task) {
103   if (worker_ == nullptr) {
104     worker_ = std::make_unique<std::thread>(&AsyncRQueue::WorkerLoop, this);
105   }
106 
107   if (current_level_ == kThreadWaitLevel::kLevelUnknown) {
108     // cppcheck-suppress unreadVariable
109     std::unique_lock<std::mutex> lock(level_mutex_);
110     current_level_ = thread_id_to_wait_level_[std::this_thread::get_id()];
111   }
112 
113   if (current_level_ >= wait_level_) {
114     MS_LOG(EXCEPTION) << "Cannot push task from thread " << current_level_ << " to queue " << wait_level_;
115   }
116   tasks_queue_.Enqueue(task);
117 }
118 
Wait()119 void AsyncRQueue::Wait() {
120   if (worker_ == nullptr) {
121     return;
122   }
123   if (current_level_ == kThreadWaitLevel::kLevelUnknown) {
124     // cppcheck-suppress unreadVariable
125     std::unique_lock<std::mutex> lock(level_mutex_);
126     current_level_ = thread_id_to_wait_level_[std::this_thread::get_id()];
127   }
128 
129   if (current_level_ >= wait_level_) {
130     MS_LOG(DEBUG) << "No need to wait, current level " << current_level_ << " AsyncQueue name " << name_;
131     // Only need to wait the low level thread.
132     return;
133   }
134 
135   MS_LOG(DEBUG) << "Start to wait thread " << name_;
136   while (!tasks_queue_.IsEmpty()) {
137   }
138   MsException::Instance().CheckException();
139   MS_LOG(DEBUG) << "End to wait thread " << name_;
140 }
141 
Empty()142 bool AsyncRQueue::Empty() { return tasks_queue_.IsEmpty(); }
143 
Clear()144 void AsyncRQueue::Clear() {
145   {
146     if (tasks_queue_.IsEmpty()) {
147       return;
148     }
149 
150     ClearTaskWithException();
151 
152     // Avoid to push task after WorkerJoin.
153     if (worker_ != nullptr && worker_->joinable()) {
154       auto task = std::make_shared<WaitTask>();
155       tasks_queue_.Enqueue(task);
156     }
157   }
158   // There is still one task in progress
159   Wait();
160 }
161 
Reset()162 void AsyncRQueue::Reset() {
163   {
164     if (tasks_queue_.IsEmpty()) {
165       return;
166     }
167 
168     ClearTaskWithException();
169     MS_LOG(DEBUG) << "Reset AsyncQueue";
170   }
171 }
172 
ClearTaskWithException()173 void AsyncRQueue::ClearTaskWithException() {
174   while (!tasks_queue_.IsEmpty()) {
175     auto &t = tasks_queue_.Head();
176     t->SetException(std::make_exception_ptr(std::runtime_error("Clean up tasks that are not yet running")));
177     tasks_queue_.Dequeue();
178   }
179 }
180 
WorkerJoin()181 void AsyncRQueue::WorkerJoin() {
182   try {
183     if (worker_ == nullptr) {
184       return;
185     }
186     // Avoid worker thread join itself which will cause deadlock
187     if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
188       {
189         auto task = std::make_shared<ExitTask>();
190         tasks_queue_.Enqueue(task);
191         MS_LOG(DEBUG) << "Push exit task and notify all";
192       }
193       worker_->join();
194       MS_LOG(DEBUG) << "Worker join finish";
195       MsException::Instance().CheckException();
196     }
197   } catch (const std::exception &e) {
198     MS_LOG(ERROR) << "WorkerJoin failed: " << e.what();
199   } catch (...) {
200     MS_LOG(ERROR) << "WorkerJoin failed";
201   }
202 }
203 
ChildAfterFork()204 void AsyncRQueue::ChildAfterFork() {
205   MS_LOG(DEBUG) << "AsyncQueue reinitialize after fork";
206   if (worker_ != nullptr) {
207     MS_LOG(DEBUG) << "Release and recreate worker_.";
208     (void)worker_.release();
209     worker_ = std::make_unique<std::thread>(&AsyncRQueue::WorkerLoop, this);
210   }
211   MS_LOG(DEBUG) << "AsyncQueue reinitialize after fork done.";
212 }
213 }  // namespace runtime
214 }  // namespace mindspore
215