• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/func_graph.h"
20 
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "ir/manager.h"
24 #include "utils/ordered_set.h"
25 #include "abstract/abstract_value.h"
26 #include "abstract/abstract_function.h"
27 #include "ir/func_graph_cloner.h"
28 
29 namespace mindspore {
30 using mindspore::abstract::AbstractFunction;
31 using mindspore::abstract::AbstractFunctionPtr;
32 using mindspore::abstract::AnalysisContextPtr;
33 using mindspore::abstract::PrimitiveAbstractClosure;
34 using mindspore::abstract::VirtualAbstractClosure;
35 
abstract()36 AbstractFunctionPtr FuncGraph::abstract() {
37   AbstractBasePtrList args_abs_list;
38 
39   for (auto &para : parameters_) {
40     MS_EXCEPTION_IF_NULL(para);
41     if (para->abstract() == nullptr) {
42       MS_LOG(ERROR) << "Error!!";
43       return nullptr;
44     }
45     args_abs_list.push_back(para->abstract());
46   }
47 
48   if (output() == nullptr) {
49     MS_LOG(ERROR) << "Error func graph no output";
50     return nullptr;
51   }
52   MS_EXCEPTION_IF_NULL(output());
53   return std::make_shared<VirtualAbstractClosure>(args_abs_list, output()->abstract());
54 }
55 
set_output(const AnfNodePtr & value,bool force_new_ret)56 void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
57   MS_EXCEPTION_IF_NULL(value);
58   if (force_new_ret || return_node() == nullptr) {
59     AnfNodePtrList params({NewValueNode(prim::kPrimReturn), value});
60     FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
61     set_return(this_graph->NewCNodeInOrder(std::move(params)));
62   } else {
63     if (manager_.lock()) {
64       manager_.lock()->SetEdge(return_node(), 1, value);
65     } else {
66       constexpr auto first_data_index = 1;
67       return_node()->set_input(first_data_index, value);
68     }
69   }
70 
71   return_node()->set_abstract(value->abstract());
72   AnfNodePtr input0 = return_node()->input(0);
73   auto f = std::make_shared<PrimitiveAbstractClosure>(prim::kPrimReturn, input0);
74   input0->set_abstract(f);
75 }
76 
GenerateVarParams(const FuncGraphPtr & specialized_graph,int variable_args_count,int pos_args_input_count,AnfNodePtrList * specialized_parameter_list,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const77 void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count,
78                                   int pos_args_input_count, AnfNodePtrList *specialized_parameter_list,
79                                   mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
80   MS_EXCEPTION_IF_NULL(specialized_graph);
81   if (!specialized_graph->has_vararg()) {
82     if (variable_args_count > 0) {
83       MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << GetPositionalArgsCount()
84                         << " positional arguments, but " << pos_args_input_count << " were given.";
85     }
86     // Only copy parameters other than default arguments.
87     for (size_t i = 0; i < IntToSize(pos_args_input_count); ++i) {
88       specialized_parameter_list->push_back(specialized_graph->parameters()[i]);
89     }
90     return;
91   }
92 
93   // If there is variable argument.
94   if (variable_args_count < 0) {
95     MS_LOG(EXCEPTION) << "For function:" << this->ToString() << ", its argument size: " << pos_args_input_count
96                       << " is less or equal to parameter size: " << GetPositionalArgsCount();
97   }
98   // Copy other parameters than vararg's firstly.
99   for (size_t i = 0; i < IntToSize(GetPositionalArgsCount()); ++i) {
100     specialized_parameter_list->push_back(specialized_graph->parameters()[i]);
101   }
102   MS_EXCEPTION_IF_NULL(specialized_graph->GetVariableArgParameter());
103   TraceGuard trace_guard(
104     std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
105   AnfNodePtrList var_param_tuple_nodes;
106   var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
107 
108   auto varg_name = specialized_graph->GetVariableArgName();
109   // For python variable argument input, there is no upper limit.
110   for (int i = 0; i < variable_args_count; ++i) {
111     ParameterPtr para = std::make_shared<Parameter>(specialized_graph);
112     std::string param_name = varg_name + std::to_string(i);
113     para->set_name(param_name);
114     MS_EXCEPTION_IF_NULL(para->debug_info());
115     para->debug_info()->set_name(param_name);
116     var_param_tuple_nodes.push_back(para);
117     MS_EXCEPTION_IF_NULL(specialized_parameter_list);
118     specialized_parameter_list->push_back(para);
119   }
120   auto var_tuple_param = specialized_graph->NewCNode(std::move(var_param_tuple_nodes));
121   MS_EXCEPTION_IF_NULL(repl_nodes);
122   (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
123 }
124 
GenerateKwParams(const FuncGraphPtr & specialized_graph,const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list,int pos_args_input_count,AnfNodePtrList * specialized_parameter_list,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const125 void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
126                                  const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
127                                  int pos_args_input_count, AnfNodePtrList *specialized_parameter_list,
128                                  mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
129   MS_EXCEPTION_IF_NULL(specialized_parameter_list);
130   MS_EXCEPTION_IF_NULL(repl_nodes);
131   MS_EXCEPTION_IF_NULL(specialized_graph);
132   AnfNodePtrList kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
133   AnfNodePtrList kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
134 
135   std::set<AnfNodePtr> kwarg_nodes;
136   for (size_t i = 0; i < kwarg_list.size(); ++i) {
137     auto kwarg = kwarg_list[i];
138     MS_EXCEPTION_IF_NULL(kwarg);
139     std::string kw_param_name = kwarg->get_key();
140     AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
141     // If not find corresponding parameter node.
142     if (param_node == nullptr) {
143       if (!has_kwarg()) {
144         MS_LOG(DEBUG) << "Not found parameter by name '" << kw_param_name << "'";
145         if (IntToSize(pos_args_input_count) + i < specialized_graph->parameters().size()) {
146           auto kw_param = dyn_cast<Parameter>(specialized_graph->parameters()[IntToSize(pos_args_input_count) + i]);
147           if (kw_param != nullptr && (specialized_graph->has_flag(FUNC_GRAPH_FLAG_ARGS_NO_EXPAND) ||
148                                       kw_param->name() == "kwargs[" + kw_param_name + "]")) {
149             specialized_parameter_list->push_back(kw_param);
150             continue;
151           }
152         }
153         MS_LOG(EXCEPTION) << "Got an unexpected keyword argument '" << kw_param_name << "'";
154       } else {
155         ParameterPtr para = std::make_shared<Parameter>(specialized_graph);
156         std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
157         auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
158                                                [param_name](const AnfNodePtr &node) {
159                                                  MS_EXCEPTION_IF_NULL(node);
160                                                  auto param = node->cast_ptr<Parameter>();
161                                                  return param != nullptr && param->name() == param_name;
162                                                });
163         if (find_kw_arg_in_list) {
164           MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name;
165         }
166         para->set_name(param_name);
167         MS_EXCEPTION_IF_NULL(para->debug_info());
168         para->debug_info()->set_name(param_name);
169         kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name));
170         auto extract_node =
171           specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), para});
172         kwarg_values_tuple_nodes.push_back(extract_node);
173         specialized_parameter_list->push_back(para);
174       }
175     } else {
176       auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
177       // Multiply values found given for parameter.
178       if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) {
179         MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
180       } else {
181         specialized_parameter_list->push_back(param_node);
182         auto extract_node = specialized_graph->NewCNode(
183           {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
184         kwarg_nodes.insert(param_node);
185         (void)repl_nodes->emplace(param_node, extract_node);
186       }
187     }
188   }
189 
190   GenerateKwargReplNode(specialized_graph, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes, repl_nodes);
191 }
192 
GenerateKwargReplNode(const FuncGraphPtr & specialized_graph,const AnfNodePtrList & kwarg_keys_tuple_nodes,const AnfNodePtrList & kwarg_values_tuple_nodes,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const193 void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
194                                       const AnfNodePtrList &kwarg_keys_tuple_nodes,
195                                       const AnfNodePtrList &kwarg_values_tuple_nodes,
196                                       mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
197   if (has_kwarg() && !kwarg_keys_tuple_nodes.empty()) {
198     MS_EXCEPTION_IF_NULL(specialized_graph);
199     TraceGuard guard(
200       std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
201     auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
202     auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
203     auto make_dict_node =
204       specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
205     MS_EXCEPTION_IF_NULL(repl_nodes);
206     (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
207   }
208 }
209 
NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list)210 bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
211   // If the function does not have any vararg/kwarg/kwonly/default value/kw args input
212   // return the original graph
213   if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
214     return false;
215   }
216 
217   // If the graph is generated for specific input, do not need to generate again
218   return !is_generated();
219 }
220 
GenerateDefaultValue(const FuncGraphPtr & specialized_graph,const AnfNodePtrList & specialized_parameter_list,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl_nodes) const221 void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
222                                      const AnfNodePtrList &specialized_parameter_list,
223                                      mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
224   MS_EXCEPTION_IF_NULL(specialized_graph);
225   for (size_t i = 0; i < specialized_graph->parameters().size() - fv_param_count(); ++i) {
226     MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
227     auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
228     MS_EXCEPTION_IF_NULL(param_node);
229     auto param_name = param_node->name();
230     auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node);
231     if (node_itr != specialized_parameter_list.end()) {
232       continue;
233     }
234     if (param_name == specialized_graph->GetVariableArgName() ||
235         param_name == specialized_graph->GetVariableKwargName()) {
236       continue;
237     }
238     auto default_value = specialized_graph->GetDefaultValueByName(param_name);
239     if (default_value == nullptr) {
240       MS_LOG(INTERNAL_EXCEPTION) << "Miss argument input for parameter:" << param_name;
241     }
242     MS_EXCEPTION_IF_NULL(repl_nodes);
243     (void)repl_nodes->emplace(param_node, default_value);
244   }
245 }
246 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)247 FuncGraphPtr FuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
248   if (has_attr(FUNC_GRAPH_FLAG_PROXY_GRAPH)) {
249     auto original_params_size = parameters().size();
250     auto args_size = args_abs_list.size();
251     if (args_size == original_params_size) {
252       MS_LOG(DEBUG) << "proxy function graph: " << ToString();
253       return shared_from_base<FuncGraph>();
254     } else if (args_size < original_params_size) {
255       auto new_params = parameters();
256       new_params.resize(args_size);
257       auto call_node = output()->cast<CNodePtr>();
258       MS_EXCEPTION_IF_NULL(call_node);
259       auto new_inputs = call_node->inputs();
260       new_inputs.resize(new_inputs.size() + args_size - original_params_size);
261 
262       set_parameters(new_params);
263       auto new_out = NewCNodeInOrder(new_inputs);
264       set_output(new_out);
265       MS_LOG(INFO) << "The proxy truncates the parameters to match the size. fg: " << ToString()
266                    << ", original args: " << original_params_size << ", call args: " << args_size
267                    << ", new args: " << parameters().size() << ", call inputs: " << new_out->inputs().size();
268       return shared_from_base<FuncGraph>();
269     }
270     MS_LOG(WARNING) << "The number of parameter is wrong. The number of the construct function's parameter is "
271                     << original_params_size << ", but the number of call parameter is " << args_size
272                     << ". graph:" << ToString();
273     return shared_from_base<FuncGraph>();
274   }
275 
276   std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
277   std::vector<size_t> pos_arg_indexes;
278   size_t arguments_count = args_abs_list.size();
279   if (fv_param_count_ > arguments_count) {
280     MS_LOG(INTERNAL_EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
281   }
282   for (size_t i = 0; i < arguments_count - fv_param_count_; i++) {
283     MS_EXCEPTION_IF_NULL(args_abs_list[i]);
284     if (args_abs_list[i]->isa<abstract::AbstractKeywordArg>()) {
285       kwarg_list.push_back(args_abs_list[i]->cast<abstract::AbstractKeywordArgPtr>());
286     } else {
287       pos_arg_indexes.push_back(i);
288     }
289   }
290 
291   if (!NeedGenerate(kwarg_list)) {
292     MS_LOG(DEBUG) << "No need generate, " << ToString();
293     return shared_from_base<FuncGraph>();
294   }
295   auto iter = func_graph_cache_.find(args_abs_list);
296   if (iter != func_graph_cache_.end()) {
297     MS_EXCEPTION_IF_NULL(iter->second);
298     MS_LOG(DEBUG) << "Found in cache, " << iter->second->ToString() << ", for " << ToString();
299     return iter->second;
300   }
301   FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
302   size_t kwarg_count = kwarg_list.size();
303   // Get the variable args count from caller.
304   int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - fv_param_count_);
305   int variable_args_count = pos_args_input_count - GetPositionalArgsCount();
306   AnfNodePtrList specialized_parameter_list;
307   mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl_nodes;
308   MS_LOG(DEBUG) << "specialized_graph: " << specialized_graph->ToString()
309                 << ", variable_args_count: " << variable_args_count
310                 << ", pos_args_input_count: " << pos_args_input_count
311                 << ", GetPositionalArgsCount: " << GetPositionalArgsCount() << ", arguments_count: " << arguments_count
312                 << ", kwarg_count: " << kwarg_count << ", fv_param_count_: " << fv_param_count_;
313   GenerateVarParams(specialized_graph, variable_args_count, pos_args_input_count, &specialized_parameter_list,
314                     &repl_nodes);
315   GenerateKwParams(specialized_graph, kwarg_list, pos_args_input_count, &specialized_parameter_list, &repl_nodes);
316 
317   GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
318 
319   // Append hyper parameter to specialized_parameter_list
320   MS_EXCEPTION_IF_NULL(specialized_graph);
321   auto params = specialized_graph->parameters();
322   (void)specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(fv_param_count_),
323                                           params.end());
324   AnfNodePtrList specialized_parameter_list_update(
325     specialized_parameter_list.begin() + SizeToLong(pos_arg_indexes.size()), specialized_parameter_list.end());
326   for (size_t i = 0; i < pos_arg_indexes.size(); i++) {
327     (void)specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i],
328                                                    specialized_parameter_list[i]);
329   }
330 
331   std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
332   auto tr = manager->Transact();
333   for (auto &node_pair : repl_nodes) {
334     MS_EXCEPTION_IF_NULL(node_pair.first);
335     MS_EXCEPTION_IF_NULL(node_pair.second);
336     MS_LOG(DEBUG) << "GenerateFuncGraph replace:" << node_pair.first->DebugString() << "-"
337                   << node_pair.second->DebugString();
338     (void)tr.Replace(node_pair.first, node_pair.second);
339   }
340   tr.SetParameters(specialized_graph, specialized_parameter_list_update);
341   tr.Commit();
342   specialized_graph->set_has_kwarg(false);
343   specialized_graph->set_has_vararg(false);
344   specialized_graph->set_kwonlyargs_count(0);
345   specialized_graph->ClearDefaultValues();
346   specialized_graph->set_is_generate(true);
347   func_graph_cache_[args_abs_list] = specialized_graph;
348   MS_LOG(DEBUG) << "Generated, " << specialized_graph->ToString() << ", for " << ToString();
349   return specialized_graph;
350 }
351 
FindRoots(const std::vector<CNodePtr> & segment)352 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
353   std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
354   for (const auto &node : segment) {
355     if (roots->size() == 1) {
356       return roots;
357     }
358     auto input_size = node->size();
359     for (size_t i = 0; i < input_size; i++) {
360       auto in_node = node->input(i);
361       auto in_cnode = in_node->cast<CNodePtr>();
362       if (in_cnode != nullptr) {
363         (void)roots->erase(in_cnode);
364       }
365     }
366   }
367   return roots;
368 }
369 
FindLeaves(const std::vector<CNodePtr> & segment)370 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
371   std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
372   for (const auto &node : segment) {
373     if (nodes->size() == 1) {
374       return nodes;
375     }
376     if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
377       (void)nodes->erase(node);
378       continue;
379     }
380     auto input_size = node->size();
381     for (size_t i = 0; i < input_size; i++) {
382       auto in_node = node->input(i);
383       if (!in_node->isa<CNode>()) {
384         continue;
385       }
386       auto in_cnode = in_node->cast<CNodePtr>();
387       if (in_cnode != nullptr) {
388         if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) {
389           (void)nodes->erase(node);
390           break;
391         }
392       }
393     }
394   }
395   return nodes;
396 }
397 }  // namespace mindspore
398