1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/static_analysis/async_eval_result.h"
18 #include <debug/trace.h>
19 #include "utils/symbolic.h"
20 #include "debug/common.h"
21 #include "pipeline/jit/base.h"
22 #include "utils/utils.h"
23
24 namespace mindspore {
25 namespace abstract {
26 AnalysisSchedule AnalysisSchedule::instance_;
27
Schedule()28 void AnalysisSchedule::Schedule() {
29 const auto checkPeriod = std::chrono::seconds(3);
30 while (notExit_ || infer_thread_count_.load() > 0) {
31 std::unique_lock<std::mutex> lock(activate_thread_lock_);
32 if (activate_threads_.size() > 1) {
33 MS_LOG(ERROR) << "There is something wrong."
34 << " The active thread count: " << activate_threads_.size()
35 << " The infer_thread_count: " << infer_thread_count_
36 << " schedule list size: " << scheduleList_.size();
37 }
38
39 auto ok = activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return activate_threads_.empty(); });
40 if (ok && (!SetNextReady())) {
41 // If schedule list is empty, wait.
42 (void)activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return !scheduleList_.empty(); });
43 }
44 }
45 MS_LOG(DEBUG) << "Success to exit. The active thread count: " << activate_threads_.size()
46 << " The infer_thread_count: " << infer_thread_count_
47 << " schedule list size: " << scheduleList_.size();
48 }
49
Yield(const AsyncInferTask * async_infer_task)50 void AnalysisSchedule::Yield(const AsyncInferTask *async_infer_task) {
51 {
52 std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
53 // Double check ready()
54 if (async_infer_task->Ready() == 0) {
55 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << GetThreadID()
56 << " async_infer_task thread id:" << async_infer_task->ThreadID();
57 (void)activate_threads_.erase(GetThreadID());
58 }
59 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
60 << " The infer_thread_count: " << infer_thread_count_
61 << " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
62 << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
63 }
64 activate_thread_cv_.notify_one();
65 }
66
HandleException(const std::exception & ex)67 void AnalysisSchedule::HandleException(const std::exception &ex) {
68 // Just record the first exception information.
69 if (!StaticAnalysisException::Instance().HasException()) {
70 StaticAnalysisException::Instance().SetException();
71
72 // If python Exception, record the eval stack.
73 if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
74 try {
75 MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
76 std::ostringstream exceptionStream;
77 trace::GetTraceStackInfo(exceptionStream);
78 if (!exceptionStream.str().empty()) {
79 MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
80 }
81 } catch (const std::exception &e) {
82 // Ignored.
83 }
84 }
85 }
86 // Free all the locks. Let all the threads continue to run.
87 std::lock_guard<std::mutex> lock(activate_thread_lock_);
88 for (auto &item : scheduleList_) {
89 item->SetException();
90 }
91 scheduleList_.clear();
92 }
93
Wait()94 void AnalysisSchedule::Wait() {
95 EnterWaiting();
96 if (infer_thread_count_.load() > 0) {
97 py::gil_scoped_release infer_gil_release;
98 std::unique_lock<std::mutex> lock(infer_thread_lock_);
99 infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
100 }
101 if (infer_thread_count_.load() < 0) {
102 MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
103 }
104 if (IS_OUTPUT_ON(DEBUG)) {
105 AnalysisResultCacheMgr::GetInstance().Todo();
106 }
107 MS_LOG(INFO) << "Infer finished.";
108 StaticAnalysisException::Instance().CheckException();
109 }
110
Add2Schedule(const AsyncInferTaskPtr & async_infer_task_ptr)111 void AnalysisSchedule::Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr) {
112 std::lock_guard<std::mutex> lock(activate_thread_lock_);
113 MS_EXCEPTION_IF_NULL(async_infer_task_ptr);
114 scheduleList_.push_back(async_infer_task_ptr);
115 activate_thread_cv_.notify_one();
116 MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->ThreadID() << " address: " << async_infer_task_ptr.get()
117 << " The active thread count: " << activate_threads_.size()
118 << " The infer_thread_count: " << infer_thread_count_
119 << " schedule list size: " << scheduleList_.size();
120 }
SetNextReady()121 bool AnalysisSchedule::SetNextReady() {
122 if (scheduleList_.empty()) {
123 MS_LOG(DEBUG) << "The schedule list is empty. ";
124 return false;
125 }
126 // Check if enter endless loop
127 auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
128 MS_EXCEPTION_IF_NULL(item);
129 return item->HasResult();
130 });
131 if (it == scheduleList_.end()) {
132 if (IntToSize(infer_thread_count_.load()) >= scheduleList_.size()) {
133 MS_LOG(DEBUG) << "There is some task to be added. Please wait.";
134 return false;
135 }
136 MS_LOG(WARNING) << "Enter endless loop. The active thread count: " << activate_threads_.size()
137 << " The infer_thread_count: " << infer_thread_count_
138 << " schedule list size: " << scheduleList_.size();
139 // Enter endless loop if there is not ready result.
140 (void)activate_threads_.insert(scheduleList_.front()->ThreadID());
141 // Let the first thread to trigger endless loop exception.
142 MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
143 << scheduleList_.front().get() << " The active thread count: " << activate_threads_.size();
144 scheduleList_.front()->SetEndLessLoopException();
145 scheduleList_.pop_front();
146 return true;
147 }
148 auto async_task = *it;
149 (void)activate_threads_.insert(async_task->ThreadID());
150 async_task->SetReady();
151 (void)scheduleList_.erase(it);
152 MS_LOG(DEBUG) << " Success to SetReady. The active thread count: " << activate_threads_.size()
153 << " The infer_thread_count: " << infer_thread_count_ << " schedule list size: " << scheduleList_.size()
154 << " async: " << async_task->ThreadID() << " address: " << async_task.get();
155
156 return true;
157 }
158 // The thread id format is XXXX.YYYY.ZZZZ
159 thread_local std::string localThreadID = "1";
SetThreadID(const std::string & threadID)160 void AnalysisSchedule::SetThreadID(const std::string &threadID) { localThreadID = threadID; }
161
GetThreadID()162 std::string &AnalysisSchedule::GetThreadID() { return localThreadID; }
163
164 AnalysisResultCacheMgr AnalysisResultCacheMgr::instance_;
165
Clear()166 void AnalysisResultCacheMgr::Clear() {
167 std::lock_guard<std::mutex> lock(lock_);
168 cache_.clear();
169 switch_cache_.clear();
170 todo_.clear();
171 }
172
PushTodo(const AnfNodeConfigPtr & conf)173 void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
174 std::lock_guard<std::mutex> lock(todo_lock_);
175 todo_.push_back(conf);
176 }
177
InitSwitchValue(const AnfNodeConfigPtr & conf)178 void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
179 std::lock_guard<std::mutex> lock(lock_);
180 AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
181 if (async_eval_result == nullptr) {
182 async_eval_result = std::make_shared<AsyncAbstract>();
183 switch_cache_.set(conf, async_eval_result);
184 }
185 }
186
TryGetSwitchValue(const AnfNodeConfigPtr & conf)187 AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) {
188 // don't call lock_.lock(). switch_cache is protected. and it waits for result.
189 AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
190 // Conf has been visited and set value.
191 if (async_eval_result != nullptr) {
192 return async_eval_result->TryGetResult();
193 }
194 return nullptr;
195 }
196
GetSwitchValue(const AnfNodeConfigPtr & conf)197 AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
198 StaticAnalysisException::Instance().CheckException();
199 // don't call lock_.lock(). switch_cache is protected. and it waits for result.
200 AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
201 // Conf has been visited and set value.
202 if (async_eval_result != nullptr) {
203 // Add to schedule
204 auto async_infer_task = AsyncInferTask::MakeShared(async_eval_result);
205 MS_LOG(DEBUG) << " add to schedule: " << async_infer_task.get();
206 AnalysisSchedule::GetInstance().Add2Schedule(async_infer_task);
207 // Maybe blocked for waiting. AsyncAbstract maybe null, if time out.
208 auto result = async_infer_task->GetResult();
209 if (result == nullptr) {
210 result = std::make_shared<AbstractTimeOut>();
211 MS_LOG(ERROR) << "AsyncAbstract of NodeConfig " << conf->node()->ToString()
212 << " is nullptr. There is something wrong.";
213 StaticAnalysisException::Instance().CheckException();
214 }
215 return result;
216 }
217 return nullptr;
218 }
219
SetSwitchValue(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)220 void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
221 MS_EXCEPTION_IF_NULL(conf);
222 if (arg == nullptr) {
223 MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
224 }
225 std::lock_guard<std::mutex> lock(lock_);
226 AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
227 if (async_eval_result == nullptr) {
228 async_eval_result = std::make_shared<AsyncAbstract>();
229 async_eval_result->SetResult(arg);
230 switch_cache_.set(conf, async_eval_result);
231 } else {
232 auto ab1 = async_eval_result->TryGetResult();
233 AbstractBasePtrList absList;
234 if (ab1 != nullptr) {
235 absList.push_back(arg);
236 absList.push_back(ab1);
237 // Join two branches's result
238 auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
239 async_eval_result->SetResult(joined_result->abstract());
240 if (!(*joined_result == *ab1)) {
241 PushTodo(conf);
242 }
243 } else {
244 async_eval_result->SetResult(arg);
245 }
246 }
247 }
248
Todo()249 void AnalysisResultCacheMgr::Todo() {
250 std::lock_guard<std::mutex> lock(todo_lock_);
251 while (!todo_.empty()) {
252 AnfNodeConfigPtr conf = todo_.front();
253 MS_EXCEPTION_IF_NULL(conf);
254 todo_.pop_front();
255 if (GetValue(conf) == nullptr) {
256 MS_LOG(INFO) << conf->node()->ToString() << " not in globle cache.";
257 continue;
258 }
259 if (TryGetSwitchValue(conf) == nullptr) {
260 MS_LOG(INFO) << conf->node()->ToString() << " not in switch cache";
261 continue;
262 }
263 auto switch_value = TryGetSwitchValue(conf);
264 auto abstract = GetValue(conf)->abstract();
265 MS_EXCEPTION_IF_NULL(switch_value);
266 MS_EXCEPTION_IF_NULL(abstract);
267 if (!(*abstract == *switch_value)) {
268 MS_LOG(WARNING) << " Switch Value is not eq. "
269 << " switchCache: " << switch_value->ToString() << " globleCache: " << abstract->ToString()
270 << "\t\tConf: " << conf->ToString();
271 }
272 }
273 }
274
ArgsToString(const AbstractBasePtrList & args_spec_list)275 std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
276 std::ostringstream buffer;
277 buffer << "(";
278 for (const auto &item : args_spec_list) {
279 buffer << item->ToString() << " # ";
280 }
281 buffer << " )";
282 return buffer.str();
283 }
284 } // namespace abstract
285 } // namespace mindspore
286