• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/func_graph.h"
20 
21 #include <algorithm>
22 #include <sstream>
23 #include <utility>
24 
25 #include "utils/trace_base.h"
26 #include "ir/manager.h"
27 #include "utils/flags.h"
28 #include "utils/ordered_set.h"
29 #include "utils/convert_utils_base.h"
30 #include "abstract/abstract_function.h"
31 
32 namespace mindspore {
33 /*
34  * Methods of Graph
35  */
FuncGraph()36 FuncGraph::FuncGraph()
37     : attrs_(),
38       transforms_(),
39       parameter_default_value_(),
40       seen_(0),
41       parameters_(),
42       has_vararg_(false),
43       has_kwarg_(false),
44       kwonlyargs_count_(0),
45       hyper_param_count_(0),
46       is_generated_(false),
47       is_bprop_(false),
48       return_(nullptr),
49       manager_(std::weak_ptr<FuncGraphManager>()),
50       stub_(false),
51       stage_(-1) {
52   debug_info_ = std::make_shared<GraphDebugInfo>();
53   switch_input_ = std::make_shared<bool>(false);
54   switch_layer_input_ = std::make_shared<bool>(false);
55 }
56 
ToAbstract()57 abstract::AbstractBasePtr FuncGraph::ToAbstract() {
58   auto temp_context = abstract::AnalysisContext::DummyContext();
59   return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
60 }
61 
output() const62 AnfNodePtr FuncGraph::output() const {
63   constexpr size_t return_input_num = 2;
64   // If return value is set, return should have two inputs.
65   if (return_ != nullptr && return_->inputs().size() == return_input_num) {
66     return return_->input(1);
67   } else {
68     // If not set yet, return nullptr.
69     return nullptr;
70   }
71 }
72 
get_inputs() const73 const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
74   std::vector<AnfNodePtr> input_params;
75   for (auto const &node : parameters_) {
76     MS_EXCEPTION_IF_NULL(node);
77     auto parameter = dyn_cast<Parameter>(node);
78     MS_EXCEPTION_IF_NULL(parameter);
79     if (!parameter->has_default()) {
80       input_params.push_back(parameter);
81     }
82   }
83   return input_params;
84 }
85 
add_parameter()86 ParameterPtr FuncGraph::add_parameter() {
87   FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
88   ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
89   add_parameter(p);
90   return p;
91 }
92 
add_parameter(const ParameterPtr & p)93 void FuncGraph::add_parameter(const ParameterPtr &p) {
94   if (manager_.lock()) {
95     manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
96   } else {
97     parameters_.push_back(p);
98   }
99 }
100 
InsertFrontParameter()101 ParameterPtr FuncGraph::InsertFrontParameter() {
102   FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
103   ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
104   InsertFrontParameter(p);
105   return p;
106 }
107 
InsertFrontParameter(const ParameterPtr & p)108 void FuncGraph::InsertFrontParameter(const ParameterPtr &p) {
109   if (manager_.lock()) {
110     manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p);
111   } else {
112     PrependParameter(p);
113   }
114 }
115 
AddWeightParameter(const std::string & name)116 ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
117   FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
118   ParameterPtr p = std::make_shared<Parameter>(this_graph);
119   p->set_name(name);
120   p->debug_info()->set_name(name);
121 
122   if (manager_.lock()) {
123     manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
124   } else {
125     parameters_.push_back(p);
126   }
127   hyper_param_count_++;
128   return p;
129 }
130 
has_flag(const std::string & key)131 bool FuncGraph::has_flag(const std::string &key) {
132   auto iter = attrs_.find(key);
133   if (iter != attrs_.cend()) {
134     MS_EXCEPTION_IF_NULL(iter->second);
135     if (iter->second->isa<BoolImm>()) {
136       return GetValue<bool>(iter->second);
137     }
138     MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
139   }
140   return false;
141 }
142 
has_attr(const std::string & key) const143 bool FuncGraph::has_attr(const std::string &key) const {
144   auto iter = attrs_.find(key);
145   return !(iter == attrs_.cend());
146 }
147 
get_attr(const std::string & key) const148 ValuePtr FuncGraph::get_attr(const std::string &key) const {
149   auto iter = attrs_.find(key);
150   return iter == attrs_.cend() ? nullptr : iter->second;
151 }
152 
NewCNode(const std::vector<AnfNodePtr> & inputs)153 CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
154   return std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
155 }
156 
NewCNodeInOrder(const std::vector<AnfNodePtr> & inputs)157 CNodePtr FuncGraph::NewCNodeInOrder(const std::vector<AnfNodePtr> &inputs) {
158   CNodePtr cnode = NewCNode(inputs);
159   order_.push_back(cnode);
160   return cnode;
161 }
162 
NewCNodeInFront(const std::vector<AnfNodePtr> & inputs)163 CNodePtr FuncGraph::NewCNodeInFront(const std::vector<AnfNodePtr> &inputs) {
164   CNodePtr cnode = NewCNode(inputs);
165   order_.push_front(cnode);
166   return cnode;
167 }
168 
NewCNodeBefore(const AnfNodePtr & position,const std::vector<AnfNodePtr> & inputs)169 CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
170   CNodePtr cnode = NewCNode(inputs);
171   CNodePtr pos_cnode = dyn_cast<CNode>(position);
172   auto iter = order_.find(pos_cnode);
173   order_.insert(iter, cnode);
174   return cnode;
175 }
176 
NewCNodeAfter(const AnfNodePtr & position,const std::vector<AnfNodePtr> & inputs)177 CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
178   CNodePtr cnode = NewCNode(inputs);
179   CNodePtr pos_cnode = dyn_cast<CNode>(position);
180   auto iter = order_.find(pos_cnode);
181   if (iter == order_.end()) {
182     order_.push_front(cnode);
183   } else {
184     order_.insert(std::next(iter), cnode);
185   }
186   return cnode;
187 }
188 
DumpCNodeList()189 void FuncGraph::DumpCNodeList() {
190   MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
191   for (const auto &cnode : order_) {
192     MS_LOG(INFO) << cnode->DebugString();
193   }
194 }
195 
ToString() const196 std::string FuncGraph::ToString() const {
197   std::ostringstream buffer;
198   auto debug_info = const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info();
199   buffer << mindspore::label_manage::Label(debug_info);
200   buffer << "." << debug_info->get_id();
201   return buffer.str();
202 }
203 
debug_info()204 GraphDebugInfoPtr FuncGraph::debug_info() {
205   MS_EXCEPTION_IF_NULL(this->debug_info_);
206   if (this->debug_info_->get_graph() == nullptr) {
207     this->debug_info_->set_graph(shared_from_base<FuncGraph>());
208   }
209   return this->debug_info_;
210 }
211 
nodes() const212 const AnfNodeSet &FuncGraph::nodes() const { return nodes_; }
213 
CopyNodes(const FuncGraphPtr & source)214 void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_.update(source->nodes()); }
215 
ClearNodes()216 void FuncGraph::ClearNodes() { nodes_.clear(); }
217 
AddNode(const AnfNodePtr & node)218 void FuncGraph::AddNode(const AnfNodePtr &node) { nodes_.add(node); }
219 
DropNode(const AnfNodePtr & node)220 void FuncGraph::DropNode(const AnfNodePtr &node) {
221   nodes_.erase(node);
222   if (node == nullptr) {
223     MS_LOG(ERROR) << "Node is nullptr";
224     return;
225   }
226   auto graph = node->func_graph();
227   if (node->isa<Parameter>()) {
228     (void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
229   }
230   // Remove the node from order list.
231   if (graph) {
232     graph->EraseUnusedNodeInOrder(node);
233   }
234 }
235 
value_nodes() const236 const AnfNodeCounterMap &FuncGraph::value_nodes() const { return value_nodes_; }
237 
CopyValueNodes(const FuncGraphPtr & source)238 void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
239   auto &others = source->value_nodes();
240   for (auto it = others.begin(); it != others.end(); ++it) {
241     AddValueNode(it->first, it->second);
242   }
243 }
244 
ClearValueNodes()245 void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
246 
AddValueNode(const AnfNodePtr & node,int count)247 void FuncGraph::AddValueNode(const AnfNodePtr &node, int count) {
248   if (value_nodes_.count(node) == 0) {
249     value_nodes_[node] = count;
250   } else {
251     value_nodes_[node] += count;
252   }
253 }
254 
DropValueNode(const AnfNodePtr & node)255 void FuncGraph::DropValueNode(const AnfNodePtr &node) {
256   if (value_nodes_.count(node) != 0) {
257     if (value_nodes_[node] == 1) {
258       (void)value_nodes_.erase(node);
259     } else {
260       value_nodes_[node]--;
261       if (value_nodes_[node] < 0) {
262         MS_LOG(EXCEPTION) << "Count of ValueNode '" << node
263                           << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
264       }
265     }
266   }
267 }
268 
free_variables() const269 const AnfNodeCounterMap &FuncGraph::free_variables() const { return free_variables_; }
270 
CopyFreeVariables(const FuncGraphPtr & source)271 void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
272   auto &others = source->free_variables();
273   for (auto it = others.begin(); it != others.end(); ++it) {
274     const auto &free_var = it->first;
275     MS_EXCEPTION_IF_NULL(free_var);
276     if (free_var->func_graph().get() != this) {
277       (void)AddFreeVariable(free_var, it->second);
278     }
279   }
280 }
281 
ClearFreeVariables()282 void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }
283 
AddFreeVariable(const AnfNodePtr & node,int count)284 bool FuncGraph::AddFreeVariable(const AnfNodePtr &node, int count) {
285   if (free_variables_.count(node) == 0) {
286     free_variables_[node] = count;
287     return true;
288   } else {
289     free_variables_[node] += count;
290     return false;
291   }
292 }
293 
DropFreeVariable(const AnfNodePtr & node)294 bool FuncGraph::DropFreeVariable(const AnfNodePtr &node) {
295   if (free_variables_.count(node) != 0) {
296     if (free_variables_[node] == 1) {
297       (void)free_variables_.erase(node);
298       return true;
299     } else {
300       free_variables_[node]--;
301       if (free_variables_[node] < 0) {
302         MS_LOG(EXCEPTION) << "Count of free variable '" << node
303                           << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
304       }
305     }
306   }
307   return false;
308 }
309 
free_variables_total()310 const BaseRefCounterMap &FuncGraph::free_variables_total() {
311   auto mng = manager_.lock();
312   MS_EXCEPTION_IF_NULL(mng);
313   auto &fv_total = mng->free_variables_total();
314   return fv_total[shared_from_base<FuncGraph>()];
315 }
316 
free_variables_nodes()317 std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
318   std::vector<AnfNodePtr> nodes;
319   const auto &fv_total = this->free_variables_total();
320   for (auto &p : fv_total) {
321     auto key = p.first;
322     if (utils::isa<AnfNodePtr>(key)) {
323       nodes.push_back(utils::cast<AnfNodePtr>(key));
324     }
325   }
326   return nodes;
327 }
328 
free_variables_func_graphs()329 std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
330   std::vector<FuncGraphPtr> func_graphs;
331   const auto &fv_total = this->free_variables_total();
332   for (auto &p : fv_total) {
333     auto key = p.first;
334     if (utils::isa<FuncGraphPtr>(key)) {
335       func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
336     }
337   }
338 
339   return func_graphs;
340 }
341 
func_graphs_used() const342 const FuncGraphCounterMap &FuncGraph::func_graphs_used() const { return func_graphs_used_; }
343 
CopyFuncGraphsUsed(const FuncGraphPtr & source)344 void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
345   auto &others = source->func_graphs_used();
346   for (auto it = others.begin(); it != others.end(); ++it) {
347     (void)AddFuncGraphUsed(it->first, it->second);
348   }
349   func_graphs_used_.erase(source);
350 }
351 
ClearFuncGraphsUsed()352 void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
353 
AddFuncGraphUsed(const FuncGraphPtr & fg,int count)354 bool FuncGraph::AddFuncGraphUsed(const FuncGraphPtr &fg, int count) {
355   if (func_graphs_used_.count(fg) == 0) {
356     func_graphs_used_[fg] = count;
357     return true;
358   } else {
359     func_graphs_used_[fg] += count;
360     return false;
361   }
362 }
363 
DropFuncGraphUsed(const FuncGraphPtr & fg)364 bool FuncGraph::DropFuncGraphUsed(const FuncGraphPtr &fg) {
365   if (func_graphs_used_.count(fg) != 0) {
366     if (func_graphs_used_[fg] == 1) {
367       (void)func_graphs_used_.erase(fg);
368       return true;
369     } else {
370       func_graphs_used_[fg]--;
371       if (func_graphs_used_[fg] < 0) {
372         MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
373                           << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
374       }
375     }
376   }
377   return false;
378 }
379 
func_graphs_used_total()380 const FuncGraphSet &FuncGraph::func_graphs_used_total() {
381   auto mng = manager_.lock();
382   MS_EXCEPTION_IF_NULL(mng);
383   auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
384   return used;
385 }
386 
func_graph_cnodes_index() const387 const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() const { return func_graph_cnodes_index_; }
388 
CopyFuncGraphCNodesIndex(const FuncGraphPtr & source)389 void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
390   auto &others = source->func_graph_cnodes_index();
391   for (auto it = others.begin(); it != others.end(); ++it) {
392     // Ignore the user graph who may own itself.
393     auto fg = it->first->first->func_graph();
394     MS_EXCEPTION_IF_NULL(fg);
395     if (fg.get() != this) {
396       AddFuncGraphCNodeIndex(it->first, it->second);
397     }
398   }
399 }
400 
ClearFuncGraphCNodesIndex()401 void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }
402 
AddFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair,int count)403 void FuncGraph::AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair, int count) {
404   if (func_graph_cnodes_index_.count(pair) == 0) {
405     func_graph_cnodes_index_[pair] = count;
406   } else {
407     func_graph_cnodes_index_[pair] += count;
408   }
409 }
410 
DropFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair)411 void FuncGraph::DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair) {
412   if (func_graph_cnodes_index_.count(pair) != 0) {
413     if (func_graph_cnodes_index_[pair] == 1) {
414       (void)func_graph_cnodes_index_.erase(pair);
415     } else {
416       func_graph_cnodes_index_[pair]--;
417       if (func_graph_cnodes_index_[pair] < 0) {
418         MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
419                           << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
420       }
421     }
422   }
423 }
424 
j_value_nodes() const425 const std::unordered_map<AnfNodePtr, int> &FuncGraph::j_value_nodes() const { return j_value_nodes_; }
426 
CopyJValueNodes(const FuncGraphPtr & source)427 void FuncGraph::CopyJValueNodes(const FuncGraphPtr &source) {
428   MS_EXCEPTION_IF_NULL(source);
429   auto &others = source->j_value_nodes();
430   for (const auto &other : others) {
431     AddJValueNode(other.first, other.second);
432   }
433 }
434 
ClearJValueNodes()435 void FuncGraph::ClearJValueNodes() { j_value_nodes_.clear(); }
436 
AddJValueNode(const AnfNodePtr & value_node,int count)437 void FuncGraph::AddJValueNode(const AnfNodePtr &value_node, int count) {
438   if (j_value_nodes_.count(value_node) == 0) {
439     j_value_nodes_[value_node] = count;
440   } else {
441     j_value_nodes_[value_node] += count;
442   }
443 }
444 
DropJValueNode(const AnfNodePtr & value_node)445 void FuncGraph::DropJValueNode(const AnfNodePtr &value_node) {
446   if (j_value_nodes_.count(value_node) != 0) {
447     if (j_value_nodes_[value_node] == 1) {
448       (void)j_value_nodes_.erase(value_node);
449     } else {
450       j_value_nodes_[value_node]--;
451       if (j_value_nodes_[value_node] < 0) {
452         MS_LOG(EXCEPTION) << "Count of J ValueNode '" << value_node->DebugString()
453                           << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
454       }
455     }
456   }
457 }
458 
parent()459 FuncGraphPtr FuncGraph::parent() {
460   // report the bug early.
461   if (manager_.lock() == nullptr) {
462     MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
463                       << " NodeInfo: " << trace::GetDebugInfo(debug_info());
464   }
465   auto mng = manager_.lock();
466   MS_EXCEPTION_IF_NULL(mng);
467   return mng->parent(shared_from_base<FuncGraph>());
468 }
469 
children()470 const FuncGraphSet &FuncGraph::children() {
471   auto mng = manager_.lock();
472   MS_EXCEPTION_IF_NULL(mng);
473   return mng->children(shared_from_base<FuncGraph>());
474 }
475 
scope()476 const FuncGraphSet &FuncGraph::scope() {
477   auto mng = manager_.lock();
478   MS_EXCEPTION_IF_NULL(mng);
479   return mng->scopes(shared_from_base<FuncGraph>());
480 }
481 
recursive()482 bool FuncGraph::recursive() {
483   auto mng = manager_.lock();
484   MS_EXCEPTION_IF_NULL(mng);
485   return mng->recursive(shared_from_base<FuncGraph>());
486 }
487 
recursive_graphs()488 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
489   auto mng = manager_.lock();
490   MS_EXCEPTION_IF_NULL(mng);
491   return mng->recursive_graphs(shared_from_base<FuncGraph>());
492 }
493 
ClearAllManagerInfo()494 void FuncGraph::ClearAllManagerInfo() {
495   ClearNodes();
496   ClearValueNodes();
497   ClearFuncGraphCNodesIndex();
498   ClearFreeVariables();
499   ClearFuncGraphsUsed();
500   ClearJValueNodes();
501 }
502 
GetDefaultValueByName(const std::string & name)503 AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
504   auto itr = this->parameter_default_value_.find(name);
505   if (itr == parameter_default_value_.end()) {
506     return nullptr;
507   }
508   auto default_value = itr->second;
509   if (default_value == nullptr) {
510     MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist";
511   }
512   if (IsValueNode<Null>(default_value)) {
513     return nullptr;
514   }
515   return default_value;
516 }
517 
518 // set the default values
SetDefaultValues(const std::vector<std::string> & name_list,const std::vector<AnfNodePtr> & value_list)519 void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
520   auto all_is_null =
521     std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode<Null>(node); });
522   if (value_list.empty()) {
523     all_is_null = true;
524   }
525   for (size_t i = 0; i < name_list.size(); ++i) {
526     if (!all_is_null) {
527       this->parameter_default_value_[name_list[i]] = value_list[i];
528     }
529   }
530 }
531 
ClearDefaultValues()532 void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
533 
GetDefaultValueCount()534 size_t FuncGraph::GetDefaultValueCount() {
535   int64_t null_count =
536     std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
537                   [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<Null>(pair.second); });
538   return parameter_default_value_.size() - LongToSize(null_count);
539 }
540 
GetVariableArgParameter()541 AnfNodePtr FuncGraph::GetVariableArgParameter() {
542   if (!has_vararg_) {
543     return nullptr;
544   }
545 
546   // one vararg + kwarg so the min param num is 2;
547   constexpr size_t min_param_num = 2;
548   if (has_kwarg_) {
549     if (parameters_.size() < hyper_param_count_ + min_param_num) {
550       MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
551                         << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
552     }
553     return parameters_[parameters_.size() - hyper_param_count_ - min_param_num];
554   }
555 
556   if (parameters_.size() < hyper_param_count_ + 1) {
557     MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
558                       << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
559   }
560   return parameters_[parameters_.size() - hyper_param_count_ - 1];
561 }
562 
GetVariableArgName()563 std::string FuncGraph::GetVariableArgName() {
564   if (!has_vararg_) {
565     return "";
566   }
567 
568   // one vararg + kwarg so the min param num is 2;
569   constexpr size_t min_param_num = 2;
570   if (has_kwarg_) {
571     if (parameters_.size() < hyper_param_count_ + min_param_num) {
572       MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
573                         << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
574     }
575     const auto &parameter = parameters_[parameters_.size() - hyper_param_count_ - min_param_num]->cast<ParameterPtr>();
576     MS_EXCEPTION_IF_NULL(parameter);
577     return parameter->name();
578   }
579 
580   if (parameters_.size() < hyper_param_count_ + 1) {
581     MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
582                       << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
583   }
584   const auto &parameter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>();
585   MS_EXCEPTION_IF_NULL(parameter);
586   return parameter->name();
587 }
588 
GetVariableKwargParameter()589 AnfNodePtr FuncGraph::GetVariableKwargParameter() {
590   if (has_kwarg_) {
591     if (parameters_.size() < hyper_param_count_ + 1) {
592       MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
593                         << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
594     }
595     return parameters_[parameters_.size() - hyper_param_count_ - 1];
596   }
597   return nullptr;
598 }
599 
GetVariableKwargName()600 std::string FuncGraph::GetVariableKwargName() {
601   if (has_kwarg_) {
602     if (parameters_.size() < hyper_param_count_ + 1) {
603       MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
604                         << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
605     }
606     const auto &parameter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>();
607     MS_EXCEPTION_IF_NULL(parameter);
608     return parameter->name();
609   }
610   return "";
611 }
612 
GetPositionalArgsCount() const613 int FuncGraph::GetPositionalArgsCount() const {
614   int count = SizeToInt(parameters_.size());
615   if (has_kwarg_) {
616     count--;
617   }
618   if (has_vararg_) {
619     count--;
620   }
621   return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
622 }
623 
GetParameterByName(const std::string & name)624 AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
625   for (size_t i = 0; i < parameters_.size(); ++i) {
626     MS_EXCEPTION_IF_NULL(parameters_[i]);
627     auto param_cast = parameters_[i]->cast<ParameterPtr>();
628     MS_EXCEPTION_IF_NULL(param_cast);
629     if (param_cast->name() == name) {
630       return parameters_[i];
631     }
632   }
633   return nullptr;
634 }
635 
GetOrderedCnodes()636 std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
637   auto this_ptr = shared_from_base<FuncGraph>();
638   auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1);
639   auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);
640 
641   std::list<CNodePtr> cnodes;
642   auto nodes = mindspore::TopoSort(get_return(), SuccDepends, BelongSameGraph);
643   for (const auto &node : nodes) {
644     auto cnode = dyn_cast<CNode>(node);
645     if (cnode) {
646       cnodes.push_back(cnode);
647     }
648   }
649   return cnodes;
650 }
651 
EraseUnusedNodeInOrder()652 void FuncGraph::EraseUnusedNodeInOrder() {
653   auto mng = manager_.lock();
654   if (mng) {
655     auto &all_nodes = nodes();
656     // Erase unused cnode.
657     for (auto it = order_.begin(); it != order_.end();) {
658       if (!all_nodes.contains(*it)) {
659         MS_LOG(DEBUG) << "Remove node: " << (*it)->ToString() << " in graph " << ToString() << " order.";
660         it = order_.erase(it);
661         continue;
662       }
663       (void)it++;
664     }
665   }
666 }
667 
EraseUnusedNodeInOrder(const AnfNodePtr & node)668 void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) {
669   if (node) {
670     auto cnode = node->cast<CNodePtr>();
671     if (cnode) {
672       order_.erase(cnode);
673       MS_LOG(DEBUG) << "Remove node: " << node->DebugString() << " from order list.";
674     }
675   }
676 }
677 
678 // Maintain cnode order list when a cnode is replaced by a new one.
ReplaceInOrder(const AnfNodePtr & old_node,const AnfNodePtr & new_node)679 void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
680   MS_EXCEPTION_IF_NULL(old_node);
681   MS_EXCEPTION_IF_NULL(new_node);
682   if (order_.empty()) {
683     // Skip if order list is empty.
684     return;
685   }
686   auto old_cnode = old_node->cast<CNodePtr>();
687   if (old_cnode == nullptr) {
688     // Skip if old node is not cnode, since order list contains cnode only.
689     return;
690   }
691   // Search old node in order list.
692   auto iter = order_.find(old_cnode);
693   if (iter == order_.end()) {
694     // Skip if old node not found in order list.
695     return;
696   }
697   auto new_cnode = new_node->cast<CNodePtr>();
698   if (new_cnode != nullptr) {
699     // Insert new node just before the old node.
700     order_.insert(iter, new_cnode);
701   }
702   // Remove old node from order list.
703   // Unused children nodes can be cleared by EraseUnusedNodeInOrder().
704   order_.erase(iter);
705 }
706 
MakeInputNodes(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs)707 static std::vector<AnfNodePtr> MakeInputNodes(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
708   std::vector<AnfNodePtr> input_node_list;
709   input_node_list.reserve(inputs.size() + 1);
710   input_node_list.emplace_back(std::make_shared<ValueNode>(primitive));
711   input_node_list.insert(input_node_list.end(), inputs.begin(), inputs.end());
712   return input_node_list;
713 }
714 
NewCNode(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs)715 CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
716   auto input_node_list = MakeInputNodes(primitive, inputs);
717   return NewCNode(input_node_list);
718 }
719 
NewCNodeInOrder(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs)720 CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
721   auto input_node_list = MakeInputNodes(primitive, inputs);
722   return NewCNodeInOrder(input_node_list);
723 }
724 
add_weight(const tensor::MetaTensorPtr & meta_tensor)725 ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
726   auto parameter = add_parameter();
727   parameter->set_default_param(MakeValue(meta_tensor));
728   parameter->set_abstract(meta_tensor->ToAbstract());
729   return parameter;
730 }
731 
ContainMultiTarget() const732 bool FuncGraph::ContainMultiTarget() const {
733   auto graph_manager = manager();
734   MS_EXCEPTION_IF_NULL(graph_manager);
735   FuncGraphSet graphs = graph_manager->func_graphs();
736   for (auto &g : graphs) {
737     auto nodes = mindspore::TopoSort(g->get_return());
738     if (mindspore::ContainMultiTarget(nodes)) {
739       return true;
740     }
741   }
742   return false;
743 }
744 
set_used_forward_nodes(const std::vector<AnfNodePtr> & used_forward_nodes)745 void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes) {
746   (void)std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(), [this](const AnfNodePtr &node) {
747     MS_EXCEPTION_IF_NULL(node);
748     (void)used_forward_nodes_.emplace(node);
749   });
750 }
751 
NewFgSeenGeneration()752 size_t NewFgSeenGeneration() {
753   static size_t fg_seen_generation = 0;
754   return ++fg_seen_generation;
755 }
756 
757 // Implement TopoSort api.
TopoSort(const AnfNodePtr & node)758 std::vector<AnfNodePtr> api::FuncGraph::TopoSort(const AnfNodePtr &node) { return mindspore::TopoSort(node); }
759 
760 // Create an api::FuncGraph instance.
Create()761 api::FuncGraphPtr api::FuncGraph::Create() { return std::make_shared<mindspore::FuncGraph>(); }
762 
MakeValueNode(const api::FuncGraphPtr & func_graph)763 AnfNodePtr api::FuncGraph::MakeValueNode(const api::FuncGraphPtr &func_graph) {
764   auto fg = std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph);
765   return NewValueNode(fg);
766 }
767 
GetFuncGraphFromAnfNode(const AnfNodePtr & input)768 api::FuncGraphPtr api::FuncGraph::GetFuncGraphFromAnfNode(const AnfNodePtr &input) {
769   auto fg = GetValueNode<mindspore::FuncGraphPtr>(input);
770   return fg;
771 }
772 
773 const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
774 }  // namespace mindspore
775