• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "runtime/pynative/op_function/pyboost_grad_functions.h"
18#include "runtime/pynative/op_executor.h"
19#include "runtime/pynative/op_function/value_converter.h"
20#include "kernel/pyboost/pyboost_utils.h"
21#include "runtime/pynative/op_function/pyboost_grad_functions.h"
22#include "backend/graph_compiler/vmimpl.h"
23#include "include/common/utils/python_adapter.h"
24#include "pybind_api/gil_scoped_long_running.h"
25${include_op_header}
26
27namespace mindspore::runtime {
28namespace {
29session::BackendOpRunInfoPtr GetBackendOpRunInfo(OpRunnerInfo *op_runner_info) {
30  MS_EXCEPTION_IF_NULL(op_runner_info);
31  MS_EXCEPTION_IF_NULL(op_runner_info->prim);
32  pynative::BaseOpRunInfo base_op_run_info;
33  base_op_run_info.op_name = op_runner_info->prim->name();
34  base_op_run_info.device_target = op_runner_info->device_target;
35  base_op_run_info.expanded_input_values = op_runner_info->inputs;
36  base_op_run_info.input_types = op_runner_info->inputs_mask;
37  // Do infer and refresh output abstract
38  op_runner_info->output_abs = kernel::pyboost::PyBoostUtils::InferByOpDef(op_runner_info->prim, op_runner_info->inputs_abs);
39  base_op_run_info.abstract = op_runner_info->output_abs ;
40  return std::make_shared<BackendOpRunInfo>(base_op_run_info, op_runner_info->prim, false, false);
41}
42}
43
44PyBoostOpExecute& PyBoostOpExecute::GetInstance() {
45  static PyBoostOpExecute instance;
46  return instance;
47}
48
49bool PyBoostOpExecute::IsPyBoostOpRegistered(const std::string &op_name) {
50  return grad_op_func_map_.find(op_name) != grad_op_func_map_.end();
51}
52
53void PyBoostOpExecute::Execute(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) {
54 #ifndef ENABLE_TEST
55  GilReleaseWithCheck release_gil;
56  MS_EXCEPTION_IF_NULL(op_runner_info);
57  const auto it = grad_op_func_map_.find(op_runner_info->prim->name());
58  // Run op by pyboost
59  if (it != grad_op_func_map_.end() &&
60      (kernel::pyboost::PyBoostUtils::IsKernelModRegistered(op_runner_info->device_target, op_runner_info->prim->name())
61       || kernel::pyboost::PyBoostUtils::IsPyBoostCustomRegistered(op_runner_info->device_target, op_runner_info->prim->name()))) {
62    const auto &func = FuncCast<Func>(it->second);
63    MS_EXCEPTION_IF_NULL(func);
64    func(op_runner_info, op_outputs);
65    return;
66  }
67  // Run op by single op graph
68  RunOpDeprecated(op_runner_info, op_outputs);
69#else
70  RunOpInVm(op_runner_info, op_outputs);
71#endif
72}
73
74void PyBoostOpExecute::RunPyBoostCall(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) {
75  MS_EXCEPTION_IF_NULL(op_runner_info);
76  const auto &func = FuncCast<Func>(grad_op_func_map_.at(op_runner_info->prim->name()));
77  MS_EXCEPTION_IF_NULL(func);
78  func(op_runner_info, op_outputs);
79}
80
81void PyBoostOpExecute::RunOpDeprecated(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) {
82  // For call runop
83  const auto &backend_op_run_info = GetBackendOpRunInfo(op_runner_info);
84  // Do infer and refresh output abstract
85  op_runner_info->output_abs = kernel::pyboost::PyBoostUtils::InferByOpDef(op_runner_info->prim, op_runner_info->inputs_abs);
86  backend_op_run_info->base_op_run_info.abstract = op_runner_info->output_abs ;
87  // Call single op graph run
88  GetMindRtBackend(op_runner_info->device_target);
89  backend_op_run_info->base_op_run_info.use_dynamic_shape_process = true;
90  backend_op_run_info->op_prim = std::make_shared<Primitive>(*op_runner_info->prim);
91  AnfAlgo::SetDynamicAttrToPrim(backend_op_run_info->op_prim);
92  backend_->RunOpDynamic(backend_op_run_info, op_outputs);
93}
94
95void PyBoostOpExecute::RunOpInVm(OpRunnerInfo *op_runner_info, VectorRef *op_outputs) {
96  VectorRef args;
97  std::transform(op_runner_info->inputs.begin(), op_runner_info->inputs.end(), std::back_inserter(args),
98                 [](const auto &value) { return value; });
99  py::gil_scoped_acquire gil;
100  auto result = compile::RunOperation(op_runner_info->prim, args);
101  if (utils::isa<PyObjectRef>(result)) {
102    PyObjectRef py_ref = utils::cast<PyObjectRef>(result);
103    py::object value = py_ref.object_;
104    auto result_v = python_adapter::PyAdapterCallback::PyDataToValue(value);
105    if (!result_v->isa<ValueSequence>()) {
106      (void)op_outputs->emplace_back(result_v);
107    } else {
108      auto seq = result_v->cast<ValueSequencePtr>();
109          std::transform(seq->value().begin(), seq->value().end(), std::back_inserter(*op_outputs),
110                         [](const auto &value) { return value; });
111    }
112    op_runner_info->output_abs = result_v->ToAbstract()->Broaden();
113    return;
114  }
115
116  MS_LOG(EXCEPTION) << "prim: " << op_runner_info->prim->name() << "did not has vm op!";
117}
118
119void PyBoostOpExecute::GetMindRtBackend(const string &cur_device_target) {
120  if (backend_ != nullptr) {
121    return;
122  }
123  auto ms_context = MsContext::GetInstance();
124  MS_EXCEPTION_IF_NULL(ms_context);
125  auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
126  auto backend = std::make_shared<compile::MindRTBackend>("ms", cur_device_target, device_id);
127  backend_ = backend;
128}
129
130${function_body}
131
132${register_function_body}
133
134} // namespace mindspore::pynative
135