1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/common/utils/primitive_utils.h"
18
19 #include <memory>
20
21 #include "include/common/utils/python_adapter.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_utils.h"
24 #include "include/common/utils/convert_utils_py.h"
25 #include "pybind_api/ir/base_ref_py.h"
26
27 namespace mindspore {
GetBpropFunctionByObj(const py::object & obj,bool get_closure)28 py::function GetBpropFunctionByObj(const py::object &obj, bool get_closure) {
29 static const std::string get_bprop_fn = "get_bprop_fn";
30 static const std::string ad_experimental_module = "mindspore.ops._grad_experimental";
31 py::function fn = python_adapter::GetPyFn(ad_experimental_module, get_bprop_fn)(obj, get_closure);
32 return fn;
33 }
34
GetBpropFunction(const std::string & name)35 py::function GetBpropFunction(const std::string &name) {
36 auto fn = GetBpropFunctionByObj(py::str(name));
37 return fn;
38 }
39
GetTaylorRuleFunctionByObj(const py::object & obj)40 py::function GetTaylorRuleFunctionByObj(const py::object &obj) {
41 static const std::string get_taylor_fprop_fn = "get_taylor_fprop_fn";
42 static const std::string ad_module = "mindspore.ops._grad_experimental";
43 py::function fn = python_adapter::GetPyFn(ad_module, get_taylor_fprop_fn)(obj);
44 return fn;
45 }
46
GetTaylorRuleFunction(const std::string & name)47 py::function GetTaylorRuleFunction(const std::string &name) {
48 auto fn = GetTaylorRuleFunctionByObj(py::str(name));
49 return fn;
50 }
51
GetComputeFunction(const std::string & name)52 py::function GetComputeFunction(const std::string &name) {
53 static const std::string module = "mindspore._extends.builtin_operations";
54 py::module mod = py::module::import(common::SafeCStr(module));
55 if (!py::hasattr(mod, common::SafeCStr(name))) {
56 PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name));
57 // If raise AttributeError, user can't understand. This case need raise NotImplementedError.
58 throw(py::error_already_set());
59 }
60 py::object fn = mod.attr(common::SafeCStr(name));
61 return fn;
62 }
63
ConvertDatatoPyTuple(const VectorRef & args)64 py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
65 auto py_args = py::tuple(args.size());
66 size_t i = 0;
67 for (auto &arg : args) {
68 py_args[i] = BaseRefToPyData(arg);
69 MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString();
70 i++;
71 }
72 return py_args;
73 }
74
GetComputeFunctionWithoutPyObj(const std::string & name)75 py::function GetComputeFunctionWithoutPyObj(const std::string &name) {
76 static const std::string vm_module = "mindspore.ops.vm_impl_registry";
77 static const std::string get_vm_impl_fn = "get_vm_impl_fn";
78 py::function get_fn = python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
79 if (py::isinstance<py::none>(get_fn)) {
80 MS_LOG(DEBUG) << "Failed to get the function 'get_vm_impl_fn'";
81 return py::none();
82 }
83 py::function vm_fn = get_fn(py::str(name));
84 return vm_fn;
85 }
86
RunComputeFunctionWithoutPyObj(const PrimitivePtr & prim,const VectorRef & args)87 BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args) {
88 auto func = GetComputeFunctionWithoutPyObj(prim->name());
89 if (py::isinstance<py::none>(func)) {
90 return nullptr;
91 }
92 auto py_args = ConvertDatatoPyTuple(args);
93 py::object obj = func(*py_args);
94 if (py::isinstance<py::none>(obj)) {
95 return nullptr;
96 }
97 return std::make_shared<PyObjectRef>(obj);
98 }
99
RunComputeFunction(const PrimitivePtr & prim,const VectorRef & args)100 BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
101 auto func = GetComputeFunction(prim->name());
102 if (py::isinstance<py::none>(func)) {
103 MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented";
104 }
105 auto py_args = ConvertDatatoPyTuple(args);
106 py::object obj = func(*py_args);
107 return std::make_shared<PyObjectRef>(obj);
108 }
109
GetVmapRuleFunctionByObj(const py::object & obj,int axis_size)110 py::function GetVmapRuleFunctionByObj(const py::object &obj, int axis_size) {
111 constexpr char get_vmap_rule_fn[] = "get_vmap_rule";
112 constexpr char vmap_module[] = "mindspore.ops._vmap";
113 py::function fn = python_adapter::GetPyFn(vmap_module, get_vmap_rule_fn)(obj, axis_size);
114 return fn;
115 }
116
GetVmapRuleFunction(const std::string & name,int axis_size)117 py::function GetVmapRuleFunction(const std::string &name, int axis_size) {
118 auto fn = GetVmapRuleFunctionByObj(py::str(name), axis_size);
119 return fn;
120 }
121 } // namespace mindspore
122