• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_ASYNC_TASK_MANAGERER_H_
18 #define MINDSPORE_PI_JIT_ASYNC_TASK_MANAGERER_H_
19 
20 #include <functional>
21 #include <atomic>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include "ir/anf.h"
27 #include "runtime/pipeline/async_rqueue.h"
28 #include "pybind11/stl.h"
29 
30 namespace mindspore {
31 namespace pijit {
32 namespace py = pybind11;
33 using RecordFunc = std::function<void(const py::object &prim, const py::object &out, const py::list &inputs)>;
34 class AsyncTaskMultiWorker : public std::enable_shared_from_this<AsyncTaskMultiWorker> {
35  public:
AsyncTaskMultiWorker(runtime::TaskType task_type)36   explicit AsyncTaskMultiWorker(runtime::TaskType task_type) : comp_count_(0), task_type_(task_type), done_(false) {}
37   virtual ~AsyncTaskMultiWorker() = default;
38   virtual void Run() = 0;
task_type()39   runtime::TaskType task_type() const { return task_type_; }
40   void Depend(std::shared_ptr<AsyncTaskMultiWorker> task);
41   void DependOn(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks);
42   void Notify();
43   void NotifyTo(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks);
44   bool Available();
45   void Reset();
46   void RunWrapper();
Done()47   bool Done() const { return done_; }
48 
49  protected:
50   std::vector<std::shared_ptr<AsyncTaskMultiWorker>> depends_;
51   std::vector<std::shared_ptr<AsyncTaskMultiWorker>> notifies_;
52   std::atomic<size_t> comp_count_;
53   runtime::TaskType task_type_;
54   bool done_;
55 };
56 #ifdef USE_ASYNC_SINGLE_WORKER
57 using AsyncTask = runtime::AsyncTask;
58 using AsyncTaskPtr = std::shared_ptr<runtime::AsyncTask>;
59 #else
60 using AsyncTask = AsyncTaskMultiWorker;
61 using AsyncTaskPtr = std::shared_ptr<AsyncTaskMultiWorker>;
62 #endif
63 
64 class AsyncQueueMultiWorker {
65  public:
66   AsyncQueueMultiWorker(std::string name, runtime::kThreadWaitLevel wait_level, size_t worker_count = 8);
67   virtual ~AsyncQueueMultiWorker();
68 
69   void Push(const AsyncTaskPtr &task);
70   void Wait();
71   bool Empty();
72   void Clear();
73   void WorkerJoin();
74 
75  protected:
76   bool Available();
77   AsyncTaskPtr PopAvailable();
78   AsyncTaskPtr Pop();
79   void WorkerLoop();
80   std::vector<std::unique_ptr<std::thread>> workers_;
81   std::mutex mutex_;
82   std::condition_variable ready_cv_;
83   std::condition_variable task_cv_;
84   std::vector<AsyncTaskPtr> tasks_queue_;
85   std::vector<AsyncTaskPtr> wait_queue_;
86   std::string name_;
87   runtime::kThreadWaitLevel wait_level_;
88   size_t worker_cnt_;
89   size_t ready_cnt_;
90   bool terminate_;
91 };
92 #ifdef USE_ASYNC_SINGLE_WORKER
93 using AsyncQueue = runtime::AsyncRQueue;
94 using AsyncQueuePtr = AsyncRQueuePtr;
95 #else
96 using AsyncQueue = AsyncQueueMultiWorker;
97 using AsyncQueuePtr = std::shared_ptr<AsyncQueueMultiWorker>;
98 #endif
99 
100 class RecordTask : public mindspore::pijit::AsyncTask {
101  public:
RecordTask(RecordFunc task,const py::object & prim,const py::object & out,const py::list & inputs)102   explicit RecordTask(RecordFunc task, const py::object &prim, const py::object &out, const py::list &inputs)
103       : mindspore::pijit::AsyncTask(runtime::kBpropTask),
104         run_task_(std::move(task)),
105         prim_(prim),
106         out_(out),
107         inputs_(inputs) {}
108   ~RecordTask() override = default;
109   void Run() override;
110 
111  private:
112   RecordFunc run_task_;
113   py::object prim_;
114   py::object out_;
115   py::list inputs_;
116 };
117 
118 using RecordTaskPtr = std::shared_ptr<RecordTask>;
119 
120 class RunGenerateBpropTask : public mindspore::pijit::AsyncTask {
121  public:
RunGenerateBpropTask(std::function<void ()> task)122   explicit RunGenerateBpropTask(std::function<void()> task)
123       : mindspore::pijit::AsyncTask(runtime::kBpropTask), run_task_(std::move(task)) {}
124   ~RunGenerateBpropTask() override = default;
125   void Run() override;
126 
127  private:
128   std::function<void()> run_task_;
129 };
130 
131 using RunGenerateBpropTaskPtr = std::shared_ptr<RunGenerateBpropTask>;
132 
133 class RunBpropTask : public mindspore::pijit::AsyncTask {
134  public:
RunBpropTask(std::function<void (const ValuePtr & value)> task,const ValuePtr & value)135   explicit RunBpropTask(std::function<void(const ValuePtr &value)> task, const ValuePtr &value)
136       : mindspore::pijit::AsyncTask(runtime::kBpropTask), run_task_(std::move(task)), value_(value) {}
137   ~RunBpropTask() override = default;
138   void Run() override;
139 
140  private:
141   std::function<void(const ValuePtr &value)> run_task_;
142   ValuePtr value_;
143 };
144 
145 using RunBpropTaskPtr = std::shared_ptr<RunBpropTask>;
146 
147 using Level = runtime::kThreadWaitLevel;
148 class AsyncTaskManager {
149  public:
AsyncTaskManager()150   AsyncTaskManager()
151       : record_task_queue_(std::make_shared<AsyncQueue>("record_task_queue", Level::kLevelGrad)),
152         generate_task_queue_(std::make_shared<AsyncQueue>("generate_task_queue", Level::kLevelGrad)),
153         run_task_queue_(std::make_shared<AsyncQueue>("run_task_queue", Level::kLevelGrad)) {}
154   virtual ~AsyncTaskManager() = default;
155 
GetRecordTaskQueue()156   const AsyncQueuePtr &GetRecordTaskQueue() const { return record_task_queue_; }
GetGenerateTaskQueue()157   const AsyncQueuePtr &GetGenerateTaskQueue() const { return generate_task_queue_; }
GetRunTaskQueue()158   const AsyncQueuePtr &GetRunTaskQueue() const { return run_task_queue_; }
DispatchRecordTask(const AsyncTaskPtr & task)159   void DispatchRecordTask(const AsyncTaskPtr &task) const { record_task_queue_->Push(task); }
DispatchGenerateTask(const AsyncTaskPtr & task)160   void DispatchGenerateTask(const AsyncTaskPtr &task) const { generate_task_queue_->Push(task); }
DispatchRunTask(const AsyncTaskPtr & task)161   void DispatchRunTask(const AsyncTaskPtr &task) const { run_task_queue_->Push(task); }
162 
163  private:
164   AsyncQueuePtr record_task_queue_;
165   AsyncQueuePtr generate_task_queue_;
166   AsyncQueuePtr run_task_queue_;
167 };
168 
169 using AsyncTaskManagerPtr = std::shared_ptr<AsyncTaskManager>;
170 }  // namespace pijit
171 }  // namespace mindspore
172 
173 #endif  // MINDSPORE_PI_JIT_ASYNC_TASK_MANAGERER_H_
174