• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
18 #include "pipeline/jit/ps/debug/trace.h"
19 #include "utils/symbolic.h"
20 #include "utils/compile_config.h"
21 #include "include/common/debug/common.h"
22 #include "pipeline/jit/ps/base.h"
23 #include "include/common/utils/utils.h"
24 
25 namespace mindspore {
26 namespace abstract {
27 namespace {
28 constexpr auto kStateStop = "Stop";
29 }  // namespace
30 thread_local std::string AnalysisSchedule::thread_id_ = "m";
31 
Schedule()32 void AnalysisSchedule::Schedule() {
33   const auto checkPeriod = std::chrono::seconds(3);
34   while (run_ || infer_thread_count_.load() > 0) {
35     std::unique_lock<std::mutex> lock(activate_thread_lock_);
36     auto ok = activate_thread_cv_.wait_for(lock, checkPeriod,
37                                            [this] { return activate_threads_.empty() && !schedule_list_.empty(); });
38     if (ok) {
39       SetNextReady();
40     }
41   }
42   MS_LOG(DEBUG) << "Success to exit.";
43 }
44 
YieldTask(AsyncInferTask * async_infer_task)45 void AnalysisSchedule::YieldTask(AsyncInferTask *async_infer_task) {
46   MS_EXCEPTION_IF_NULL(async_infer_task);
47   {
48     std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
49     if (async_infer_task->ready() == 0) {
50       MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << thread_id()
51                     << " async_infer_task thread id:" << async_infer_task->thread_id();
52       (void)activate_threads_.erase(thread_id());
53     }
54   }
55   activate_thread_cv_.notify_one();
56 }
57 
HandleException(const std::exception & ex)58 void AnalysisSchedule::HandleException(const std::exception &ex) {
59   // Just record the first exception information.
60   if (!StaticAnalysisException::Instance().HasException()) {
61     StaticAnalysisException::Instance().SetException();
62 
63     // If python Exception, record the eval stack.
64     if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
65       try {
66         MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
67         std::ostringstream exceptionStream;
68         exceptionStream << ex.what() << std::endl;
69         trace::GetTraceStackInfo(exceptionStream);
70         if (!trace::GetCNodeDebugStack().empty()) {
71           MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
72         }
73       } catch (const std::exception &e) {
74         // Ignored.
75       }
76     }
77   }
78   // Free all the locks. Let all the threads continue to run.
79   std::lock_guard<std::mutex> lock(activate_thread_lock_);
80   for (auto &item : schedule_list_) {
81     MS_EXCEPTION_IF_NULL(item);
82     item->SetException();
83   }
84   schedule_list_.clear();
85   // The global primitive evaluate cache should be cleared,
86   // Since it may contains invalid results when exception raised.
87   auto &prim_eval_cache = AnalysisResultCacheMgr::GetInstance().prim_eval_cache();
88   if (prim_eval_cache != nullptr) {
89     prim_eval_cache->Clear();
90   }
91 }
92 
Stop()93 void AnalysisSchedule::Stop() {
94   AsyncInferTaskPtr stop_task = AsyncInferTask::MakeShared(std::make_shared<AsyncAbstract>(), kStateStop);
95   Add2Schedule(stop_task);
96   if (dispatcher_ != nullptr && dispatcher_->joinable()) {
97     try {
98       dispatcher_->join();
99     } catch (const std::exception &e) {
100       MS_LOG(WARNING) << " Analysis schedule thread join exception:" << e.what();
101     }
102   }
103   MS_LOG(DEBUG) << "Set analysis schedule to stop";
104 }
105 
Wait()106 void AnalysisSchedule::Wait() {
107   EnterWaiting();
108   if (infer_thread_count_.load() > 0) {
109     py::gil_scoped_release infer_gil_release;
110     MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " waiting.";
111     std::unique_lock<std::mutex> lock(infer_thread_lock_);
112     infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
113   }
114   MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " active.";
115   if (infer_thread_count_.load() < 0) {
116     MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
117   }
118   MS_LOG(DEBUG) << "Infer finished.";
119   StaticAnalysisException::Instance().CheckException();
120 }
121 
WaitForRun() const122 void AnalysisSchedule::WaitForRun() const {
123   // Control the order to run.
124   AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
125   control_run_order->set_result(std::make_shared<AbstractScalar>(1));
126   AsyncInferTaskPtr async_task = AsyncInferTask::MakeShared(control_run_order);
127   AnalysisSchedule::GetInstance().Add2Schedule(async_task);
128   (void)async_task->GetResult();
129 }
130 
Add2Schedule(const AsyncInferTaskPtr & async_infer_task_ptr)131 void AnalysisSchedule::Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr) {
132   std::lock_guard<std::mutex> lock(activate_thread_lock_);
133   MS_EXCEPTION_IF_NULL(async_infer_task_ptr);
134   schedule_list_.push_back(async_infer_task_ptr);
135   activate_thread_cv_.notify_one();
136   MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->thread_id() << " address: " << async_infer_task_ptr.get()
137                 << " The active thread count: " << activate_threads_.size()
138                 << " The infer_thread_count: " << infer_thread_count_
139                 << " schedule list size: " << schedule_list_.size();
140 }
141 
SetNextReady()142 void AnalysisSchedule::SetNextReady() {
143   if (schedule_list_.empty()) {
144     return;
145   }
146   // Exit Flag
147   if (schedule_list_.front() != nullptr && schedule_list_.front()->thread_id() == kStateStop) {
148     run_ = false;
149     schedule_list_.pop_front();
150     return;
151   }
152   // Check if enter endless loop
153   auto it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
154     MS_EXCEPTION_IF_NULL(item);
155     return item->HasResult();
156   });
157   while (it == schedule_list_.end()) {
158     if (IntToSize(infer_thread_count_.load()) > schedule_list_.size()) {
159       MS_LOG(DEBUG) << "There is some task to be added. Please wait. "
160                     << " infer_count: " << infer_thread_count_.load() << " schedule: " << schedule_list_.size();
161       return;
162     }
163 
164     (void)std::for_each(schedule_list_.begin(), schedule_list_.end(),
165                         [](const auto &item) { MS_LOG(DEBUG) << "Leave infer thread: " << item->thread_id(); });
166     if (enable_waiting_branch_eval()) {
167       // Try to set one of possible result in the case of ignore value.
168       auto possible_it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
169         MS_EXCEPTION_IF_NULL(item);
170         return item->SetPossibleResult(true);
171       });
172       if (possible_it != schedule_list_.end()) {
173         MS_EXCEPTION_IF_NULL(*possible_it);
174         MS_LOG(DEBUG) << "Try to set one branch result from the other branch, ignore value: true, infer thread: "
175                       << (*possible_it)->thread_id() << " , result: " << (*possible_it)->HasResult();
176         it = possible_it;
177         break;
178       }
179       // Try to set one of possible result.
180       possible_it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
181         MS_EXCEPTION_IF_NULL(item);
182         return item->SetPossibleResult(false);
183       });
184       if (possible_it != schedule_list_.end()) {
185         MS_EXCEPTION_IF_NULL(*possible_it);
186         MS_LOG(DEBUG) << "Try to set one branch result from the other branch, ignore value: false, infer thread: "
187                       << (*possible_it)->thread_id() << " , result: " << (*possible_it)->HasResult();
188         it = possible_it;
189         break;
190       }
191     }
192     MS_EXCEPTION_IF_NULL(schedule_list_.front());
193     // Enter endless loop if there is not ready result.
194     (void)activate_threads_.insert(schedule_list_.front()->thread_id());
195     // Let the first thread to trigger endless loop exception.
196     MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
197                   << schedule_list_.front().get() << " The active thread count: " << activate_threads_.size();
198     schedule_list_.front()->SetEndLessLoopException();
199     schedule_list_.pop_front();
200     return;
201   }
202 
203   auto async_task = *it;
204   MS_EXCEPTION_IF_NULL(async_task);
205   (void)activate_threads_.insert(async_task->thread_id());
206   async_task->SetReady();
207   (void)schedule_list_.erase(it);
208   MS_LOG(DEBUG) << " Success to SetReady. The active thread count: " << activate_threads_.size()
209                 << " The infer_thread_count: " << infer_thread_count_
210                 << " schedule list size: " << schedule_list_.size() << " async: " << async_task->thread_id()
211                 << "  address: " << async_task.get();
212 }
213 
GetResult()214 AbstractBasePtr AsyncAbstract::GetResult() {
215   ClearPossibleResult();
216   auto async_task = AsyncInferTask::MakeShared(shared_from_this());
217   MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
218   AnalysisSchedule::GetInstance().Add2Schedule(async_task);
219   auto ret = async_task->GetResult();
220   MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
221   return ret;
222 }
ClearPossibleResult()223 void AsyncAbstract::ClearPossibleResult() {
224   std::lock_guard<std::mutex> lock(lock_);
225   if (result_ != nullptr && result_->isa<AsyncAbstractFuncAtom>()) {
226     result_ = nullptr;
227   }
228 }
229 
SetPossibleResult(bool first)230 bool AsyncAbstract::SetPossibleResult(bool first) {
231   std::lock_guard<std::mutex> lock(lock_);
232   bool condition = not_copy_from_other_ && switchAbstract_ != nullptr && switchAbstract_->HasResult();
233   if (first && condition) {
234     condition = switchAbstract_->ignore_value_;
235   }
236   if (condition) {
237     result_ = switchAbstract_->TryGetResult();
238     // Set the result with the other branches abstract
239     // when there are not available branches to infer.
240     // Just copy the type otherwise the two branches would be optimized to a const value.
241     MS_EXCEPTION_IF_NULL(result_->BuildValue());
242     if (!result_->BuildValue()->isa<ValueAny>()) {
243       result_ = AbstractBroaden(result_);
244     }
245     if (NeedWaitForBranches(result_)) {
246       result_ = AsyncAbstractFuncAtom::MakeShared(shared_from_this(), std::vector<size_t>{0});
247     }
248     not_copy_from_other_ = false;
249     return true;
250   }
251   return false;
252 }
253 
254 namespace {
GetAbstractFuncRecursively(const AbstractBasePtr & abs,const std::vector<std::size_t> & index,const std::size_t offset)255 AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const std::vector<std::size_t> &index,
256                                                const std::size_t offset) {
257   MS_EXCEPTION_IF_NULL(abs);
258   if (abs->isa<AbstractFuncAtom>()) {
259     return abs->cast<AbstractFuncAtomPtr>();
260   }
261   if (abs->isa<AbstractSequence>()) {
262     auto abs_seq = abs->cast_ptr<AbstractSequence>();
263     MS_EXCEPTION_IF_NULL(abs_seq);
264     const auto &elements = abs_seq->elements();
265     if (offset >= index.size()) {
266       MS_LOG(INTERNAL_EXCEPTION) << "Offset " << offset << " is greater than or equal to vector size: " << index.size();
267     }
268     if (index[offset] >= elements.size()) {
269       MS_LOG(INTERNAL_EXCEPTION) << "At offset" << offset
270                                  << ", elements size of AsyncAbstract result: " << abs->ToString()
271                                  << " is less than or equal to index: " << index[offset];
272     }
273     auto resolved = GetAbstractFuncRecursively(elements[index[offset]], index, offset + 1);
274     MS_EXCEPTION_IF_NULL(resolved);
275     if (!resolved->isa<AbstractFuncAtom>()) {
276       MS_LOG(INTERNAL_EXCEPTION) << "AsyncAbstract result cannot be resolved to AbstractFuncAtom, but: "
277                                  << resolved->ToString();
278     }
279     MS_LOG(DEBUG) << "Return abstract: " << resolved->ToString();
280     return resolved;
281   }
282   MS_LOG(INTERNAL_EXCEPTION) << "AsyncAbstract cannot resolved to AbstractFuncAtom or AbstractSeqeunce, but: "
283                              << abs->ToString();
284 }
285 }  // namespace
NeedWaitForBranches(const AbstractBasePtr & abstract)286 bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
287   MS_EXCEPTION_IF_NULL(abstract);
288   if (abstract->isa<AbstractFunction>()) {
289     return true;
290   }
291   if (abstract->isa<AbstractSequence>()) {
292     auto seq = abstract->cast_ptr<AbstractSequence>();
293     MS_EXCEPTION_IF_NULL(seq);
294     auto elements = seq->elements();
295     if (std::any_of(elements.begin(), elements.end(),
296                     [](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
297       return true;
298     }
299   }
300   return false;
301 }
302 
GetUnique()303 AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
304   if (resolved_ != nullptr) {
305     return resolved_;
306   }
307   // Release GIL for C++;
308   py::gil_scoped_release infer_gil_release;
309 
310   MS_LOG(DEBUG) << "Try to GetResult from async_abstract: " << async_abstract_->ToString();
311   const auto &result = async_abstract_->GetResult();
312   resolved_ = GetAbstractFuncRecursively(result, index_, 0);
313   return resolved_;
314 }
315 
ToString() const316 std::string AsyncAbstractFuncAtom::ToString() const {
317   if (resolved_ == nullptr) {
318     return "AsyncAbstractFuncAtom(Not Resolved)";
319   }
320 
321   std::ostringstream buffer;
322   buffer << "AsyncAbstractFuncAtom(";
323   buffer << resolved_->ToString();
324   buffer << ")";
325 
326   return buffer.str();
327 }
328 
Clear()329 void AnalysisResultCacheMgr::Clear() {
330   prim_eval_cache_->Clear();
331   std::lock_guard<std::mutex> lock(lock_);
332   cache_.clear();
333   switch_cache_.clear();
334   switch_cache_for_check_.clear();
335 }
336 
InitSwitchValue(const AnfNodeConfigPtr & conf)337 void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
338   std::lock_guard<std::mutex> lock(lock_);
339   AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
340   if (async_eval_result == nullptr) {
341     async_eval_result = std::make_shared<AsyncAbstract>();
342     switch_cache_.set(conf, async_eval_result);
343   }
344 }
345 
GetSwitchValue(const AnfNodeConfigPtr & conf)346 AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
347   // don't call lock_.lock(). switch_cache is protected. and it waits for result.
348   AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
349   if (async_eval_result == nullptr) {
350     return nullptr;
351   }
352   return async_eval_result->GetResult();
353 }
354 
SetCacheValue(const AnfNodeConfigPtr & conf,const AbstractBasePtr & current_abs,AnalysisConfigAsyncResultCache * cache)355 void AnalysisResultCacheMgr::SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &current_abs,
356                                            AnalysisConfigAsyncResultCache *cache) {
357   MS_EXCEPTION_IF_NULL(conf);
358   MS_EXCEPTION_IF_NULL(cache);
359   if (current_abs == nullptr) {
360     MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
361   }
362   std::lock_guard<std::mutex> lock(lock_);
363   AsyncAbstractPtr async_eval_result = cache->get(conf);
364   if (async_eval_result == nullptr) {
365     async_eval_result = std::make_shared<AsyncAbstract>();
366     async_eval_result->set_result(current_abs);
367     cache->set(conf, async_eval_result);
368   } else {
369     auto previous_abs = async_eval_result->TryGetResult();
370     AbstractBasePtrList abstract_list;
371     if (previous_abs != nullptr) {
372       abstract_list.push_back(previous_abs);
373       abstract_list.push_back(current_abs);
374       // Join two branches's result
375       MS_EXCEPTION_IF_NULL(conf->node());
376       MS_LOG(DEBUG) << "Join node: " << conf->node()->DebugString() << ", previous_abs: " << previous_abs->ToString()
377                     << ", and current_abs: " << current_abs->ToString();
378       auto joined_result = AnalysisEngine::ProcessEvalResults(abstract_list, conf->node());
379       async_eval_result->set_result(joined_result->abstract());
380     } else {
381       async_eval_result->set_result(current_abs);
382     }
383   }
384 }
385 
CheckSwitchValueJoinable(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)386 void AnalysisResultCacheMgr::CheckSwitchValueJoinable(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
387   SetCacheValue(conf, arg, &switch_cache_for_check_);
388 }
389 
SetSwitchValue(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)390 void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
391   SetCacheValue(conf, arg, &switch_cache_);
392 }
393 
ArgsToString(const AbstractBasePtrList & args_abs_list)394 std::string ArgsToString(const AbstractBasePtrList &args_abs_list) {
395   std::ostringstream buffer;
396   for (const auto &item : args_abs_list) {
397     MS_EXCEPTION_IF_NULL(item);
398     MS_EXCEPTION_IF_NULL(item->BuildType());
399     MS_EXCEPTION_IF_NULL(item->BuildShape());
400     MS_EXCEPTION_IF_NULL(item->BuildValue());
401     buffer << " # " << item->BuildType()->ToString() << ", " << item->BuildShape()->ToString() << ", "
402            << item->BuildValue()->ToString() << "\n";
403   }
404   return buffer.str();
405 }
enable_waiting_branch_eval()406 bool enable_waiting_branch_eval() {
407   static const bool enable_waiting_branch_eval = common::GetCompileConfig("NOT_WAIT_BRANCH_EVAL") != "1";
408   return enable_waiting_branch_eval;
409 }
410 }  // namespace abstract
411 }  // namespace mindspore
412