• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
19 #include <map>
20 #include <memory>
21 #include <string>
22 
23 #include "pybind11/embed.h"
24 #include "pybind11/pybind11.h"
25 #include "pybind11/stl.h"
26 #include "pybind11/numpy.h"
27 
28 #include "utils/log_adapter.h"
29 #include "ir/tensor.h"
30 #include "base/base_ref.h"
31 #include "include/common/visible.h"
32 #include "utils/callback_handler.h"
33 
34 namespace py = pybind11;
35 namespace mindspore {
36 // A utility to call python interface
37 namespace python_adapter {
38 COMMON_EXPORT py::module GetPyModule(const std::string &module);
39 COMMON_EXPORT py::object GetPyObjAttr(const py::object &obj, const std::string &attr);
40 template <class... T>
CallPyObjMethod(const py::object & obj,const std::string & method,T...args)41 py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) {
42   if (!method.empty() && !py::isinstance<py::none>(obj)) {
43     return obj.attr(method.c_str())(args...);
44   }
45   return py::none();
46 }
47 
48 // call python function of module
49 template <class... T>
CallPyModFn(const py::module & mod,const std::string & function,T...args)50 py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) {
51   if (!function.empty() && !py::isinstance<py::none>(mod)) {
52     return mod.attr(function.c_str())(args...);
53   }
54   return py::none();
55 }
56 
57 // turn off the signature when ut use parser to construct a graph.
58 COMMON_EXPORT void set_use_signature_in_resolve(bool use_signature) noexcept;
59 COMMON_EXPORT bool UseSignatureInResolve();
60 
61 COMMON_EXPORT std::shared_ptr<py::scoped_interpreter> set_python_scoped();
62 COMMON_EXPORT void ResetPythonScope();
63 COMMON_EXPORT bool IsPythonEnv();
64 COMMON_EXPORT void SetPythonPath(const std::string &path);
65 COMMON_EXPORT void set_python_env_flag(bool python_env) noexcept;
66 COMMON_EXPORT py::object GetPyFn(const std::string &module, const std::string &name);
67 
68 // Call the python function
69 template <class... T>
CallPyFn(const std::string & module,const std::string & name,T...args)70 py::object CallPyFn(const std::string &module, const std::string &name, T... args) {
71   (void)set_python_scoped();
72   if (!module.empty() && !name.empty()) {
73     py::module mod = py::module::import(module.c_str());
74     py::object fn = mod.attr(name.c_str())(args...);
75     return fn;
76   }
77   return py::none();
78 }
79 
80 // Cast shared_ptr to py::object.
81 template <typename T>
CastToPyObj(const std::shared_ptr<T> & ptr)82 py::object CastToPyObj(const std::shared_ptr<T> &ptr) {
83   // Use a pybind11 typecaster to create a PyObject from a shared_ptr<T> pointer.
84   py::detail::type_caster<std::shared_ptr<T>> shared_ptr_caster;
85   py::handle cast_handle = shared_ptr_caster.cast(ptr, py::return_value_policy::take_ownership, py::handle());
86   py::object obj = py::cast<py::object>(cast_handle);
87   MS_LOG(DEBUG) << ptr << ", obj: " << obj << ", handle ptr: " << cast_handle.ptr();
88   return obj;
89 }
90 
91 // Cast pointer to py::object.
92 template <typename T>
CastToPyObj(const T * ptr)93 py::object CastToPyObj(const T *ptr) {
94   // Use a pybind11 typecaster to create a PyObject from a T* pointer.
95   py::detail::type_caster<T *> ptr_caster;
96   py::handle cast_handle = ptr_caster.cast(ptr, py::return_value_policy::take_ownership, py::handle());
97   py::object obj = py::cast<py::object>(cast_handle);
98   MS_LOG(DEBUG) << ptr << ", obj: " << obj << ", handle ptr: " << cast_handle.ptr();
99   return obj;
100 }
101 
102 class COMMON_EXPORT PyAdapterCallback {
103   HANDLER_DEFINE(ValuePtr, PyDataToValue, py::object);
104   HANDLER_DEFINE(BaseRef, RunPrimitivePyHookFunction, PrimitivePtr, VectorRef);
105   HANDLER_DEFINE(py::array, TensorToNumpy, tensor::Tensor);
106   HANDLER_DEFINE(void, ProcessUnPairedCellHook, bool);
107 };
108 }  // namespace python_adapter
109 }  // namespace mindspore
110 #endif  // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_PYTHON_ADAPTER_H_
111