1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2020-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/func_graph.h"
20
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "ir/manager.h"
24 #include "utils/ordered_set.h"
25 #include "abstract/abstract_value.h"
26 #include "abstract/abstract_function.h"
27 #include "ir/func_graph_cloner.h"
28
29 namespace mindspore {
30 using mindspore::abstract::AbstractFunction;
31 using mindspore::abstract::AbstractFunctionPtr;
32 using mindspore::abstract::AnalysisContextPtr;
33 using mindspore::abstract::PrimitiveAbstractClosure;
34 using mindspore::abstract::VirtualAbstractClosure;
35
abstract()36 AbstractFunctionPtr FuncGraph::abstract() {
37 AbstractBasePtrList args_abs_list;
38
39 for (auto ¶ : parameters_) {
40 MS_EXCEPTION_IF_NULL(para);
41 if (para->abstract() == nullptr) {
42 MS_LOG(ERROR) << "Error!!";
43 return nullptr;
44 }
45 args_abs_list.push_back(para->abstract());
46 }
47
48 if (output() == nullptr) {
49 MS_LOG(ERROR) << "Error func graph no output";
50 return nullptr;
51 }
52 MS_EXCEPTION_IF_NULL(output());
53 return std::make_shared<VirtualAbstractClosure>(args_abs_list, output()->abstract());
54 }
55
set_output(const AnfNodePtr & value,bool force_new_ret)56 void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
57 MS_EXCEPTION_IF_NULL(value);
58 if (force_new_ret || return_node() == nullptr) {
59 AnfNodePtrList params({NewValueNode(prim::kPrimReturn), value});
60 FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
61 set_return(this_graph->NewCNodeInOrder(std::move(params)));
62 } else {
63 if (manager_.lock()) {
64 manager_.lock()->SetEdge(return_node(), 1, value);
65 } else {
66 constexpr auto first_data_index = 1;
67 return_node()->set_input(first_data_index, value);
68 }
69 }
70
71 return_node()->set_abstract(value->abstract());
72 AnfNodePtr input0 = return_node()->input(0);
73 auto f = std::make_shared<PrimitiveAbstractClosure>(prim::kPrimReturn, input0);
74 input0->set_abstract(f);
75 }
76
GenerateVarParams(const FuncGraphPtr & specialized_graph,int variable_args_count,int pos_args_input_count,AnfNodePtrList * specialized_parameter_list,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const77 void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count,
78 int pos_args_input_count, AnfNodePtrList *specialized_parameter_list,
79 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
80 MS_EXCEPTION_IF_NULL(specialized_graph);
81 if (!specialized_graph->has_vararg()) {
82 if (variable_args_count > 0) {
83 MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << GetPositionalArgsCount()
84 << " positional arguments, but " << pos_args_input_count << " were given.";
85 }
86 // Only copy parameters other than default arguments.
87 for (size_t i = 0; i < IntToSize(pos_args_input_count); ++i) {
88 specialized_parameter_list->push_back(specialized_graph->parameters()[i]);
89 }
90 return;
91 }
92
93 // If there is variable argument.
94 if (variable_args_count < 0) {
95 MS_LOG(EXCEPTION) << "For function:" << this->ToString() << ", its argument size: " << pos_args_input_count
96 << " is less or equal to parameter size: " << GetPositionalArgsCount();
97 }
98 // Copy other parameters than vararg's firstly.
99 for (size_t i = 0; i < IntToSize(GetPositionalArgsCount()); ++i) {
100 specialized_parameter_list->push_back(specialized_graph->parameters()[i]);
101 }
102 MS_EXCEPTION_IF_NULL(specialized_graph->GetVariableArgParameter());
103 TraceGuard trace_guard(
104 std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
105 AnfNodePtrList var_param_tuple_nodes;
106 var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
107
108 auto varg_name = specialized_graph->GetVariableArgName();
109 // For python variable argument input, there is no upper limit.
110 for (int i = 0; i < variable_args_count; ++i) {
111 ParameterPtr para = std::make_shared<Parameter>(specialized_graph);
112 std::string param_name = varg_name + std::to_string(i);
113 para->set_name(param_name);
114 MS_EXCEPTION_IF_NULL(para->debug_info());
115 para->debug_info()->set_name(param_name);
116 var_param_tuple_nodes.push_back(para);
117 MS_EXCEPTION_IF_NULL(specialized_parameter_list);
118 specialized_parameter_list->push_back(para);
119 }
120 auto var_tuple_param = specialized_graph->NewCNode(std::move(var_param_tuple_nodes));
121 MS_EXCEPTION_IF_NULL(repl_nodes);
122 (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
123 }
124
GenerateKwParams(const FuncGraphPtr & specialized_graph,const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list,int pos_args_input_count,AnfNodePtrList * specialized_parameter_list,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const125 void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
126 const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
127 int pos_args_input_count, AnfNodePtrList *specialized_parameter_list,
128 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
129 MS_EXCEPTION_IF_NULL(specialized_parameter_list);
130 MS_EXCEPTION_IF_NULL(repl_nodes);
131 MS_EXCEPTION_IF_NULL(specialized_graph);
132 AnfNodePtrList kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
133 AnfNodePtrList kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
134
135 std::set<AnfNodePtr> kwarg_nodes;
136 for (size_t i = 0; i < kwarg_list.size(); ++i) {
137 auto kwarg = kwarg_list[i];
138 MS_EXCEPTION_IF_NULL(kwarg);
139 std::string kw_param_name = kwarg->get_key();
140 AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
141 // If not find corresponding parameter node.
142 if (param_node == nullptr) {
143 if (!has_kwarg()) {
144 MS_LOG(DEBUG) << "Not found parameter by name '" << kw_param_name << "'";
145 if (IntToSize(pos_args_input_count) + i < specialized_graph->parameters().size()) {
146 auto kw_param = dyn_cast<Parameter>(specialized_graph->parameters()[IntToSize(pos_args_input_count) + i]);
147 if (kw_param != nullptr && (specialized_graph->has_flag(FUNC_GRAPH_FLAG_ARGS_NO_EXPAND) ||
148 kw_param->name() == "kwargs[" + kw_param_name + "]")) {
149 specialized_parameter_list->push_back(kw_param);
150 continue;
151 }
152 }
153 MS_LOG(EXCEPTION) << "Got an unexpected keyword argument '" << kw_param_name << "'";
154 } else {
155 ParameterPtr para = std::make_shared<Parameter>(specialized_graph);
156 std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
157 auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
158 [param_name](const AnfNodePtr &node) {
159 MS_EXCEPTION_IF_NULL(node);
160 auto param = node->cast_ptr<Parameter>();
161 return param != nullptr && param->name() == param_name;
162 });
163 if (find_kw_arg_in_list) {
164 MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name;
165 }
166 para->set_name(param_name);
167 MS_EXCEPTION_IF_NULL(para->debug_info());
168 para->debug_info()->set_name(param_name);
169 kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name));
170 auto extract_node =
171 specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), para});
172 kwarg_values_tuple_nodes.push_back(extract_node);
173 specialized_parameter_list->push_back(para);
174 }
175 } else {
176 auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
177 // Multiply values found given for parameter.
178 if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) {
179 MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
180 } else {
181 specialized_parameter_list->push_back(param_node);
182 auto extract_node = specialized_graph->NewCNode(
183 {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
184 kwarg_nodes.insert(param_node);
185 (void)repl_nodes->emplace(param_node, extract_node);
186 }
187 }
188 }
189
190 GenerateKwargReplNode(specialized_graph, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes, repl_nodes);
191 }
192
GenerateKwargReplNode(const FuncGraphPtr & specialized_graph,const AnfNodePtrList & kwarg_keys_tuple_nodes,const AnfNodePtrList & kwarg_values_tuple_nodes,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const193 void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
194 const AnfNodePtrList &kwarg_keys_tuple_nodes,
195 const AnfNodePtrList &kwarg_values_tuple_nodes,
196 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
197 if (has_kwarg() && !kwarg_keys_tuple_nodes.empty()) {
198 MS_EXCEPTION_IF_NULL(specialized_graph);
199 TraceGuard guard(
200 std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
201 auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
202 auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
203 auto make_dict_node =
204 specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
205 MS_EXCEPTION_IF_NULL(repl_nodes);
206 (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
207 }
208 }
209
NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list)210 bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
211 // If the function does not have any vararg/kwarg/kwonly/default value/kw args input
212 // return the original graph
213 if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
214 return false;
215 }
216
217 // If the graph is generated for specific input, do not need to generate again
218 return !is_generated();
219 }
220
GenerateDefaultValue(const FuncGraphPtr & specialized_graph,const AnfNodePtrList & specialized_parameter_list,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const221 void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
222 const AnfNodePtrList &specialized_parameter_list,
223 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
224 MS_EXCEPTION_IF_NULL(specialized_graph);
225 for (size_t i = 0; i < specialized_graph->parameters().size() - fv_param_count(); ++i) {
226 MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
227 auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
228 MS_EXCEPTION_IF_NULL(param_node);
229 auto param_name = param_node->name();
230 auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node);
231 if (node_itr != specialized_parameter_list.end()) {
232 continue;
233 }
234 if (param_name == specialized_graph->GetVariableArgName() ||
235 param_name == specialized_graph->GetVariableKwargName()) {
236 continue;
237 }
238 auto default_value = specialized_graph->GetDefaultValueByName(param_name);
239 if (default_value == nullptr) {
240 MS_LOG(INTERNAL_EXCEPTION) << "Miss argument input for parameter:" << param_name;
241 }
242 MS_EXCEPTION_IF_NULL(repl_nodes);
243 (void)repl_nodes->emplace(param_node, default_value);
244 }
245 }
246
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)247 FuncGraphPtr FuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
248 if (has_attr(FUNC_GRAPH_FLAG_PROXY_GRAPH)) {
249 auto original_params_size = parameters().size();
250 auto args_size = args_abs_list.size();
251 if (args_size == original_params_size) {
252 MS_LOG(DEBUG) << "proxy function graph: " << ToString();
253 return shared_from_base<FuncGraph>();
254 } else if (args_size < original_params_size) {
255 auto new_params = parameters();
256 new_params.resize(args_size);
257 auto call_node = output()->cast<CNodePtr>();
258 MS_EXCEPTION_IF_NULL(call_node);
259 auto new_inputs = call_node->inputs();
260 new_inputs.resize(new_inputs.size() + args_size - original_params_size);
261
262 set_parameters(new_params);
263 auto new_out = NewCNodeInOrder(new_inputs);
264 set_output(new_out);
265 MS_LOG(INFO) << "The proxy truncates the parameters to match the size. fg: " << ToString()
266 << ", original args: " << original_params_size << ", call args: " << args_size
267 << ", new args: " << parameters().size() << ", call inputs: " << new_out->inputs().size();
268 return shared_from_base<FuncGraph>();
269 }
270 MS_LOG(WARNING) << "The number of parameter is wrong. The number of the construct function's parameter is "
271 << original_params_size << ", but the number of call parameter is " << args_size
272 << ". graph:" << ToString();
273 return shared_from_base<FuncGraph>();
274 }
275
276 std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
277 std::vector<size_t> pos_arg_indexes;
278 size_t arguments_count = args_abs_list.size();
279 if (fv_param_count_ > arguments_count) {
280 MS_LOG(INTERNAL_EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
281 }
282 for (size_t i = 0; i < arguments_count - fv_param_count_; i++) {
283 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
284 if (args_abs_list[i]->isa<abstract::AbstractKeywordArg>()) {
285 kwarg_list.push_back(args_abs_list[i]->cast<abstract::AbstractKeywordArgPtr>());
286 } else {
287 pos_arg_indexes.push_back(i);
288 }
289 }
290
291 if (!NeedGenerate(kwarg_list)) {
292 MS_LOG(DEBUG) << "No need generate, " << ToString();
293 return shared_from_base<FuncGraph>();
294 }
295 auto iter = func_graph_cache_.find(args_abs_list);
296 if (iter != func_graph_cache_.end()) {
297 MS_EXCEPTION_IF_NULL(iter->second);
298 MS_LOG(DEBUG) << "Found in cache, " << iter->second->ToString() << ", for " << ToString();
299 return iter->second;
300 }
301 FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
302 size_t kwarg_count = kwarg_list.size();
303 // Get the variable args count from caller.
304 int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - fv_param_count_);
305 int variable_args_count = pos_args_input_count - GetPositionalArgsCount();
306 AnfNodePtrList specialized_parameter_list;
307 mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl_nodes;
308 MS_LOG(DEBUG) << "specialized_graph: " << specialized_graph->ToString()
309 << ", variable_args_count: " << variable_args_count
310 << ", pos_args_input_count: " << pos_args_input_count
311 << ", GetPositionalArgsCount: " << GetPositionalArgsCount() << ", arguments_count: " << arguments_count
312 << ", kwarg_count: " << kwarg_count << ", fv_param_count_: " << fv_param_count_;
313 GenerateVarParams(specialized_graph, variable_args_count, pos_args_input_count, &specialized_parameter_list,
314 &repl_nodes);
315 GenerateKwParams(specialized_graph, kwarg_list, pos_args_input_count, &specialized_parameter_list, &repl_nodes);
316
317 GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
318
319 // Append hyper parameter to specialized_parameter_list
320 MS_EXCEPTION_IF_NULL(specialized_graph);
321 auto params = specialized_graph->parameters();
322 (void)specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(fv_param_count_),
323 params.end());
324 AnfNodePtrList specialized_parameter_list_update(
325 specialized_parameter_list.begin() + SizeToLong(pos_arg_indexes.size()), specialized_parameter_list.end());
326 for (size_t i = 0; i < pos_arg_indexes.size(); i++) {
327 (void)specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i],
328 specialized_parameter_list[i]);
329 }
330
331 std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
332 auto tr = manager->Transact();
333 for (auto &node_pair : repl_nodes) {
334 MS_EXCEPTION_IF_NULL(node_pair.first);
335 MS_EXCEPTION_IF_NULL(node_pair.second);
336 MS_LOG(DEBUG) << "GenerateFuncGraph replace:" << node_pair.first->DebugString() << "-"
337 << node_pair.second->DebugString();
338 (void)tr.Replace(node_pair.first, node_pair.second);
339 }
340 tr.SetParameters(specialized_graph, specialized_parameter_list_update);
341 tr.Commit();
342 specialized_graph->set_has_kwarg(false);
343 specialized_graph->set_has_vararg(false);
344 specialized_graph->set_kwonlyargs_count(0);
345 specialized_graph->ClearDefaultValues();
346 specialized_graph->set_is_generate(true);
347 func_graph_cache_[args_abs_list] = specialized_graph;
348 MS_LOG(DEBUG) << "Generated, " << specialized_graph->ToString() << ", for " << ToString();
349 return specialized_graph;
350 }
351
FindRoots(const std::vector<CNodePtr> & segment)352 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
353 std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
354 for (const auto &node : segment) {
355 if (roots->size() == 1) {
356 return roots;
357 }
358 auto input_size = node->size();
359 for (size_t i = 0; i < input_size; i++) {
360 auto in_node = node->input(i);
361 auto in_cnode = in_node->cast<CNodePtr>();
362 if (in_cnode != nullptr) {
363 (void)roots->erase(in_cnode);
364 }
365 }
366 }
367 return roots;
368 }
369
FindLeaves(const std::vector<CNodePtr> & segment)370 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
371 std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
372 for (const auto &node : segment) {
373 if (nodes->size() == 1) {
374 return nodes;
375 }
376 if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
377 (void)nodes->erase(node);
378 continue;
379 }
380 auto input_size = node->size();
381 for (size_t i = 0; i < input_size; i++) {
382 auto in_node = node->input(i);
383 if (!in_node->isa<CNode>()) {
384 continue;
385 }
386 auto in_cnode = in_node->cast<CNodePtr>();
387 if (in_cnode != nullptr) {
388 if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) {
389 (void)nodes->erase(node);
390 break;
391 }
392 }
393 }
394 }
395 return nodes;
396 }
397 } // namespace mindspore
398