1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2020 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/func_graph.h"
20
21 #include <algorithm>
22 #include <sstream>
23
24 #include "ir/manager.h"
25 #include "base/core_ops.h"
26 #include "utils/ordered_set.h"
27 #include "abstract/abstract_value.h"
28 #include "abstract/abstract_function.h"
29 #include "utils/flags.h"
30
31 namespace mindspore {
32 using mindspore::abstract::AbstractFunction;
33 using mindspore::abstract::AbstractFunctionPtr;
34 using mindspore::abstract::AnalysisContextPtr;
35 using mindspore::abstract::PrimitiveAbstractClosure;
36 using mindspore::abstract::VirtualAbstractClosure;
37
abstract()38 AbstractFunctionPtr FuncGraph::abstract() {
39 AbstractBasePtrList args_spec_list;
40
41 for (auto &p : parameters_) {
42 MS_EXCEPTION_IF_NULL(p);
43 if (p->abstract() == nullptr) {
44 MS_LOG(ERROR) << "Error!!";
45 return nullptr;
46 }
47 args_spec_list.push_back(p->abstract());
48 }
49
50 if (output() == nullptr) {
51 MS_LOG(ERROR) << "Error func graph no output";
52 return nullptr;
53 }
54
55 return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
56 }
57
set_output(const AnfNodePtr & value,bool force_new_ret)58 void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
59 MS_EXCEPTION_IF_NULL(value);
60 if (force_new_ret || return_ == nullptr) {
61 std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
62 FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
63 return_ = this_graph->NewCNodeInOrder(params);
64 } else {
65 if (manager_.lock()) {
66 manager_.lock()->SetEdge(return_, 1, value);
67 } else {
68 constexpr auto first_data_index = 1;
69 return_->set_input(first_data_index, value);
70 }
71 }
72
73 return_->set_abstract(value->abstract());
74 AnfNodePtr input0 = return_->input(0);
75 auto f = std::make_shared<PrimitiveAbstractClosure>(prim::kPrimReturn, input0);
76 input0->set_abstract(f);
77 }
78
DumpFuncGraph(const std::string & path)79 void FuncGraph::DumpFuncGraph(const std::string &path) {
80 if (drawer_) {
81 drawer_(path + ".dot", shared_from_base<FuncGraph>());
82 }
83 }
84
GenerateVarParams(const FuncGraphPtr & specialized_graph,int variable_args_count,int pos_args_input_count,std::vector<AnfNodePtr> * specialized_parameter_list,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const85 void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count,
86 int pos_args_input_count, std::vector<AnfNodePtr> *specialized_parameter_list,
87 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
88 // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
89 MS_EXCEPTION_IF_NULL(specialized_graph);
90 if (specialized_graph->has_vararg()) {
91 TraceGuard trace_guard(
92 std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
93 std::vector<AnfNodePtr> var_param_tuple_nodes;
94 var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
95
96 if (variable_args_count < 0) {
97 MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count
98 << " were given.";
99 }
100 auto varg_name = specialized_graph->GetVariableArgName();
101 // for python variable argument input , there is no upper limit
102 for (int i = 0; i < variable_args_count; ++i) {
103 ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
104 std::string param_name = varg_name + std::to_string(i);
105 p->set_name(param_name);
106 MS_EXCEPTION_IF_NULL(p->debug_info());
107 p->debug_info()->set_name(param_name);
108 var_param_tuple_nodes.push_back(p);
109 MS_EXCEPTION_IF_NULL(specialized_parameter_list);
110 specialized_parameter_list->push_back(p);
111 }
112 auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes);
113 (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
114 } else if (variable_args_count > 0) {
115 MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount()
116 << " positional arguments, but " << pos_args_input_count << " were given.";
117 }
118 }
119
GenerateKwParams(const FuncGraphPtr & specialized_graph,const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list,std::vector<AnfNodePtr> * specialized_parameter_list,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const120 void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
121 const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
122 std::vector<AnfNodePtr> *specialized_parameter_list,
123 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
124 std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
125 std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
126
127 std::set<AnfNodePtr> kwarg_nodes;
128 for (const auto &kwarg : kwarg_list) {
129 MS_EXCEPTION_IF_NULL(kwarg);
130 std::string kw_param_name = kwarg->get_key();
131 MS_EXCEPTION_IF_NULL(specialized_graph);
132 AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
133 // if not find corresponding parameter node
134 if (param_node == nullptr) {
135 if (!has_kwarg()) {
136 MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
137 } else {
138 ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
139 std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
140 MS_EXCEPTION_IF_NULL(specialized_parameter_list);
141 auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
142 [param_name](const AnfNodePtr &node) {
143 MS_EXCEPTION_IF_NULL(node);
144 auto param = node->cast<ParameterPtr>();
145 return param != nullptr && param->name() == param_name;
146 });
147 if (find_kw_arg_in_list) {
148 MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name;
149 }
150 p->set_name(param_name);
151 p->debug_info()->set_name(param_name);
152 kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name));
153 auto extract_node =
154 specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p});
155 kwarg_values_tuple_nodes.push_back(extract_node);
156 specialized_parameter_list->push_back(p);
157 }
158 } else {
159 auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
160 // multiply values found given for parameter
161 if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) {
162 MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
163 } else {
164 specialized_parameter_list->push_back(param_node);
165 auto extract_node = specialized_graph->NewCNode(
166 {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
167 kwarg_nodes.insert(param_node);
168 (void)repl_nodes->emplace(param_node, extract_node);
169 }
170 }
171 }
172
173 GenerateKwargReplNode(specialized_graph, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes, repl_nodes);
174 }
175
GenerateKwargReplNode(const FuncGraphPtr & specialized_graph,const std::vector<AnfNodePtr> & kwarg_keys_tuple_nodes,const std::vector<AnfNodePtr> & kwarg_values_tuple_nodes,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const176 void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
177 const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
178 const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes,
179 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
180 if (has_kwarg()) {
181 MS_EXCEPTION_IF_NULL(specialized_graph);
182 TraceGuard guard(
183 std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
184 auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
185 auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
186 auto make_dict_node =
187 specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
188 MS_EXCEPTION_IF_NULL(repl_nodes);
189 (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
190 }
191 }
192
NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list)193 bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
194 // if the function does not have any vararg/kwarg/kwonly/default value/kw args input
195 // return the original graph
196 if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
197 return false;
198 }
199
200 // if the graph is generated for specific input, do not need to generate again
201 return !is_generated();
202 }
203
GenerateDefaultValue(const FuncGraphPtr & specialized_graph,const std::vector<AnfNodePtr> & specialized_parameter_list,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const204 void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
205 const std::vector<AnfNodePtr> &specialized_parameter_list,
206 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
207 MS_EXCEPTION_IF_NULL(specialized_graph);
208 for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
209 MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
210 auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
211 MS_EXCEPTION_IF_NULL(param_node);
212 auto param_name = param_node->name();
213 auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node);
214 if (node_itr != specialized_parameter_list.end()) {
215 continue;
216 }
217 if (param_name == specialized_graph->GetVariableArgName() ||
218 param_name == specialized_graph->GetVariableKwargName()) {
219 continue;
220 }
221 auto default_value = specialized_graph->GetDefaultValueByName(param_name);
222 if (default_value == nullptr) {
223 MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name;
224 }
225 MS_EXCEPTION_IF_NULL(repl_nodes);
226 (void)repl_nodes->emplace(param_node, default_value);
227 }
228 }
229
GenerateGraph(const AbstractBasePtrList & args_spec_list)230 FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
231 std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
232 std::vector<size_t> pos_arg_indexes;
233 size_t arguments_count = args_spec_list.size();
234 if (hyper_param_count_ > arguments_count) {
235 MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
236 }
237 for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
238 MS_EXCEPTION_IF_NULL(args_spec_list[i]);
239 if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
240 kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
241 } else {
242 pos_arg_indexes.push_back(i);
243 }
244 }
245
246 if (!NeedGenerate(kwarg_list)) {
247 return shared_from_base<FuncGraph>();
248 }
249 auto iter = func_graph_cache_.find(args_spec_list);
250 if (iter != func_graph_cache_.end()) {
251 return iter->second;
252 }
253 FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
254 size_t kwarg_count = kwarg_list.size();
255 int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
256 int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
257 int variable_args_count = pos_args_input_count - pos_args_count;
258 std::vector<AnfNodePtr> specialized_parameter_list;
259 std::unordered_map<AnfNodePtr, AnfNodePtr> repl_nodes;
260 // the parameters that has arg input, copy from original parameters
261 for (size_t i = 0; i < IntToSize(pos_args_count); ++i) {
262 specialized_parameter_list.push_back(specialized_graph->parameters()[i]);
263 }
264
265 GenerateVarParams(specialized_graph, variable_args_count, pos_args_input_count, &specialized_parameter_list,
266 &repl_nodes);
267
268 GenerateKwParams(specialized_graph, kwarg_list, &specialized_parameter_list, &repl_nodes);
269
270 GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
271
272 // append hyper parameter to specialized_parameter_list
273 MS_EXCEPTION_IF_NULL(specialized_graph);
274 auto params = specialized_graph->parameters();
275 specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_),
276 params.end());
277 std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
278 specialized_parameter_list.end());
279 for (size_t i = 0; i < pos_arg_indexes.size(); i++) {
280 specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i],
281 specialized_parameter_list[i]);
282 }
283
284 std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
285 auto tr = manager->Transact();
286 for (auto &node_pair : repl_nodes) {
287 MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
288 << node_pair.second->DebugString();
289 (void)tr.Replace(node_pair.first, node_pair.second);
290 }
291 tr.SetParameters(specialized_graph, specialized_parameter_list_update);
292 tr.Commit();
293 specialized_graph->set_has_kwarg(false);
294 specialized_graph->set_has_vararg(false);
295 specialized_graph->set_kwonlyargs_count(0);
296 specialized_graph->ClearDefaultValues();
297 specialized_graph->set_is_generate(true);
298 func_graph_cache_[args_spec_list] = specialized_graph;
299 return specialized_graph;
300 }
301
FindRoots(const std::vector<CNodePtr> & segment)302 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
303 std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
304 for (const auto &node : segment) {
305 if (roots->size() == 1) {
306 return roots;
307 }
308 auto input_size = node->size();
309 for (size_t i = 0; i < input_size; i++) {
310 auto in_node = node->input(i);
311 auto in_cnode = in_node->cast<CNodePtr>();
312 if (in_cnode != nullptr) {
313 (void)roots->erase(in_cnode);
314 }
315 }
316 }
317 return roots;
318 }
319
FindLeaves(const std::vector<CNodePtr> & segment)320 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
321 std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
322 for (const auto &node : segment) {
323 if (nodes->size() == 1) {
324 return nodes;
325 }
326 if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
327 (void)nodes->erase(node);
328 continue;
329 }
330 auto input_size = node->size();
331 for (size_t i = 0; i < input_size; i++) {
332 auto in_node = node->input(i);
333 if (!in_node->isa<CNode>()) {
334 continue;
335 }
336 auto in_cnode = in_node->cast<CNodePtr>();
337 if (in_cnode != nullptr) {
338 if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) {
339 (void)nodes->erase(node);
340 break;
341 }
342 }
343 }
344 }
345 return nodes;
346 }
347 } // namespace mindspore
348