• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pynative/op_executor.h"
18 #include "pybind_api/gil_scoped_long_running.h"
19 #include "runtime/pipeline/pipeline.h"
20 
21 namespace mindspore::runtime {
GetInstance()22 OpExecutor &OpExecutor::GetInstance() {
23   static OpExecutor instance;
24   return instance;
25 }
26 
27 OpExecutor::OpExecutor() = default;
28 
29 OpExecutor::~OpExecutor() = default;
30 
RegisterForwardCallback(const std::function<void ()> & callback)31 void OpExecutor::RegisterForwardCallback(const std::function<void()> &callback) {
32   forward_callback_ = callback;
33   tensor::Tensor::RegisterLazyCallback([]() { OpExecutor::GetInstance().WaitAll(); });
34 }
35 
Reset()36 void OpExecutor::Reset() {
37   runtime::Pipeline::Get().backend_stage()->Reset();
38   runtime::Pipeline::Get().launch_stage()->Reset();
39 }
40 
WaitForRun()41 void OpExecutor::WaitForRun() {
42   MS_LOG(DEBUG) << "Start";
43   runtime::Pipeline::Get().backend_stage()->Wait();
44   runtime::Pipeline::Get().launch_stage()->Wait();
45   MS_LOG(DEBUG) << "All task finish";
46 }
47 
Wait()48 void OpExecutor::Wait() {
49   GilReleaseWithCheck gil_release;
50   WaitForRun();
51 }
52 
WaitAll()53 void OpExecutor::WaitAll() {
54   GilReleaseWithCheck gil_release;
55   if (forward_callback_ != nullptr) {
56     forward_callback_();
57   }
58   WaitForRun();
59 }
60 
PushOpRunTask(const std::shared_ptr<DeviceOpRunTask> & op_run_task)61 void OpExecutor::PushOpRunTask(const std::shared_ptr<DeviceOpRunTask> &op_run_task) {
62   MS_EXCEPTION_IF_NULL(op_run_task);
63   MS_EXCEPTION_IF_NULL(op_run_task->context());
64   runtime::Pipeline::Get().backend_stage()->Push(op_run_task);
65 }
66 
PushOpRunTask(const std::shared_ptr<PyBoostDeviceTask> & op_run_task)67 void OpExecutor::PushOpRunTask(const std::shared_ptr<PyBoostDeviceTask> &op_run_task) {
68   MS_EXCEPTION_IF_NULL(op_run_task);
69   runtime::Pipeline::Get().backend_stage()->Push(op_run_task);
70 }
71 
PushSimpleOpRunTask(const std::shared_ptr<AsyncTask> & op_run_task)72 void OpExecutor::PushSimpleOpRunTask(const std::shared_ptr<AsyncTask> &op_run_task) {
73   runtime::Pipeline::Get().backend_stage()->Push(op_run_task);
74 }
75 
RunQueueEmpty()76 bool OpExecutor::RunQueueEmpty() { return runtime::Pipeline::Get().backend_stage()->Empty(); }
77 
WorkerJoin()78 void OpExecutor::WorkerJoin() {
79   GilReleaseWithCheck release_gil;
80   runtime::Pipeline::Get().backend_stage()->WorkerJoin();
81   runtime::Pipeline::Get().launch_stage()->WorkerJoin();
82 }
83 
DispatchLaunchTask(const std::function<void ()> & func)84 void OpExecutor::DispatchLaunchTask(const std::function<void()> &func) {
85   if (NeedSync()) {
86     runtime::OpExecutor::GetInstance().WaitAll();
87     func();
88   } else {
89     auto task = std::make_shared<runtime::DeviceLaunchTask>([=]() { func(); });
90     runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
91     runtime::Pipeline::Get().launch_stage()->Push(task);
92   }
93 }
94 
NeedSync()95 bool OpExecutor::NeedSync() {
96   auto context = MsContext::GetInstance();
97   MS_EXCEPTION_IF_NULL(context);
98   return context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) ||
99          (context->get_param<int>(MS_CTX_EXECUTION_MODE) == mindspore::kGraphMode && !async_for_graph_);
100 }
101 
ChildAfterFork()102 void OpExecutor::ChildAfterFork() {
103   MS_LOG(DEBUG) << "OpExecutor reinitialize after fork";
104   MS_LOG(DEBUG) << "Reinitialize async_queue_.";
105   runtime::Pipeline::Get().backend_stage()->ChildAfterFork();
106   runtime::Pipeline::Get().launch_stage()->ChildAfterFork();
107   // Refresh the lazy callback in Tensor.
108   tensor::Tensor::RegisterLazyCallback([]() { OpExecutor::GetInstance().WaitAll(); });
109   MS_LOG(DEBUG) << "OpExecutor reinitialize after fork done.";
110 }
111 }  // namespace mindspore::runtime
112