1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_ASYNC_TASK_MANAGERER_H_ 18 #define MINDSPORE_PI_JIT_ASYNC_TASK_MANAGERER_H_ 19 20 #include <functional> 21 #include <atomic> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 #include "ir/anf.h" 27 #include "runtime/pipeline/async_rqueue.h" 28 #include "pybind11/stl.h" 29 30 namespace mindspore { 31 namespace pijit { 32 namespace py = pybind11; 33 using RecordFunc = std::function<void(const py::object &prim, const py::object &out, const py::list &inputs)>; 34 class AsyncTaskMultiWorker : public std::enable_shared_from_this<AsyncTaskMultiWorker> { 35 public: AsyncTaskMultiWorker(runtime::TaskType task_type)36 explicit AsyncTaskMultiWorker(runtime::TaskType task_type) : comp_count_(0), task_type_(task_type), done_(false) {} 37 virtual ~AsyncTaskMultiWorker() = default; 38 virtual void Run() = 0; task_type()39 runtime::TaskType task_type() const { return task_type_; } 40 void Depend(std::shared_ptr<AsyncTaskMultiWorker> task); 41 void DependOn(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks); 42 void Notify(); 43 void NotifyTo(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks); 44 bool Available(); 45 void Reset(); 46 void RunWrapper(); Done()47 bool Done() const { return done_; } 48 49 protected: 50 std::vector<std::shared_ptr<AsyncTaskMultiWorker>> depends_; 51 std::vector<std::shared_ptr<AsyncTaskMultiWorker>> notifies_; 52 std::atomic<size_t> comp_count_; 53 runtime::TaskType task_type_; 54 bool done_; 55 }; 56 #ifdef USE_ASYNC_SINGLE_WORKER 57 using AsyncTask = runtime::AsyncTask; 58 using AsyncTaskPtr = std::shared_ptr<runtime::AsyncTask>; 59 #else 60 using AsyncTask = AsyncTaskMultiWorker; 61 using AsyncTaskPtr = std::shared_ptr<AsyncTaskMultiWorker>; 62 #endif 63 64 class AsyncQueueMultiWorker { 65 public: 66 AsyncQueueMultiWorker(std::string name, runtime::kThreadWaitLevel wait_level, size_t worker_count = 8); 67 virtual ~AsyncQueueMultiWorker(); 68 69 void Push(const AsyncTaskPtr &task); 70 void Wait(); 71 bool Empty(); 72 void Clear(); 73 void WorkerJoin(); 74 75 protected: 76 bool Available(); 77 AsyncTaskPtr PopAvailable(); 78 AsyncTaskPtr Pop(); 79 void WorkerLoop(); 80 std::vector<std::unique_ptr<std::thread>> workers_; 81 std::mutex mutex_; 82 std::condition_variable ready_cv_; 83 std::condition_variable task_cv_; 84 std::vector<AsyncTaskPtr> tasks_queue_; 85 std::vector<AsyncTaskPtr> wait_queue_; 86 std::string name_; 87 runtime::kThreadWaitLevel wait_level_; 88 size_t worker_cnt_; 89 size_t ready_cnt_; 90 bool terminate_; 91 }; 92 #ifdef USE_ASYNC_SINGLE_WORKER 93 using AsyncQueue = runtime::AsyncRQueue; 94 using AsyncQueuePtr = AsyncRQueuePtr; 95 #else 96 using AsyncQueue = AsyncQueueMultiWorker; 97 using AsyncQueuePtr = std::shared_ptr<AsyncQueueMultiWorker>; 98 #endif 99 100 class RecordTask : public mindspore::pijit::AsyncTask { 101 public: RecordTask(RecordFunc task,const py::object & prim,const py::object & out,const py::list & inputs)102 explicit RecordTask(RecordFunc task, const py::object &prim, const py::object &out, const py::list &inputs) 103 : mindspore::pijit::AsyncTask(runtime::kBpropTask), 104 run_task_(std::move(task)), 105 prim_(prim), 106 out_(out), 107 inputs_(inputs) {} 108 ~RecordTask() override = default; 109 void Run() override; 110 111 private: 112 RecordFunc run_task_; 113 py::object prim_; 114 py::object out_; 115 py::list inputs_; 116 }; 117 118 using RecordTaskPtr = std::shared_ptr<RecordTask>; 119 120 class RunGenerateBpropTask : public mindspore::pijit::AsyncTask { 121 public: RunGenerateBpropTask(std::function<void ()> task)122 explicit RunGenerateBpropTask(std::function<void()> task) 123 : mindspore::pijit::AsyncTask(runtime::kBpropTask), run_task_(std::move(task)) {} 124 ~RunGenerateBpropTask() override = default; 125 void Run() override; 126 127 private: 128 std::function<void()> run_task_; 129 }; 130 131 using RunGenerateBpropTaskPtr = std::shared_ptr<RunGenerateBpropTask>; 132 133 class RunBpropTask : public mindspore::pijit::AsyncTask { 134 public: RunBpropTask(std::function<void (const ValuePtr & value)> task,const ValuePtr & value)135 explicit RunBpropTask(std::function<void(const ValuePtr &value)> task, const ValuePtr &value) 136 : mindspore::pijit::AsyncTask(runtime::kBpropTask), run_task_(std::move(task)), value_(value) {} 137 ~RunBpropTask() override = default; 138 void Run() override; 139 140 private: 141 std::function<void(const ValuePtr &value)> run_task_; 142 ValuePtr value_; 143 }; 144 145 using RunBpropTaskPtr = std::shared_ptr<RunBpropTask>; 146 147 using Level = runtime::kThreadWaitLevel; 148 class AsyncTaskManager { 149 public: AsyncTaskManager()150 AsyncTaskManager() 151 : record_task_queue_(std::make_shared<AsyncQueue>("record_task_queue", Level::kLevelGrad)), 152 generate_task_queue_(std::make_shared<AsyncQueue>("generate_task_queue", Level::kLevelGrad)), 153 run_task_queue_(std::make_shared<AsyncQueue>("run_task_queue", Level::kLevelGrad)) {} 154 virtual ~AsyncTaskManager() = default; 155 GetRecordTaskQueue()156 const AsyncQueuePtr &GetRecordTaskQueue() const { return record_task_queue_; } GetGenerateTaskQueue()157 const AsyncQueuePtr &GetGenerateTaskQueue() const { return generate_task_queue_; } GetRunTaskQueue()158 const AsyncQueuePtr &GetRunTaskQueue() const { return run_task_queue_; } DispatchRecordTask(const AsyncTaskPtr & task)159 void DispatchRecordTask(const AsyncTaskPtr &task) const { record_task_queue_->Push(task); } DispatchGenerateTask(const AsyncTaskPtr & task)160 void DispatchGenerateTask(const AsyncTaskPtr &task) const { generate_task_queue_->Push(task); } DispatchRunTask(const AsyncTaskPtr & task)161 void DispatchRunTask(const AsyncTaskPtr &task) const { run_task_queue_->Push(task); } 162 163 private: 164 AsyncQueuePtr record_task_queue_; 165 AsyncQueuePtr generate_task_queue_; 166 AsyncQueuePtr run_task_queue_; 167 }; 168 169 using AsyncTaskManagerPtr = std::shared_ptr<AsyncTaskManager>; 170 } // namespace pijit 171 } // namespace mindspore 172 173 #endif // MINDSPORE_PI_JIT_ASYNC_TASK_MANAGERER_H_ 174