• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
19 
20 #include <iostream>
21 #include <utility>
22 #include <future>
23 #include <thread>
24 #include <memory>
25 #include <vector>
26 #include <string>
27 #include <functional>
28 #include <list>
29 #include <set>
30 #include <unordered_map>
31 #include <fstream>
32 #include <chrono>
33 #include <mutex>
34 
35 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
36 #include "utils/hash_map.h"
37 
38 namespace mindspore {
39 namespace abstract {
40 
41 class AsyncInferTask;
42 class AsyncAbstract;
43 class AsyncAbstractFuncAtom;
44 using AsyncInferTaskPtr = std::shared_ptr<AsyncInferTask>;
45 using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
46 class AnalysisSchedule {
47  public:
48   ~AnalysisSchedule() = default;
49   AnalysisSchedule(const AnalysisSchedule &) = delete;
50   AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
GetInstance()51   static AnalysisSchedule &GetInstance() {
52     static AnalysisSchedule instance;
53     return instance;
54   }
set_thread_id(const std::string & thread_id)55   static void set_thread_id(const std::string &thread_id) { thread_id_ = thread_id; }
thread_id()56   static std::string &thread_id() { return thread_id_; }
57   void HandleException(const std::exception &ex);
58   void Stop();
59   void Wait();
60   void Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr);
61   void WaitForRun() const;
62   void YieldTask(AsyncInferTask *asyncTask);
63 
EnterWaiting()64   void EnterWaiting() {
65     {
66       std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
67       (void)activate_threads_.erase(AnalysisSchedule::thread_id());
68       MS_LOG(DEBUG) << "Infer return to main thread.";
69     }
70     activate_thread_cv_.notify_one();
71   }
72 
IncreaseThreadCount()73   void IncreaseThreadCount() {
74     infer_thread_count_.fetch_add(1);
75     MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
76                   << " The infer_thread_count: " << infer_thread_count_
77                   << " schedule list size: " << schedule_list_.size();
78   }
79 
DecreaseThreadCount()80   void DecreaseThreadCount() {
81     {
82       std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
83       infer_thread_count_.fetch_sub(1);
84     }
85     infer_thread_cv_.notify_one();
86 
87     {
88       std::lock_guard<std::mutex> active_lock(activate_thread_lock_);
89       (void)activate_threads_.erase(AnalysisSchedule::thread_id());
90       MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
91                     << " The infer_thread_count: " << infer_thread_count_
92                     << " schedule list size: " << schedule_list_.size() << " thread: " << thread_id() + " "
93                     << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
94     }
95     activate_thread_cv_.notify_one();
96   }
97 
Start()98   void Start() {
99     run_ = true;
100     dispatcher_ = std::make_shared<std::thread>([this] { Schedule(); });
101   }
102 
103  private:
104   void Schedule();
105   void SetNextReady();
AnalysisSchedule()106   AnalysisSchedule() { Start(); }
107   std::atomic<int> infer_thread_count_{0};
108   bool run_{true};
109   std::mutex infer_thread_lock_;
110   std::condition_variable infer_thread_cv_;
111   std::mutex activate_thread_lock_;
112   std::condition_variable activate_thread_cv_;
113   std::list<AsyncInferTaskPtr> schedule_list_;
114   std::set<std::string> activate_threads_;
115   static thread_local std::string thread_id_;
116   std::shared_ptr<std::thread> dispatcher_;
117 };
118 
119 template <typename KeyType, typename ValueType, typename CacheType>
120 class MultiThreadCache {
121  public:
122   using iterator = typename CacheType::iterator;
123   using const_iterator = typename CacheType::const_iterator;
124 
get(const KeyType & key)125   ValueType get(const KeyType &key) {
126     std::lock_guard<std::mutex> lock(lock_);
127     auto it = cache_.find(key);
128     if (it != cache_.end()) {
129       return it->second;
130     }
131     return nullptr;
132   }
133 
set(const KeyType & key,const ValueType & data)134   void set(const KeyType &key, const ValueType &data) {
135     std::lock_guard<std::mutex> lock(lock_);
136     cache_[key] = data;
137   }
138 
clear()139   void clear() {
140     std::lock_guard<std::mutex> lock(lock_);
141     cache_.clear();
142   }
143 
size()144   size_t size() { return cache_.size(); }
145 
empty()146   bool empty() { return size() == 0; }
147 
dump()148   std::string dump() {
149     std::ostringstream buf;
150     for (auto &item : cache_) {
151       MS_EXCEPTION_IF_NULL(item.first);
152       MS_EXCEPTION_IF_NULL(item.second);
153       buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
154     }
155     return buf.str();
156   }
157 
begin()158   iterator begin() { return cache_.begin(); }
end()159   iterator end() { return cache_.end(); }
160 
begin()161   const_iterator begin() const { return cache_.cbegin(); }
end()162   const_iterator end() const { return cache_.cend(); }
163 
cbegin()164   const_iterator cbegin() const { return cache_.cbegin(); }
cend()165   const_iterator cend() const { return cache_.cend(); }
166 
167  private:
168   std::mutex lock_;
169   CacheType cache_;
170 };
171 
172 template <typename KeyType, typename ValueType, typename CacheType>
173 class NormalCache {
174  public:
175   using iterator = typename CacheType::iterator;
176   using const_iterator = typename CacheType::const_iterator;
177 
get(const KeyType & key)178   ValueType get(const KeyType &key) const {
179     auto it = cache_.find(key);
180     if (it != cache_.end()) {
181       return it->second;
182     }
183     return nullptr;
184   }
185 
find(const KeyType & key)186   const_iterator find(const KeyType &key) const { return cache_.find(key); }
187 
set(const KeyType & key,const ValueType & data)188   void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
189 
clear()190   void clear() { cache_.clear(); }
191 
size()192   size_t size() const { return cache_.size(); }
193 
empty()194   bool empty() const { return size() == 0; }
195 
dump()196   std::string dump() const {
197     std::ostringstream buf;
198     for (auto &item : cache_) {
199       MS_EXCEPTION_IF_NULL(item.first);
200       MS_EXCEPTION_IF_NULL(item.second);
201       buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
202     }
203     return buf.str();
204   }
205 
begin()206   iterator begin() { return cache_.begin(); }
end()207   iterator end() { return cache_.end(); }
208 
begin()209   const_iterator begin() const { return cache_.cbegin(); }
end()210   const_iterator end() const { return cache_.cend(); }
211 
cbegin()212   const_iterator cbegin() const { return cache_.cbegin(); }
cend()213   const_iterator cend() const { return cache_.cend(); }
214 
215  private:
216   CacheType cache_;
217 };
218 
219 class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
220  public:
221   explicit AsyncAbstract(const std::shared_ptr<AsyncAbstract> &switchAbstract = nullptr)
switchAbstract_(switchAbstract)222       : switchAbstract_(switchAbstract) {}
223   ~AsyncAbstract() = default;
224   AbstractBasePtr GetResult();
TryGetResult()225   AbstractBasePtr TryGetResult() {
226     std::lock_guard<std::mutex> lock(lock_);
227     return result_;
228   }
HasResult()229   bool HasResult() {
230     std::lock_guard<std::mutex> lock(lock_);
231     return result_ != nullptr;
232   }
set_result(const AbstractBasePtr & result)233   void set_result(const AbstractBasePtr &result) {
234     std::lock_guard<std::mutex> lock(lock_);
235     result_ = result;
236   }
237 
238   void ClearPossibleResult();
239 
ToString()240   std::string ToString() {
241     std::ostringstream buffer;
242     std::lock_guard<std::mutex> lock(lock_);
243     buffer << (result_ == nullptr ? "NOT SET" : result_->ToString());
244     return buffer.str();
245   }
246 
247   bool SetPossibleResult(bool first);
set_ignore_value(bool ignore_value)248   void set_ignore_value(bool ignore_value) { ignore_value_ = ignore_value; }
249 
250  private:
251   std::mutex lock_;
252   AbstractBasePtr result_{nullptr};
253   bool not_copy_from_other_{true};
254   bool ignore_value_{false};
255   std::shared_ptr<AsyncAbstract> switchAbstract_;
256 };
257 
258 // Wrap AsyncAbstract, so it can work with Join method of AbstractFunction.
259 class AsyncAbstractFuncAtom : public AbstractFuncAtom {
260  public:
AsyncAbstractFuncAtom(const AsyncAbstractPtr & async_abstract,const std::vector<std::size_t> & index)261   AsyncAbstractFuncAtom(const AsyncAbstractPtr &async_abstract, const std::vector<std::size_t> &index)
262       : async_abstract_(async_abstract), index_(index) {}
263   ~AsyncAbstractFuncAtom() = default;
264   MS_DECLARE_PARENT(AsyncAbstractFuncAtom, AbstractFuncAtom);
265 
MakeShared(const AsyncAbstractPtr & async_abstract,const std::vector<std::size_t> & index)266   static std::shared_ptr<AsyncAbstractFuncAtom> MakeShared(const AsyncAbstractPtr &async_abstract,
267                                                            const std::vector<std::size_t> &index) {
268     MS_EXCEPTION_IF_NULL(async_abstract);
269     auto ret = std::make_shared<AsyncAbstractFuncAtom>(async_abstract, index);
270     MS_EXCEPTION_IF_NULL(ret);
271     return ret;
272   }
273 
Copy()274   AbstractFunctionPtr Copy() const override { return MakeShared(async_abstract_, index_); }
275 
276   bool operator==(const AbstractFunction &other) const override {
277     if (!other.isa<AsyncAbstractFuncAtom>()) {
278       return false;
279     }
280     auto other_async = static_cast<const AsyncAbstractFuncAtom *>(&other);
281     MS_EXCEPTION_IF_NULL(other_async);
282     if (index_ != other_async->index_) {
283       return false;
284     }
285     if (async_abstract_ == other_async->async_abstract_) {
286       return true;
287     }
288     MS_EXCEPTION_IF_NULL(async_abstract_);
289     auto abs = async_abstract_->TryGetResult();
290     auto other_abs = other_async->async_abstract_->TryGetResult();
291     if (abs != nullptr && other_abs != nullptr) {
292       return *abs == *other_abs;
293     } else {
294       return false;
295     }
296   }
297 
hash()298   std::size_t hash() const override {
299     std::size_t hash_index = 0;
300     for (const auto i : index_) {
301       hash_index = hash_combine(hash_index, std::hash<std::size_t>{}(i));
302     }
303     return hash_index;
304   }
305 
306   AbstractFunctionPtr GetUnique() override;
307 
308   std::string ToString() const override;
309 
310  private:
311   // Resolved AbstractFunction after fully analyzed.
312   AbstractFunctionPtr resolved_{nullptr};
313   // Before resolved, use the following two items to track.
314   const AsyncAbstractPtr async_abstract_;
315   const std::vector<std::size_t> index_;
316 };
317 using AsyncAbstractFuncAtomPtr = std::shared_ptr<AsyncAbstractFuncAtom>;
318 
319 class AsyncInferTask {
320  public:
AsyncInferTask(const std::string & thread_id,const AsyncAbstractPtr & abstract)321   explicit AsyncInferTask(const std::string &thread_id, const AsyncAbstractPtr &abstract)
322       : thread_id_(thread_id), abstract_ptr_(abstract) {
323     MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " : " << this;
324   }
~AsyncInferTask()325   ~AsyncInferTask() { MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " : " << this; }
326 
327   static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") {
328     std::string thread_id = thread;
329     if (thread_id == "") {
330       thread_id = AnalysisSchedule::thread_id();
331     }
332     MS_EXCEPTION_IF_NULL(abstract);
333     auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
334     MS_EXCEPTION_IF_NULL(ret);
335     return ret;
336   }
337 
HasResult()338   bool HasResult() {
339     MS_EXCEPTION_IF_NULL(abstract_ptr_);
340     return abstract_ptr_->HasResult();
341   }
SetPossibleResult(bool first)342   bool SetPossibleResult(bool first) {
343     MS_EXCEPTION_IF_NULL(abstract_ptr_);
344     return abstract_ptr_->SetPossibleResult(first);
345   }
ready()346   int ready() {
347     std::lock_guard<std::mutex> lock(lock_);
348     return SizeToInt(ready_);
349   }
thread_id()350   const std::string &thread_id() const { return thread_id_; }
351 
GetResult()352   AbstractBasePtr GetResult() {
353     StaticAnalysisException::Instance().CheckException();
354     AnalysisSchedule::GetInstance().YieldTask(this);
355     std::unique_lock<std::mutex> lock(lock_);
356     MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " waiting.";
357     condition_var_.wait(lock, [this] { return ready_; });
358     MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << thread_id_;
359     ProcessResult();
360     MS_EXCEPTION_IF_NULL(abstract_ptr_);
361     auto ans = abstract_ptr_->TryGetResult();
362     MS_EXCEPTION_IF_NULL(ans);
363     MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " active.";
364     return ans;
365   }
366 
SetReady()367   void SetReady() {
368     {
369       std::lock_guard<std::mutex> lock(lock_);
370       ready_ = ready_ | 0b001;  // Set the first bit = 1
371       MS_EXCEPTION_IF_NULL(abstract_ptr_);
372       MS_LOG(DEBUG) << this << " notify ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get()
373                     << " thread_id: " << thread_id_;
374     }
375     condition_var_.notify_one();
376   }
377 
SetException()378   void SetException() {
379     {
380       std::lock_guard<std::mutex> lock(lock_);
381       ready_ = ready_ | 0b010;  // Set the second bit = 1
382       MS_LOG(DEBUG) << this << " notify ready: " << ready_;
383     }
384     condition_var_.notify_one();
385   }
386 
SetEndLessLoopException()387   void SetEndLessLoopException() {
388     {
389       std::lock_guard<std::mutex> lock(lock_);
390       ready_ = ready_ | 0b100;  // Set the third bit = 1
391       MS_LOG(DEBUG) << this << " notify ready: " << ready_;
392     }
393     condition_var_.notify_one();
394   }
395 
396  private:
HandleEndLessLoopException()397   void HandleEndLessLoopException() {
398     // Get third bit
399     if (ready_ & 0b100) {
400       ready_ = ready_ & 0b011;  // Set the third bit = 0 , Only trigger once.
401       MS_LOG(EXCEPTION) << "There isn't any branch that can be evaluated. \n"
402                         << "Please check the code if it has the infinite recursion or loop.\n"
403                         << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
404     }
405   }
ProcessResult()406   void ProcessResult() {
407     HandleEndLessLoopException();
408     StaticAnalysisException::Instance().CheckException();
409     MS_EXCEPTION_IF_NULL(abstract_ptr_);
410     MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " thread_id: " << thread_id_
411                   << " result: " << abstract_ptr_->TryGetResult().get();
412   }
413   std::string thread_id_;
414   AsyncAbstractPtr abstract_ptr_;
415   std::mutex lock_;
416   std::condition_variable condition_var_;
417   size_t ready_{0};  // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
418 };
419 
420 using EvaluatorCacheMap =
421   std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
422 using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
423 
424 class EvaluatorCacheMgr {
425  public:
426   EvaluatorCacheMgr() = default;
427   ~EvaluatorCacheMgr() = default;
428 
Clear()429   void Clear() { eval_result_cache_.clear(); }
GetCache()430   const EvalResultCache &GetCache() const { return eval_result_cache_; }
GetValue(const AbstractBasePtrList & key)431   EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
SetValue(const AbstractBasePtrList & key,const EvalResultPtr & arg)432   void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
GetSize()433   size_t GetSize() const { return eval_result_cache_.size(); }
434 
435  private:
436   EvalResultCache eval_result_cache_;
437 };
438 
439 // AnalysisCache
440 class AnalysisResultCacheMgr {
441  public:
442   using AnalysisConfigResultMap =
443     mindspore::HashMap<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
444   using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
445   using const_iterator = typename AnalysisConfigResultCache::const_iterator;
446 
447   ~AnalysisResultCacheMgr() = default;
448   AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
449   AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
GetInstance()450   static AnalysisResultCacheMgr &GetInstance() {
451     static AnalysisResultCacheMgr instance;
452     return instance;
453   }
454   void Clear();
GetCache()455   const AnalysisConfigResultCache &GetCache() const { return cache_; }
SetValue(const AnfNodeConfigPtr & conf,const EvalResultPtr & arg)456   inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
GetValue(const AnfNodeConfigPtr & conf)457   inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
458   void InitSwitchValue(const AnfNodeConfigPtr &conf);
459   AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
460   void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
begin()461   const_iterator begin() { return cache_.begin(); }
end()462   const_iterator end() { return cache_.end(); }
463   void CheckSwitchValueJoinable(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
prim_eval_cache()464   const PrimitiveEvalCachePtr &prim_eval_cache() const { return prim_eval_cache_; }
465 
466  private:
467   using AnalysisConfigAsyncResultMap =
468     mindspore::HashMap<AnfNodeConfigPtr, AsyncAbstractPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
469   using AnalysisConfigAsyncResultCache =
470     MultiThreadCache<AnfNodeConfigPtr, AsyncAbstractPtr, AnalysisConfigAsyncResultMap>;
471   AnalysisResultCacheMgr() = default;
472   void SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &current_abs,
473                      AnalysisConfigAsyncResultCache *cache);
474 
475   std::mutex lock_;
476   AnalysisConfigResultCache cache_;
477   AnalysisConfigAsyncResultCache switch_cache_;
478   AnalysisConfigAsyncResultCache switch_cache_for_check_;
479   PrimitiveEvalCachePtr prim_eval_cache_ = std::make_shared<PrimitiveEvalCache>();
480 };
481 
482 std::string ArgsToString(const AbstractBasePtrList &args_abs_list);
483 bool enable_waiting_branch_eval();
484 bool NeedWaitForBranches(const AbstractBasePtr &abstract);
485 
GetInferThread()486 inline std::string GetInferThread() { return std::string("[tid: ") + AnalysisSchedule::thread_id() + "] "; }
487 }  // namespace abstract
488 }  // namespace mindspore
489 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
490