1
2 /**
3 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
4 *
5 * Copyright 2020 Huawei Technologies Co., Ltd
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "frontend/operator/composite/multitype_funcgraph.h"
21
22 #include "abstract/abstract_function.h"
23 #include "abstract/dshape.h"
24 #include "frontend/optimizer/opt.h"
25 #include "utils/ms_context.h"
26 #include "pipeline/jit/ps/fallback.h"
27 #include "include/common/pybind_api/api_register.h"
28 #include "include/common/fallback.h"
29 #include "ir/signature.h"
30 #include "ir/dtype.h"
31 #include "pipeline/jit/ps/debug/trace.h"
32
33 namespace mindspore {
34 // namespace to support composite operators definition
35 namespace prim {
MultitypeFuncGraph(const std::string & name)36 MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
37 fn_cache_.clear();
38 // def multitype(*args:ref):
39 signatures_ = std::vector<Signature>({{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
40 }
41
Register(const TypePtrList & types,specialize_fn s_fn)42 void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
43 MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
44 auto result = fn_cache_.emplace(types, s_fn);
45 if (!result.second) {
46 MS_LOG(INTERNAL_EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
47 }
48 }
49
Register(const TypePtrList & types,const py::function & py_fn)50 void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
51 MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << py::str(py_fn.cast<py::object>())
52 << ").";
53 if (fn_cache_.find(types) != fn_cache_.end()) {
54 MS_LOG(INTERNAL_EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
55 }
56 fn_cache_py_[types] = py_fn;
57 }
58
PyRegister(const py::tuple & tuple,const py::function & py_fn)59 void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
60 TypePtrList types;
61 for (size_t it = 0; it < tuple.size(); ++it) {
62 py::object type_in = tuple[it];
63 TypePtr type_ptr = nullptr;
64 if (py::isinstance<py::str>(type_in)) {
65 auto type_name = type_in.cast<std::string>();
66 type_ptr = StringToType(type_name);
67 if (type_ptr == nullptr) {
68 MS_LOG(INTERNAL_EXCEPTION) << type_name << " convert from string error ";
69 }
70 } else if (py::isinstance<Type>(type_in)) {
71 type_ptr = type_in.cast<TypePtr>();
72 } else {
73 MS_LOG(INTERNAL_EXCEPTION) << "Register must be string or `mindspore.dtype.Type`";
74 }
75 types.push_back(type_ptr);
76 }
77 Register(types, py_fn);
78 }
79
80 namespace {
HasUMonadType(const TypePtrList & types)81 bool HasUMonadType(const TypePtrList &types) {
82 auto types_size = types.size();
83 // If UMonad is the only type, ignore it.
84 if (types_size > 1) {
85 auto last_type = types[types_size - 1];
86 if (IsIdentidityOrSubclass(last_type, kUMonadType)) {
87 MS_LOG(DEBUG) << "Have Extra UMonad type";
88 return true;
89 }
90 }
91 return false;
92 }
93
GetTypesPrefixMatchedNum(const TypePtrList & types,const TypePtrList & sign)94 size_t GetTypesPrefixMatchedNum(const TypePtrList &types, const TypePtrList &sign) {
95 for (size_t i = 0; i < sign.size(); ++i) {
96 if (!IsIdentidityOrSubclass(types[i], sign[i])) {
97 return i;
98 }
99 }
100 return sign.size();
101 }
102
IntToNumber(const std::string & v)103 std::string IntToNumber(const std::string &v) {
104 static mindspore::HashMap<std::string, std::string> int_to_number{
105 {"Int64", "Number"}, {"Int32", "Number"}, {"Int8", "Number"}};
106 auto iter = int_to_number.find(v);
107 if (iter != int_to_number.end()) {
108 return iter->second;
109 } else {
110 return v;
111 }
112 }
113
GetSortedCache(const TypeListMap<py::function> & fn_cache_py_,const TypePtrList & types,size_t match_max_idx)114 std::vector<mindspore::TypePtrList> GetSortedCache(const TypeListMap<py::function> &fn_cache_py_,
115 const TypePtrList &types, size_t match_max_idx) {
116 std::vector<mindspore::TypePtrList> cache_vec;
117 (void)std::transform(fn_cache_py_.begin(), fn_cache_py_.end(), back_inserter(cache_vec),
118 [](const auto &fcp) { return fcp.first; });
119
120 for (auto it = cache_vec.begin(); it != cache_vec.end();) {
121 if (GetTypesPrefixMatchedNum(types, *it) != match_max_idx) {
122 it = cache_vec.erase(it);
123 } else {
124 ++it;
125 }
126 }
127
128 auto comparator = [match_max_idx](const mindspore::TypePtrList &a, const mindspore::TypePtrList &b) {
129 if (a.size() > b.size()) {
130 return false;
131 }
132 if (a.size() < b.size()) {
133 return true;
134 }
135 for (size_t i = match_max_idx; i < a.size(); ++i) {
136 if (a[i]->type_id() == b[i]->type_id()) {
137 continue;
138 }
139 return a[i]->type_id() < b[i]->type_id();
140 }
141 return false;
142 };
143 std::sort(cache_vec.begin(), cache_vec.end(), comparator);
144 return cache_vec;
145 }
146 } // namespace
147
148 // Return Exact match if exists, else return non ambiguous sub class match
149 // Return py::none() if matching is ambiguous
SignMatch(const TypePtrList & types)150 const std::tuple<py::function, bool, size_t> MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
151 // Exact match
152 size_t match_max_idx = 0;
153 for (auto &item : fn_cache_py_) {
154 bool has_extra_u_monad = false;
155 TypePtrList sign = item.first;
156 auto types_size = types.size();
157 if (sign.size() != types_size) {
158 // Don't take the UMonad type into account.
159 has_extra_u_monad = (types_size > 1) && (sign.size() == (types_size - 1)) && HasUMonadType(types);
160 if (!has_extra_u_monad) {
161 continue;
162 }
163 }
164 size_t match_idx = GetTypesPrefixMatchedNum(types, sign);
165 if (match_idx > match_max_idx) {
166 match_max_idx = match_idx;
167 }
168 if (match_idx == sign.size()) {
169 return std::make_tuple(item.second, has_extra_u_monad, sign.size());
170 }
171 }
172 return std::make_tuple(py::none(), false, match_max_idx);
173 }
174
PrintMatchFailLog(const TypeListMap<py::function>,const TypePtrList & types,size_t match_max_idx,bool has_any)175 const std::string MultitypeFuncGraph::PrintMatchFailLog(const TypeListMap<py::function>, const TypePtrList &types,
176 size_t match_max_idx, bool has_any) {
177 std::ostringstream buffer1;
178 py::list types_list;
179 bool external_flag = false;
180 buffer1 << "<";
181 for (size_t i = 0; i < types.size(); ++i) {
182 if (types[i]->type_id() == kMetaTypeExternal) {
183 external_flag = true;
184 }
185 std::string types_to_int = IntToNumber(TypeIdLabel(types[i]->type_id()));
186 types_list.append(types_to_int);
187 if (i != types.size() - 1) {
188 buffer1 << types_to_int << ", ";
189 } else {
190 buffer1 << types_to_int << ">";
191 }
192 }
193 if (has_any && match_max_idx >= types_list.size()) {
194 MS_LOG(EXCEPTION)
195 << "In the inputs of operation '" << name_
196 << "', there are unsupported syntax in graph mode. Those codes would be fallen back to python interpreter, "
197 << "which is not supported for operation '" << name_ << "'.";
198 }
199
200 std::ostringstream buffer2;
201 if (match_max_idx == 1) {
202 buffer2 << "When first argument is '" << types_list[0].str() << "', ";
203 }
204 if (match_max_idx > 1) {
205 buffer2 << "When arguments are given as ";
206 for (size_t i = 0; i < match_max_idx; ++i) {
207 buffer2 << "'" << types_list[i].str() << "', ";
208 }
209 }
210
211 std::ostringstream oss;
212 oss << "For operation '" << name_ << "', current input arguments types are " << buffer1.str() << ". The "
213 << (match_max_idx + 1) << "-th argument type '" << types_list[match_max_idx].str() << "' is not supported now.\n"
214 << buffer2.str() << "the support arguments types of '" << name_ << "' operation as follows:\n";
215 const std::vector<mindspore::TypePtrList> cache_vec = GetSortedCache(fn_cache_py_, types, match_max_idx);
216 for (auto &item : cache_vec) {
217 oss << "<";
218 for (size_t i = 0; i < item.size(); ++i) {
219 std::string item_str = item[i]->ToString();
220 (void)item_str.erase(std::remove(item_str.begin(), item_str.end(), ' '), item_str.end());
221 if (i != item.size() - 1) {
222 oss << item_str << ", ";
223 } else {
224 oss << item_str << ">\n";
225 }
226 }
227 }
228
229 if (!doc_url_.empty()) {
230 oss << "For more details with '" << name_ << "', please refer to " << doc_url_ << "\n";
231 } else if (external_flag) {
232 oss << "For more details with 'External', please refer to "
233 "https://www.mindspore.cn/search?inputValue=%27External%27%20TypeError\n";
234 }
235
236 return oss.str();
237 }
238
CheckDictContainsAny(const std::vector<std::pair<mindspore::ValuePtr,mindspore::TypePtr>> & key_values)239 bool CheckDictContainsAny(const std::vector<std::pair<mindspore::ValuePtr, mindspore::TypePtr>> &key_values) {
240 for (const auto &pair : key_values) {
241 const auto &type = pair.second;
242 if (type->isa<AnyType>()) {
243 return true;
244 }
245 bool res = false;
246 if (type->isa<Tuple>()) {
247 const auto &elements = type->cast<TuplePtr>()->elements();
248 res = CheckContainsAny(elements);
249 } else if (type->isa<List>()) {
250 const auto &elements = type->cast<ListPtr>()->elements();
251 res = CheckContainsAny(elements);
252 } else if (type->isa<Dictionary>()) {
253 const auto &elements = type->cast<DictionaryPtr>()->key_values();
254 res = CheckDictContainsAny(elements);
255 }
256 if (res) {
257 return true;
258 }
259 }
260 return false;
261 }
262
CheckContainsAny(const TypePtrList & types)263 bool CheckContainsAny(const TypePtrList &types) {
264 for (const auto &type : types) {
265 if (type->isa<AnyType>()) {
266 return true;
267 }
268 bool res = false;
269 if (type->isa<Tuple>()) {
270 const auto &elements = type->cast<TuplePtr>()->elements();
271 res = CheckContainsAny(elements);
272 } else if (type->isa<List>()) {
273 const auto &elements = type->cast<ListPtr>()->elements();
274 res = CheckContainsAny(elements);
275 } else if (type->isa<Dictionary>()) {
276 const auto &elements = type->cast<DictionaryPtr>()->key_values();
277 res = CheckDictContainsAny(elements);
278 }
279 if (res) {
280 return true;
281 }
282 }
283 return false;
284 }
285
GenerateFromTypes(const TypePtrList & types)286 FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
287 auto [py_fn, has_extra_u_monad, match_max_idx] = SignMatch(types);
288 std::ostringstream buffer;
289 buffer << types;
290 bool need_convert = false;
291 if (name_ == "getitem" || name_ == "setitem") {
292 need_convert = CheckContainsAny(types);
293 } else {
294 need_convert = std::any_of(types.begin(), types.end(), [](const TypePtr &type) { return type->isa<AnyType>(); });
295 }
296 if (!py_fn.is_none() && !need_convert) {
297 FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
298 if (func_graph == nullptr) {
299 MS_LOG(INTERNAL_EXCEPTION) << "Fail to parse overload function " << buffer.str() << ".";
300 }
301 MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString() << ".";
302 if (has_extra_u_monad) {
303 MS_LOG(DEBUG) << "Add extra UMoand type for func_graph: " << func_graph->ToString() << ".";
304 func_graph->add_parameter();
305 }
306 return func_graph;
307 }
308 auto stub = GenerateStubFunc(types);
309 if (stub != nullptr) {
310 MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString() << ".";
311 return stub;
312 }
313
314 bool has_dic = std::any_of(types.begin(), types.end(), [](const TypePtr &type) { return type->isa<Dictionary>(); });
315 if (!need_raise_ || !has_dic) {
316 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
317 AnfNodePtrList node_inputs{};
318 for (auto type : types) {
319 node_inputs.push_back(func_graph->add_parameter());
320 }
321 auto ret_node =
322 fallback::GeneratePyInterpretNodeFromMetaFuncGraph(func_graph, node_inputs, meta_obj_, types, name_);
323 func_graph->set_output(ret_node);
324 return func_graph;
325 }
326
327 auto match_fail_log = PrintMatchFailLog(fn_cache_py_, types, match_max_idx, need_convert);
328 MS_LOG(EXCEPTION) << match_fail_log;
329 }
330 } // namespace prim
331 } // namespace mindspore
332