• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
4  *
5  * Copyright 2020 Huawei Technologies Co., Ltd
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "frontend/operator/composite/multitype_funcgraph.h"
21 #include <utility>
22 #include <sstream>
23 
24 #include "abstract/abstract_function.h"
25 #include "abstract/dshape.h"
26 #include "frontend/optimizer/opt.h"
27 #include "utils/ms_context.h"
28 #include "pybind_api/api_register.h"
29 #include "ir/signature.h"
30 #include "ir/dtype.h"
31 #include "debug/trace.h"
32 
33 namespace mindspore {
34 // namespace to support composite operators definition
35 namespace prim {
MultitypeFuncGraph(const std::string & name)36 MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) {
37   fn_cache_.clear();
38   // def multitype(*args:ref):
39   signatures_ = std::vector<Signature>({{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
40 }
41 
Register(const TypePtrList & types,specialize_fn s_fn)42 void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) {
43   MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ".";
44   auto fn = fn_cache_.find(types);
45   if (fn != fn_cache_.end()) {
46     MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
47   }
48   fn_cache_[types] = s_fn;
49 }
50 
Register(const TypePtrList & types,const py::function & py_fn)51 void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) {
52   MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << py::str(py_fn.cast<py::object>())
53                 << ").";
54   auto fn = fn_cache_.find(types);
55   if (fn != fn_cache_.end()) {
56     MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered.";
57   }
58   fn_cache_py_[types] = py_fn;
59 }
60 
PyRegister(const py::tuple & tuple,const py::function & py_fn)61 void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) {
62   TypePtrList types;
63   for (size_t it = 0; it < tuple.size(); ++it) {
64     py::object type_in = tuple[it];
65     TypePtr type_ptr = nullptr;
66     if (py::isinstance<py::str>(type_in)) {
67       auto type_name = type_in.cast<std::string>();
68       type_ptr = StringToType(type_name);
69       if (type_ptr == nullptr) {
70         MS_LOG(EXCEPTION) << type_name << " convert from string error ";
71       }
72     } else if (py::isinstance<Type>(type_in)) {
73       type_ptr = type_in.cast<TypePtr>();
74     } else {
75       MS_LOG(EXCEPTION) << "Register must be string or `mindspore.dtype.Type`";
76     }
77     types.push_back(type_ptr);
78   }
79   Register(types, py_fn);
80 }
81 
82 namespace {
HasUMonadType(const TypePtrList & types)83 bool HasUMonadType(const TypePtrList &types) {
84   auto types_size = types.size();
85   // If UMonad is the only type, ignore it.
86   if (types_size > 1) {
87     auto last_type = types[types_size - 1];
88     if (IsIdentidityOrSubclass(last_type, kUMonadType)) {
89       MS_LOG(DEBUG) << "Have Extra UMonad type";
90       return true;
91     }
92   }
93   return false;
94 }
95 }  // namespace
96 
97 // Return Exact match if exists,  else return non ambiguous sub class match
98 // Return py::none() if matching is ambiguous
SignMatch(const TypePtrList & types)99 const std::pair<py::function, bool> MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
100   // Exact match
101   for (auto &item : fn_cache_py_) {
102     bool has_extra_u_monad = false;
103     TypePtrList sign = item.first;
104     auto types_size = types.size();
105     if (sign.size() != types_size) {
106       // Don't take the UMonad type into account.
107       has_extra_u_monad = (types_size > 1) && (sign.size() == (types_size - 1)) && HasUMonadType(types);
108       if (!has_extra_u_monad) {
109         continue;
110       }
111     }
112     auto match = true;
113     for (size_t i = 0; i < sign.size(); ++i) {
114       if (!IsIdentidityOrSubclass(types[i], sign[i])) {
115         match = false;
116         break;
117       }
118     }
119     if (!match) {
120       continue;
121     }
122     return std::pair(item.second, has_extra_u_monad);
123   }
124   return std::pair(py::none(), false);
125 }
126 
GenerateFromTypes(const TypePtrList & types)127 FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
128   auto py_fn_pair = SignMatch(types);
129   auto py_fn = py_fn_pair.first;
130   std::ostringstream buffer;
131   buffer << types;
132   if (!py_fn.is_none()) {
133     FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
134     if (func_graph == nullptr) {
135       MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();
136     }
137     MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
138     if (py_fn_pair.second) {
139       MS_LOG(DEBUG) << "Add extra UMoand type for func_graph: " << func_graph->ToString();
140       func_graph->add_parameter();
141     }
142     return func_graph;
143   }
144   auto stub = GenerateStubFunc(types);
145   if (stub != nullptr) {
146     MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString();
147     return stub;
148   }
149   std::ostringstream oss;
150   oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
151       << "`, corresponding location info:\n";
152   int64_t idx = 0;
153   for (auto &item : fn_cache_py_) {
154     FuncGraphPtr func_graph = parse::ParsePythonCode(item.second);
155     if (func_graph == nullptr) {
156       MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`.";
157       continue;
158     }
159     oss << ++idx << ". " << item.first << "\n  " << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
160   }
161   MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n"
162                     << oss.str();
163 }
164 
__anonecae05cb0202(const py::module *m) 165 REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) {
166                          (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>(
167                            *m, "MultitypeFuncGraph_")
168                            .def(py::init<std::string &>())
169                            .def("register_fn", &MultitypeFuncGraph::PyRegister);
170                        }));
171 }  // namespace prim
172 }  // namespace mindspore
173