1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/auto_grad/async_task_manager.h"
17 #include <string>
18 #include <vector>
19 #include "include/common/profiler.h"
20
21 namespace mindspore {
22 namespace pijit {
23
Depend(std::shared_ptr<AsyncTaskMultiWorker> task)24 void AsyncTaskMultiWorker::Depend(std::shared_ptr<AsyncTaskMultiWorker> task) {
25 depends_.push_back(task);
26 task->notifies_.push_back(shared_from_this());
27 if (task->Done()) {
28 comp_count_++;
29 }
30 }
31
DependOn(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> * tasks)32 void AsyncTaskMultiWorker::DependOn(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks) {
33 if (tasks != nullptr) {
34 tasks->clear();
35 tasks->insert(tasks->begin(), depends_.begin(), depends_.end());
36 }
37 }
38
Notify()39 void AsyncTaskMultiWorker::Notify() {
40 for (auto task : notifies_) {
41 task->comp_count_++;
42 }
43 }
44
NotifyTo(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> * tasks)45 void AsyncTaskMultiWorker::NotifyTo(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks) {
46 if (tasks != nullptr) {
47 tasks->clear();
48 tasks->insert(tasks->begin(), notifies_.begin(), notifies_.end());
49 }
50 }
51
Available()52 bool AsyncTaskMultiWorker::Available() { return comp_count_ == depends_.size(); }
53
Reset()54 void AsyncTaskMultiWorker::Reset() {
55 comp_count_ = 0;
56 done_ = false;
57 }
58
RunWrapper()59 void AsyncTaskMultiWorker::RunWrapper() {
60 Run();
61 done_ = true;
62 Notify();
63 }
64
AsyncQueueMultiWorker(std::string name,runtime::kThreadWaitLevel wait_level,size_t worker_count)65 AsyncQueueMultiWorker::AsyncQueueMultiWorker(std::string name, runtime::kThreadWaitLevel wait_level,
66 size_t worker_count)
67 : name_(name), wait_level_(wait_level), worker_cnt_(worker_count), ready_cnt_(0), terminate_(false) {}
68
~AsyncQueueMultiWorker()69 AsyncQueueMultiWorker::~AsyncQueueMultiWorker() { WorkerJoin(); }
70
Push(const AsyncTaskPtr & task)71 void AsyncQueueMultiWorker::Push(const AsyncTaskPtr &task) {
72 while (workers_.size() < worker_cnt_) {
73 workers_.emplace_back(std::make_unique<std::thread>(&AsyncQueueMultiWorker::WorkerLoop, this));
74 }
75 std::unique_lock<std::mutex> lock(mutex_);
76 if (task->Available()) {
77 tasks_queue_.push_back(task);
78 } else {
79 wait_queue_.push_back(task);
80 }
81 lock.unlock();
82 task_cv_.notify_one();
83 }
84
Wait()85 void AsyncQueueMultiWorker::Wait() {
86 if (workers_.size() == 0) {
87 return;
88 }
89 std::unique_lock<std::mutex> lock(mutex_);
90 ready_cv_.wait(lock, [this] { return tasks_queue_.size() == 0 && ready_cnt_ == worker_cnt_; });
91 }
92
Empty()93 bool AsyncQueueMultiWorker::Empty() { return tasks_queue_.size() == 0; }
94
Clear()95 void AsyncQueueMultiWorker::Clear() {
96 std::unique_lock<std::mutex> lock(mutex_);
97 tasks_queue_.clear();
98 }
99
WorkerJoin()100 void AsyncQueueMultiWorker::WorkerJoin() {
101 std::unique_lock<std::mutex> lock(mutex_);
102 tasks_queue_.clear();
103 terminate_ = true;
104 lock.unlock();
105 task_cv_.notify_all();
106 for (size_t w = 0; w < workers_.size(); ++w) {
107 if (workers_[w]->joinable()) {
108 workers_[w]->join();
109 }
110 }
111 }
112
Available()113 bool AsyncQueueMultiWorker::Available() { return tasks_queue_.size() > 0; }
114
PopAvailable()115 AsyncTaskPtr AsyncQueueMultiWorker::PopAvailable() {
116 auto iter = std::find_if(tasks_queue_.begin(), tasks_queue_.end(), [](auto &task) { return task->Available(); });
117 if (iter != tasks_queue_.end()) {
118 AsyncTaskPtr ret = *iter;
119 tasks_queue_.erase(iter);
120 return ret;
121 } else {
122 return nullptr;
123 }
124 }
125
Pop()126 AsyncTaskPtr AsyncQueueMultiWorker::Pop() {
127 std::unique_lock<std::mutex> lock(mutex_);
128 auto task = PopAvailable();
129 if (task != nullptr) {
130 return task;
131 } else {
132 ready_cnt_++;
133 if (ready_cnt_ == worker_cnt_) {
134 ready_cv_.notify_one();
135 }
136 task_cv_.wait(lock, [this] { return Available() || terminate_; });
137 AsyncTaskPtr ret = PopAvailable();
138 if (ret != nullptr) {
139 ready_cnt_--;
140 } else {
141 if (ready_cnt_ == worker_cnt_) {
142 lock.unlock();
143 ready_cv_.notify_one();
144 }
145 }
146 return ret;
147 }
148 }
149
WorkerLoop()150 void AsyncQueueMultiWorker::WorkerLoop() {
151 while (!terminate_) {
152 auto task = Pop();
153 if (task != nullptr) {
154 task->RunWrapper();
155 }
156 std::unique_lock<std::mutex> lock(mutex_);
157 if (tasks_queue_.size() != 0) {
158 return;
159 }
160 for (auto iter = wait_queue_.begin(); iter != wait_queue_.end();) {
161 if (!(*iter)->Available()) {
162 iter++;
163 } else {
164 tasks_queue_.push_back((*iter));
165 iter = wait_queue_.erase(iter);
166 }
167 }
168 }
169 }
170
Run()171 void RecordTask::Run() {
172 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeBpropTask,
173 runtime::ProfilerRecorder::kNoName, false);
174 MS_LOG(DEBUG) << "Gradient record task start...";
175 run_task_(prim_, out_, inputs_);
176 run_task_ = nullptr;
177 MS_LOG(DEBUG) << "Gradient record task finished.";
178 }
179
Run()180 void RunGenerateBpropTask::Run() {
181 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeBpropTask,
182 runtime::ProfilerRecorder::kNoName, false);
183 MS_LOG(DEBUG) << "Generate bprop graph task start...";
184 run_task_();
185 run_task_ = nullptr;
186 MS_LOG(DEBUG) << "Generate bprop graph task finished.";
187 }
188
Run()189 void RunBpropTask::Run() {
190 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeBpropTask,
191 runtime::ProfilerRecorder::kNoName, false);
192 MS_LOG(DEBUG) << "Run gradient bprop graph task start...";
193 run_task_(value_);
194 run_task_ = nullptr;
195 MS_LOG(DEBUG) << "Run gradient bprop graph task finished.";
196 }
197 } // namespace pijit
198 } // namespace mindspore
199