• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/func_graph.h"
20 
21 #include <algorithm>
22 #include <sstream>
23 
24 #include "ir/manager.h"
25 #include "base/core_ops.h"
26 #include "utils/ordered_set.h"
27 #include "abstract/abstract_value.h"
28 #include "abstract/abstract_function.h"
29 #include "utils/flags.h"
30 
31 namespace mindspore {
32 using mindspore::abstract::AbstractFunction;
33 using mindspore::abstract::AbstractFunctionPtr;
34 using mindspore::abstract::AnalysisContextPtr;
35 using mindspore::abstract::PrimitiveAbstractClosure;
36 using mindspore::abstract::VirtualAbstractClosure;
37 
abstract()38 AbstractFunctionPtr FuncGraph::abstract() {
39   AbstractBasePtrList args_spec_list;
40 
41   for (auto &p : parameters_) {
42     MS_EXCEPTION_IF_NULL(p);
43     if (p->abstract() == nullptr) {
44       MS_LOG(ERROR) << "Error!!";
45       return nullptr;
46     }
47     args_spec_list.push_back(p->abstract());
48   }
49 
50   if (output() == nullptr) {
51     MS_LOG(ERROR) << "Error func graph no output";
52     return nullptr;
53   }
54 
55   return std::make_shared<VirtualAbstractClosure>(args_spec_list, output()->abstract());
56 }
57 
set_output(const AnfNodePtr & value,bool force_new_ret)58 void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) {
59   MS_EXCEPTION_IF_NULL(value);
60   if (force_new_ret || return_ == nullptr) {
61     std::vector<AnfNodePtr> params({NewValueNode(prim::kPrimReturn), value});
62     FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
63     return_ = this_graph->NewCNodeInOrder(params);
64   } else {
65     if (manager_.lock()) {
66       manager_.lock()->SetEdge(return_, 1, value);
67     } else {
68       constexpr auto first_data_index = 1;
69       return_->set_input(first_data_index, value);
70     }
71   }
72 
73   return_->set_abstract(value->abstract());
74   AnfNodePtr input0 = return_->input(0);
75   auto f = std::make_shared<PrimitiveAbstractClosure>(prim::kPrimReturn, input0);
76   input0->set_abstract(f);
77 }
78 
DumpFuncGraph(const std::string & path)79 void FuncGraph::DumpFuncGraph(const std::string &path) {
80   if (drawer_) {
81     drawer_(path + ".dot", shared_from_base<FuncGraph>());
82   }
83 }
84 
GenerateVarParams(const FuncGraphPtr & specialized_graph,int variable_args_count,int pos_args_input_count,std::vector<AnfNodePtr> * specialized_parameter_list,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const85 void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count,
86                                   int pos_args_input_count, std::vector<AnfNodePtr> *specialized_parameter_list,
87                                   std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
88   // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple
89   MS_EXCEPTION_IF_NULL(specialized_graph);
90   if (specialized_graph->has_vararg()) {
91     TraceGuard trace_guard(
92       std::make_shared<TraceGenerateVarArg>(specialized_graph->GetVariableArgParameter()->debug_info()));
93     std::vector<AnfNodePtr> var_param_tuple_nodes;
94     var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple));
95 
96     if (variable_args_count < 0) {
97       MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count
98                         << " were given.";
99     }
100     auto varg_name = specialized_graph->GetVariableArgName();
101     // for python variable argument input , there is no upper limit
102     for (int i = 0; i < variable_args_count; ++i) {
103       ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
104       std::string param_name = varg_name + std::to_string(i);
105       p->set_name(param_name);
106       MS_EXCEPTION_IF_NULL(p->debug_info());
107       p->debug_info()->set_name(param_name);
108       var_param_tuple_nodes.push_back(p);
109       MS_EXCEPTION_IF_NULL(specialized_parameter_list);
110       specialized_parameter_list->push_back(p);
111     }
112     auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes);
113     (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param);
114   } else if (variable_args_count > 0) {
115     MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount()
116                       << " positional arguments, but " << pos_args_input_count << " were given.";
117   }
118 }
119 
GenerateKwParams(const FuncGraphPtr & specialized_graph,const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list,std::vector<AnfNodePtr> * specialized_parameter_list,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const120 void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
121                                  const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
122                                  std::vector<AnfNodePtr> *specialized_parameter_list,
123                                  std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
124   std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
125   std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
126 
127   std::set<AnfNodePtr> kwarg_nodes;
128   for (const auto &kwarg : kwarg_list) {
129     MS_EXCEPTION_IF_NULL(kwarg);
130     std::string kw_param_name = kwarg->get_key();
131     MS_EXCEPTION_IF_NULL(specialized_graph);
132     AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
133     // if not find corresponding parameter node
134     if (param_node == nullptr) {
135       if (!has_kwarg()) {
136         MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
137       } else {
138         ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
139         std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
140         MS_EXCEPTION_IF_NULL(specialized_parameter_list);
141         auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
142                                                [param_name](const AnfNodePtr &node) {
143                                                  MS_EXCEPTION_IF_NULL(node);
144                                                  auto param = node->cast<ParameterPtr>();
145                                                  return param != nullptr && param->name() == param_name;
146                                                });
147         if (find_kw_arg_in_list) {
148           MS_EXCEPTION(TypeError) << "Multiply values for keyword argument: " << kw_param_name;
149         }
150         p->set_name(param_name);
151         p->debug_info()->set_name(param_name);
152         kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name));
153         auto extract_node =
154           specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p});
155         kwarg_values_tuple_nodes.push_back(extract_node);
156         specialized_parameter_list->push_back(p);
157       }
158     } else {
159       auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node);
160       // multiply values found given for parameter
161       if (node_itr != specialized_parameter_list->end() && kwarg_nodes.find(param_node) == kwarg_nodes.end()) {
162         MS_EXCEPTION(TypeError) << "Multiply values for specific argument: " << kw_param_name;
163       } else {
164         specialized_parameter_list->push_back(param_node);
165         auto extract_node = specialized_graph->NewCNode(
166           {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
167         kwarg_nodes.insert(param_node);
168         (void)repl_nodes->emplace(param_node, extract_node);
169       }
170     }
171   }
172 
173   GenerateKwargReplNode(specialized_graph, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes, repl_nodes);
174 }
175 
GenerateKwargReplNode(const FuncGraphPtr & specialized_graph,const std::vector<AnfNodePtr> & kwarg_keys_tuple_nodes,const std::vector<AnfNodePtr> & kwarg_values_tuple_nodes,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const176 void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
177                                       const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
178                                       const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes,
179                                       std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
180   if (has_kwarg()) {
181     MS_EXCEPTION_IF_NULL(specialized_graph);
182     TraceGuard guard(
183       std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
184     auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes);
185     auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes);
186     auto make_dict_node =
187       specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values});
188     MS_EXCEPTION_IF_NULL(repl_nodes);
189     (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node);
190   }
191 }
192 
NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> & kwarg_list)193 bool FuncGraph::NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list) {
194   // if the function does not have any vararg/kwarg/kwonly/default value/kw args input
195   // return the original graph
196   if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) {
197     return false;
198   }
199 
200   // if the graph is generated for specific input, do not need to generate again
201   return !is_generated();
202 }
203 
GenerateDefaultValue(const FuncGraphPtr & specialized_graph,const std::vector<AnfNodePtr> & specialized_parameter_list,std::unordered_map<AnfNodePtr,AnfNodePtr> * repl_nodes) const204 void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
205                                      const std::vector<AnfNodePtr> &specialized_parameter_list,
206                                      std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
207   MS_EXCEPTION_IF_NULL(specialized_graph);
208   for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
209     MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
210     auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
211     MS_EXCEPTION_IF_NULL(param_node);
212     auto param_name = param_node->name();
213     auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node);
214     if (node_itr != specialized_parameter_list.end()) {
215       continue;
216     }
217     if (param_name == specialized_graph->GetVariableArgName() ||
218         param_name == specialized_graph->GetVariableKwargName()) {
219       continue;
220     }
221     auto default_value = specialized_graph->GetDefaultValueByName(param_name);
222     if (default_value == nullptr) {
223       MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name;
224     }
225     MS_EXCEPTION_IF_NULL(repl_nodes);
226     (void)repl_nodes->emplace(param_node, default_value);
227   }
228 }
229 
GenerateGraph(const AbstractBasePtrList & args_spec_list)230 FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) {
231   std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
232   std::vector<size_t> pos_arg_indexes;
233   size_t arguments_count = args_spec_list.size();
234   if (hyper_param_count_ > arguments_count) {
235     MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
236   }
237   for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
238     MS_EXCEPTION_IF_NULL(args_spec_list[i]);
239     if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
240       kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
241     } else {
242       pos_arg_indexes.push_back(i);
243     }
244   }
245 
246   if (!NeedGenerate(kwarg_list)) {
247     return shared_from_base<FuncGraph>();
248   }
249   auto iter = func_graph_cache_.find(args_spec_list);
250   if (iter != func_graph_cache_.end()) {
251     return iter->second;
252   }
253   FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
254   size_t kwarg_count = kwarg_list.size();
255   int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count_);
256   int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
257   int variable_args_count = pos_args_input_count - pos_args_count;
258   std::vector<AnfNodePtr> specialized_parameter_list;
259   std::unordered_map<AnfNodePtr, AnfNodePtr> repl_nodes;
260   // the parameters that has arg input, copy from original parameters
261   for (size_t i = 0; i < IntToSize(pos_args_count); ++i) {
262     specialized_parameter_list.push_back(specialized_graph->parameters()[i]);
263   }
264 
265   GenerateVarParams(specialized_graph, variable_args_count, pos_args_input_count, &specialized_parameter_list,
266                     &repl_nodes);
267 
268   GenerateKwParams(specialized_graph, kwarg_list, &specialized_parameter_list, &repl_nodes);
269 
270   GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
271 
272   // append hyper parameter to specialized_parameter_list
273   MS_EXCEPTION_IF_NULL(specialized_graph);
274   auto params = specialized_graph->parameters();
275   specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_),
276                                     params.end());
277   std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
278                                                             specialized_parameter_list.end());
279   for (size_t i = 0; i < pos_arg_indexes.size(); i++) {
280     specialized_parameter_list_update.insert(specialized_parameter_list_update.begin() + pos_arg_indexes[i],
281                                              specialized_parameter_list[i]);
282   }
283 
284   std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(specialized_graph, false);
285   auto tr = manager->Transact();
286   for (auto &node_pair : repl_nodes) {
287     MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-"
288                   << node_pair.second->DebugString();
289     (void)tr.Replace(node_pair.first, node_pair.second);
290   }
291   tr.SetParameters(specialized_graph, specialized_parameter_list_update);
292   tr.Commit();
293   specialized_graph->set_has_kwarg(false);
294   specialized_graph->set_has_vararg(false);
295   specialized_graph->set_kwonlyargs_count(0);
296   specialized_graph->ClearDefaultValues();
297   specialized_graph->set_is_generate(true);
298   func_graph_cache_[args_spec_list] = specialized_graph;
299   return specialized_graph;
300 }
301 
FindRoots(const std::vector<CNodePtr> & segment)302 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment) {
303   std::shared_ptr<OrderedSet<CNodePtr>> roots = std::make_shared<OrderedSet<CNodePtr>>(segment);
304   for (const auto &node : segment) {
305     if (roots->size() == 1) {
306       return roots;
307     }
308     auto input_size = node->size();
309     for (size_t i = 0; i < input_size; i++) {
310       auto in_node = node->input(i);
311       auto in_cnode = in_node->cast<CNodePtr>();
312       if (in_cnode != nullptr) {
313         (void)roots->erase(in_cnode);
314       }
315     }
316   }
317   return roots;
318 }
319 
FindLeaves(const std::vector<CNodePtr> & segment)320 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment) {
321   std::shared_ptr<OrderedSet<CNodePtr>> nodes = std::make_shared<OrderedSet<CNodePtr>>(segment);
322   for (const auto &node : segment) {
323     if (nodes->size() == 1) {
324       return nodes;
325     }
326     if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
327       (void)nodes->erase(node);
328       continue;
329     }
330     auto input_size = node->size();
331     for (size_t i = 0; i < input_size; i++) {
332       auto in_node = node->input(i);
333       if (!in_node->isa<CNode>()) {
334         continue;
335       }
336       auto in_cnode = in_node->cast<CNodePtr>();
337       if (in_cnode != nullptr) {
338         if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) {
339           (void)nodes->erase(node);
340           break;
341         }
342       }
343     }
344   }
345   return nodes;
346 }
347 }  // namespace mindspore
348