• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/session/executor.h"
18 #include "backend/session/executor_manager.h"
19 #include <algorithm>
20 #include <exception>
21 #include <set>
22 #include "runtime/device/kernel_runtime_manager.h"
23 #include "utils/comm_manager.h"
24 #include "utils/scoped_long_running.h"
25 #include "pybind_api/ir/tensor_py.h"
26 #if ((defined ENABLE_CPU) && (!defined _WIN32))
27 #include "ps/ps_cache/ps_cache_manager.h"
28 #endif
29 
30 using mindspore::tensor::TensorPy;
31 namespace mindspore {
32 namespace session {
33 namespace {
GetNeedNotifyTensors(const VectorRef * outputs,std::set<TensorPtr> * result)34 void GetNeedNotifyTensors(const VectorRef *outputs, std::set<TensorPtr> *result) {
35   MS_EXCEPTION_IF_NULL(outputs);
36   MS_EXCEPTION_IF_NULL(result);
37   for (auto &item : *outputs) {
38     if (utils::isa<VectorRefPtr>(item)) {
39       auto vector_ref = utils::cast<VectorRef>(item);
40       GetNeedNotifyTensors(&vector_ref, result);
41     } else if (utils::isa<tensor::TensorPtr>(item)) {
42       auto tensor = utils::cast<tensor::TensorPtr>(item);
43       result->emplace(tensor);
44     }
45   }
46 }
47 
TensorInVector(const VectorRef * outputs)48 bool TensorInVector(const VectorRef *outputs) {
49   MS_EXCEPTION_IF_NULL(outputs);
50   for (auto &item : *outputs) {
51     if (utils::isa<VectorRefPtr>(item)) {
52       auto vector_ref = utils::cast<VectorRef>(item);
53       if (TensorInVector(&vector_ref)) {
54         return true;
55       }
56     } else if (utils::isa<tensor::TensorPtr>(item)) {
57       return true;
58     }
59   }
60   return false;
61 }
62 
IsTaskReady(const std::shared_ptr<RunGraphTask> & task)63 bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
64   MS_EXCEPTION_IF_NULL(task);
65   for (auto &input : task->input_need_wait_tensors_) {
66     MS_EXCEPTION_IF_NULL(input);
67     if (input->NeedWait()) {
68       return false;
69     }
70   }
71   auto session = task->session_;
72   MS_EXCEPTION_IF_NULL(session);
73   auto graph = session->GetGraph(task->graph_id_);
74   if (graph != nullptr) {
75     return graph->IsPreGraphFinished();
76   }
77   return true;
78 }
79 
WaitLockedInputs(const std::shared_ptr<RunGraphTask> & task)80 void WaitLockedInputs(const std::shared_ptr<RunGraphTask> &task) {
81   bool need_lock = false;
82   for (auto &tensor : task->input_tensors_) {
83     if (tensor->NeedWait()) {
84       if (tensor->IsGraphOutput()) {
85         task->input_need_wait_tensors_.emplace_back(tensor);
86       } else {
87         need_lock = true;
88       }
89     }
90   }
91   if (need_lock) {
92     mindspore::ScopedLongRunning long_running;
93     for (auto &input_tensor : task->input_tensors_) {
94       if (input_tensor->NeedWait() && !input_tensor->IsGraphOutput()) {
95         MsException::Instance().CheckException();
96         input_tensor->Wait();
97       }
98     }
99     MsException::Instance().CheckException();
100   }
101   // need lock input parameters for optimizer
102   for (auto &need_lock_tensor : task->input_need_lock_tensors_) {
103     need_lock_tensor->SetNeedWait(true);
104   }
105 }
106 }  // namespace
107 
Run()108 void CompileNodesTask::Run() {
109   MS_EXCEPTION_IF_NULL(session_);
110   MS_EXCEPTION_IF_NULL(segment_);
111   graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
112 }
113 
Run()114 void CompileGraphTask::Run() {
115   MS_EXCEPTION_IF_NULL(session_);
116   graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
117 }
118 
Run()119 void BuildGraphTask::Run() {
120   MS_EXCEPTION_IF_NULL(session_);
121   session_->BuildGraphImpl(graph_id_);
122 }
123 
Run()124 void RunGraphTask::Run() {
125   MS_EXCEPTION_IF_NULL(session_);
126   MS_LOG(INFO) << "Start run graph " << graph_id_;
127   auto graph = session_->GetGraph(graph_id_);
128   if (graph == nullptr) {
129     MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
130     return;
131   }
132   graph->ResetGraphRunningStatus();
133   try {
134     session_->LoadInputs(graph_id_, input_tensors_);
135     session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
136     std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
137     session_->UpdateOutputTensors(&outputs_, tensor_to_node_, &new_to_old_device_address);
138   } catch (const std::exception &e) {
139     session_->ReportErrorMessage();
140     ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
141     MsException::Instance().SetException();
142   }
143   MS_LOG(INFO) << "End run graph " << graph_id_;
144   graph->OnRunGraphFinished();
145   std::set<TensorPtr> need_notify_tensors(input_need_lock_tensors_.begin(), input_need_lock_tensors_.end());
146   GetNeedNotifyTensors(&outputs_, &need_notify_tensors);
147   for (auto &tensor : need_notify_tensors) {
148     if (tensor != nullptr) {
149       tensor->SetNeedWait(false);
150     }
151   }
152   ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
153 }
154 
Run()155 void RunOpTask::Run() {
156   MS_EXCEPTION_IF_NULL(session_);
157   session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
158 }
159 
Run()160 void RunOpsInGraphTask::Run() {
161   MS_EXCEPTION_IF_NULL(session_);
162   session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
163 }
164 
Run()165 void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
166 
Run()167 void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
168 
Executor(const std::string & device_name,uint32_t device_id)169 Executor::Executor(const std::string &device_name, uint32_t device_id) {
170   device_name_ = device_name;
171   device_id_ = device_id;
172   worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
173 }
174 
~Executor()175 Executor::~Executor() {
176   try {
177     WorkerJoin();
178   } catch (const std::exception &e) {
179     MS_LOG(ERROR) << "Executor call destructor failed: " << e.what();
180   } catch (...) {
181     MS_LOG(ERROR) << "KernelGraph call destructor failed";
182   }
183 }
184 
WorkerJoin()185 void Executor::WorkerJoin() {
186   // Avoid worker thread join itself which will cause deadlock
187   if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
188     {
189       std::lock_guard<std::mutex> lock(task_mutex_);
190       auto task = std::make_shared<ExitTask>();
191       ready_tasks_.push(task);
192       task_cond_var_.notify_all();
193     }
194     worker_->join();
195   }
196 }
197 
WorkerLoop()198 void Executor::WorkerLoop() {
199   while (true) {
200     std::shared_ptr<Task> task;
201     {
202       std::unique_lock<std::mutex> lock(task_mutex_);
203       task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
204       task = ready_tasks_.front();
205       ready_tasks_.pop();
206     }
207     MS_EXCEPTION_IF_NULL(task);
208     enum TaskType task_type = task->type_;
209     bool task_sync_flag = task->sync_run_;
210     if (task_type == kExit) {
211       OnWorkerExit();
212       return;
213     }
214     try {
215       if (task->session_ != nullptr) {
216         task->session_->SetThreadContext();
217       }
218       task->Run();
219       if (task->session_ != nullptr) {
220         task->session_->ReportWarningMessage();
221       }
222     } catch (const std::exception &e) {
223       if (task->session_ != nullptr) {
224         task->session_->ReportErrorMessage();
225       }
226       ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
227       MsException::Instance().SetException();
228     }
229     {
230       std::lock_guard<std::mutex> lock(done_task_mutex_);
231       done_tasks_.emplace_back(std::move(task));
232     }
233     if (task_type != kRunGraph || task_sync_flag) {
234       std::lock_guard<std::mutex> lock(task_mutex_);
235       sync_run_task_finished_ = true;
236       sync_cond_var_.notify_all();
237     }
238   }
239 }
240 
GetReadyTasksFromPendingList()241 std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
242   std::vector<std::shared_ptr<RunGraphTask>> ready_tasks;
243   std::lock_guard<std::mutex> lock(pending_task_mutex_);
244   for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
245     auto task = *iter;
246     if (IsTaskReady(task)) {
247       (void)ready_tasks.emplace_back(task);
248       pending_tasks_.erase(iter++);
249     } else {
250       ++iter;
251     }
252   }
253   return ready_tasks;
254 }
255 
OnEvent(const ExecutorEvent & event)256 void Executor::OnEvent(const ExecutorEvent &event) {
257   if (event == ExecutorEvent::kRunGraphFinished) {
258     OnRunGraphFinished();
259   } else if (event == ExecutorEvent::kClear) {
260     OnClear();
261   } else if (event == ExecutorEvent::kException) {
262     OnException();
263   }
264 }
265 
OnClear()266 void Executor::OnClear() {
267   {
268     mindspore::ScopedLongRunning long_running;
269     WorkerJoin();
270   }
271   ClearDoneTasks();
272 }
273 
OnException()274 void Executor::OnException() {
275   std::vector<std::shared_ptr<Task>> done_tasks;
276   {
277     std::lock_guard<std::mutex> lock(task_mutex_);
278     while (!ready_tasks_.empty()) {
279       (void)done_tasks.emplace_back(ready_tasks_.front());
280       ready_tasks_.pop();
281     }
282   }
283   {
284     std::lock_guard<std::mutex> lock(pending_task_mutex_);
285     (void)std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(done_tasks));
286     pending_tasks_.clear();
287   }
288   {
289     std::lock_guard<std::mutex> lock(done_task_mutex_);
290     (void)done_tasks_.insert(done_tasks_.end(), done_tasks.begin(), done_tasks.end());
291   }
292 }
293 
OnRunGraphFinished()294 void Executor::OnRunGraphFinished() {
295   auto ready_tasks = GetReadyTasksFromPendingList();
296   std::lock_guard<std::mutex> lock(task_mutex_);
297   for (auto &task : ready_tasks) {
298     ready_tasks_.push(task);
299   }
300   if (!ready_tasks.empty()) {
301     task_cond_var_.notify_all();
302   }
303   reenter_cond_var_.notify_all();
304 }
305 
ClearDoneTasks()306 void Executor::ClearDoneTasks() {
307   std::lock_guard<std::mutex> lock(done_task_mutex_);
308   done_tasks_.clear();
309 }
310 
RunTask(const std::shared_ptr<Task> & task,bool sync,bool long_run)311 void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
312   if (sync) {
313     ClearDoneTasks();
314   }
315   {
316     std::lock_guard<std::mutex> lock(task_mutex_);
317     sync_run_task_finished_ = false;
318     ready_tasks_.push(task);
319   }
320   task_cond_var_.notify_all();
321   if (sync && !sync_run_task_finished_) {
322     std::unique_lock<std::mutex> lock(task_mutex_);
323     if (sync && long_run) {
324       mindspore::ScopedLongRunning long_running;
325       sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
326     } else {
327       sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
328     }
329   }
330   ClearDoneTasks();
331   MsException::Instance().CheckException();
332 }
333 
CompileGraph(const SessionPtr & session,const GraphSegmentPtr & segment,const AnfNodePtrList & outputs)334 GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
335                                const AnfNodePtrList &outputs) {
336   auto task = std::make_shared<CompileNodesTask>();
337   task->session_ = session;
338   task->segment_ = segment;
339   task->output_nodes_ = outputs;
340   RunTask(task, true);
341   return task->graph_id_;
342 }
343 
CompileGraph(const SessionPtr & session,NotNull<FuncGraphPtr> func_graph)344 GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
345   auto task = std::make_shared<CompileGraphTask>();
346   task->session_ = session;
347   task->func_graph_ = func_graph.get();
348   RunTask(task, true);
349   return task->graph_id_;
350 }
351 
BuildGraph(const SessionPtr & session,GraphId graphId)352 void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
353   auto task = std::make_shared<BuildGraphTask>();
354   task->session_ = session;
355   task->graph_id_ = graphId;
356   RunTask(task, true);
357 }
358 
RunGraph(const SessionPtr & session,const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)359 void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
360                         const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
361   MS_EXCEPTION_IF_NULL(session);
362   MS_EXCEPTION_IF_NULL(outputs);
363   auto task = std::make_shared<RunGraphTask>();
364   task->session_ = session;
365   task->graph_id_ = graph_id;
366   task->input_tensors_ = inputs;
367   session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
368   task->outputs_ = *outputs;
369   task->sync_run_ = true;
370   RunTask(task, true, true);
371 }
372 
RunGraphAsync(const SessionPtr & session,const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)373 void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
374                              const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
375   MS_EXCEPTION_IF_NULL(session);
376   MS_EXCEPTION_IF_NULL(outputs);
377   auto task = std::make_shared<RunGraphTask>();
378   task->session_ = session;
379   task->graph_id_ = graph_id;
380   task->input_tensors_ = inputs;
381   task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
382   auto graph = session->GetGraph(task->graph_id_);
383   if (graph != nullptr && !graph->IsPostGraphFinished()) {
384     mindspore::ScopedLongRunning long_running;
385     std::unique_lock<std::mutex> lock(reenter_mutex_);
386     reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
387     MsException::Instance().CheckException();
388   }
389   session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
390   // maintain a copy of output vector
391   task->outputs_ = *outputs;
392 
393   // Run graph synchronously when the graph require gil.
394   if (graph != nullptr && graph->is_need_gil()) {
395     std::unique_lock<std::mutex> lock(reenter_mutex_);
396     reenter_cond_var_.wait(lock, [&graph] { return graph->IsPreGraphFinished(); });
397     MsException::Instance().CheckException();
398     task->sync_run_ = true;
399     RunTask(task, true, true);
400     return;
401   }
402 
403   // sync run graph without output tensor(int dataset graph)
404   if ((!TensorInVector(outputs) && !graph->HasPostGraph())) {
405     task->sync_run_ = true;
406     RunTask(task, true, true);
407     return;
408   }
409   WaitLockedInputs(task);
410   for (auto &tensor_node : task->tensor_to_node_) {
411     tensor_node.first->SetNeedWait(true);
412   }
413   {
414     std::lock_guard<std::mutex> lock(pending_task_mutex_);
415     if (!IsTaskReady(task)) {
416       ClearDoneTasks();
417       pending_tasks_.push_back(task);
418       return;
419     }
420   }
421   RunTask(task, false);
422 }
423 
RunOp(const SessionPtr & session,OpRunInfo * op_run_info,const GraphInfo & graph_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)424 void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
425                      std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
426                      const std::vector<int64_t> &tensors_mask) {
427   MS_EXCEPTION_IF_NULL(session);
428   MS_EXCEPTION_IF_NULL(input_tensors);
429   MS_EXCEPTION_IF_NULL(outputs);
430   MS_EXCEPTION_IF_NULL(op_run_info);
431   auto ms_context = MsContext::GetInstance();
432   auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
433   if (target == kGPUDevice) {
434     for (auto &tensor : *input_tensors) {
435       if (tensor->NeedWait()) {
436         tensor->Wait();
437       }
438     }
439     {
440       // Release GIL before calling into (potentially long-running) C++ code
441       if (Py_IsInitialized()) {
442         py::gil_scoped_release release;
443         session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
444       } else {
445         session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
446       }
447     }
448   } else {
449     auto task = std::make_shared<RunOpTask>();
450     task->session_ = session;
451     task->op_run_info_ = op_run_info;
452     task->graph_info_ = graph_info;
453     task->input_tensors_ = input_tensors;
454     task->tensors_mask_ = tensors_mask;
455     for (auto &tensor : *input_tensors) {
456       if (tensor->NeedWait()) {
457         tensor->Wait();
458       }
459     }
460     RunTask(task, true, true);
461     *outputs = task->outputs_;
462   }
463 }
464 
RunOpsInGraph(const SessionPtr & session,const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)465 void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
466                              const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
467   MS_EXCEPTION_IF_NULL(session);
468   MS_EXCEPTION_IF_NULL(outputs);
469   auto task = std::make_shared<RunOpsInGraphTask>();
470   task->session_ = session;
471   task->graph_id_ = graph_id;
472   task->input_tensors_ = inputs;
473   RunTask(task, true, true);
474   *outputs = task->outputs_;
475 }
476 
CreateCommGroup(const std::string & group_name,const std::vector<uint32_t> & ranks)477 bool Executor::CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks) {
478   auto task = std::make_shared<CreateCommGroupTask>();
479   task->group_name_ = group_name;
480   task->ranks_ = ranks;
481   RunTask(task, true);
482   return task->result_;
483 }
484 
DestroyCommGroup(const std::string & group_name)485 bool Executor::DestroyCommGroup(const std::string &group_name) {
486   auto task = std::make_shared<DestroyCommGroupTask>();
487   task->group_name_ = group_name;
488   RunTask(task, true);
489   return task->result_;
490 }
491 
OnWorkerExit()492 void Executor::OnWorkerExit() {
493   if (device_name_ == kAscendDevice) {
494     device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
495   }
496 }
497 }  // namespace session
498 }  // namespace mindspore
499