• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/auto_grad/async_task_manager.h"
17 #include <string>
18 #include <vector>
19 #include "include/common/profiler.h"
20 
21 namespace mindspore {
22 namespace pijit {
23 
Depend(std::shared_ptr<AsyncTaskMultiWorker> task)24 void AsyncTaskMultiWorker::Depend(std::shared_ptr<AsyncTaskMultiWorker> task) {
25   depends_.push_back(task);
26   task->notifies_.push_back(shared_from_this());
27   if (task->Done()) {
28     comp_count_++;
29   }
30 }
31 
DependOn(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> * tasks)32 void AsyncTaskMultiWorker::DependOn(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks) {
33   if (tasks != nullptr) {
34     tasks->clear();
35     tasks->insert(tasks->begin(), depends_.begin(), depends_.end());
36   }
37 }
38 
Notify()39 void AsyncTaskMultiWorker::Notify() {
40   for (auto task : notifies_) {
41     task->comp_count_++;
42   }
43 }
44 
NotifyTo(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> * tasks)45 void AsyncTaskMultiWorker::NotifyTo(std::vector<std::shared_ptr<AsyncTaskMultiWorker>> *tasks) {
46   if (tasks != nullptr) {
47     tasks->clear();
48     tasks->insert(tasks->begin(), notifies_.begin(), notifies_.end());
49   }
50 }
51 
Available()52 bool AsyncTaskMultiWorker::Available() { return comp_count_ == depends_.size(); }
53 
Reset()54 void AsyncTaskMultiWorker::Reset() {
55   comp_count_ = 0;
56   done_ = false;
57 }
58 
RunWrapper()59 void AsyncTaskMultiWorker::RunWrapper() {
60   Run();
61   done_ = true;
62   Notify();
63 }
64 
AsyncQueueMultiWorker(std::string name,runtime::kThreadWaitLevel wait_level,size_t worker_count)65 AsyncQueueMultiWorker::AsyncQueueMultiWorker(std::string name, runtime::kThreadWaitLevel wait_level,
66                                              size_t worker_count)
67     : name_(name), wait_level_(wait_level), worker_cnt_(worker_count), ready_cnt_(0), terminate_(false) {}
68 
~AsyncQueueMultiWorker()69 AsyncQueueMultiWorker::~AsyncQueueMultiWorker() { WorkerJoin(); }
70 
Push(const AsyncTaskPtr & task)71 void AsyncQueueMultiWorker::Push(const AsyncTaskPtr &task) {
72   while (workers_.size() < worker_cnt_) {
73     workers_.emplace_back(std::make_unique<std::thread>(&AsyncQueueMultiWorker::WorkerLoop, this));
74   }
75   std::unique_lock<std::mutex> lock(mutex_);
76   if (task->Available()) {
77     tasks_queue_.push_back(task);
78   } else {
79     wait_queue_.push_back(task);
80   }
81   lock.unlock();
82   task_cv_.notify_one();
83 }
84 
Wait()85 void AsyncQueueMultiWorker::Wait() {
86   if (workers_.size() == 0) {
87     return;
88   }
89   std::unique_lock<std::mutex> lock(mutex_);
90   ready_cv_.wait(lock, [this] { return tasks_queue_.size() == 0 && ready_cnt_ == worker_cnt_; });
91 }
92 
Empty()93 bool AsyncQueueMultiWorker::Empty() { return tasks_queue_.size() == 0; }
94 
Clear()95 void AsyncQueueMultiWorker::Clear() {
96   std::unique_lock<std::mutex> lock(mutex_);
97   tasks_queue_.clear();
98 }
99 
WorkerJoin()100 void AsyncQueueMultiWorker::WorkerJoin() {
101   std::unique_lock<std::mutex> lock(mutex_);
102   tasks_queue_.clear();
103   terminate_ = true;
104   lock.unlock();
105   task_cv_.notify_all();
106   for (size_t w = 0; w < workers_.size(); ++w) {
107     if (workers_[w]->joinable()) {
108       workers_[w]->join();
109     }
110   }
111 }
112 
Available()113 bool AsyncQueueMultiWorker::Available() { return tasks_queue_.size() > 0; }
114 
PopAvailable()115 AsyncTaskPtr AsyncQueueMultiWorker::PopAvailable() {
116   auto iter = std::find_if(tasks_queue_.begin(), tasks_queue_.end(), [](auto &task) { return task->Available(); });
117   if (iter != tasks_queue_.end()) {
118     AsyncTaskPtr ret = *iter;
119     tasks_queue_.erase(iter);
120     return ret;
121   } else {
122     return nullptr;
123   }
124 }
125 
Pop()126 AsyncTaskPtr AsyncQueueMultiWorker::Pop() {
127   std::unique_lock<std::mutex> lock(mutex_);
128   auto task = PopAvailable();
129   if (task != nullptr) {
130     return task;
131   } else {
132     ready_cnt_++;
133     if (ready_cnt_ == worker_cnt_) {
134       ready_cv_.notify_one();
135     }
136     task_cv_.wait(lock, [this] { return Available() || terminate_; });
137     AsyncTaskPtr ret = PopAvailable();
138     if (ret != nullptr) {
139       ready_cnt_--;
140     } else {
141       if (ready_cnt_ == worker_cnt_) {
142         lock.unlock();
143         ready_cv_.notify_one();
144       }
145     }
146     return ret;
147   }
148 }
149 
WorkerLoop()150 void AsyncQueueMultiWorker::WorkerLoop() {
151   while (!terminate_) {
152     auto task = Pop();
153     if (task != nullptr) {
154       task->RunWrapper();
155     }
156     std::unique_lock<std::mutex> lock(mutex_);
157     if (tasks_queue_.size() != 0) {
158       return;
159     }
160     for (auto iter = wait_queue_.begin(); iter != wait_queue_.end();) {
161       if (!(*iter)->Available()) {
162         iter++;
163       } else {
164         tasks_queue_.push_back((*iter));
165         iter = wait_queue_.erase(iter);
166       }
167     }
168   }
169 }
170 
Run()171 void RecordTask::Run() {
172   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeBpropTask,
173                                      runtime::ProfilerRecorder::kNoName, false);
174   MS_LOG(DEBUG) << "Gradient record task start...";
175   run_task_(prim_, out_, inputs_);
176   run_task_ = nullptr;
177   MS_LOG(DEBUG) << "Gradient record task finished.";
178 }
179 
Run()180 void RunGenerateBpropTask::Run() {
181   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeBpropTask,
182                                      runtime::ProfilerRecorder::kNoName, false);
183   MS_LOG(DEBUG) << "Generate bprop graph task start...";
184   run_task_();
185   run_task_ = nullptr;
186   MS_LOG(DEBUG) << "Generate bprop graph task finished.";
187 }
188 
Run()189 void RunBpropTask::Run() {
190   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeBpropTask,
191                                      runtime::ProfilerRecorder::kNoName, false);
192   MS_LOG(DEBUG) << "Run gradient bprop graph task start...";
193   run_task_(value_);
194   run_task_ = nullptr;
195   MS_LOG(DEBUG) << "Run gradient bprop graph task finished.";
196 }
197 }  // namespace pijit
198 }  // namespace mindspore
199