1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
18 #define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
19
20 #include <iostream>
21 #include <utility>
22 #include <future>
23 #include <thread>
24 #include <memory>
25 #include <vector>
26 #include <string>
27 #include <functional>
28 #include <list>
29 #include <set>
30 #include <unordered_map>
31 #include <fstream>
32 #include <chrono>
33 #include <mutex>
34
35 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
36 #include "utils/hash_map.h"
37
38 namespace mindspore {
39 namespace abstract {
40
41 class AsyncInferTask;
42 class AsyncAbstract;
43 class AsyncAbstractFuncAtom;
44 using AsyncInferTaskPtr = std::shared_ptr<AsyncInferTask>;
45 using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
46 class AnalysisSchedule {
47 public:
48 ~AnalysisSchedule() = default;
49 AnalysisSchedule(const AnalysisSchedule &) = delete;
50 AnalysisSchedule &operator=(const AnalysisSchedule &) = delete;
GetInstance()51 static AnalysisSchedule &GetInstance() {
52 static AnalysisSchedule instance;
53 return instance;
54 }
set_thread_id(const std::string & thread_id)55 static void set_thread_id(const std::string &thread_id) { thread_id_ = thread_id; }
thread_id()56 static std::string &thread_id() { return thread_id_; }
57 void HandleException(const std::exception &ex);
58 void Stop();
59 void Wait();
60 void Add2Schedule(const AsyncInferTaskPtr &async_infer_task_ptr);
61 void WaitForRun() const;
62 void YieldTask(AsyncInferTask *asyncTask);
63
EnterWaiting()64 void EnterWaiting() {
65 {
66 std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
67 (void)activate_threads_.erase(AnalysisSchedule::thread_id());
68 MS_LOG(DEBUG) << "Infer return to main thread.";
69 }
70 activate_thread_cv_.notify_one();
71 }
72
IncreaseThreadCount()73 void IncreaseThreadCount() {
74 infer_thread_count_.fetch_add(1);
75 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
76 << " The infer_thread_count: " << infer_thread_count_
77 << " schedule list size: " << schedule_list_.size();
78 }
79
DecreaseThreadCount()80 void DecreaseThreadCount() {
81 {
82 std::lock_guard<std::mutex> threadNumLock(infer_thread_lock_);
83 infer_thread_count_.fetch_sub(1);
84 }
85 infer_thread_cv_.notify_one();
86
87 {
88 std::lock_guard<std::mutex> active_lock(activate_thread_lock_);
89 (void)activate_threads_.erase(AnalysisSchedule::thread_id());
90 MS_LOG(DEBUG) << " The active thread count: " << activate_threads_.size()
91 << " The infer_thread_count: " << infer_thread_count_
92 << " schedule list size: " << schedule_list_.size() << " thread: " << thread_id() + " "
93 << (activate_threads_.size() > 0 ? activate_threads_.begin()->c_str() : "");
94 }
95 activate_thread_cv_.notify_one();
96 }
97
Start()98 void Start() {
99 run_ = true;
100 dispatcher_ = std::make_shared<std::thread>([this] { Schedule(); });
101 }
102
103 private:
104 void Schedule();
105 void SetNextReady();
AnalysisSchedule()106 AnalysisSchedule() { Start(); }
107 std::atomic<int> infer_thread_count_{0};
108 bool run_{true};
109 std::mutex infer_thread_lock_;
110 std::condition_variable infer_thread_cv_;
111 std::mutex activate_thread_lock_;
112 std::condition_variable activate_thread_cv_;
113 std::list<AsyncInferTaskPtr> schedule_list_;
114 std::set<std::string> activate_threads_;
115 static thread_local std::string thread_id_;
116 std::shared_ptr<std::thread> dispatcher_;
117 };
118
119 template <typename KeyType, typename ValueType, typename CacheType>
120 class MultiThreadCache {
121 public:
122 using iterator = typename CacheType::iterator;
123 using const_iterator = typename CacheType::const_iterator;
124
get(const KeyType & key)125 ValueType get(const KeyType &key) {
126 std::lock_guard<std::mutex> lock(lock_);
127 auto it = cache_.find(key);
128 if (it != cache_.end()) {
129 return it->second;
130 }
131 return nullptr;
132 }
133
set(const KeyType & key,const ValueType & data)134 void set(const KeyType &key, const ValueType &data) {
135 std::lock_guard<std::mutex> lock(lock_);
136 cache_[key] = data;
137 }
138
clear()139 void clear() {
140 std::lock_guard<std::mutex> lock(lock_);
141 cache_.clear();
142 }
143
size()144 size_t size() { return cache_.size(); }
145
empty()146 bool empty() { return size() == 0; }
147
dump()148 std::string dump() {
149 std::ostringstream buf;
150 for (auto &item : cache_) {
151 MS_EXCEPTION_IF_NULL(item.first);
152 MS_EXCEPTION_IF_NULL(item.second);
153 buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
154 }
155 return buf.str();
156 }
157
begin()158 iterator begin() { return cache_.begin(); }
end()159 iterator end() { return cache_.end(); }
160
begin()161 const_iterator begin() const { return cache_.cbegin(); }
end()162 const_iterator end() const { return cache_.cend(); }
163
cbegin()164 const_iterator cbegin() const { return cache_.cbegin(); }
cend()165 const_iterator cend() const { return cache_.cend(); }
166
167 private:
168 std::mutex lock_;
169 CacheType cache_;
170 };
171
172 template <typename KeyType, typename ValueType, typename CacheType>
173 class NormalCache {
174 public:
175 using iterator = typename CacheType::iterator;
176 using const_iterator = typename CacheType::const_iterator;
177
get(const KeyType & key)178 ValueType get(const KeyType &key) const {
179 auto it = cache_.find(key);
180 if (it != cache_.end()) {
181 return it->second;
182 }
183 return nullptr;
184 }
185
find(const KeyType & key)186 const_iterator find(const KeyType &key) const { return cache_.find(key); }
187
set(const KeyType & key,const ValueType & data)188 void set(const KeyType &key, const ValueType &data) { cache_[key] = data; }
189
clear()190 void clear() { cache_.clear(); }
191
size()192 size_t size() const { return cache_.size(); }
193
empty()194 bool empty() const { return size() == 0; }
195
dump()196 std::string dump() const {
197 std::ostringstream buf;
198 for (auto &item : cache_) {
199 MS_EXCEPTION_IF_NULL(item.first);
200 MS_EXCEPTION_IF_NULL(item.second);
201 buf << "{" << item.first->ToString() << ": " << item.second->ToString() << "}" << std::endl;
202 }
203 return buf.str();
204 }
205
begin()206 iterator begin() { return cache_.begin(); }
end()207 iterator end() { return cache_.end(); }
208
begin()209 const_iterator begin() const { return cache_.cbegin(); }
end()210 const_iterator end() const { return cache_.cend(); }
211
cbegin()212 const_iterator cbegin() const { return cache_.cbegin(); }
cend()213 const_iterator cend() const { return cache_.cend(); }
214
215 private:
216 CacheType cache_;
217 };
218
219 class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
220 public:
221 explicit AsyncAbstract(const std::shared_ptr<AsyncAbstract> &switchAbstract = nullptr)
switchAbstract_(switchAbstract)222 : switchAbstract_(switchAbstract) {}
223 ~AsyncAbstract() = default;
224 AbstractBasePtr GetResult();
TryGetResult()225 AbstractBasePtr TryGetResult() {
226 std::lock_guard<std::mutex> lock(lock_);
227 return result_;
228 }
HasResult()229 bool HasResult() {
230 std::lock_guard<std::mutex> lock(lock_);
231 return result_ != nullptr;
232 }
set_result(const AbstractBasePtr & result)233 void set_result(const AbstractBasePtr &result) {
234 std::lock_guard<std::mutex> lock(lock_);
235 result_ = result;
236 }
237
238 void ClearPossibleResult();
239
ToString()240 std::string ToString() {
241 std::ostringstream buffer;
242 std::lock_guard<std::mutex> lock(lock_);
243 buffer << (result_ == nullptr ? "NOT SET" : result_->ToString());
244 return buffer.str();
245 }
246
247 bool SetPossibleResult(bool first);
set_ignore_value(bool ignore_value)248 void set_ignore_value(bool ignore_value) { ignore_value_ = ignore_value; }
249
250 private:
251 std::mutex lock_;
252 AbstractBasePtr result_{nullptr};
253 bool not_copy_from_other_{true};
254 bool ignore_value_{false};
255 std::shared_ptr<AsyncAbstract> switchAbstract_;
256 };
257
258 // Wrap AsyncAbstract, so it can work with Join method of AbstractFunction.
259 class AsyncAbstractFuncAtom : public AbstractFuncAtom {
260 public:
AsyncAbstractFuncAtom(const AsyncAbstractPtr & async_abstract,const std::vector<std::size_t> & index)261 AsyncAbstractFuncAtom(const AsyncAbstractPtr &async_abstract, const std::vector<std::size_t> &index)
262 : async_abstract_(async_abstract), index_(index) {}
263 ~AsyncAbstractFuncAtom() = default;
264 MS_DECLARE_PARENT(AsyncAbstractFuncAtom, AbstractFuncAtom);
265
MakeShared(const AsyncAbstractPtr & async_abstract,const std::vector<std::size_t> & index)266 static std::shared_ptr<AsyncAbstractFuncAtom> MakeShared(const AsyncAbstractPtr &async_abstract,
267 const std::vector<std::size_t> &index) {
268 MS_EXCEPTION_IF_NULL(async_abstract);
269 auto ret = std::make_shared<AsyncAbstractFuncAtom>(async_abstract, index);
270 MS_EXCEPTION_IF_NULL(ret);
271 return ret;
272 }
273
Copy()274 AbstractFunctionPtr Copy() const override { return MakeShared(async_abstract_, index_); }
275
276 bool operator==(const AbstractFunction &other) const override {
277 if (!other.isa<AsyncAbstractFuncAtom>()) {
278 return false;
279 }
280 auto other_async = static_cast<const AsyncAbstractFuncAtom *>(&other);
281 MS_EXCEPTION_IF_NULL(other_async);
282 if (index_ != other_async->index_) {
283 return false;
284 }
285 if (async_abstract_ == other_async->async_abstract_) {
286 return true;
287 }
288 MS_EXCEPTION_IF_NULL(async_abstract_);
289 auto abs = async_abstract_->TryGetResult();
290 auto other_abs = other_async->async_abstract_->TryGetResult();
291 if (abs != nullptr && other_abs != nullptr) {
292 return *abs == *other_abs;
293 } else {
294 return false;
295 }
296 }
297
hash()298 std::size_t hash() const override {
299 std::size_t hash_index = 0;
300 for (const auto i : index_) {
301 hash_index = hash_combine(hash_index, std::hash<std::size_t>{}(i));
302 }
303 return hash_index;
304 }
305
306 AbstractFunctionPtr GetUnique() override;
307
308 std::string ToString() const override;
309
310 private:
311 // Resolved AbstractFunction after fully analyzed.
312 AbstractFunctionPtr resolved_{nullptr};
313 // Before resolved, use the following two items to track.
314 const AsyncAbstractPtr async_abstract_;
315 const std::vector<std::size_t> index_;
316 };
317 using AsyncAbstractFuncAtomPtr = std::shared_ptr<AsyncAbstractFuncAtom>;
318
319 class AsyncInferTask {
320 public:
AsyncInferTask(const std::string & thread_id,const AsyncAbstractPtr & abstract)321 explicit AsyncInferTask(const std::string &thread_id, const AsyncAbstractPtr &abstract)
322 : thread_id_(thread_id), abstract_ptr_(abstract) {
323 MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " : " << this;
324 }
~AsyncInferTask()325 ~AsyncInferTask() { MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " : " << this; }
326
327 static AsyncInferTaskPtr MakeShared(const AsyncAbstractPtr &abstract, const std::string &thread = "") {
328 std::string thread_id = thread;
329 if (thread_id == "") {
330 thread_id = AnalysisSchedule::thread_id();
331 }
332 MS_EXCEPTION_IF_NULL(abstract);
333 auto ret = std::make_shared<AsyncInferTask>(thread_id, abstract);
334 MS_EXCEPTION_IF_NULL(ret);
335 return ret;
336 }
337
HasResult()338 bool HasResult() {
339 MS_EXCEPTION_IF_NULL(abstract_ptr_);
340 return abstract_ptr_->HasResult();
341 }
SetPossibleResult(bool first)342 bool SetPossibleResult(bool first) {
343 MS_EXCEPTION_IF_NULL(abstract_ptr_);
344 return abstract_ptr_->SetPossibleResult(first);
345 }
ready()346 int ready() {
347 std::lock_guard<std::mutex> lock(lock_);
348 return SizeToInt(ready_);
349 }
thread_id()350 const std::string &thread_id() const { return thread_id_; }
351
GetResult()352 AbstractBasePtr GetResult() {
353 StaticAnalysisException::Instance().CheckException();
354 AnalysisSchedule::GetInstance().YieldTask(this);
355 std::unique_lock<std::mutex> lock(lock_);
356 MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " waiting.";
357 condition_var_.wait(lock, [this] { return ready_; });
358 MS_LOG(DEBUG) << this << " received notify and wake up: " << ready_ << " thread id:" << thread_id_;
359 ProcessResult();
360 MS_EXCEPTION_IF_NULL(abstract_ptr_);
361 auto ans = abstract_ptr_->TryGetResult();
362 MS_EXCEPTION_IF_NULL(ans);
363 MS_LOG(DEBUG) << AnalysisSchedule::thread_id() << " active.";
364 return ans;
365 }
366
SetReady()367 void SetReady() {
368 {
369 std::lock_guard<std::mutex> lock(lock_);
370 ready_ = ready_ | 0b001; // Set the first bit = 1
371 MS_EXCEPTION_IF_NULL(abstract_ptr_);
372 MS_LOG(DEBUG) << this << " notify ready: " << ready_ << " result: " << abstract_ptr_->TryGetResult().get()
373 << " thread_id: " << thread_id_;
374 }
375 condition_var_.notify_one();
376 }
377
SetException()378 void SetException() {
379 {
380 std::lock_guard<std::mutex> lock(lock_);
381 ready_ = ready_ | 0b010; // Set the second bit = 1
382 MS_LOG(DEBUG) << this << " notify ready: " << ready_;
383 }
384 condition_var_.notify_one();
385 }
386
SetEndLessLoopException()387 void SetEndLessLoopException() {
388 {
389 std::lock_guard<std::mutex> lock(lock_);
390 ready_ = ready_ | 0b100; // Set the third bit = 1
391 MS_LOG(DEBUG) << this << " notify ready: " << ready_;
392 }
393 condition_var_.notify_one();
394 }
395
396 private:
HandleEndLessLoopException()397 void HandleEndLessLoopException() {
398 // Get third bit
399 if (ready_ & 0b100) {
400 ready_ = ready_ & 0b011; // Set the third bit = 0 , Only trigger once.
401 MS_LOG(EXCEPTION) << "There isn't any branch that can be evaluated. \n"
402 << "Please check the code if it has the infinite recursion or loop.\n"
403 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
404 }
405 }
ProcessResult()406 void ProcessResult() {
407 HandleEndLessLoopException();
408 StaticAnalysisException::Instance().CheckException();
409 MS_EXCEPTION_IF_NULL(abstract_ptr_);
410 MS_LOG(DEBUG) << this << " Success to GetResult. ready: " << ready_ << " thread_id: " << thread_id_
411 << " result: " << abstract_ptr_->TryGetResult().get();
412 }
413 std::string thread_id_;
414 AsyncAbstractPtr abstract_ptr_;
415 std::mutex lock_;
416 std::condition_variable condition_var_;
417 size_t ready_{0}; // 0: not ready, bit 1 = 1: ready, bit 2 = 1: exception, bit 3 = 1: endless loop
418 };
419
420 using EvaluatorCacheMap =
421 std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
422 using EvalResultCache = NormalCache<AbstractBasePtrList, EvalResultPtr, EvaluatorCacheMap>;
423
424 class EvaluatorCacheMgr {
425 public:
426 EvaluatorCacheMgr() = default;
427 ~EvaluatorCacheMgr() = default;
428
Clear()429 void Clear() { eval_result_cache_.clear(); }
GetCache()430 const EvalResultCache &GetCache() const { return eval_result_cache_; }
GetValue(const AbstractBasePtrList & key)431 EvalResultPtr GetValue(const AbstractBasePtrList &key) { return eval_result_cache_.get(key); }
SetValue(const AbstractBasePtrList & key,const EvalResultPtr & arg)432 void SetValue(const AbstractBasePtrList &key, const EvalResultPtr &arg) { eval_result_cache_.set(key, arg); }
GetSize()433 size_t GetSize() const { return eval_result_cache_.size(); }
434
435 private:
436 EvalResultCache eval_result_cache_;
437 };
438
439 // AnalysisCache
440 class AnalysisResultCacheMgr {
441 public:
442 using AnalysisConfigResultMap =
443 mindspore::HashMap<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
444 using AnalysisConfigResultCache = NormalCache<AnfNodeConfigPtr, EvalResultPtr, AnalysisConfigResultMap>;
445 using const_iterator = typename AnalysisConfigResultCache::const_iterator;
446
447 ~AnalysisResultCacheMgr() = default;
448 AnalysisResultCacheMgr(const AnalysisResultCacheMgr &) = delete;
449 AnalysisResultCacheMgr &operator=(const AnalysisResultCacheMgr &) = delete;
GetInstance()450 static AnalysisResultCacheMgr &GetInstance() {
451 static AnalysisResultCacheMgr instance;
452 return instance;
453 }
454 void Clear();
GetCache()455 const AnalysisConfigResultCache &GetCache() const { return cache_; }
SetValue(const AnfNodeConfigPtr & conf,const EvalResultPtr & arg)456 inline void SetValue(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg) { cache_.set(conf, arg); }
GetValue(const AnfNodeConfigPtr & conf)457 inline EvalResultPtr GetValue(const AnfNodeConfigPtr &conf) { return cache_.get(conf); }
458 void InitSwitchValue(const AnfNodeConfigPtr &conf);
459 AbstractBasePtr GetSwitchValue(const AnfNodeConfigPtr &conf);
460 void SetSwitchValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
begin()461 const_iterator begin() { return cache_.begin(); }
end()462 const_iterator end() { return cache_.end(); }
463 void CheckSwitchValueJoinable(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
prim_eval_cache()464 const PrimitiveEvalCachePtr &prim_eval_cache() const { return prim_eval_cache_; }
465
466 private:
467 using AnalysisConfigAsyncResultMap =
468 mindspore::HashMap<AnfNodeConfigPtr, AsyncAbstractPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
469 using AnalysisConfigAsyncResultCache =
470 MultiThreadCache<AnfNodeConfigPtr, AsyncAbstractPtr, AnalysisConfigAsyncResultMap>;
471 AnalysisResultCacheMgr() = default;
472 void SetCacheValue(const AnfNodeConfigPtr &conf, const AbstractBasePtr ¤t_abs,
473 AnalysisConfigAsyncResultCache *cache);
474
475 std::mutex lock_;
476 AnalysisConfigResultCache cache_;
477 AnalysisConfigAsyncResultCache switch_cache_;
478 AnalysisConfigAsyncResultCache switch_cache_for_check_;
479 PrimitiveEvalCachePtr prim_eval_cache_ = std::make_shared<PrimitiveEvalCache>();
480 };
481
482 std::string ArgsToString(const AbstractBasePtrList &args_abs_list);
483 bool enable_waiting_branch_eval();
484 bool NeedWaitForBranches(const AbstractBasePtr &abstract);
485
GetInferThread()486 inline std::string GetInferThread() { return std::string("[tid: ") + AnalysisSchedule::thread_id() + "] "; }
487 } // namespace abstract
488 } // namespace mindspore
489 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ASYNC_EVAL_RESULT_H_
490