1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
18 #include "pipeline/jit/ps/debug/trace.h"
19 #include "utils/symbolic.h"
20 #include "utils/compile_config.h"
21 #include "include/common/debug/common.h"
22 #include "pipeline/jit/ps/base.h"
23 #include "include/common/utils/utils.h"
24
25 namespace mindspore {
26 namespace abstract {
27 namespace {
28 constexpr auto kStateStop = "Stop";
29 } // namespace
30 thread_local std::string AnalysisSchedule::thread_id_ = "m";
31
Schedule()32 void AnalysisSchedule::Schedule() {
33 const auto checkPeriod = std::chrono::seconds(3);
34 while (run_ || infer_thread_count_.load() > 0) {
35 std::unique_lock<std::mutex> lock(activate_thread_lock_);
36 auto ok = activate_thread_cv_.wait_for(lock, checkPeriod,
37 [this] { return activate_threads_.empty() && !schedule_list_.empty(); });
38 if (ok) {
39 SetNextReady();
40 }
41 }
42 MS_LOG(DEBUG) << "Success to exit.";
43 }
44
YieldTask(AsyncInferTask * async_infer_task)45 void AnalysisSchedule::YieldTask(AsyncInferTask *async_infer_task) {
46 MS_EXCEPTION_IF_NULL(async_infer_task);
47 {
48 std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
49 if (async_infer_task->ready() == 0) {
50 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << thread_id()
51 << " async_infer_task thread id:" << async_infer_task->thread_id();
52 (void)activate_threads_.erase(thread_id());
53 }
54 }
55 activate_thread_cv_.notify_one();
56 }
57
HandleException(const std::exception & ex)58 void AnalysisSchedule::HandleException(const std::exception &ex) {
59 // Just record the first exception information.
60 if (!StaticAnalysisException::Instance().HasException()) {
61 StaticAnalysisException::Instance().SetException();
62
63 // If python Exception, record the eval stack.
64 if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
65 try {
66 MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
67 std::ostringstream exceptionStream;
68 exceptionStream << ex.what() << std::endl;
69 trace::GetTraceStackInfo(exceptionStream);
70 if (!trace::GetCNodeDebugStack().empty()) {
71 MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
72 }
73 } catch (const std::exception &e) {
74 // Ignored.
75 }
76 }
77 }
78 // Free all the locks. Let all the threads continue to run.
79 std::lock_guard<std::mutex> lock(activate_thread_lock_);
80 for (auto &item : schedule_list_) {
81 MS_EXCEPTION_IF_NULL(item);
82 item->SetException();
83 }
84 schedule_list_.clear();
85 // The global primitive evaluate cache should be cleared,
86 // Since it may contains invalid results when exception raised.
87 auto &prim_eval_cache = AnalysisResultCacheMgr::GetInstance().prim_eval_cache();
88 if (prim_eval_cache != nullptr) {
89 prim_eval_cache->Clear();
90 }
91 }
92
Stop()93 void AnalysisSchedule::Stop() {
94 AsyncInferTaskPtr stop_task = AsyncInferTask::MakeShared(std::make_shared<AsyncAbstract>(), kStateStop);
95 Add2Schedule(stop_task);
96 if (dispatcher_ != nullptr && dispatcher_->joinable()) {
97 try {
98 dispatcher_->join();
99 } catch (const std::exception &e) {
100 MS_LOG(WARNING) << " Analysis schedule thread join exception:" << e.what();
101 }
102 }
103 MS_LOG(DEBUG) << "Set analysis schedule to stop";
104 }
105
Wait()106 void AnalysisSchedule::Wait() {
107 EnterWaiting();
108 if (infer_thread_count_.load() > 0) {
109 py::gil_scoped_release infer_gil_release;
110 MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " waiting.";
111 std::unique_lock<std::mutex> lock(infer_thread_lock_);
112 infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
113 }
114 MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " active.";
115 if (infer_thread_count_.load() < 0) {
116 MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
117 }
118 MS_LOG(DEBUG) << "Infer finished.";
119 StaticAnalysisException::Instance().CheckException();
120 }
121
WaitForRun() const122 void AnalysisSchedule::WaitForRun() const {
123 // Control the order to run.
124 AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
125 control_run_order->set_result(std::make_shared<AbstractScalar>(1));
126 AsyncInferTaskPtr async_task = AsyncInferTask::MakeShared(control_run_order);
127 AnalysisSchedule::GetInstance().Add2Schedule(async_task);
128 (void)async_task->GetResult();
129 }
130
Add2Schedule(const AsyncInferTaskPtr & async_infer_task_ptr)131 void AnalysisSchedule::Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr) {
132 std::lock_guard<std::mutex> lock(activate_thread_lock_);
133 MS_EXCEPTION_IF_NULL(async_infer_task_ptr);
134 schedule_list_.push_back(async_infer_task_ptr);
135 activate_thread_cv_.notify_one();
136 MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->thread_id() << " address: " << async_infer_task_ptr.get()
137 << " The active thread count: " << activate_threads_.size()
138 << " The infer_thread_count: " << infer_thread_count_
139 << " schedule list size: " << schedule_list_.size();
140 }
141
SetNextReady()142 void AnalysisSchedule::SetNextReady() {
143 if (schedule_list_.empty()) {
144 return;
145 }
146 // Exit Flag
147 if (schedule_list_.front() != nullptr && schedule_list_.front()->thread_id() == kStateStop) {
148 run_ = false;
149 schedule_list_.pop_front();
150 return;
151 }
152 // Check if enter endless loop
153 auto it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
154 MS_EXCEPTION_IF_NULL(item);
155 return item->HasResult();
156 });
157 while (it == schedule_list_.end()) {
158 if (IntToSize(infer_thread_count_.load()) > schedule_list_.size()) {
159 MS_LOG(DEBUG) << "There is some task to be added. Please wait. "
160 << " infer_count: " << infer_thread_count_.load() << " schedule: " << schedule_list_.size();
161 return;
162 }
163
164 (void)std::for_each(schedule_list_.begin(), schedule_list_.end(),
165 [](const auto &item) { MS_LOG(DEBUG) << "Leave infer thread: " << item->thread_id(); });
166 if (enable_waiting_branch_eval()) {
167 // Try to set one of possible result in the case of ignore value.
168 auto possible_it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
169 MS_EXCEPTION_IF_NULL(item);
170 return item->SetPossibleResult(true);
171 });
172 if (possible_it != schedule_list_.end()) {
173 MS_EXCEPTION_IF_NULL(*possible_it);
174 MS_LOG(DEBUG) << "Try to set one branch result from the other branch, ignore value: true, infer thread: "
175 << (*possible_it)->thread_id() << " , result: " << (*possible_it)->HasResult();
176 it = possible_it;
177 break;
178 }
179 // Try to set one of possible result.
180 possible_it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
181 MS_EXCEPTION_IF_NULL(item);
182 return item->SetPossibleResult(false);
183 });
184 if (possible_it != schedule_list_.end()) {
185 MS_EXCEPTION_IF_NULL(*possible_it);
186 MS_LOG(DEBUG) << "Try to set one branch result from the other branch, ignore value: false, infer thread: "
187 << (*possible_it)->thread_id() << " , result: " << (*possible_it)->HasResult();
188 it = possible_it;
189 break;
190 }
191 }
192 MS_EXCEPTION_IF_NULL(schedule_list_.front());
193 // Enter endless loop if there is not ready result.
194 (void)activate_threads_.insert(schedule_list_.front()->thread_id());
195 // Let the first thread to trigger endless loop exception.
196 MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
197 << schedule_list_.front().get() << " The active thread count: " << activate_threads_.size();
198 schedule_list_.front()->SetEndLessLoopException();
199 schedule_list_.pop_front();
200 return;
201 }
202
203 auto async_task = *it;
204 MS_EXCEPTION_IF_NULL(async_task);
205 (void)activate_threads_.insert(async_task->thread_id());
206 async_task->SetReady();
207 (void)schedule_list_.erase(it);
208 MS_LOG(DEBUG) << " Success to SetReady. The active thread count: " << activate_threads_.size()
209 << " The infer_thread_count: " << infer_thread_count_
210 << " schedule list size: " << schedule_list_.size() << " async: " << async_task->thread_id()
211 << " address: " << async_task.get();
212 }
213
GetResult()214 AbstractBasePtr AsyncAbstract::GetResult() {
215 ClearPossibleResult();
216 auto async_task = AsyncInferTask::MakeShared(shared_from_this());
217 MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
218 AnalysisSchedule::GetInstance().Add2Schedule(async_task);
219 auto ret = async_task->GetResult();
220 MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
221 return ret;
222 }
ClearPossibleResult()223 void AsyncAbstract::ClearPossibleResult() {
224 std::lock_guard<std::mutex> lock(lock_);
225 if (result_ != nullptr && result_->isa<AsyncAbstractFuncAtom>()) {
226 result_ = nullptr;
227 }
228 }
229
SetPossibleResult(bool first)230 bool AsyncAbstract::SetPossibleResult(bool first) {
231 std::lock_guard<std::mutex> lock(lock_);
232 bool condition = not_copy_from_other_ && switchAbstract_ != nullptr && switchAbstract_->HasResult();
233 if (first && condition) {
234 condition = switchAbstract_->ignore_value_;
235 }
236 if (condition) {
237 result_ = switchAbstract_->TryGetResult();
238 // Set the result with the other branches abstract
239 // when there are not available branches to infer.
240 // Just copy the type otherwise the two branches would be optimized to a const value.
241 MS_EXCEPTION_IF_NULL(result_->BuildValue());
242 if (!result_->BuildValue()->isa<ValueAny>()) {
243 result_ = AbstractBroaden(result_);
244 }
245 if (NeedWaitForBranches(result_)) {
246 result_ = AsyncAbstractFuncAtom::MakeShared(shared_from_this(), std::vector<size_t>{0});
247 }
248 not_copy_from_other_ = false;
249 return true;
250 }
251 return false;
252 }
253
254 namespace {
GetAbstractFuncRecursively(const AbstractBasePtr & abs,const std::vector<std::size_t> & index,const std::size_t offset)255 AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const std::vector<std::size_t> &index,
256 const std::size_t offset) {
257 MS_EXCEPTION_IF_NULL(abs);
258 if (abs->isa<AbstractFuncAtom>()) {
259 return abs->cast<AbstractFuncAtomPtr>();
260 }
261 if (abs->isa<AbstractSequence>()) {
262 auto abs_seq = abs->cast_ptr<AbstractSequence>();
263 MS_EXCEPTION_IF_NULL(abs_seq);
264 const auto &elements = abs_seq->elements();
265 if (offset >= index.size()) {
266 MS_LOG(INTERNAL_EXCEPTION) << "Offset " << offset << " is greater than or equal to vector size: " << index.size();
267 }
268 if (index[offset] >= elements.size()) {
269 MS_LOG(INTERNAL_EXCEPTION) << "At offset" << offset
270 << ", elements size of AsyncAbstract result: " << abs->ToString()
271 << " is less than or equal to index: " << index[offset];
272 }
273 auto resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1);
274 MS_EXCEPTION_IF_NULL(resolved);
275 if (!resolved->isa<AbstractFuncAtom>()) {
276 MS_LOG(INTERNAL_EXCEPTION) << "AsyncAbstract result cannot be resolved to AbstractFuncAtom, but: "
277 << resolved->ToString();
278 }
279 MS_LOG(DEBUG) << "Return abstract: " << resolved->ToString();
280 return resolved;
281 }
282 MS_LOG(INTERNAL_EXCEPTION) << "AsyncAbstract cannot resolved to AbstractFuncAtom or AbstractSeqeunce, but: "
283 << abs->ToString();
284 }
285 } // namespace
NeedWaitForBranches(const AbstractBasePtr & abstract)286 bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
287 MS_EXCEPTION_IF_NULL(abstract);
288 if (abstract->isa<AbstractFunction>()) {
289 return true;
290 }
291 if (abstract->isa<AbstractSequence>()) {
292 auto seq = abstract->cast_ptr<AbstractSequence>();
293 MS_EXCEPTION_IF_NULL(seq);
294 auto elements = seq->elements();
295 if (std::any_of(elements.begin(), elements.end(),
296 [](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
297 return true;
298 }
299 }
300 return false;
301 }
302
GetUnique()303 AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
304 if (resolved_ != nullptr) {
305 return resolved_;
306 }
307 // Release GIL for C++;
308 py::gil_scoped_release infer_gil_release;
309
310 MS_LOG(DEBUG) << "Try to GetResult from async_abstract: " << async_abstract_->ToString();
311 const auto &result = async_abstract_->GetResult();
312 resolved_ = GetAbstractFuncRecursively(result, index_, 0);
313 return resolved_;
314 }
315
ToString() const316 std::string AsyncAbstractFuncAtom::ToString() const {
317 if (resolved_ == nullptr) {
318 return "AsyncAbstractFuncAtom(Not Resolved)";
319 }
320
321 std::ostringstream buffer;
322 buffer << "AsyncAbstractFuncAtom(";
323 buffer << resolved_->ToString();
324 buffer << ")";
325
326 return buffer.str();
327 }
328
Clear()329 void AnalysisResultCacheMgr::Clear() {
330 prim_eval_cache_->Clear();
331 std::lock_guard<std::mutex> lock(lock_);
332 cache_.clear();
333 switch_cache_.clear();
334 switch_cache_for_check_.clear();
335 }
336
InitSwitchValue(const AnfNodeConfigPtr & conf)337 void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
338 std::lock_guard<std::mutex> lock(lock_);
339 AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
340 if (async_eval_result == nullptr) {
341 async_eval_result = std::make_shared<AsyncAbstract>();
342 switch_cache_.set(conf, async_eval_result);
343 }
344 }
345
GetSwitchValue(const AnfNodeConfigPtr & conf)346 AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
347 // don't call lock_.lock(). switch_cache is protected. and it waits for result.
348 AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
349 if (async_eval_result == nullptr) {
350 return nullptr;
351 }
352 return async_eval_result->GetResult();
353 }
354
SetCacheValue(const AnfNodeConfigPtr & conf,const AbstractBasePtr & current_abs,AnalysisConfigAsyncResultCache * cache)355 void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr ¤t_abs,
356 AnalysisConfigAsyncResultCache *cache) {
357 MS_EXCEPTION_IF_NULL(conf);
358 MS_EXCEPTION_IF_NULL(cache);
359 if (current_abs == nullptr) {
360 MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
361 }
362 std::lock_guard<std::mutex> lock(lock_);
363 AsyncAbstractPtr async_eval_result = cache->get(conf);
364 if (async_eval_result == nullptr) {
365 async_eval_result = std::make_shared<AsyncAbstract>();
366 async_eval_result->set_result(current_abs);
367 cache->set(conf, async_eval_result);
368 } else {
369 auto previous_abs = async_eval_result->TryGetResult();
370 AbstractBasePtrList abstract_list;
371 if (previous_abs != nullptr) {
372 abstract_list.push_back(previous_abs);
373 abstract_list.push_back(current_abs);
374 // Join two branches's result
375 MS_EXCEPTION_IF_NULL(conf->node());
376 MS_LOG(DEBUG) << "Join node: " << conf->node()->DebugString() << ", previous_abs: " << previous_abs->ToString()
377 << ", and current_abs: " << current_abs->ToString();
378 auto joined_result = AnalysisEngine::ProcessEvalResults(abstract_list, conf->node());
379 async_eval_result->set_result(joined_result->abstract());
380 } else {
381 async_eval_result->set_result(current_abs);
382 }
383 }
384 }
385
CheckSwitchValueJoinable(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)386 void AnalysisResultCacheMgr::CheckSwitchValueJoinable(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
387 SetCacheValue(conf, arg, &switch_cache_for_check_);
388 }
389
SetSwitchValue(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)390 void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
391 SetCacheValue(conf, arg, &switch_cache_);
392 }
393
ArgsToString(const AbstractBasePtrList & args_abs_list)394 std::string ArgsToString(const AbstractBasePtrList &args_abs_list) {
395 std::ostringstream buffer;
396 for (const auto &item : args_abs_list) {
397 MS_EXCEPTION_IF_NULL(item);
398 MS_EXCEPTION_IF_NULL(item->BuildType());
399 MS_EXCEPTION_IF_NULL(item->BuildShape());
400 MS_EXCEPTION_IF_NULL(item->BuildValue());
401 buffer << " # " << item->BuildType()->ToString() << ", " << item->BuildShape()->ToString() << ", "
402 << item->BuildValue()->ToString() << "\n";
403 }
404 return buffer.str();
405 }
enable_waiting_branch_eval()406 bool enable_waiting_branch_eval() {
407 static const bool enable_waiting_branch_eval = common::GetCompileConfig("NOT_WAIT_BRANCH_EVAL") != "1";
408 return enable_waiting_branch_eval;
409 }
410 } // namespace abstract
411 } // namespace mindspore
412