• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
19 
20 #include <iostream>
21 #include <utility>
22 #include <future>
23 #include <thread>
24 #include <memory>
25 #include <unordered_map>
26 #include <vector>
27 #include <string>
28 #include <functional>
29 #include <list>
30 #include <set>
31 #include <fstream>
32 #include <chrono>
33 
34 #include "pipeline/jit/static_analysis/static_analysis.h"
35 
36 namespace mindspore {
37 namespace abstract {
38 
39 class AsyncInferTask;
40 class AsyncAbstract;
41 using AsyncInferTaskPtr = std::shared_ptr<AsyncInferTask>;
42 using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
43 class AnalysisSchedule {
44  public:
~AnalysisSchedule()45   ~AnalysisSchedule() { Stop(); }
46   AnalysisSchedule(const AnalysisSchedule &) = delete;
47   AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
GetInstance()48   static AnalysisSchedule &GetInstance() { return instance_; }
49   static void SetThreadID(const std::string &caller);
50   static std::string &GetThreadID();
51   void HandleException(const std::exception &ex);
Stop()52   void Stop() {
53     notExit_ = false;
54     MS_LOG(DEBUG) << " Set AnalysisSchedule::Exit . The active thread count: " << activate_threads_.size()
55                   << " The infer_thread_count: " << infer_thread_count_
56                   << " schedule list size: " << scheduleList_.size();
57   }
58   void Wait();
59   void Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr);
60   void Yield(const AsyncInferTask *asyncTask);
61 
EnterWaiting()62   void EnterWaiting() {
63     {
64       MS_LOG(DEBUG) << " Require activate_thread_lock. The active thread count: " << activate_threads_.size()
65                     << " The infer_thread_count: " << infer_thread_count_
66                     << " schedule list size: " << scheduleList_.size();
67       std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
68       activate_threads_.clear();
69       MS_LOG(DEBUG) << " Get activate_thread_lock. The active thread count: " << activate_threads_.size()
70                     << " The infer_thread_count: " << infer_thread_count_
71                     << " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
72                     << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
73     }
74     activate_thread_cv_.notify_one();
75   }
76 
IncreaseThreadCount()77   void IncreaseThreadCount() {
78     infer_thread_count_.fetch_add(1);
79     MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
80                   << " The infer_thread_count: " << infer_thread_count_
81                   << " schedule list size: " << scheduleList_.size();
82   }
83 
DecreaseThreadCount()84   void DecreaseThreadCount() {
85     {
86       std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
87       infer_thread_count_.fetch_sub(1);
88     }
89     infer_thread_cv_.notify_one();
90 
91     {
92       std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
93       activate_threads_.clear();
94       MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
95                     << " The infer_thread_count: " << infer_thread_count_
96                     << " schedule list size: " << scheduleList_.size() << " thread: " << GetThreadID() + " "
97                     << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
98     }
99     activate_thread_cv_.notify_one();
100   }
101 
102  private:
103   void Schedule();
104   bool SetNextReady();
Start()105   void Start() {
106     auto thread = std::thread([this] { Schedule(); });
107     thread.detach();
108   }
AnalysisSchedule()109   AnalysisSchedule() { Start(); }
110   static AnalysisSchedule instance_;
111   std::atomic<int> infer_thread_count_{0};
112   bool notExit_{true};
113   std::mutex infer_thread_lock_;
114   std::condition_variable infer_thread_cv_;
115   std::mutex activate_thread_lock_;
116   std::condition_variable activate_thread_cv_;
117   std::list<AsyncInferTaskPtr> scheduleList_;
118   std::set<std::string> activate_threads_;
119 };
120 
121 template <typename KeyType, typename ValueType, typename CacheType>
122 class MultiThreadCache {
123  public:
124   using iterator = typename CacheType::iterator;
125   using const_iterator = typename CacheType::const_iterator;
126 
get(const KeyType & key)127   ValueType get(const KeyType &key) {
128     std::lock_guard<std::mutex> lock(lock_);
129     auto it = cache_.find(key);
130     if (it != cache_.end()) {
131       return it->second;
132     }
133     return nullptr;
134   }
135 
set(const KeyType & key,const ValueType & data)136   void set(const KeyType &key, const ValueType &data) {
137     std::lock_guard<std::mutex> lock(lock_);
138     cache_[key] = data;
139   }
140 
clear()141   void clear() {
142     std::lock_guard<std::mutex> lock(lock_);
143     cache_.clear();
144   }
145 
size()146   size_t size() { return cache_.size(); }
147 
empty()148   bool empty() { return size() == 0; }
149 
dump()150   std::string dump() {
151     std::ostringstream buf;
152     for (auto &item : cache_) {
153       buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
154     }
155     return buf.str();
156   }
157 
begin()158   iterator begin() { return cache_.begin(); }
end()159   iterator end() { return cache_.end(); }
160 
begin()161   const_iterator begin() const { return cache_.cbegin(); }
end()162   const_iterator end() const { return cache_.cend(); }
163 
cbegin()164   const_iterator cbegin() const { return cache_.cbegin(); }
cend()165   const_iterator cend() const { return cache_.cend(); }
166 
167  private:
168   std::mutex lock_;
169   CacheType cache_;
170 };
171 
172 template <typename KeyType, typename ValueType, typename CacheType>
173 class NormalCache {
174  public:
175   using iterator = typename CacheType::iterator;
176   using const_iterator = typename CacheType::const_iterator;
177 
get(const KeyType & key)178   ValueType get(const KeyType &key) const {
179     auto it = cache_.find(key);
180     if (it != cache_.end()) {
181       return it->second;
182     }
183     return nullptr;
184   }
185 
set(const KeyType & key,const ValueType & data)186   void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
187 
clear()188   void clear() { cache_.clear(); }
189 
size()190   size_t size() const { return cache_.size(); }
191 
empty()192   bool empty() const { return size() == 0; }
193 
dump()194   std::string dump() const {
195     std::ostringstream buf;
196     for (auto &item : cache_) {
197       buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
198     }
199     return buf.str();
200   }
201 
begin()202   iterator begin() { return cache_.begin(); }
end()203   iterator end() { return cache_.end(); }
204 
begin()205   const_iterator begin() const { return cache_.cbegin(); }
end()206   const_iterator end() const { return cache_.cend(); }
207 
cbegin()208   const_iterator cbegin() const { return cache_.cbegin(); }
cend()209   const_iterator cend() const { return cache_.cend(); }
210 
211  private:
212   CacheType cache_;
213 };
214 
215 class AsyncAbstract {
216  public:
217   AsyncAbstract() = default;
218   ~AsyncAbstract() = default;
TryGetResult()219   AbstractBasePtr TryGetResult() {
220     std::lock_guard<std::mutex> lock(lock_);
221     return result_;
222   }
HasResult()223   bool HasResult() {
224     std::lock_guard<std::mutex> lock(lock_);
225     return result_ != nullptr;
226   }
SetResult(const AbstractBasePtr & result)227   void SetResult(const AbstractBasePtr &result) {
228     MS_EXCEPTION_IF_NULL(result);
229     std::lock_guard<std::mutex> lock(lock_);
230     result_ = result;
231   }
232 
ToString()233   std::string ToString() {
234     std::ostringstream buffer;
235     std::lock_guard<std::mutex> lock(lock_);
236     buffer << (result_ == nullptr ? "NOT SET" : result_->ToString());
237     return buffer.str();
238   }
239 
240  private:
241   std::mutex lock_;
242   AbstractBasePtr result_{nullptr};
243 };
244 
245 class AsyncInferTask {
246  public:
AsyncInferTask(const std::string & threadId,const AsyncAbstractPtr & abstract)247   explicit AsyncInferTask(const std::string &threadId, const AsyncAbstractPtr &abstract)
248       : threadId_(threadId), abstract_ptr_(abstract) {}
249   ~AsyncInferTask() = default;
250 
251   static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &threadId = "") {
252     std::string thread_id = threadId;
253     if (thread_id == "") {
254       thread_id = AnalysisSchedule::GetInstance().GetThreadID();
255     }
256     MS_EXCEPTION_IF_NULL(abstract);
257     auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
258     MS_EXCEPTION_IF_NULL(ret);
259     return ret;
260   }
261 
HasResult()262   bool HasResult() { return abstract_ptr_->HasResult(); }
Ready()263   int Ready() const { return ready_; }
ThreadID()264   std::string ThreadID() const { return threadId_; }
265 
GetResult()266   AbstractBasePtr GetResult() {
267     std::unique_lock<std::mutex> lock(lock_);
268     if (ready_) {
269       ProcessResult();
270       return abstract_ptr_->TryGetResult();
271     }
272     // Avoid to dead lock between AsyncAbstract::lock and AnalysisSchedule::activate_thread_lock_
273     lock.unlock();
274     AnalysisSchedule::GetInstance().Yield(this);
275 
276     lock.lock();
277     MS_LOG(DEBUG) << this << " after enter waiting ready: " << ready_ << " thread id:" << threadId_
278                   << " GetThreadId: " << AnalysisSchedule::GetInstance().GetThreadID();
279     condition_var_.wait(lock, [this] { return ready_; });
280     MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << threadId_
281                   << " GetThreadId: " << AnalysisSchedule::GetInstance().GetThreadID();
282     ProcessResult();
283     auto ans = abstract_ptr_->TryGetResult();
284     MS_EXCEPTION_IF_NULL(ans);
285     return ans;
286   }
287 
SetReady()288   void SetReady() {
289     MS_LOG(DEBUG) << this << " want to set ready.";
290     {
291       std::lock_guard<std::mutex> lock(lock_);
292       ready_ = ready_ | 1;  // Set the first bit = 1
293       MS_LOG(DEBUG) << this << " notify ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get()
294                     << " threadId: " << threadId_;
295     }
296     condition_var_.notify_one();
297   }
298 
SetException()299   void SetException() {
300     MS_LOG(DEBUG) << this << " want to set ready.";
301     {
302       std::lock_guard<std::mutex> lock(lock_);
303       ready_ = ready_ | 2;  // Set the second bit = 1
304       MS_LOG(DEBUG) << this << " notify ready: " << ready_;
305     }
306     condition_var_.notify_one();
307   }
308 
SetEndLessLoopException()309   void SetEndLessLoopException() {
310     MS_LOG(DEBUG) << this << " want to set ready.";
311     {
312       std::lock_guard<std::mutex> lock(lock_);
313       ready_ = ready_ | 4;  // Set the third bit = 1
314       MS_LOG(DEBUG) << this << " notify ready: " << ready_;
315     }
316     condition_var_.notify_one();
317   }
318 
319  private:
ClearReady()320   void ClearReady() {
321     ready_ = ready_ & 6;  // Set first bit = 0
322     MS_LOG(DEBUG) << this << " ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get();
323   }
HandleEndLessLoopException()324   void HandleEndLessLoopException() {
325     // Get third bit
326     if (ready_ & 4) {
327       ready_ = ready_ & 3;  // Set the third bit = 0 , Only trigger once.
328       MS_LOG(EXCEPTION) << "There isn't any branch that can be evaluated. \n"
329                         << "Please check the code if it's has the infinite recursion or loop.\n"
330                         << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
331     }
332   }
ProcessResult()333   void ProcessResult() {
334     ClearReady();  // Clear nomal ready flag
335     HandleEndLessLoopException();
336     StaticAnalysisException::Instance().CheckException();
337     MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " threadId: " << threadId_
338                   << " GetThreadId:" << AnalysisSchedule::GetInstance().GetThreadID()
339                   << " result: " << abstract_ptr_->TryGetResult().get();
340   }
341   std::string threadId_;
342   AsyncAbstractPtr abstract_ptr_;
343   std::mutex lock_;
344   std::condition_variable condition_var_;
345   size_t ready_{0};  // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
346 };
347 
348 using EvaluatorCacheMap =
349   std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
350 using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
351 
352 class EvaluatorCacheMgr {
353  public:
354   EvaluatorCacheMgr() = default;
355   ~EvaluatorCacheMgr() = default;
356 
Clear()357   void Clear() { eval_result_cache_.clear(); }
GetCache()358   const EvalResultCache &GetCache() { return eval_result_cache_; }
GetValue(const AbstractBasePtrList & key)359   EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
SetValue(const AbstractBasePtrList & key,const EvalResultPtr & arg)360   void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
GetSize()361   size_t GetSize() { return eval_result_cache_.size(); }
362 
363  private:
364   EvalResultCache eval_result_cache_;
365 };
366 
367 // AnalysisCache
368 class AnalysisResultCacheMgr {
369  public:
370   using AnalysisConfigResultMap =
371     std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
372   using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
373   using const_iterator = typename AnalysisConfigResultCache::const_iterator;
374 
375   ~AnalysisResultCacheMgr() = default;
376   AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
377   AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
GetInstance()378   static AnalysisResultCacheMgr &GetInstance() { return instance_; }
379   void Clear();
SetValue(const AnfNodeConfigPtr & conf,const EvalResultPtr & arg)380   inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
GetValue(const AnfNodeConfigPtr & conf)381   inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
382   void PushTodo(const AnfNodeConfigPtr &conf);
383   void Todo();
384   void InitSwitchValue(const AnfNodeConfigPtr &conf);
385   AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
386   AbstractBasePtr TryGetSwitchValue(const AnfNodeConfigPtr &conf);
387   void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &vale);
begin()388   const_iterator begin() { return cache_.begin(); }
end()389   const_iterator end() { return cache_.end(); }
390 
391  private:
392   using AnalysisConfigAsyncResultMap =
393     std::unordered_map<AnfNodeConfigPtr, AsyncAbstractPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
394   using AnalysisConfigAsyncResultCache =
395     MultiThreadCache<AnfNodeConfigPtr, AsyncAbstractPtr, AnalysisConfigAsyncResultMap>;
396   AnalysisResultCacheMgr() = default;
397   static AnalysisResultCacheMgr instance_;
398   std::mutex lock_;
399   std::mutex todo_lock_;
400   std::list<AnfNodeConfigPtr> todo_;
401   AnalysisConfigResultCache cache_;
402   AnalysisConfigAsyncResultCache switch_cache_;
403 };
404 
405 std::string ArgsToString(const AbstractBasePtrList &args_spec_list);
406 
GetInferThread()407 inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::GetThreadID() + ":"; }
408 
409 }  // namespace abstract
410 }  // namespace mindspore
411 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
412