• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/utils/primitive_utils.h"
18 
19 #include <memory>
20 
21 #include "include/common/utils/python_adapter.h"
22 #include "utils/log_adapter.h"
23 #include "utils/ms_utils.h"
24 #include "include/common/utils/convert_utils_py.h"
25 #include "pybind_api/ir/base_ref_py.h"
26 
27 namespace mindspore {
GetBpropFunctionByObj(const py::object & obj,bool get_closure)28 py::function GetBpropFunctionByObj(const py::object &obj, bool get_closure) {
29   static const std::string get_bprop_fn = "get_bprop_fn";
30   static const std::string ad_experimental_module = "mindspore.ops._grad_experimental";
31   py::function fn = python_adapter::GetPyFn(ad_experimental_module, get_bprop_fn)(obj, get_closure);
32   return fn;
33 }
34 
GetBpropFunction(const std::string & name)35 py::function GetBpropFunction(const std::string &name) {
36   auto fn = GetBpropFunctionByObj(py::str(name));
37   return fn;
38 }
39 
GetTaylorRuleFunctionByObj(const py::object & obj)40 py::function GetTaylorRuleFunctionByObj(const py::object &obj) {
41   static const std::string get_taylor_fprop_fn = "get_taylor_fprop_fn";
42   static const std::string ad_module = "mindspore.ops._grad_experimental";
43   py::function fn = python_adapter::GetPyFn(ad_module, get_taylor_fprop_fn)(obj);
44   return fn;
45 }
46 
GetTaylorRuleFunction(const std::string & name)47 py::function GetTaylorRuleFunction(const std::string &name) {
48   auto fn = GetTaylorRuleFunctionByObj(py::str(name));
49   return fn;
50 }
51 
GetComputeFunction(const std::string & name)52 py::function GetComputeFunction(const std::string &name) {
53   static const std::string module = "mindspore._extends.builtin_operations";
54   py::module mod = py::module::import(common::SafeCStr(module));
55   if (!py::hasattr(mod, common::SafeCStr(name))) {
56     PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name));
57     // If raise AttributeError, user can't understand. This case need raise NotImplementedError.
58     throw(py::error_already_set());
59   }
60   py::object fn = mod.attr(common::SafeCStr(name));
61   return fn;
62 }
63 
ConvertDatatoPyTuple(const VectorRef & args)64 py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
65   auto py_args = py::tuple(args.size());
66   size_t i = 0;
67   for (auto &arg : args) {
68     py_args[i] = BaseRefToPyData(arg);
69     MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString();
70     i++;
71   }
72   return py_args;
73 }
74 
GetComputeFunctionWithoutPyObj(const std::string & name)75 py::function GetComputeFunctionWithoutPyObj(const std::string &name) {
76   static const std::string vm_module = "mindspore.ops.vm_impl_registry";
77   static const std::string get_vm_impl_fn = "get_vm_impl_fn";
78   py::function get_fn = python_adapter::GetPyFn(vm_module, get_vm_impl_fn);
79   if (py::isinstance<py::none>(get_fn)) {
80     MS_LOG(DEBUG) << "Failed to get the function 'get_vm_impl_fn'";
81     return py::none();
82   }
83   py::function vm_fn = get_fn(py::str(name));
84   return vm_fn;
85 }
86 
RunComputeFunctionWithoutPyObj(const PrimitivePtr & prim,const VectorRef & args)87 BaseRef RunComputeFunctionWithoutPyObj(const PrimitivePtr &prim, const VectorRef &args) {
88   auto func = GetComputeFunctionWithoutPyObj(prim->name());
89   if (py::isinstance<py::none>(func)) {
90     return nullptr;
91   }
92   auto py_args = ConvertDatatoPyTuple(args);
93   py::object obj = func(*py_args);
94   if (py::isinstance<py::none>(obj)) {
95     return nullptr;
96   }
97   return std::make_shared<PyObjectRef>(obj);
98 }
99 
RunComputeFunction(const PrimitivePtr & prim,const VectorRef & args)100 BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
101   auto func = GetComputeFunction(prim->name());
102   if (py::isinstance<py::none>(func)) {
103     MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented";
104   }
105   auto py_args = ConvertDatatoPyTuple(args);
106   py::object obj = func(*py_args);
107   return std::make_shared<PyObjectRef>(obj);
108 }
109 
GetVmapRuleFunctionByObj(const py::object & obj,int axis_size)110 py::function GetVmapRuleFunctionByObj(const py::object &obj, int axis_size) {
111   constexpr char get_vmap_rule_fn[] = "get_vmap_rule";
112   constexpr char vmap_module[] = "mindspore.ops._vmap";
113   py::function fn = python_adapter::GetPyFn(vmap_module, get_vmap_rule_fn)(obj, axis_size);
114   return fn;
115 }
116 
GetVmapRuleFunction(const std::string & name,int axis_size)117 py::function GetVmapRuleFunction(const std::string &name, int axis_size) {
118   auto fn = GetVmapRuleFunctionByObj(py::str(name), axis_size);
119   return fn;
120 }
121 }  // namespace mindspore
122