• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
4  *
5  * Copyright 2020 Huawei Technologies Co., Ltd
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "frontend/operator/composite/multitype_funcgraph.h"
21 
22 #include "abstract/abstract_function.h"
23 #include "abstract/dshape.h"
24 #include "frontend/optimizer/opt.h"
25 #include "utils/ms_context.h"
26 #include "pipeline/jit/ps/fallback.h"
27 #include "include/common/pybind_api/api_register.h"
28 #include "include/common/fallback.h"
29 #include "ir/signature.h"
30 #include "ir/dtype.h"
31 #include "pipeline/jit/ps/debug/trace.h"
32 
33 namespace mindspore {
34 // namespace to support composite operators definition
35 namespace prim {
MultitypeFuncGraph(const std::string & name)36 MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
37   fn_cache_.clear();
38   // def multitype(*args:ref):
39   signatures_ = std::vector<Signature>({{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
40 }
41 
Register(const TypePtrList & types,specialize_fn s_fn)42 void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
43   MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
44   auto result = fn_cache_.emplace(types, s_fn);
45   if (!result.second) {
46     MS_LOG(INTERNAL_EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
47   }
48 }
49 
Register(const TypePtrList & types,const py::function & py_fn)50 void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
51   MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << py::str(py_fn.cast<py::object>())
52                 << ").";
53   if (fn_cache_.find(types) != fn_cache_.end()) {
54     MS_LOG(INTERNAL_EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
55   }
56   fn_cache_py_[types] = py_fn;
57 }
58 
PyRegister(const py::tuple & tuple,const py::function & py_fn)59 void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
60   TypePtrList types;
61   for (size_t it = 0; it < tuple.size(); ++it) {
62     py::object type_in = tuple[it];
63     TypePtr type_ptr = nullptr;
64     if (py::isinstance<py::str>(type_in)) {
65       auto type_name = type_in.cast<std::string>();
66       type_ptr = StringToType(type_name);
67       if (type_ptr == nullptr) {
68         MS_LOG(INTERNAL_EXCEPTION) << type_name << " convert from string error ";
69       }
70     } else if (py::isinstance<Type>(type_in)) {
71       type_ptr = type_in.cast<TypePtr>();
72     } else {
73       MS_LOG(INTERNAL_EXCEPTION) << "Register must be string or `mindspore.dtype.Type`";
74     }
75     types.push_back(type_ptr);
76   }
77   Register(types, py_fn);
78 }
79 
80 namespace {
HasUMonadType(const TypePtrList & types)81 bool HasUMonadType(const TypePtrList &types) {
82   auto types_size = types.size();
83   // If UMonad is the only type, ignore it.
84   if (types_size > 1) {
85     auto last_type = types[types_size - 1];
86     if (IsIdentidityOrSubclass(last_type, kUMonadType)) {
87       MS_LOG(DEBUG) << "Have Extra UMonad type";
88       return true;
89     }
90   }
91   return false;
92 }
93 
GetTypesPrefixMatchedNum(const TypePtrList & types,const TypePtrList & sign)94 size_t GetTypesPrefixMatchedNum(const TypePtrList &types, const TypePtrList &sign) {
95   for (size_t i = 0; i < sign.size(); ++i) {
96     if (!IsIdentidityOrSubclass(types[i], sign[i])) {
97       return i;
98     }
99   }
100   return sign.size();
101 }
102 
IntToNumber(const std::string & v)103 std::string IntToNumber(const std::string &v) {
104   static mindspore::HashMap<std::string, std::string> int_to_number{
105     {"Int64", "Number"}, {"Int32", "Number"}, {"Int8", "Number"}};
106   auto iter = int_to_number.find(v);
107   if (iter != int_to_number.end()) {
108     return iter->second;
109   } else {
110     return v;
111   }
112 }
113 
GetSortedCache(const TypeListMap<py::function> & fn_cache_py_,const TypePtrList & types,size_t match_max_idx)114 std::vector<mindspore::TypePtrList> GetSortedCache(const TypeListMap<py::function> &fn_cache_py_,
115                                                    const TypePtrList &types, size_t match_max_idx) {
116   std::vector<mindspore::TypePtrList> cache_vec;
117   (void)std::transform(fn_cache_py_.begin(), fn_cache_py_.end(), back_inserter(cache_vec),
118                        [](const auto &fcp) { return fcp.first; });
119 
120   for (auto it = cache_vec.begin(); it != cache_vec.end();) {
121     if (GetTypesPrefixMatchedNum(types, *it) != match_max_idx) {
122       it = cache_vec.erase(it);
123     } else {
124       ++it;
125     }
126   }
127 
128   auto comparator = [match_max_idx](const mindspore::TypePtrList &a, const mindspore::TypePtrList &b) {
129     if (a.size() > b.size()) {
130       return false;
131     }
132     if (a.size() < b.size()) {
133       return true;
134     }
135     for (size_t i = match_max_idx; i < a.size(); ++i) {
136       if (a[i]->type_id() == b[i]->type_id()) {
137         continue;
138       }
139       return a[i]->type_id() < b[i]->type_id();
140     }
141     return false;
142   };
143   std::sort(cache_vec.begin(), cache_vec.end(), comparator);
144   return cache_vec;
145 }
146 }  // namespace
147 
148 // Return Exact match if exists,  else return non ambiguous sub class match
149 // Return py::none() if matching is ambiguous
SignMatch(const TypePtrList & types)150 const std::tuple<py::function, bool, size_t> MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
151   // Exact match
152   size_t match_max_idx = 0;
153   for (auto &item : fn_cache_py_) {
154     bool has_extra_u_monad = false;
155     TypePtrList sign = item.first;
156     auto types_size = types.size();
157     if (sign.size() != types_size) {
158       // Don't take the UMonad type into account.
159       has_extra_u_monad = (types_size > 1) && (sign.size() == (types_size - 1)) && HasUMonadType(types);
160       if (!has_extra_u_monad) {
161         continue;
162       }
163     }
164     size_t match_idx = GetTypesPrefixMatchedNum(types, sign);
165     if (match_idx > match_max_idx) {
166       match_max_idx = match_idx;
167     }
168     if (match_idx == sign.size()) {
169       return std::make_tuple(item.second, has_extra_u_monad, sign.size());
170     }
171   }
172   return std::make_tuple(py::none(), false, match_max_idx);
173 }
174 
PrintMatchFailLog(const TypeListMap<py::function>,const TypePtrList & types,size_t match_max_idx,bool has_any)175 const std::string MultitypeFuncGraph::PrintMatchFailLog(const TypeListMap<py::function>, const TypePtrList &types,
176                                                         size_t match_max_idx, bool has_any) {
177   std::ostringstream buffer1;
178   py::list types_list;
179   bool external_flag = false;
180   buffer1 << "<";
181   for (size_t i = 0; i < types.size(); ++i) {
182     if (types[i]->type_id() == kMetaTypeExternal) {
183       external_flag = true;
184     }
185     std::string types_to_int = IntToNumber(TypeIdLabel(types[i]->type_id()));
186     types_list.append(types_to_int);
187     if (i != types.size() - 1) {
188       buffer1 << types_to_int << ", ";
189     } else {
190       buffer1 << types_to_int << ">";
191     }
192   }
193   if (has_any && match_max_idx >= types_list.size()) {
194     MS_LOG(EXCEPTION)
195       << "In the inputs of operation '" << name_
196       << "', there are unsupported syntax in graph mode. Those codes would be fallen back to python interpreter, "
197       << "which is not supported for operation '" << name_ << "'.";
198   }
199 
200   std::ostringstream buffer2;
201   if (match_max_idx == 1) {
202     buffer2 << "When first argument is '" << types_list[0].str() << "', ";
203   }
204   if (match_max_idx > 1) {
205     buffer2 << "When arguments are given as ";
206     for (size_t i = 0; i < match_max_idx; ++i) {
207       buffer2 << "'" << types_list[i].str() << "', ";
208     }
209   }
210 
211   std::ostringstream oss;
212   oss << "For operation '" << name_ << "', current input arguments types are " << buffer1.str() << ". The "
213       << (match_max_idx + 1) << "-th argument type '" << types_list[match_max_idx].str() << "' is not supported now.\n"
214       << buffer2.str() << "the support arguments types of '" << name_ << "' operation as follows:\n";
215   const std::vector<mindspore::TypePtrList> cache_vec = GetSortedCache(fn_cache_py_, types, match_max_idx);
216   for (auto &item : cache_vec) {
217     oss << "<";
218     for (size_t i = 0; i < item.size(); ++i) {
219       std::string item_str = item[i]->ToString();
220       (void)item_str.erase(std::remove(item_str.begin(), item_str.end(), ' '), item_str.end());
221       if (i != item.size() - 1) {
222         oss << item_str << ", ";
223       } else {
224         oss << item_str << ">\n";
225       }
226     }
227   }
228 
229   if (!doc_url_.empty()) {
230     oss << "For more details with '" << name_ << "', please refer to " << doc_url_ << "\n";
231   } else if (external_flag) {
232     oss << "For more details with 'External', please refer to "
233            "https://www.mindspore.cn/search?inputValue=%27External%27%20TypeError\n";
234   }
235 
236   return oss.str();
237 }
238 
CheckDictContainsAny(const std::vector<std::pair<mindspore::ValuePtr,mindspore::TypePtr>> & key_values)239 bool CheckDictContainsAny(const std::vector<std::pair<mindspore::ValuePtr, mindspore::TypePtr>> &key_values) {
240   for (const auto &pair : key_values) {
241     const auto &type = pair.second;
242     if (type->isa<AnyType>()) {
243       return true;
244     }
245     bool res = false;
246     if (type->isa<Tuple>()) {
247       const auto &elements = type->cast<TuplePtr>()->elements();
248       res = CheckContainsAny(elements);
249     } else if (type->isa<List>()) {
250       const auto &elements = type->cast<ListPtr>()->elements();
251       res = CheckContainsAny(elements);
252     } else if (type->isa<Dictionary>()) {
253       const auto &elements = type->cast<DictionaryPtr>()->key_values();
254       res = CheckDictContainsAny(elements);
255     }
256     if (res) {
257       return true;
258     }
259   }
260   return false;
261 }
262 
CheckContainsAny(const TypePtrList & types)263 bool CheckContainsAny(const TypePtrList &types) {
264   for (const auto &type : types) {
265     if (type->isa<AnyType>()) {
266       return true;
267     }
268     bool res = false;
269     if (type->isa<Tuple>()) {
270       const auto &elements = type->cast<TuplePtr>()->elements();
271       res = CheckContainsAny(elements);
272     } else if (type->isa<List>()) {
273       const auto &elements = type->cast<ListPtr>()->elements();
274       res = CheckContainsAny(elements);
275     } else if (type->isa<Dictionary>()) {
276       const auto &elements = type->cast<DictionaryPtr>()->key_values();
277       res = CheckDictContainsAny(elements);
278     }
279     if (res) {
280       return true;
281     }
282   }
283   return false;
284 }
285 
GenerateFromTypes(const TypePtrList & types)286 FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
287   auto [py_fn, has_extra_u_monad, match_max_idx] = SignMatch(types);
288   std::ostringstream buffer;
289   buffer << types;
290   bool need_convert = false;
291   if (name_ == "getitem" || name_ == "setitem") {
292     need_convert = CheckContainsAny(types);
293   } else {
294     need_convert = std::any_of(types.begin(), types.end(), [](const TypePtr &type) { return type->isa<AnyType>(); });
295   }
296   if (!py_fn.is_none() && !need_convert) {
297     FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
298     if (func_graph == nullptr) {
299       MS_LOG(INTERNAL_EXCEPTION) << "Fail to parse overload function " << buffer.str() << ".";
300     }
301     MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString() << ".";
302     if (has_extra_u_monad) {
303       MS_LOG(DEBUG) << "Add extra UMoand type for func_graph: " << func_graph->ToString() << ".";
304       func_graph->add_parameter();
305     }
306     return func_graph;
307   }
308   auto stub = GenerateStubFunc(types);
309   if (stub != nullptr) {
310     MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString() << ".";
311     return stub;
312   }
313 
314   bool has_dic = std::any_of(types.begin(), types.end(), [](const TypePtr &type) { return type->isa<Dictionary>(); });
315   if (!need_raise_ || !has_dic) {
316     FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
317     AnfNodePtrList node_inputs{};
318     for (auto type : types) {
319       node_inputs.push_back(func_graph->add_parameter());
320     }
321     auto ret_node =
322       fallback::GeneratePyInterpretNodeFromMetaFuncGraph(func_graph, node_inputs, meta_obj_, types, name_);
323     func_graph->set_output(ret_node);
324     return func_graph;
325   }
326 
327   auto match_fail_log = PrintMatchFailLog(fn_cache_py_, types, match_max_idx, need_convert);
328   MS_LOG(EXCEPTION) << match_fail_log;
329 }
330 }  // namespace prim
331 }  // namespace mindspore
332