• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/static_analysis/async_eval_result.h"
18 #include <debug/trace.h>
19 #include "utils/symbolic.h"
20 #include "debug/common.h"
21 #include "pipeline/jit/base.h"
22 #include "utils/utils.h"
23 
24 namespace mindspore {
25 namespace abstract {
26 AnalysisSchedule AnalysisSchedule::instance_;
27 
Schedule()28 void AnalysisSchedule::Schedule() {
29   const auto checkPeriod = std::chrono::seconds(3);
30   while (notExit_ || infer_thread_count_.load() > 0) {
31     std::unique_lock<std::mutex> lock(activate_thread_lock_);
32     if (activate_threads_.size() > 1) {
33       MS_LOG(ERROR) << "There is something wrong."
34                     << " The active thread count: " << activate_threads_.size()
35                     << " The infer_thread_count: " << infer_thread_count_
36                     << " schedule list size: " << scheduleList_.size();
37     }
38 
39     auto ok = activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return activate_threads_.empty(); });
40     if (ok && (!SetNextReady())) {
41       // If schedule list is empty, wait.
42       (void)activate_thread_cv_.wait_for(lock, checkPeriod, [this] { return !scheduleList_.empty(); });
43     }
44   }
45   MS_LOG(DEBUG) << "Success to exit. The active thread count: " << activate_threads_.size()
46                 << " The infer_thread_count: " << infer_thread_count_
47                 << " schedule list size: " << scheduleList_.size();
48 }
49 
Yield(const AsyncInferTask * async_infer_task)50 void AnalysisSchedule::Yield(const AsyncInferTask *async_infer_task) {
51   {
52     std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
53     // Double check ready()
54     if (async_infer_task->Ready() == 0) {
55       MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size() << " thread id: " << GetThreadID()
56                     << " async_infer_task thread id:" << async_infer_task->ThreadID();
57       (void)activate_threads_.erase(GetThreadID());
58     }
59     MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
60                   << " The infer_thread_count: " << infer_thread_count_
61                   << " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
62                   << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
63   }
64   activate_thread_cv_.notify_one();
65 }
66 
HandleException(const std::exception & ex)67 void AnalysisSchedule::HandleException(const std::exception &ex) {
68   // Just record the first exception information.
69   if (!StaticAnalysisException::Instance().HasException()) {
70     StaticAnalysisException::Instance().SetException();
71 
72     // If python Exception, record the eval stack.
73     if (dynamic_cast<const py::error_already_set *>(&ex) != nullptr) {
74       try {
75         MS_LOG(DEBUG) << "Python exception happened, check the information as below.";
76         std::ostringstream exceptionStream;
77         trace::GetTraceStackInfo(exceptionStream);
78         if (!exceptionStream.str().empty()) {
79           MS_LOG(ERROR) << "Exception happened, check the information as below.\n" << exceptionStream.str();
80         }
81       } catch (const std::exception &e) {
82         // Ignored.
83       }
84     }
85   }
86   // Free all the locks. Let all the threads continue to run.
87   std::lock_guard<std::mutex> lock(activate_thread_lock_);
88   for (auto &item : scheduleList_) {
89     item->SetException();
90   }
91   scheduleList_.clear();
92 }
93 
Wait()94 void AnalysisSchedule::Wait() {
95   EnterWaiting();
96   if (infer_thread_count_.load() > 0) {
97     py::gil_scoped_release infer_gil_release;
98     std::unique_lock<std::mutex> lock(infer_thread_lock_);
99     infer_thread_cv_.wait(lock, [this] { return infer_thread_count_.load() <= 0; });
100   }
101   if (infer_thread_count_.load() < 0) {
102     MS_LOG(ERROR) << "There is something wrong. thread count: " << infer_thread_count_;
103   }
104   if (IS_OUTPUT_ON(DEBUG)) {
105     AnalysisResultCacheMgr::GetInstance().Todo();
106   }
107   MS_LOG(INFO) << "Infer finished.";
108   StaticAnalysisException::Instance().CheckException();
109 }
110 
Add2Schedule(const AsyncInferTaskPtr & async_infer_task_ptr)111 void AnalysisSchedule::Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr) {
112   std::lock_guard<std::mutex> lock(activate_thread_lock_);
113   MS_EXCEPTION_IF_NULL(async_infer_task_ptr);
114   scheduleList_.push_back(async_infer_task_ptr);
115   activate_thread_cv_.notify_one();
116   MS_LOG(DEBUG) << " async: " << async_infer_task_ptr->ThreadID() << " address: " << async_infer_task_ptr.get()
117                 << " The active thread count: " << activate_threads_.size()
118                 << " The infer_thread_count: " << infer_thread_count_
119                 << " schedule list size: " << scheduleList_.size();
120 }
SetNextReady()121 bool AnalysisSchedule::SetNextReady() {
122   if (scheduleList_.empty()) {
123     MS_LOG(DEBUG) << "The schedule list is empty. ";
124     return false;
125   }
126   // Check if enter endless loop
127   auto it = std::find_if(scheduleList_.begin(), scheduleList_.end(), [](const auto &item) {
128     MS_EXCEPTION_IF_NULL(item);
129     return item->HasResult();
130   });
131   if (it == scheduleList_.end()) {
132     if (IntToSize(infer_thread_count_.load()) >= scheduleList_.size()) {
133       MS_LOG(DEBUG) << "There is some task to be added. Please wait.";
134       return false;
135     }
136     MS_LOG(WARNING) << "Enter endless loop. The active thread count: " << activate_threads_.size()
137                     << " The infer_thread_count: " << infer_thread_count_
138                     << " schedule list size: " << scheduleList_.size();
139     // Enter endless loop if there is not ready result.
140     (void)activate_threads_.insert(scheduleList_.front()->ThreadID());
141     // Let the first thread to trigger endless loop exception.
142     MS_LOG(DEBUG) << "Enter endless loop if there is not ready result.Set the async to trigger exception:"
143                   << scheduleList_.front().get() << " The active thread count: " << activate_threads_.size();
144     scheduleList_.front()->SetEndLessLoopException();
145     scheduleList_.pop_front();
146     return true;
147   }
148   auto async_task = *it;
149   (void)activate_threads_.insert(async_task->ThreadID());
150   async_task->SetReady();
151   (void)scheduleList_.erase(it);
152   MS_LOG(DEBUG) << " Success to SetReady. The active thread count: " << activate_threads_.size()
153                 << " The infer_thread_count: " << infer_thread_count_ << " schedule list size: " << scheduleList_.size()
154                 << " async: " << async_task->ThreadID() << "  address: " << async_task.get();
155 
156   return true;
157 }
158 // The thread id format is XXXX.YYYY.ZZZZ
159 thread_local std::string localThreadID = "1";
SetThreadID(const std::string & threadID)160 void AnalysisSchedule::SetThreadID(const std::string &threadID) { localThreadID = threadID; }
161 
GetThreadID()162 std::string &AnalysisSchedule::GetThreadID() { return localThreadID; }
163 
164 AnalysisResultCacheMgr AnalysisResultCacheMgr::instance_;
165 
Clear()166 void AnalysisResultCacheMgr::Clear() {
167   std::lock_guard<std::mutex> lock(lock_);
168   cache_.clear();
169   switch_cache_.clear();
170   todo_.clear();
171 }
172 
PushTodo(const AnfNodeConfigPtr & conf)173 void AnalysisResultCacheMgr::PushTodo(const AnfNodeConfigPtr &conf) {
174   std::lock_guard<std::mutex> lock(todo_lock_);
175   todo_.push_back(conf);
176 }
177 
InitSwitchValue(const AnfNodeConfigPtr & conf)178 void AnalysisResultCacheMgr::InitSwitchValue(const AnfNodeConfigPtr &conf) {
179   std::lock_guard<std::mutex> lock(lock_);
180   AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
181   if (async_eval_result == nullptr) {
182     async_eval_result = std::make_shared<AsyncAbstract>();
183     switch_cache_.set(conf, async_eval_result);
184   }
185 }
186 
TryGetSwitchValue(const AnfNodeConfigPtr & conf)187 AbstractBasePtr AnalysisResultCacheMgr::TryGetSwitchValue(const AnfNodeConfigPtr &conf) {
188   // don't call lock_.lock(). switch_cache is protected. and it waits for result.
189   AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
190   // Conf has been visited and set value.
191   if (async_eval_result != nullptr) {
192     return async_eval_result->TryGetResult();
193   }
194   return nullptr;
195 }
196 
GetSwitchValue(const AnfNodeConfigPtr & conf)197 AbstractBasePtr AnalysisResultCacheMgr::GetSwitchValue(const AnfNodeConfigPtr &conf) {
198   StaticAnalysisException::Instance().CheckException();
199   // don't call lock_.lock(). switch_cache is protected. and it waits for result.
200   AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
201   // Conf has been visited and set value.
202   if (async_eval_result != nullptr) {
203     // Add to schedule
204     auto async_infer_task = AsyncInferTask::MakeShared(async_eval_result);
205     MS_LOG(DEBUG) << " add to schedule: " << async_infer_task.get();
206     AnalysisSchedule::GetInstance().Add2Schedule(async_infer_task);
207     // Maybe blocked for waiting. AsyncAbstract maybe null, if time out.
208     auto result = async_infer_task->GetResult();
209     if (result == nullptr) {
210       result = std::make_shared<AbstractTimeOut>();
211       MS_LOG(ERROR) << "AsyncAbstract of NodeConfig " << conf->node()->ToString()
212                     << " is nullptr. There is something wrong.";
213       StaticAnalysisException::Instance().CheckException();
214     }
215     return result;
216   }
217   return nullptr;
218 }
219 
SetSwitchValue(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)220 void AnalysisResultCacheMgr::SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
221   MS_EXCEPTION_IF_NULL(conf);
222   if (arg == nullptr) {
223     MS_LOG(EXCEPTION) << conf->ToString() << " value is nullptr";
224   }
225   std::lock_guard<std::mutex> lock(lock_);
226   AsyncAbstractPtr async_eval_result = switch_cache_.get(conf);
227   if (async_eval_result == nullptr) {
228     async_eval_result = std::make_shared<AsyncAbstract>();
229     async_eval_result->SetResult(arg);
230     switch_cache_.set(conf, async_eval_result);
231   } else {
232     auto ab1 = async_eval_result->TryGetResult();
233     AbstractBasePtrList absList;
234     if (ab1 != nullptr) {
235       absList.push_back(arg);
236       absList.push_back(ab1);
237       // Join two branches's result
238       auto joined_result = AnalysisEngine::ProcessEvalResults(absList, conf->node());
239       async_eval_result->SetResult(joined_result->abstract());
240       if (!(*joined_result == *ab1)) {
241         PushTodo(conf);
242       }
243     } else {
244       async_eval_result->SetResult(arg);
245     }
246   }
247 }
248 
Todo()249 void AnalysisResultCacheMgr::Todo() {
250   std::lock_guard<std::mutex> lock(todo_lock_);
251   while (!todo_.empty()) {
252     AnfNodeConfigPtr conf = todo_.front();
253     MS_EXCEPTION_IF_NULL(conf);
254     todo_.pop_front();
255     if (GetValue(conf) == nullptr) {
256       MS_LOG(INFO) << conf->node()->ToString() << " not in globle cache.";
257       continue;
258     }
259     if (TryGetSwitchValue(conf) == nullptr) {
260       MS_LOG(INFO) << conf->node()->ToString() << " not in switch cache";
261       continue;
262     }
263     auto switch_value = TryGetSwitchValue(conf);
264     auto abstract = GetValue(conf)->abstract();
265     MS_EXCEPTION_IF_NULL(switch_value);
266     MS_EXCEPTION_IF_NULL(abstract);
267     if (!(*abstract == *switch_value)) {
268       MS_LOG(WARNING) << " Switch Value is not eq. "
269                       << " switchCache: " << switch_value->ToString() << " globleCache: " << abstract->ToString()
270                       << "\t\tConf: " << conf->ToString();
271     }
272   }
273 }
274 
ArgsToString(const AbstractBasePtrList & args_spec_list)275 std::string ArgsToString(const AbstractBasePtrList &args_spec_list) {
276   std::ostringstream buffer;
277   buffer << "(";
278   for (const auto &item : args_spec_list) {
279     buffer << item->ToString() << " # ";
280   }
281   buffer << " )";
282   return buffer.str();
283 }
284 }  // namespace abstract
285 }  // namespace mindspore
286