1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/session/executor.h"
18 #include "backend/session/executor_manager.h"
19 #include <algorithm>
20 #include <exception>
21 #include <set>
22 #include "runtime/device/kernel_runtime_manager.h"
23 #include "utils/comm_manager.h"
24 #include "utils/scoped_long_running.h"
25 #include "pybind_api/ir/tensor_py.h"
26 #if ((defined ENABLE_CPU) && (!defined _WIN32))
27 #include "ps/ps_cache/ps_cache_manager.h"
28 #endif
29
30 using mindspore::tensor::TensorPy;
31 namespace mindspore {
32 namespace session {
33 namespace {
GetNeedNotifyTensors(const VectorRef * outputs,std::set<TensorPtr> * result)34 void GetNeedNotifyTensors(const VectorRef *outputs, std::set<TensorPtr> *result) {
35 MS_EXCEPTION_IF_NULL(outputs);
36 MS_EXCEPTION_IF_NULL(result);
37 for (auto &item : *outputs) {
38 if (utils::isa<VectorRefPtr>(item)) {
39 auto vector_ref = utils::cast<VectorRef>(item);
40 GetNeedNotifyTensors(&vector_ref, result);
41 } else if (utils::isa<tensor::TensorPtr>(item)) {
42 auto tensor = utils::cast<tensor::TensorPtr>(item);
43 result->emplace(tensor);
44 }
45 }
46 }
47
TensorInVector(const VectorRef * outputs)48 bool TensorInVector(const VectorRef *outputs) {
49 MS_EXCEPTION_IF_NULL(outputs);
50 for (auto &item : *outputs) {
51 if (utils::isa<VectorRefPtr>(item)) {
52 auto vector_ref = utils::cast<VectorRef>(item);
53 if (TensorInVector(&vector_ref)) {
54 return true;
55 }
56 } else if (utils::isa<tensor::TensorPtr>(item)) {
57 return true;
58 }
59 }
60 return false;
61 }
62
IsTaskReady(const std::shared_ptr<RunGraphTask> & task)63 bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
64 MS_EXCEPTION_IF_NULL(task);
65 for (auto &input : task->input_need_wait_tensors_) {
66 MS_EXCEPTION_IF_NULL(input);
67 if (input->NeedWait()) {
68 return false;
69 }
70 }
71 auto session = task->session_;
72 MS_EXCEPTION_IF_NULL(session);
73 auto graph = session->GetGraph(task->graph_id_);
74 if (graph != nullptr) {
75 return graph->IsPreGraphFinished();
76 }
77 return true;
78 }
79
WaitLockedInputs(const std::shared_ptr<RunGraphTask> & task)80 void WaitLockedInputs(const std::shared_ptr<RunGraphTask> &task) {
81 bool need_lock = false;
82 for (auto &tensor : task->input_tensors_) {
83 if (tensor->NeedWait()) {
84 if (tensor->IsGraphOutput()) {
85 task->input_need_wait_tensors_.emplace_back(tensor);
86 } else {
87 need_lock = true;
88 }
89 }
90 }
91 if (need_lock) {
92 mindspore::ScopedLongRunning long_running;
93 for (auto &input_tensor : task->input_tensors_) {
94 if (input_tensor->NeedWait() && !input_tensor->IsGraphOutput()) {
95 MsException::Instance().CheckException();
96 input_tensor->Wait();
97 }
98 }
99 MsException::Instance().CheckException();
100 }
101 // need lock input parameters for optimizer
102 for (auto &need_lock_tensor : task->input_need_lock_tensors_) {
103 need_lock_tensor->SetNeedWait(true);
104 }
105 }
106 } // namespace
107
Run()108 void CompileNodesTask::Run() {
109 MS_EXCEPTION_IF_NULL(session_);
110 MS_EXCEPTION_IF_NULL(segment_);
111 graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
112 }
113
Run()114 void CompileGraphTask::Run() {
115 MS_EXCEPTION_IF_NULL(session_);
116 graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
117 }
118
Run()119 void BuildGraphTask::Run() {
120 MS_EXCEPTION_IF_NULL(session_);
121 session_->BuildGraphImpl(graph_id_);
122 }
123
Run()124 void RunGraphTask::Run() {
125 MS_EXCEPTION_IF_NULL(session_);
126 MS_LOG(INFO) << "Start run graph " << graph_id_;
127 auto graph = session_->GetGraph(graph_id_);
128 if (graph == nullptr) {
129 MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
130 return;
131 }
132 graph->ResetGraphRunningStatus();
133 try {
134 session_->LoadInputs(graph_id_, input_tensors_);
135 session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
136 std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
137 session_->UpdateOutputTensors(&outputs_, tensor_to_node_, &new_to_old_device_address);
138 } catch (const std::exception &e) {
139 session_->ReportErrorMessage();
140 ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
141 MsException::Instance().SetException();
142 }
143 MS_LOG(INFO) << "End run graph " << graph_id_;
144 graph->OnRunGraphFinished();
145 std::set<TensorPtr> need_notify_tensors(input_need_lock_tensors_.begin(), input_need_lock_tensors_.end());
146 GetNeedNotifyTensors(&outputs_, &need_notify_tensors);
147 for (auto &tensor : need_notify_tensors) {
148 if (tensor != nullptr) {
149 tensor->SetNeedWait(false);
150 }
151 }
152 ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
153 }
154
Run()155 void RunOpTask::Run() {
156 MS_EXCEPTION_IF_NULL(session_);
157 session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
158 }
159
Run()160 void RunOpsInGraphTask::Run() {
161 MS_EXCEPTION_IF_NULL(session_);
162 session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
163 }
164
Run()165 void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
166
Run()167 void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
168
Executor(const std::string & device_name,uint32_t device_id)169 Executor::Executor(const std::string &device_name, uint32_t device_id) {
170 device_name_ = device_name;
171 device_id_ = device_id;
172 worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
173 }
174
~Executor()175 Executor::~Executor() {
176 try {
177 WorkerJoin();
178 } catch (const std::exception &e) {
179 MS_LOG(ERROR) << "Executor call destructor failed: " << e.what();
180 } catch (...) {
181 MS_LOG(ERROR) << "KernelGraph call destructor failed";
182 }
183 }
184
WorkerJoin()185 void Executor::WorkerJoin() {
186 // Avoid worker thread join itself which will cause deadlock
187 if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
188 {
189 std::lock_guard<std::mutex> lock(task_mutex_);
190 auto task = std::make_shared<ExitTask>();
191 ready_tasks_.push(task);
192 task_cond_var_.notify_all();
193 }
194 worker_->join();
195 }
196 }
197
WorkerLoop()198 void Executor::WorkerLoop() {
199 while (true) {
200 std::shared_ptr<Task> task;
201 {
202 std::unique_lock<std::mutex> lock(task_mutex_);
203 task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
204 task = ready_tasks_.front();
205 ready_tasks_.pop();
206 }
207 MS_EXCEPTION_IF_NULL(task);
208 enum TaskType task_type = task->type_;
209 bool task_sync_flag = task->sync_run_;
210 if (task_type == kExit) {
211 OnWorkerExit();
212 return;
213 }
214 try {
215 if (task->session_ != nullptr) {
216 task->session_->SetThreadContext();
217 }
218 task->Run();
219 if (task->session_ != nullptr) {
220 task->session_->ReportWarningMessage();
221 }
222 } catch (const std::exception &e) {
223 if (task->session_ != nullptr) {
224 task->session_->ReportErrorMessage();
225 }
226 ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
227 MsException::Instance().SetException();
228 }
229 {
230 std::lock_guard<std::mutex> lock(done_task_mutex_);
231 done_tasks_.emplace_back(std::move(task));
232 }
233 if (task_type != kRunGraph || task_sync_flag) {
234 std::lock_guard<std::mutex> lock(task_mutex_);
235 sync_run_task_finished_ = true;
236 sync_cond_var_.notify_all();
237 }
238 }
239 }
240
GetReadyTasksFromPendingList()241 std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
242 std::vector<std::shared_ptr<RunGraphTask>> ready_tasks;
243 std::lock_guard<std::mutex> lock(pending_task_mutex_);
244 for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
245 auto task = *iter;
246 if (IsTaskReady(task)) {
247 (void)ready_tasks.emplace_back(task);
248 pending_tasks_.erase(iter++);
249 } else {
250 ++iter;
251 }
252 }
253 return ready_tasks;
254 }
255
OnEvent(const ExecutorEvent & event)256 void Executor::OnEvent(const ExecutorEvent &event) {
257 if (event == ExecutorEvent::kRunGraphFinished) {
258 OnRunGraphFinished();
259 } else if (event == ExecutorEvent::kClear) {
260 OnClear();
261 } else if (event == ExecutorEvent::kException) {
262 OnException();
263 }
264 }
265
OnClear()266 void Executor::OnClear() {
267 {
268 mindspore::ScopedLongRunning long_running;
269 WorkerJoin();
270 }
271 ClearDoneTasks();
272 }
273
OnException()274 void Executor::OnException() {
275 std::vector<std::shared_ptr<Task>> done_tasks;
276 {
277 std::lock_guard<std::mutex> lock(task_mutex_);
278 while (!ready_tasks_.empty()) {
279 (void)done_tasks.emplace_back(ready_tasks_.front());
280 ready_tasks_.pop();
281 }
282 }
283 {
284 std::lock_guard<std::mutex> lock(pending_task_mutex_);
285 (void)std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(done_tasks));
286 pending_tasks_.clear();
287 }
288 {
289 std::lock_guard<std::mutex> lock(done_task_mutex_);
290 (void)done_tasks_.insert(done_tasks_.end(), done_tasks.begin(), done_tasks.end());
291 }
292 }
293
OnRunGraphFinished()294 void Executor::OnRunGraphFinished() {
295 auto ready_tasks = GetReadyTasksFromPendingList();
296 std::lock_guard<std::mutex> lock(task_mutex_);
297 for (auto &task : ready_tasks) {
298 ready_tasks_.push(task);
299 }
300 if (!ready_tasks.empty()) {
301 task_cond_var_.notify_all();
302 }
303 reenter_cond_var_.notify_all();
304 }
305
ClearDoneTasks()306 void Executor::ClearDoneTasks() {
307 std::lock_guard<std::mutex> lock(done_task_mutex_);
308 done_tasks_.clear();
309 }
310
RunTask(const std::shared_ptr<Task> & task,bool sync,bool long_run)311 void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
312 if (sync) {
313 ClearDoneTasks();
314 }
315 {
316 std::lock_guard<std::mutex> lock(task_mutex_);
317 sync_run_task_finished_ = false;
318 ready_tasks_.push(task);
319 }
320 task_cond_var_.notify_all();
321 if (sync && !sync_run_task_finished_) {
322 std::unique_lock<std::mutex> lock(task_mutex_);
323 if (sync && long_run) {
324 mindspore::ScopedLongRunning long_running;
325 sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
326 } else {
327 sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
328 }
329 }
330 ClearDoneTasks();
331 MsException::Instance().CheckException();
332 }
333
CompileGraph(const SessionPtr & session,const GraphSegmentPtr & segment,const AnfNodePtrList & outputs)334 GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
335 const AnfNodePtrList &outputs) {
336 auto task = std::make_shared<CompileNodesTask>();
337 task->session_ = session;
338 task->segment_ = segment;
339 task->output_nodes_ = outputs;
340 RunTask(task, true);
341 return task->graph_id_;
342 }
343
CompileGraph(const SessionPtr & session,NotNull<FuncGraphPtr> func_graph)344 GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
345 auto task = std::make_shared<CompileGraphTask>();
346 task->session_ = session;
347 task->func_graph_ = func_graph.get();
348 RunTask(task, true);
349 return task->graph_id_;
350 }
351
BuildGraph(const SessionPtr & session,GraphId graphId)352 void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
353 auto task = std::make_shared<BuildGraphTask>();
354 task->session_ = session;
355 task->graph_id_ = graphId;
356 RunTask(task, true);
357 }
358
RunGraph(const SessionPtr & session,const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)359 void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
360 const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
361 MS_EXCEPTION_IF_NULL(session);
362 MS_EXCEPTION_IF_NULL(outputs);
363 auto task = std::make_shared<RunGraphTask>();
364 task->session_ = session;
365 task->graph_id_ = graph_id;
366 task->input_tensors_ = inputs;
367 session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
368 task->outputs_ = *outputs;
369 task->sync_run_ = true;
370 RunTask(task, true, true);
371 }
372
RunGraphAsync(const SessionPtr & session,const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)373 void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
374 const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
375 MS_EXCEPTION_IF_NULL(session);
376 MS_EXCEPTION_IF_NULL(outputs);
377 auto task = std::make_shared<RunGraphTask>();
378 task->session_ = session;
379 task->graph_id_ = graph_id;
380 task->input_tensors_ = inputs;
381 task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
382 auto graph = session->GetGraph(task->graph_id_);
383 if (graph != nullptr && !graph->IsPostGraphFinished()) {
384 mindspore::ScopedLongRunning long_running;
385 std::unique_lock<std::mutex> lock(reenter_mutex_);
386 reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
387 MsException::Instance().CheckException();
388 }
389 session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
390 // maintain a copy of output vector
391 task->outputs_ = *outputs;
392
393 // Run graph synchronously when the graph require gil.
394 if (graph != nullptr && graph->is_need_gil()) {
395 std::unique_lock<std::mutex> lock(reenter_mutex_);
396 reenter_cond_var_.wait(lock, [&graph] { return graph->IsPreGraphFinished(); });
397 MsException::Instance().CheckException();
398 task->sync_run_ = true;
399 RunTask(task, true, true);
400 return;
401 }
402
403 // sync run graph without output tensor(int dataset graph)
404 if ((!TensorInVector(outputs) && !graph->HasPostGraph())) {
405 task->sync_run_ = true;
406 RunTask(task, true, true);
407 return;
408 }
409 WaitLockedInputs(task);
410 for (auto &tensor_node : task->tensor_to_node_) {
411 tensor_node.first->SetNeedWait(true);
412 }
413 {
414 std::lock_guard<std::mutex> lock(pending_task_mutex_);
415 if (!IsTaskReady(task)) {
416 ClearDoneTasks();
417 pending_tasks_.push_back(task);
418 return;
419 }
420 }
421 RunTask(task, false);
422 }
423
RunOp(const SessionPtr & session,OpRunInfo * op_run_info,const GraphInfo & graph_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)424 void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
425 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
426 const std::vector<int64_t> &tensors_mask) {
427 MS_EXCEPTION_IF_NULL(session);
428 MS_EXCEPTION_IF_NULL(input_tensors);
429 MS_EXCEPTION_IF_NULL(outputs);
430 MS_EXCEPTION_IF_NULL(op_run_info);
431 auto ms_context = MsContext::GetInstance();
432 auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
433 if (target == kGPUDevice) {
434 for (auto &tensor : *input_tensors) {
435 if (tensor->NeedWait()) {
436 tensor->Wait();
437 }
438 }
439 {
440 // Release GIL before calling into (potentially long-running) C++ code
441 if (Py_IsInitialized()) {
442 py::gil_scoped_release release;
443 session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
444 } else {
445 session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
446 }
447 }
448 } else {
449 auto task = std::make_shared<RunOpTask>();
450 task->session_ = session;
451 task->op_run_info_ = op_run_info;
452 task->graph_info_ = graph_info;
453 task->input_tensors_ = input_tensors;
454 task->tensors_mask_ = tensors_mask;
455 for (auto &tensor : *input_tensors) {
456 if (tensor->NeedWait()) {
457 tensor->Wait();
458 }
459 }
460 RunTask(task, true, true);
461 *outputs = task->outputs_;
462 }
463 }
464
RunOpsInGraph(const SessionPtr & session,const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)465 void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
466 const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
467 MS_EXCEPTION_IF_NULL(session);
468 MS_EXCEPTION_IF_NULL(outputs);
469 auto task = std::make_shared<RunOpsInGraphTask>();
470 task->session_ = session;
471 task->graph_id_ = graph_id;
472 task->input_tensors_ = inputs;
473 RunTask(task, true, true);
474 *outputs = task->outputs_;
475 }
476
CreateCommGroup(const std::string & group_name,const std::vector<uint32_t> & ranks)477 bool Executor::CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks) {
478 auto task = std::make_shared<CreateCommGroupTask>();
479 task->group_name_ = group_name;
480 task->ranks_ = ranks;
481 RunTask(task, true);
482 return task->result_;
483 }
484
DestroyCommGroup(const std::string & group_name)485 bool Executor::DestroyCommGroup(const std::string &group_name) {
486 auto task = std::make_shared<DestroyCommGroupTask>();
487 task->group_name_ = group_name;
488 RunTask(task, true);
489 return task->result_;
490 }
491
OnWorkerExit()492 void Executor::OnWorkerExit() {
493 if (device_name_ == kAscendDevice) {
494 device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
495 }
496 }
497 } // namespace session
498 } // namespace mindspore
499