1
2 /**
3 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
4 *
5 * Copyright 2020 Huawei Technologies Co., Ltd
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "frontend/operator/composite/multitype_funcgraph.h"
21 #include <utility>
22 #include <sstream>
23
24 #include "abstract/abstract_function.h"
25 #include "abstract/dshape.h"
26 #include "frontend/optimizer/opt.h"
27 #include "utils/ms_context.h"
28 #include "pybind_api/api_register.h"
29 #include "ir/signature.h"
30 #include "ir/dtype.h"
31 #include "debug/trace.h"
32
33 namespace mindspore {
34 // namespace to support composite operators definition
35 namespace prim {
MultitypeFuncGraph(const std::string & name)36 MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
37 fn_cache_.clear();
38 // def multitype(*args:ref):
39 signatures_ = std::vector<Signature>({{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
40 }
41
Register(const TypePtrList & types,specialize_fn s_fn)42 void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
43 MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
44 auto fn = fn_cache_.find(types);
45 if (fn != fn_cache_.end()) {
46 MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
47 }
48 fn_cache_[types] = s_fn;
49 }
50
Register(const TypePtrList & types,const py::function & py_fn)51 void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
52 MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << py::str(py_fn.cast<py::object>())
53 << ").";
54 auto fn = fn_cache_.find(types);
55 if (fn != fn_cache_.end()) {
56 MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
57 }
58 fn_cache_py_[types] = py_fn;
59 }
60
PyRegister(const py::tuple & tuple,const py::function & py_fn)61 void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
62 TypePtrList types;
63 for (size_t it = 0; it < tuple.size(); ++it) {
64 py::object type_in = tuple[it];
65 TypePtr type_ptr = nullptr;
66 if (py::isinstance<py::str>(type_in)) {
67 auto type_name = type_in.cast<std::string>();
68 type_ptr = StringToType(type_name);
69 if (type_ptr == nullptr) {
70 MS_LOG(EXCEPTION) << type_name << " convert from string error ";
71 }
72 } else if (py::isinstance<Type>(type_in)) {
73 type_ptr = type_in.cast<TypePtr>();
74 } else {
75 MS_LOG(EXCEPTION) << "Register must be string or `mindspore.dtype.Type`";
76 }
77 types.push_back(type_ptr);
78 }
79 Register(types, py_fn);
80 }
81
82 namespace {
HasUMonadType(const TypePtrList & types)83 bool HasUMonadType(const TypePtrList &types) {
84 auto types_size = types.size();
85 // If UMonad is the only type, ignore it.
86 if (types_size > 1) {
87 auto last_type = types[types_size - 1];
88 if (IsIdentidityOrSubclass(last_type, kUMonadType)) {
89 MS_LOG(DEBUG) << "Have Extra UMonad type";
90 return true;
91 }
92 }
93 return false;
94 }
95 } // namespace
96
97 // Return Exact match if exists, else return non ambiguous sub class match
98 // Return py::none() if matching is ambiguous
SignMatch(const TypePtrList & types)99 const std::pair<py::function, bool> MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
100 // Exact match
101 for (auto &item : fn_cache_py_) {
102 bool has_extra_u_monad = false;
103 TypePtrList sign = item.first;
104 auto types_size = types.size();
105 if (sign.size() != types_size) {
106 // Don't take the UMonad type into account.
107 has_extra_u_monad = (types_size > 1) && (sign.size() == (types_size - 1)) && HasUMonadType(types);
108 if (!has_extra_u_monad) {
109 continue;
110 }
111 }
112 auto match = true;
113 for (size_t i = 0; i < sign.size(); ++i) {
114 if (!IsIdentidityOrSubclass(types[i], sign[i])) {
115 match = false;
116 break;
117 }
118 }
119 if (!match) {
120 continue;
121 }
122 return std::pair(item.second, has_extra_u_monad);
123 }
124 return std::pair(py::none(), false);
125 }
126
GenerateFromTypes(const TypePtrList & types)127 FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
128 auto py_fn_pair = SignMatch(types);
129 auto py_fn = py_fn_pair.first;
130 std::ostringstream buffer;
131 buffer << types;
132 if (!py_fn.is_none()) {
133 FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
134 if (func_graph == nullptr) {
135 MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
136 }
137 MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
138 if (py_fn_pair.second) {
139 MS_LOG(DEBUG) << "Add extra UMoand type for func_graph: " << func_graph->ToString();
140 func_graph->add_parameter();
141 }
142 return func_graph;
143 }
144 auto stub = GenerateStubFunc(types);
145 if (stub != nullptr) {
146 MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString();
147 return stub;
148 }
149 std::ostringstream oss;
150 oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
151 << "`, corresponding location info:\n";
152 int64_t idx = 0;
153 for (auto &item : fn_cache_py_) {
154 FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
155 if (func_graph == nullptr) {
156 MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
157 continue;
158 }
159 oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
160 }
161 MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n"
162 << oss.str();
163 }
164
__anonecae05cb0202(const py::module *m) 165 REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
166 (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
167 *m, "MultitypeFuncGraph_")
168 .def(py::init<std::string &>())
169 .def("register_fn", &MultitypeFuncGraph::PyRegister);
170 }));
171 } // namespace prim
172 } // namespace mindspore
173