1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
19
20 #include <iostream>
21 #include <utility>
22 #include <future>
23 #include <thread>
24 #include <memory>
25 #include <unordered_map>
26 #include <vector>
27 #include <string>
28 #include <functional>
29 #include <list>
30 #include <set>
31 #include <fstream>
32 #include <chrono>
33
34 #include "pipeline/jit/static_analysis/static_analysis.h"
35
36 namespace mindspore {
37 namespace abstract {
38
39 class AsyncInferTask;
40 class AsyncAbstract;
41 using AsyncInferTaskPtr = std::shared_ptr<AsyncInferTask>;
42 using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
43 class AnalysisSchedule {
44 public:
~AnalysisSchedule()45 ~AnalysisSchedule() { Stop(); }
46 AnalysisSchedule(const AnalysisSchedule &) = delete;
47 AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
GetInstance()48 static AnalysisSchedule &GetInstance() { return instance_; }
49 static void SetThreadID(const std::string &caller);
50 static std::string &GetThreadID();
51 void HandleException(const std::exception &ex);
Stop()52 void Stop() {
53 notExit_ = false;
54 MS_LOG(DEBUG) << " Set AnalysisSchedule::Exit . The active thread count: " << activate_threads_.size()
55 << " The infer_thread_count: " << infer_thread_count_
56 << " schedule list size: " << scheduleList_.size();
57 }
58 void Wait();
59 void Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr);
60 void Yield(const AsyncInferTask *asyncTask);
61
EnterWaiting()62 void EnterWaiting() {
63 {
64 MS_LOG(DEBUG) << " Require activate_thread_lock. The active thread count: " << activate_threads_.size()
65 << " The infer_thread_count: " << infer_thread_count_
66 << " schedule list size: " << scheduleList_.size();
67 std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
68 activate_threads_.clear();
69 MS_LOG(DEBUG) << " Get activate_thread_lock. The active thread count: " << activate_threads_.size()
70 << " The infer_thread_count: " << infer_thread_count_
71 << " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
72 << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
73 }
74 activate_thread_cv_.notify_one();
75 }
76
IncreaseThreadCount()77 void IncreaseThreadCount() {
78 infer_thread_count_.fetch_add(1);
79 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
80 << " The infer_thread_count: " << infer_thread_count_
81 << " schedule list size: " << scheduleList_.size();
82 }
83
DecreaseThreadCount()84 void DecreaseThreadCount() {
85 {
86 std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
87 infer_thread_count_.fetch_sub(1);
88 }
89 infer_thread_cv_.notify_one();
90
91 {
92 std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
93 activate_threads_.clear();
94 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
95 << " The infer_thread_count: " << infer_thread_count_
96 << " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
97 << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
98 }
99 activate_thread_cv_.notify_one();
100 }
101
102 private:
103 void Schedule();
104 bool SetNextReady();
Start()105 void Start() {
106 auto thread = std::thread([this] { Schedule(); });
107 thread.detach();
108 }
AnalysisSchedule()109 AnalysisSchedule() { Start(); }
110 static AnalysisSchedule instance_;
111 std::atomic<int> infer_thread_count_{0};
112 bool notExit_{true};
113 std::mutex infer_thread_lock_;
114 std::condition_variable infer_thread_cv_;
115 std::mutex activate_thread_lock_;
116 std::condition_variable activate_thread_cv_;
117 std::list<AsyncInferTaskPtr> scheduleList_;
118 std::set<std::string> activate_threads_;
119 };
120
121 template <typename KeyType, typename ValueType, typename CacheType>
122 class MultiThreadCache {
123 public:
124 using iterator = typename CacheType::iterator;
125 using const_iterator = typename CacheType::const_iterator;
126
get(const KeyType & key)127 ValueType get(const KeyType &key) {
128 std::lock_guard<std::mutex> lock(lock_);
129 auto it = cache_.find(key);
130 if (it != cache_.end()) {
131 return it->second;
132 }
133 return nullptr;
134 }
135
set(const KeyType & key,const ValueType & data)136 void set(const KeyType &key, const ValueType &data) {
137 std::lock_guard<std::mutex> lock(lock_);
138 cache_[key] = data;
139 }
140
clear()141 void clear() {
142 std::lock_guard<std::mutex> lock(lock_);
143 cache_.clear();
144 }
145
size()146 size_t size() { return cache_.size(); }
147
empty()148 bool empty() { return size() == 0; }
149
dump()150 std::string dump() {
151 std::ostringstream buf;
152 for (auto &item : cache_) {
153 buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
154 }
155 return buf.str();
156 }
157
begin()158 iterator begin() { return cache_.begin(); }
end()159 iterator end() { return cache_.end(); }
160
begin()161 const_iterator begin() const { return cache_.cbegin(); }
end()162 const_iterator end() const { return cache_.cend(); }
163
cbegin()164 const_iterator cbegin() const { return cache_.cbegin(); }
cend()165 const_iterator cend() const { return cache_.cend(); }
166
167 private:
168 std::mutex lock_;
169 CacheType cache_;
170 };
171
172 template <typename KeyType, typename ValueType, typename CacheType>
173 class NormalCache {
174 public:
175 using iterator = typename CacheType::iterator;
176 using const_iterator = typename CacheType::const_iterator;
177
get(const KeyType & key)178 ValueType get(const KeyType &key) const {
179 auto it = cache_.find(key);
180 if (it != cache_.end()) {
181 return it->second;
182 }
183 return nullptr;
184 }
185
set(const KeyType & key,const ValueType & data)186 void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
187
clear()188 void clear() { cache_.clear(); }
189
size()190 size_t size() const { return cache_.size(); }
191
empty()192 bool empty() const { return size() == 0; }
193
dump()194 std::string dump() const {
195 std::ostringstream buf;
196 for (auto &item : cache_) {
197 buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
198 }
199 return buf.str();
200 }
201
begin()202 iterator begin() { return cache_.begin(); }
end()203 iterator end() { return cache_.end(); }
204
begin()205 const_iterator begin() const { return cache_.cbegin(); }
end()206 const_iterator end() const { return cache_.cend(); }
207
cbegin()208 const_iterator cbegin() const { return cache_.cbegin(); }
cend()209 const_iterator cend() const { return cache_.cend(); }
210
211 private:
212 CacheType cache_;
213 };
214
215 class AsyncAbstract {
216 public:
217 AsyncAbstract() = default;
218 ~AsyncAbstract() = default;
TryGetResult()219 AbstractBasePtr TryGetResult() {
220 std::lock_guard<std::mutex> lock(lock_);
221 return result_;
222 }
HasResult()223 bool HasResult() {
224 std::lock_guard<std::mutex> lock(lock_);
225 return result_ != nullptr;
226 }
SetResult(const AbstractBasePtr & result)227 void SetResult(const AbstractBasePtr &result) {
228 MS_EXCEPTION_IF_NULL(result);
229 std::lock_guard<std::mutex> lock(lock_);
230 result_ = result;
231 }
232
ToString()233 std::string ToString() {
234 std::ostringstream buffer;
235 std::lock_guard<std::mutex> lock(lock_);
236 buffer << (result_ == nullptr ? "NOT SET" : result_->ToString());
237 return buffer.str();
238 }
239
240 private:
241 std::mutex lock_;
242 AbstractBasePtr result_{nullptr};
243 };
244
245 class AsyncInferTask {
246 public:
AsyncInferTask(const std::string & threadId,const AsyncAbstractPtr & abstract)247 explicit AsyncInferTask(const std::string &threadId, const AsyncAbstractPtr &abstract)
248 : threadId_(threadId), abstract_ptr_(abstract) {}
249 ~AsyncInferTask() = default;
250
251 static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &threadId = "") {
252 std::string thread_id = threadId;
253 if (thread_id == "") {
254 thread_id = AnalysisSchedule::GetInstance().GetThreadID();
255 }
256 MS_EXCEPTION_IF_NULL(abstract);
257 auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
258 MS_EXCEPTION_IF_NULL(ret);
259 return ret;
260 }
261
HasResult()262 bool HasResult() { return abstract_ptr_->HasResult(); }
Ready()263 int Ready() const { return ready_; }
ThreadID()264 std::string ThreadID() const { return threadId_; }
265
GetResult()266 AbstractBasePtr GetResult() {
267 std::unique_lock<std::mutex> lock(lock_);
268 if (ready_) {
269 ProcessResult();
270 return abstract_ptr_->TryGetResult();
271 }
272 // Avoid to dead lock between AsyncAbstract::lock and AnalysisSchedule::activate_thread_lock_
273 lock.unlock();
274 AnalysisSchedule::GetInstance().Yield(this);
275
276 lock.lock();
277 MS_LOG(DEBUG) << this << " after enter waiting ready: " << ready_ << " thread id:" << threadId_
278 << " GetThreadId: " << AnalysisSchedule::GetInstance().GetThreadID();
279 condition_var_.wait(lock, [this] { return ready_; });
280 MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << threadId_
281 << " GetThreadId: " << AnalysisSchedule::GetInstance().GetThreadID();
282 ProcessResult();
283 auto ans = abstract_ptr_->TryGetResult();
284 MS_EXCEPTION_IF_NULL(ans);
285 return ans;
286 }
287
SetReady()288 void SetReady() {
289 MS_LOG(DEBUG) << this << " want to set ready.";
290 {
291 std::lock_guard<std::mutex> lock(lock_);
292 ready_ = ready_ | 1; // Set the first bit = 1
293 MS_LOG(DEBUG) << this << " notify ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get()
294 << " threadId: " << threadId_;
295 }
296 condition_var_.notify_one();
297 }
298
SetException()299 void SetException() {
300 MS_LOG(DEBUG) << this << " want to set ready.";
301 {
302 std::lock_guard<std::mutex> lock(lock_);
303 ready_ = ready_ | 2; // Set the second bit = 1
304 MS_LOG(DEBUG) << this << " notify ready: " << ready_;
305 }
306 condition_var_.notify_one();
307 }
308
SetEndLessLoopException()309 void SetEndLessLoopException() {
310 MS_LOG(DEBUG) << this << " want to set ready.";
311 {
312 std::lock_guard<std::mutex> lock(lock_);
313 ready_ = ready_ | 4; // Set the third bit = 1
314 MS_LOG(DEBUG) << this << " notify ready: " << ready_;
315 }
316 condition_var_.notify_one();
317 }
318
319 private:
ClearReady()320 void ClearReady() {
321 ready_ = ready_ & 6; // Set first bit = 0
322 MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get();
323 }
HandleEndLessLoopException()324 void HandleEndLessLoopException() {
325 // Get third bit
326 if (ready_ & 4) {
327 ready_ = ready_ & 3; // Set the third bit = 0 , Only trigger once.
328 MS_LOG(EXCEPTION) << "There isn't any branch that can be evaluated. \n"
329 << "Please check the code if it's has the infinite recursion or loop.\n"
330 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
331 }
332 }
ProcessResult()333 void ProcessResult() {
334 ClearReady(); // Clear nomal ready flag
335 HandleEndLessLoopException();
336 StaticAnalysisException::Instance().CheckException();
337 MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " threadId: " << threadId_
338 << " GetThreadId:" << AnalysisSchedule::GetInstance().GetThreadID()
339 << " result: " << abstract_ptr_->TryGetResult().get();
340 }
341 std::string threadId_;
342 AsyncAbstractPtr abstract_ptr_;
343 std::mutex lock_;
344 std::condition_variable condition_var_;
345 size_t ready_{0}; // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
346 };
347
348 using EvaluatorCacheMap =
349 std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
350 using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
351
352 class EvaluatorCacheMgr {
353 public:
354 EvaluatorCacheMgr() = default;
355 ~EvaluatorCacheMgr() = default;
356
Clear()357 void Clear() { eval_result_cache_.clear(); }
GetCache()358 const EvalResultCache &GetCache() { return eval_result_cache_; }
GetValue(const AbstractBasePtrList & key)359 EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
SetValue(const AbstractBasePtrList & key,const EvalResultPtr & arg)360 void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
GetSize()361 size_t GetSize() { return eval_result_cache_.size(); }
362
363 private:
364 EvalResultCache eval_result_cache_;
365 };
366
367 // AnalysisCache
368 class AnalysisResultCacheMgr {
369 public:
370 using AnalysisConfigResultMap =
371 std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
372 using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
373 using const_iterator = typename AnalysisConfigResultCache::const_iterator;
374
375 ~AnalysisResultCacheMgr() = default;
376 AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
377 AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
GetInstance()378 static AnalysisResultCacheMgr &GetInstance() { return instance_; }
379 void Clear();
SetValue(const AnfNodeConfigPtr & conf,const EvalResultPtr & arg)380 inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
GetValue(const AnfNodeConfigPtr & conf)381 inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
382 void PushTodo(const AnfNodeConfigPtr &conf);
383 void Todo();
384 void InitSwitchValue(const AnfNodeConfigPtr &conf);
385 AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
386 AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
387 void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale);
begin()388 const_iterator begin() { return cache_.begin(); }
end()389 const_iterator end() { return cache_.end(); }
390
391 private:
392 using AnalysisConfigAsyncResultMap =
393 std::unordered_map<AnfNodeConfigPtr, AsyncAbstractPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
394 using AnalysisConfigAsyncResultCache =
395 MultiThreadCache<AnfNodeConfigPtr, AsyncAbstractPtr, AnalysisConfigAsyncResultMap>;
396 AnalysisResultCacheMgr() = default;
397 static AnalysisResultCacheMgr instance_;
398 std::mutex lock_;
399 std::mutex todo_lock_;
400 std::list<AnfNodeConfigPtr> todo_;
401 AnalysisConfigResultCache cache_;
402 AnalysisConfigAsyncResultCache switch_cache_;
403 };
404
405 std::string ArgsToString(const AbstractBasePtrList &args_spec_list);
406
GetInferThread()407 inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::GetThreadID() + ":"; }
408
409 } // namespace abstract
410 } // namespace mindspore
411 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
412