1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "runtime/pynative/op_executor.h" 18 #include "pybind_api/gil_scoped_long_running.h" 19 #include "runtime/pipeline/pipeline.h" 20 21 namespace mindspore::runtime { GetInstance()22OpExecutor &OpExecutor::GetInstance() { 23 static OpExecutor instance; 24 return instance; 25 } 26 27 OpExecutor::OpExecutor() = default; 28 29 OpExecutor::~OpExecutor() = default; 30 RegisterForwardCallback(const std::function<void ()> & callback)31void OpExecutor::RegisterForwardCallback(const std::function<void()> &callback) { 32 forward_callback_ = callback; 33 tensor::Tensor::RegisterLazyCallback([]() { OpExecutor::GetInstance().WaitAll(); }); 34 } 35 Reset()36void OpExecutor::Reset() { 37 runtime::Pipeline::Get().backend_stage()->Reset(); 38 runtime::Pipeline::Get().launch_stage()->Reset(); 39 } 40 WaitForRun()41void OpExecutor::WaitForRun() { 42 MS_LOG(DEBUG) << "Start"; 43 runtime::Pipeline::Get().backend_stage()->Wait(); 44 runtime::Pipeline::Get().launch_stage()->Wait(); 45 MS_LOG(DEBUG) << "All task finish"; 46 } 47 Wait()48void OpExecutor::Wait() { 49 GilReleaseWithCheck gil_release; 50 WaitForRun(); 51 } 52 WaitAll()53void OpExecutor::WaitAll() { 54 GilReleaseWithCheck gil_release; 55 if (forward_callback_ != nullptr) { 56 forward_callback_(); 57 } 58 WaitForRun(); 59 } 60 PushOpRunTask(const std::shared_ptr<DeviceOpRunTask> & op_run_task)61void OpExecutor::PushOpRunTask(const std::shared_ptr<DeviceOpRunTask> &op_run_task) { 62 MS_EXCEPTION_IF_NULL(op_run_task); 63 MS_EXCEPTION_IF_NULL(op_run_task->context()); 64 runtime::Pipeline::Get().backend_stage()->Push(op_run_task); 65 } 66 PushOpRunTask(const std::shared_ptr<PyBoostDeviceTask> & op_run_task)67void OpExecutor::PushOpRunTask(const std::shared_ptr<PyBoostDeviceTask> &op_run_task) { 68 MS_EXCEPTION_IF_NULL(op_run_task); 69 runtime::Pipeline::Get().backend_stage()->Push(op_run_task); 70 } 71 PushSimpleOpRunTask(const std::shared_ptr<AsyncTask> & op_run_task)72void OpExecutor::PushSimpleOpRunTask(const std::shared_ptr<AsyncTask> &op_run_task) { 73 runtime::Pipeline::Get().backend_stage()->Push(op_run_task); 74 } 75 RunQueueEmpty()76bool OpExecutor::RunQueueEmpty() { return runtime::Pipeline::Get().backend_stage()->Empty(); } 77 WorkerJoin()78void OpExecutor::WorkerJoin() { 79 GilReleaseWithCheck release_gil; 80 runtime::Pipeline::Get().backend_stage()->WorkerJoin(); 81 runtime::Pipeline::Get().launch_stage()->WorkerJoin(); 82 } 83 DispatchLaunchTask(const std::function<void ()> & func)84void OpExecutor::DispatchLaunchTask(const std::function<void()> &func) { 85 if (NeedSync()) { 86 runtime::OpExecutor::GetInstance().WaitAll(); 87 func(); 88 } else { 89 auto task = std::make_shared<runtime::DeviceLaunchTask>([=]() { func(); }); 90 runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id()); 91 runtime::Pipeline::Get().launch_stage()->Push(task); 92 } 93 } 94 NeedSync()95bool OpExecutor::NeedSync() { 96 auto context = MsContext::GetInstance(); 97 MS_EXCEPTION_IF_NULL(context); 98 return context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) || 99 (context->get_param<int>(MS_CTX_EXECUTION_MODE) == mindspore::kGraphMode && !async_for_graph_); 100 } 101 ChildAfterFork()102void OpExecutor::ChildAfterFork() { 103 MS_LOG(DEBUG) << "OpExecutor reinitialize after fork"; 104 MS_LOG(DEBUG) << "Reinitialize async_queue_."; 105 runtime::Pipeline::Get().backend_stage()->ChildAfterFork(); 106 runtime::Pipeline::Get().launch_stage()->ChildAfterFork(); 107 // Refresh the lazy callback in Tensor. 108 tensor::Tensor::RegisterLazyCallback([]() { OpExecutor::GetInstance().WaitAll(); }); 109 MS_LOG(DEBUG) << "OpExecutor reinitialize after fork done."; 110 } 111 } // namespace mindspore::runtime 112