1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
19 #include <map>
20 #include <memory>
21 #include <string>
22
23 #include "pybind11/embed.h"
24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
26 #include "pybind11/numpy.h"
27
28 #include "utils/log_adapter.h"
29 #include "ir/tensor.h"
30 #include "base/base_ref.h"
31 #include "include/common/visible.h"
32 #include "utils/callback_handler.h"
33
34 namespace py = pybind11;
35 namespace mindspore {
36 // A utility to call python interface
37 namespace python_adapter {
38 COMMON_EXPORT py::module GetPyModule(const std::string &module);
39 COMMON_EXPORT py::object GetPyObjAttr(const py::object &obj, const std::string &attr);
40 template <class... T>
CallPyObjMethod(const py::object & obj,const std::string & method,T...args)41 py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) {
42 if (!method.empty() && !py::isinstance<py::none>(obj)) {
43 return obj.attr(method.c_str())(args...);
44 }
45 return py::none();
46 }
47
48 // call python function of module
49 template <class... T>
CallPyModFn(const py::module & mod,const std::string & function,T...args)50 py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) {
51 if (!function.empty() && !py::isinstance<py::none>(mod)) {
52 return mod.attr(function.c_str())(args...);
53 }
54 return py::none();
55 }
56
57 // turn off the signature when ut use parser to construct a graph.
58 COMMON_EXPORT void set_use_signature_in_resolve(bool use_signature) noexcept;
59 COMMON_EXPORT bool UseSignatureInResolve();
60
61 COMMON_EXPORT std::shared_ptr<py::scoped_interpreter> set_python_scoped();
62 COMMON_EXPORT void ResetPythonScope();
63 COMMON_EXPORT bool IsPythonEnv();
64 COMMON_EXPORT void SetPythonPath(const std::string &path);
65 COMMON_EXPORT void set_python_env_flag(bool python_env) noexcept;
66 COMMON_EXPORT py::object GetPyFn(const std::string &module, const std::string &name);
67
68 // Call the python function
69 template <class... T>
CallPyFn(const std::string & module,const std::string & name,T...args)70 py::object CallPyFn(const std::string &module, const std::string &name, T... args) {
71 (void)set_python_scoped();
72 if (!module.empty() && !name.empty()) {
73 py::module mod = py::module::import(module.c_str());
74 py::object fn = mod.attr(name.c_str())(args...);
75 return fn;
76 }
77 return py::none();
78 }
79
80 // Cast shared_ptr to py::object.
81 template <typename T>
CastToPyObj(const std::shared_ptr<T> & ptr)82 py::object CastToPyObj(const std::shared_ptr<T> &ptr) {
83 // Use a pybind11 typecaster to create a PyObject from a shared_ptr<T> pointer.
84 py::detail::type_caster<std::shared_ptr<T>> shared_ptr_caster;
85 py::handle cast_handle = shared_ptr_caster.cast(ptr, py::return_value_policy::take_ownership, py::handle());
86 py::object obj = py::cast<py::object>(cast_handle);
87 MS_LOG(DEBUG) << ptr << ", obj: " << obj << ", handle ptr: " << cast_handle.ptr();
88 return obj;
89 }
90
91 // Cast pointer to py::object.
92 template <typename T>
CastToPyObj(const T * ptr)93 py::object CastToPyObj(const T *ptr) {
94 // Use a pybind11 typecaster to create a PyObject from a T* pointer.
95 py::detail::type_caster<T *> ptr_caster;
96 py::handle cast_handle = ptr_caster.cast(ptr, py::return_value_policy::take_ownership, py::handle());
97 py::object obj = py::cast<py::object>(cast_handle);
98 MS_LOG(DEBUG) << ptr << ", obj: " << obj << ", handle ptr: " << cast_handle.ptr();
99 return obj;
100 }
101
102 class COMMON_EXPORT PyAdapterCallback {
103 HANDLER_DEFINE(ValuePtr, PyDataToValue, py::object);
104 HANDLER_DEFINE(BaseRef, RunPrimitivePyHookFunction, PrimitivePtr, VectorRef);
105 HANDLER_DEFINE(py::array, TensorToNumpy, tensor::Tensor);
106 HANDLER_DEFINE(void, ProcessUnPairedCellHook, bool);
107 };
108 } // namespace python_adapter
109 } // namespace mindspore
110 #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
111