1/** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "runtime/pynative/op_function/pyboost_grad_functions.h" 18#include "runtime/pynative/op_executor.h" 19#include "runtime/pynative/op_function/value_converter.h" 20#include "kernel/pyboost/pyboost_utils.h" 21#include "runtime/pynative/op_function/pyboost_grad_functions.h" 22#include "backend/graph_compiler/vmimpl.h" 23#include "include/common/utils/python_adapter.h" 24#include "pybind_api/gil_scoped_long_running.h" 25${include_op_header} 26 27namespace mindspore::runtime { 28namespace { 29session::BackendOpRunInfoPtr GetBackendOpRunInfo(OpRunnerInfo *op_runner_info) { 30 MS_EXCEPTION_IF_NULL(op_runner_info); 31 MS_EXCEPTION_IF_NULL(op_runner_info->prim); 32 pynative::BaseOpRunInfo base_op_run_info; 33 base_op_run_info.op_name = op_runner_info->prim->name(); 34 base_op_run_info.device_target = op_runner_info->device_target; 35 base_op_run_info.expanded_input_values = op_runner_info->inputs; 36 base_op_run_info.input_types = op_runner_info->inputs_mask; 37 // Do infer and refresh output abstract 38 op_runner_info->output_abs = kernel::pyboost::PyBoostUtils::InferByOpDef(op_runner_info->prim, op_runner_info->inputs_abs); 39 base_op_run_info.abstract = op_runner_info->output_abs ; 40 return std::make_shared<BackendOpRunInfo>(base_op_run_info, op_runner_info->prim, false, false); 41} 42} 43 44PyBoostOpExecute& PyBoostOpExecute::GetInstance() { 45 static PyBoostOpExecute instance; 46 return instance; 47} 48 49bool PyBoostOpExecute::IsPyBoostOpRegistered(const std::string &op_name) { 50 return grad_op_func_map_.find(op_name) != grad_op_func_map_.end(); 51} 52 53void PyBoostOpExecute::Execute(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) { 54 #ifndef ENABLE_TEST 55 GilReleaseWithCheck release_gil; 56 MS_EXCEPTION_IF_NULL(op_runner_info); 57 const auto it = grad_op_func_map_.find(op_runner_info->prim->name()); 58 // Run op by pyboost 59 if (it != grad_op_func_map_.end() && 60 (kernel::pyboost::PyBoostUtils::IsKernelModRegistered(op_runner_info->device_target, op_runner_info->prim->name()) 61 || kernel::pyboost::PyBoostUtils::IsPyBoostCustomRegistered(op_runner_info->device_target, op_runner_info->prim->name()))) { 62 const auto &func = FuncCast<Func>(it->second); 63 MS_EXCEPTION_IF_NULL(func); 64 func(op_runner_info, op_outputs); 65 return; 66 } 67 // Run op by single op graph 68 RunOpDeprecated(op_runner_info, op_outputs); 69#else 70 RunOpInVm(op_runner_info, op_outputs); 71#endif 72} 73 74void PyBoostOpExecute::RunPyBoostCall(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) { 75 MS_EXCEPTION_IF_NULL(op_runner_info); 76 const auto &func = FuncCast<Func>(grad_op_func_map_.at(op_runner_info->prim->name())); 77 MS_EXCEPTION_IF_NULL(func); 78 func(op_runner_info, op_outputs); 79} 80 81void PyBoostOpExecute::RunOpDeprecated(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) { 82 // For call runop 83 const auto &backend_op_run_info = GetBackendOpRunInfo(op_runner_info); 84 // Do infer and refresh output abstract 85 op_runner_info->output_abs = kernel::pyboost::PyBoostUtils::InferByOpDef(op_runner_info->prim, op_runner_info->inputs_abs); 86 backend_op_run_info->base_op_run_info.abstract = op_runner_info->output_abs ; 87 // Call single op graph run 88 GetMindRtBackend(op_runner_info->device_target); 89 backend_op_run_info->base_op_run_info.use_dynamic_shape_process = true; 90 backend_op_run_info->op_prim = std::make_shared<Primitive>(*op_runner_info->prim); 91 AnfAlgo::SetDynamicAttrToPrim(backend_op_run_info->op_prim); 92 backend_->RunOpDynamic(backend_op_run_info, op_outputs); 93} 94 95void PyBoostOpExecute::RunOpInVm(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) { 96 VectorRef args; 97 std::transform(op_runner_info->inputs.begin(), op_runner_info->inputs.end(), std::back_inserter(args), 98 [](const auto &value) { return value; }); 99 py::gil_scoped_acquire gil; 100 auto result = compile::RunOperation(op_runner_info->prim, args); 101 if (utils::isa<PyObjectRef>(result)) { 102 PyObjectRef py_ref = utils::cast<PyObjectRef>(result); 103 py::object value = py_ref.object_; 104 auto result_v = python_adapter::PyAdapterCallback::PyDataToValue(value); 105 if (!result_v->isa<ValueSequence>()) { 106 (void)op_outputs->emplace_back(result_v); 107 } else { 108 auto seq = result_v->cast<ValueSequencePtr>(); 109 std::transform(seq->value().begin(), seq->value().end(), std::back_inserter(*op_outputs), 110 [](const auto &value) { return value; }); 111 } 112 op_runner_info->output_abs = result_v->ToAbstract()->Broaden(); 113 return; 114 } 115 116 MS_LOG(EXCEPTION) << "prim: " << op_runner_info->prim->name() << "did not has vm op!"; 117} 118 119void PyBoostOpExecute::GetMindRtBackend(const string &cur_device_target) { 120 if (backend_ != nullptr) { 121 return; 122 } 123 auto ms_context = MsContext::GetInstance(); 124 MS_EXCEPTION_IF_NULL(ms_context); 125 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); 126 auto backend = std::make_shared<compile::MindRTBackend>("ms", cur_device_target, device_id); 127 backend_ = backend; 128} 129 130${function_body} 131 132${register_function_body} 133 134} // namespace mindspore::pynative 135