1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/primitive_utils.h"
18
19 #include <memory>
20
21 #include "pipeline/jit/parse/python_adapter.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_utils.h"
24 #include "utils/convert_utils_py.h"
25 #include "pybind_api/ir/base_ref_py.h"
26
27 namespace mindspore {
GetBpropFunctionByObj(const py::object & obj)28 py::function GetBpropFunctionByObj(const py::object &obj) {
29 static const std::string get_bprop_fn = "get_bprop_fn";
30 static const std::string ad_module = "mindspore.ops._grad";
31 static const std::string ad_experimental_module = "mindspore.ops._grad_experimental";
32 py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj);
33 if (!fn || py::isinstance<py::none>(fn)) {
34 fn = parse::python_adapter::GetPyFn(ad_experimental_module, get_bprop_fn)(obj);
35 }
36 return fn;
37 }
38
GetBpropFunction(const std::string & name)39 py::function GetBpropFunction(const std::string &name) {
40 auto fn = GetBpropFunctionByObj(py::str(name));
41 return fn;
42 }
43
GetComputeFunction(const std::string & name)44 py::function GetComputeFunction(const std::string &name) {
45 static const std::string module = "mindspore._extends.builtin_operations";
46 py::module mod = py::module::import(common::SafeCStr(module));
47 if (!py::hasattr(mod, common::SafeCStr(name))) {
48 PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name));
49 // If raise AttributeError, user can't understand. This case need raise NotImplementedError.
50 throw(py::error_already_set());
51 }
52 py::object fn = mod.attr(common::SafeCStr(name));
53 return fn;
54 }
55
ConvertDatatoPyTuple(const VectorRef & args)56 py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
57 auto py_args = py::tuple(args.size());
58 size_t i = 0;
59 for (auto &arg : args) {
60 py_args[i] = BaseRefToPyData(arg);
61 MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString();
62 i++;
63 }
64 return py_args;
65 }
66
GetComputeFunctionWithoutPyObj(const std::string & name)67 py::function GetComputeFunctionWithoutPyObj(const std::string &name) {
68 static const std::string module = "tests.vm_impl.vm_impl_function";
69 py::module mod = py::module::import(common::SafeCStr(module));
70 if (!py::hasattr(mod, common::SafeCStr(name))) {
71 return py::none();
72 }
73 py::object fn = mod.attr(common::SafeCStr(name));
74 return fn;
75 }
76
RunComputeFunctionWithoutPyObj(const PrimitivePtr & prim,const VectorRef & args)77 BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args) {
78 auto func = GetComputeFunctionWithoutPyObj(prim->name());
79 if (py::isinstance<py::none>(func)) {
80 return nullptr;
81 }
82 auto py_args = ConvertDatatoPyTuple(args);
83 py::object obj = func(*py_args);
84 if (py::isinstance<py::none>(obj)) {
85 return nullptr;
86 }
87 return std::make_shared<PyObjectRef>(obj);
88 }
89
RunComputeFunction(const PrimitivePtr & prim,const VectorRef & args)90 BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
91 auto func = GetComputeFunction(prim->name());
92 if (py::isinstance<py::none>(func)) {
93 MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented";
94 }
95 auto py_args = ConvertDatatoPyTuple(args);
96 py::object obj = func(*py_args);
97 return std::make_shared<PyObjectRef>(obj);
98 }
99 } // namespace mindspore
100