• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/func_graph.h"
20 #include <algorithm>
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "utils/trace_base.h"
23 #include "ir/manager.h"
24 #include "utils/ordered_set.h"
25 #include "utils/convert_utils_base.h"
26 #include "abstract/abstract_function.h"
27 #include "ir/func_graph_cloner.h"
28 #include "utils/phase.h"
29 
30 namespace mindspore {
31 /*
32  * Methods of Graph
33  */
FuncGraph()34 FuncGraph::FuncGraph() : FuncGraph(std::make_shared<GraphDebugInfo>()) {}
35 
FuncGraph(GraphDebugInfoPtr && debug_info)36 FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
37     : attrs_(),
38       transforms_(),
39       parameter_default_value_(),
40       seen_(0),
41       parameters_(),
42       has_vararg_(false),
43       has_kwarg_(false),
44       exist_multi_target_(false),
45       kw_only_args_count_(0),
46       fv_param_count_(0),
47       is_generated_(false),
48       manager_(),
49       debug_info_(std::move(debug_info)),
50       stub_(false),
51       stage_(-1),
52       segment_(1),
53       phase_(PhaseManager::GetInstance().phase()) {}
54 
~FuncGraph()55 FuncGraph::~FuncGraph() { subclass_destruct_flag_ = true; }
56 
DoBreakLoop()57 void FuncGraph::DoBreakLoop() {
58   if (attached_mng_cnt() > 0) {
59     MS_LOG(INFO) << "Current Graph is holding by FuncGraphManager, can't DoBreakLoop now.";
60     return;
61   }
62   ClearOrderList();
63   python_obj_ = nullptr;
64   used_forward_nodes_.clear();
65   func_graph_cache_.clear();
66   parameters_.clear();
67   parameter_obj_nodes_.clear();
68   set_dropped(true);
69 }
70 
ToAbstract()71 abstract::AbstractBasePtr FuncGraph::ToAbstract() {
72   auto temp_context = abstract::AnalysisContext::DummyContext();
73   return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
74 }
75 
output() const76 AnfNodePtr FuncGraph::output() const {
77   constexpr size_t return_input_num = 2;
78   // If return value is set, return should have two inputs.
79   if (return_node() != nullptr && return_node()->size() == return_input_num) {
80     return return_node()->input(1);
81   } else {
82     // If not set yet, return nullptr.
83     return nullptr;
84   }
85 }
86 
get_inputs() const87 const AnfNodePtrList FuncGraph::get_inputs() const {
88   AnfNodePtrList input_params;
89   for (auto const &node : parameters_) {
90     MS_EXCEPTION_IF_NULL(node);
91     auto parameter = dyn_cast<Parameter>(node);
92     MS_EXCEPTION_IF_NULL(parameter);
93     if (!parameter->has_default()) {
94       input_params.push_back(parameter);
95     }
96   }
97   return input_params;
98 }
99 
add_parameter()100 ParameterPtr FuncGraph::add_parameter() {
101   FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
102   ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
103   add_parameter(param);
104   return param;
105 }
106 
add_parameter(NodeDebugInfoPtr && debug_info)107 ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) {
108   FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
109   ParameterPtr param = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
110   add_parameter(param);
111   return param;
112 }
113 
add_parameter(const ParameterPtr & param)114 void FuncGraph::add_parameter(const ParameterPtr &param) {
115   if (manager_.lock()) {
116     manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
117   } else {
118     parameters_.push_back(param);
119   }
120 }
121 
InsertFrontParameter()122 ParameterPtr FuncGraph::InsertFrontParameter() {
123   FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
124   ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
125   InsertFrontParameter(param);
126   return param;
127 }
128 
InsertFrontParameter(const ParameterPtr & param)129 void FuncGraph::InsertFrontParameter(const ParameterPtr &param) {
130   if (manager_.lock()) {
131     manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), param);
132   } else {
133     PrependParameter(param);
134   }
135 }
136 
AddFvParameter(const std::string & name,const ValuePtr & default_value)137 ParameterPtr FuncGraph::AddFvParameter(const std::string &name, const ValuePtr &default_value) {
138   FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
139   ParameterPtr param = std::make_shared<Parameter>(this_graph);
140   param->set_name(name);
141   MS_EXCEPTION_IF_NULL(param->debug_info());
142   param->debug_info()->set_name(name);
143   param->debug_info()->set_trace_info(nullptr);
144   MS_EXCEPTION_IF_NULL(default_value);
145   param->set_default_param(default_value);
146   param->set_abstract(default_value->ToAbstract());
147   if (manager_.lock()) {
148     manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
149   } else {
150     parameters_.push_back(param);
151   }
152   ++fv_param_count_;
153   return param;
154 }
155 
has_flag(const std::string & key) const156 bool FuncGraph::has_flag(const std::string &key) const {
157   auto iter = attrs_.find(key);
158   if (iter != attrs_.cend()) {
159     MS_EXCEPTION_IF_NULL(iter->second);
160     if (iter->second->isa<BoolImm>()) {
161       return GetValue<bool>(iter->second);
162     }
163     MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
164   }
165   return false;
166 }
167 
has_attr(const std::string & key) const168 bool FuncGraph::has_attr(const std::string &key) const {
169   auto iter = attrs_.find(key);
170   return !(iter == attrs_.cend());
171 }
172 
get_attr(const std::string & key) const173 ValuePtr FuncGraph::get_attr(const std::string &key) const {
174   auto iter = attrs_.find(key);
175   return iter == attrs_.cend() ? nullptr : iter->second;
176 }
177 
NewCNodeWeak(AnfNodeWeakPtrList && weak_inputs)178 CNodePtr FuncGraph::NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs) {
179   return std::make_shared<CNode>(std::move(weak_inputs), shared_from_base<FuncGraph>());
180 }
181 
NewCNodeWeak(const AnfNodeWeakPtrList & weak_inputs)182 CNodePtr FuncGraph::NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs) {
183   return std::make_shared<CNode>(weak_inputs, shared_from_base<FuncGraph>());
184 }
185 
NewCNode(AnfNodePtrList && inputs)186 CNodePtr FuncGraph::NewCNode(AnfNodePtrList &&inputs) {
187   std::vector<AnfNodeWeakPtr> weak_inputs;
188   weak_inputs.reserve(inputs.size());
189   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
190                  [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
191   return std::make_shared<CNode>(std::move(weak_inputs), shared_from_base<FuncGraph>());
192 }
193 
NewCNode(const AnfNodePtrList & inputs)194 CNodePtr FuncGraph::NewCNode(const AnfNodePtrList &inputs) {
195   std::vector<AnfNodeWeakPtr> weak_inputs;
196   weak_inputs.reserve(inputs.size());
197   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
198                  [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
199   return std::make_shared<CNode>(std::move(weak_inputs), shared_from_base<FuncGraph>());
200 }
201 
NewCNodeInOrderWeak(AnfNodeWeakPtrList && weak_inputs)202 CNodePtr FuncGraph::NewCNodeInOrderWeak(AnfNodeWeakPtrList &&weak_inputs) {
203   CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
204   (void)order_.emplace_back(CNodeWeakPtr(cnode));
205   return cnode;
206 }
207 
NewCNodeInOrderWeak(const AnfNodeWeakPtrList & weak_inputs)208 CNodePtr FuncGraph::NewCNodeInOrderWeak(const AnfNodeWeakPtrList &weak_inputs) {
209   CNodePtr cnode = NewCNodeWeak(weak_inputs);
210   (void)order_.emplace_back(CNodeWeakPtr(cnode));
211   return cnode;
212 }
213 
NewCNodeInOrder(AnfNodePtrList && inputs)214 CNodePtr FuncGraph::NewCNodeInOrder(AnfNodePtrList &&inputs) { return NewCNodeInOrder(inputs); }
215 
NewCNodeInOrder(const AnfNodePtrList & inputs)216 CNodePtr FuncGraph::NewCNodeInOrder(const AnfNodePtrList &inputs) {
217   std::vector<AnfNodeWeakPtr> weak_inputs;
218   weak_inputs.reserve(inputs.size());
219   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
220                  [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
221   CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
222   (void)order_.emplace_back(CNodeWeakPtr(cnode));
223   return cnode;
224 }
225 
NewCNodeInFront(const AnfNodePtrList & inputs)226 CNodePtr FuncGraph::NewCNodeInFront(const AnfNodePtrList &inputs) {
227   std::vector<AnfNodeWeakPtr> weak_inputs;
228   weak_inputs.reserve(inputs.size());
229   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
230                  [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
231   CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
232   (void)order_.emplace_front(CNodeWeakPtr(cnode));
233   return cnode;
234 }
235 
NewCNodeBefore(const AnfNodePtr & position,const AnfNodePtrList & inputs)236 CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const AnfNodePtrList &inputs) {
237   std::vector<AnfNodeWeakPtr> weak_inputs;
238   weak_inputs.reserve(inputs.size());
239   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
240                  [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
241   CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
242   CNodePtr pos_cnode = dyn_cast<CNode>(position);
243   auto iter = std::find_if(order_.cbegin(), order_.cend(), [&pos_cnode](const CNodeWeakPtr &node) {
244     return node.lock() != nullptr && node.lock() == pos_cnode;
245   });
246   (void)order_.insert(iter, CNodeWeakPtr(cnode));
247   return cnode;
248 }
249 
NewCNodeAfter(const AnfNodePtr & position,const AnfNodePtrList & inputs)250 CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const AnfNodePtrList &inputs) {
251   std::vector<AnfNodeWeakPtr> weak_inputs;
252   weak_inputs.reserve(inputs.size());
253   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
254                  [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
255   CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
256   CNodePtr pos_cnode = dyn_cast<CNode>(position);
257   auto iter = std::find_if(order_.cbegin(), order_.cend(), [&pos_cnode](const CNodeWeakPtr &node) {
258     return node.lock() != nullptr && node.lock() == pos_cnode;
259   });
260   if (iter == order_.cend()) {
261     order_.push_front(CNodeWeakPtr(cnode));
262   } else {
263     (void)order_.insert(std::next(iter), CNodeWeakPtr(cnode));
264   }
265   return cnode;
266 }
267 
own_nodes() const268 const std::list<AnfNodePtr> &FuncGraph::own_nodes() const { return own_nodes_; }
269 
AddOwnNode(const AnfNodePtr & node)270 void FuncGraph::AddOwnNode(const AnfNodePtr &node) { (void)own_nodes_.emplace_back(node); }
271 
AddOwnNode(const AnfNodePtrList & nodes)272 void FuncGraph::AddOwnNode(const AnfNodePtrList &nodes) {
273   (void)own_nodes_.insert(own_nodes_.end(), nodes.cbegin(), nodes.cend());
274 }
275 
AddOwnNode(const AnfNodeWeakPtrList & weak_nodes)276 void FuncGraph::AddOwnNode(const AnfNodeWeakPtrList &weak_nodes) {
277   std::transform(weak_nodes.cbegin(), weak_nodes.cend(), std::back_inserter(own_nodes_),
278                  [](const AnfNodeWeakPtr &weak_node) -> AnfNodePtr { return weak_node.lock(); });
279 }
280 
RemoveOwnNode(const AnfNodePtr & node)281 void FuncGraph::RemoveOwnNode(const AnfNodePtr &node) {
282   auto iter = std::find(own_nodes_.cbegin(), own_nodes_.cend(), node);
283   if (iter != own_nodes_.cend()) {
284     own_nodes_.erase(iter);
285   }
286 }
287 
ResetOwnNodes()288 void FuncGraph::ResetOwnNodes() { own_nodes_.clear(); }
289 
DumpCNodeList()290 void FuncGraph::DumpCNodeList() {
291   MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
292   for (const auto &weak_cnode : order_) {
293     const auto &cnode = weak_cnode.lock();
294     if (cnode == nullptr) {
295       continue;
296     }
297     MS_LOG(INFO) << cnode->DebugString();
298   }
299 }
300 
ToString() const301 std::string FuncGraph::ToString() const {
302   std::ostringstream buffer;
303   auto debug_info = const_cast<FuncGraph *>(this)->debug_info();
304   buffer << mindspore::trace::Label(debug_info);
305   buffer << "_" << debug_info->get_id();
306   return buffer.str();
307 }
308 
debug_info()309 GraphDebugInfoPtr FuncGraph::debug_info() {
310   MS_EXCEPTION_IF_NULL(this->debug_info_);
311   if (this->debug_info_->get_graph() == nullptr) {
312     this->debug_info_->set_graph(shared_from_base<FuncGraph>());
313   }
314   return this->debug_info_;
315 }
316 
nodes() const317 const AnfNodeSet &FuncGraph::nodes() const { return nodes_; }
318 
switch_nodes() const319 const AnfNodeSet &FuncGraph::switch_nodes() const { return switch_nodes_; }
320 
CopyNodes(const FuncGraphPtr & source)321 void FuncGraph::CopyNodes(const FuncGraphPtr &source) {
322   nodes_.update(source->nodes());
323   switch_nodes_.update(source->switch_nodes());
324 }
325 
ClearNodes()326 void FuncGraph::ClearNodes() {
327   nodes_.clear();
328   switch_nodes_.clear();
329 }
330 
AddNode(const AnfNodePtr & node)331 void FuncGraph::AddNode(const AnfNodePtr &node) {
332   nodes_.add(node);
333   if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
334     switch_nodes_.add(node);
335   }
336 }
337 
DropNode(const AnfNodePtr & node)338 void FuncGraph::DropNode(const AnfNodePtr &node) {
339   if (node == nullptr) {
340     MS_LOG(ERROR) << "Node is nullptr";
341     return;
342   }
343   (void)nodes_.erase(node);
344   if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
345     switch_nodes_.erase(node);
346   }
347   auto graph = node->func_graph();
348   if (node->isa<Parameter>()) {
349     (void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
350   }
351   // Remove the node from order list.
352   if (graph != nullptr) {
353     graph->EraseUnusedNodeInOrder(node);
354   }
355 }
356 
value_nodes() const357 const AnfNodeCounterMap &FuncGraph::value_nodes() const { return value_nodes_; }
358 
CopyValueNodes(const FuncGraphPtr & source)359 void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
360   MS_EXCEPTION_IF_NULL(source);
361   auto &others = source->value_nodes();
362   for (auto it = others.begin(); it != others.end(); ++it) {
363     AddValueNode(it->first, it->second);
364   }
365 }
366 
ClearValueNodes()367 void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
368 
AddValueNode(const AnfNodePtr & node,int count)369 void FuncGraph::AddValueNode(const AnfNodePtr &node, int count) {
370   if (value_nodes_.count(node) == 0) {
371     value_nodes_[node] = count;
372   } else {
373     value_nodes_[node] += count;
374   }
375 }
376 
DropValueNode(const AnfNodePtr & node)377 void FuncGraph::DropValueNode(const AnfNodePtr &node) {
378   if (value_nodes_.count(node) != 0) {
379     if (value_nodes_[node] == 1) {
380       (void)value_nodes_.erase(node);
381     } else {
382       value_nodes_[node]--;
383       if (value_nodes_[node] < 0) {
384         MS_LOG(INTERNAL_EXCEPTION) << "Count of ValueNode '" << node
385                                    << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
386       }
387     }
388   }
389 }
390 
free_variables() const391 const AnfNodeCounterMap &FuncGraph::free_variables() const { return free_variables_; }
392 
CopyFreeVariables(const FuncGraphPtr & source)393 void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
394   MS_EXCEPTION_IF_NULL(source);
395   auto &others = source->free_variables();
396   for (auto it = others.begin(); it != others.end(); ++it) {
397     const auto &free_var = it->first;
398     MS_EXCEPTION_IF_NULL(free_var);
399     if (free_var->func_graph().get() != this) {
400       (void)AddFreeVariable(free_var, it->second);
401     }
402   }
403 }
404 
ClearFreeVariables()405 void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }
406 
AddFreeVariable(const AnfNodePtr & node,int count)407 bool FuncGraph::AddFreeVariable(const AnfNodePtr &node, int count) {
408   if (free_variables_.count(node) == 0) {
409     free_variables_[node] = count;
410     return true;
411   } else {
412     free_variables_[node] += count;
413     return false;
414   }
415 }
416 
DropFreeVariable(const AnfNodePtr & node)417 bool FuncGraph::DropFreeVariable(const AnfNodePtr &node) {
418   if (free_variables_.count(node) != 0) {
419     if (free_variables_[node] == 1) {
420       (void)free_variables_.erase(node);
421       return true;
422     } else {
423       free_variables_[node]--;
424       if (free_variables_[node] < 0) {
425         MS_LOG(INTERNAL_EXCEPTION) << "Count of free variable '" << node
426                                    << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
427       }
428     }
429   }
430   return false;
431 }
432 
free_variables_total()433 const BaseRefCounterMap &FuncGraph::free_variables_total() {
434   auto mng = manager_.lock();
435   MS_EXCEPTION_IF_NULL(mng);
436   auto &fv_total = mng->free_variables_total();
437   return fv_total[shared_from_base<FuncGraph>()];
438 }
439 
free_variables_nodes()440 AnfNodePtrList FuncGraph::free_variables_nodes() {
441   AnfNodePtrList nodes;
442   const auto &fv_total = this->free_variables_total();
443   for (auto &p : fv_total) {
444     auto key = p.first;
445     if (utils::isa<AnfNodePtr>(key)) {
446       nodes.push_back(utils::cast<AnfNodePtr>(key));
447     }
448   }
449   return nodes;
450 }
451 
free_variables_func_graphs()452 std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
453   std::vector<FuncGraphPtr> func_graphs;
454   const auto &fv_total = this->free_variables_total();
455   for (auto &p : fv_total) {
456     auto key = p.first;
457     if (utils::isa<FuncGraphPtr>(key)) {
458       func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
459     }
460   }
461 
462   return func_graphs;
463 }
464 
func_graphs_used() const465 const FuncGraphCounterMap &FuncGraph::func_graphs_used() const { return func_graphs_used_; }
466 
CopyFuncGraphsUsed(const FuncGraphPtr & source)467 void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
468   auto &others = source->func_graphs_used();
469   for (auto it = others.begin(); it != others.end(); ++it) {
470     (void)AddFuncGraphUsed(it->first, it->second);
471   }
472   (void)func_graphs_used_.erase(source);
473 }
474 
ClearFuncGraphsUsed()475 void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
476 
AddFuncGraphUsed(const FuncGraphPtr & fg,int count)477 bool FuncGraph::AddFuncGraphUsed(const FuncGraphPtr &fg, int count) {
478   if (func_graphs_used_.count(fg) == 0) {
479     func_graphs_used_[fg] = count;
480     return true;
481   } else {
482     func_graphs_used_[fg] += count;
483     return false;
484   }
485 }
486 
DropFuncGraphUsed(const FuncGraphPtr & fg)487 bool FuncGraph::DropFuncGraphUsed(const FuncGraphPtr &fg) {
488   if (func_graphs_used_.count(fg) != 0) {
489     if (func_graphs_used_[fg] == 1) {
490       (void)func_graphs_used_.erase(fg);
491       return true;
492     } else {
493       func_graphs_used_[fg]--;
494       if (func_graphs_used_[fg] < 0) {
495         MS_LOG(INTERNAL_EXCEPTION) << "Count of FuncGraph '" << fg
496                                    << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
497       }
498     }
499   }
500   return false;
501 }
502 
func_graphs_used_total()503 const FuncGraphSet &FuncGraph::func_graphs_used_total() {
504   auto mng = manager_.lock();
505   MS_EXCEPTION_IF_NULL(mng);
506   auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
507   return used;
508 }
509 
func_graph_cnodes_index() const510 const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() const { return func_graph_cnodes_index_; }
511 
CopyFuncGraphCNodesIndex(const FuncGraphPtr & source)512 void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
513   MS_EXCEPTION_IF_NULL(source);
514   auto &others = source->func_graph_cnodes_index();
515   for (auto it = others.begin(); it != others.end(); ++it) {
516     // Ignore the user graph who may own itself.
517     MS_EXCEPTION_IF_NULL(it->first);
518     MS_EXCEPTION_IF_NULL(it->first->first);
519     auto fg = it->first->first->func_graph();
520     MS_EXCEPTION_IF_NULL(fg);
521     if (fg.get() != this) {
522       AddFuncGraphCNodeIndex(it->first, it->second);
523     }
524   }
525 }
526 
ClearFuncGraphCNodesIndex()527 void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }
528 
AddFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair,int count)529 void FuncGraph::AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair, int count) {
530   if (func_graph_cnodes_index_.count(pair) == 0) {
531     func_graph_cnodes_index_[pair] = count;
532   } else {
533     func_graph_cnodes_index_[pair] += count;
534   }
535 }
536 
DropFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair)537 void FuncGraph::DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair) {
538   if (func_graph_cnodes_index_.count(pair) != 0) {
539     if (func_graph_cnodes_index_[pair] == 1) {
540       (void)func_graph_cnodes_index_.erase(pair);
541     } else {
542       func_graph_cnodes_index_[pair]--;
543       if (func_graph_cnodes_index_[pair] < 0) {
544         MS_LOG(INTERNAL_EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
545                                    << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
546       }
547     }
548   }
549 }
550 
meta_fg_prim_value_nodes() const551 const mindspore::HashMap<AnfNodePtr, int> &FuncGraph::meta_fg_prim_value_nodes() const {
552   return meta_fg_prim_value_nodes_;
553 }
554 
CopyMetaFgPrimValueNodes(const FuncGraphPtr & source)555 void FuncGraph::CopyMetaFgPrimValueNodes(const FuncGraphPtr &source) {
556   MS_EXCEPTION_IF_NULL(source);
557   auto &others = source->meta_fg_prim_value_nodes();
558   for (const auto &other : others) {
559     AddMetaFgPrimValueNode(other.first, other.second);
560   }
561 }
562 
ClearMetaFgPrimValueNodes()563 void FuncGraph::ClearMetaFgPrimValueNodes() { meta_fg_prim_value_nodes_.clear(); }
564 
AddMetaFgPrimValueNode(const AnfNodePtr & value_node,int count)565 void FuncGraph::AddMetaFgPrimValueNode(const AnfNodePtr &value_node, int count) {
566   if (meta_fg_prim_value_nodes_.count(value_node) == 0) {
567     meta_fg_prim_value_nodes_[value_node] = count;
568   } else {
569     meta_fg_prim_value_nodes_[value_node] += count;
570   }
571 }
572 
DropMetaFgPrimValueNode(const AnfNodePtr & value_node)573 void FuncGraph::DropMetaFgPrimValueNode(const AnfNodePtr &value_node) {
574   if (meta_fg_prim_value_nodes_.count(value_node) != 0) {
575     if (meta_fg_prim_value_nodes_[value_node] == 1) {
576       (void)meta_fg_prim_value_nodes_.erase(value_node);
577     } else {
578       meta_fg_prim_value_nodes_[value_node]--;
579       if (meta_fg_prim_value_nodes_[value_node] < 0) {
580         MS_LOG(INTERNAL_EXCEPTION) << "Count of MetaFgPrim ValueNode '" << value_node->DebugString()
581                                    << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
582       }
583     }
584   }
585 }
586 
parent()587 FuncGraphPtr FuncGraph::parent() {
588   // report the bug early.
589   if (manager_.lock() == nullptr) {
590     MS_LOG(INTERNAL_EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
591                                << " NodeInfo: " << trace::GetDebugInfoStr(debug_info());
592   }
593   auto mng = manager_.lock();
594   MS_EXCEPTION_IF_NULL(mng);
595   return mng->parent(shared_from_base<FuncGraph>());
596 }
597 
children()598 const FuncGraphSet &FuncGraph::children() {
599   auto mng = manager_.lock();
600   MS_EXCEPTION_IF_NULL(mng);
601   return mng->children(shared_from_base<FuncGraph>());
602 }
603 
scope()604 const FuncGraphSet &FuncGraph::scope() {
605   auto mng = manager_.lock();
606   MS_EXCEPTION_IF_NULL(mng);
607   return mng->scopes(shared_from_base<FuncGraph>());
608 }
609 
recursive()610 bool FuncGraph::recursive() {
611   auto mng = manager_.lock();
612   MS_EXCEPTION_IF_NULL(mng);
613   return mng->recursive(shared_from_base<FuncGraph>());
614 }
615 
recursive_graphs()616 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
617   auto mng = manager_.lock();
618   MS_EXCEPTION_IF_NULL(mng);
619   return mng->recursive_graphs(shared_from_base<FuncGraph>());
620 }
621 
ClearAllResource()622 void FuncGraph::ClearAllResource() {
623   ClearNodes();
624   ClearValueNodes();
625   ClearFuncGraphCNodesIndex();
626   ClearFreeVariables();
627   ClearFuncGraphsUsed();
628   ClearMetaFgPrimValueNodes();
629 }
630 
GetDefaultValueByName(const std::string & name)631 AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
632   auto itr = this->parameter_default_value_.find(name);
633   if (itr == parameter_default_value_.end()) {
634     return nullptr;
635   }
636   auto default_value = itr->second;
637   if (default_value == nullptr) {
638     MS_LOG(INTERNAL_EXCEPTION) << "Graph parameter " << name << " not exist";
639   }
640   if (IsValueNode<Null>(default_value)) {
641     return nullptr;
642   }
643   return default_value;
644 }
645 
646 // set the default values
SetDefaultValues(const std::vector<std::string> & name_list,const AnfNodePtrList & value_list)647 void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const AnfNodePtrList &value_list) {
648   auto all_is_null =
649     std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode<Null>(node); });
650   if (value_list.empty()) {
651     all_is_null = true;
652   }
653   for (size_t i = 0; i < name_list.size(); ++i) {
654     if (!all_is_null) {
655       this->parameter_default_value_[name_list[i]] = value_list[i];
656     }
657   }
658 }
659 
ClearDefaultValues()660 void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
661 
GetDefaultValueCount()662 size_t FuncGraph::GetDefaultValueCount() {
663   int64_t null_count =
664     std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
665                   [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<Null>(pair.second); });
666   return parameter_default_value_.size() - LongToSize(null_count);
667 }
668 
GetVariableArgParameter()669 AnfNodePtr FuncGraph::GetVariableArgParameter() {
670   if (!has_vararg_) {
671     return nullptr;
672   }
673 
674   size_t min_param_num = 1;
675   if (has_kwarg_) {
676     min_param_num += 1;
677   }
678   min_param_num += IntToSize(kw_only_args_count_);
679   min_param_num += fv_param_count_;
680 
681   if (parameters_.size() < min_param_num) {
682     MS_LOG(INTERNAL_EXCEPTION) << "Length of parameters is " << parameters_.size()
683                                << " which less than the sum of following: fv_param_count: " << fv_param_count_
684                                << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
685                                << ", kw_only_args_count_: " << kw_only_args_count_;
686   }
687   return parameters_[parameters_.size() - min_param_num];
688 }
689 
GetVariableArgName()690 std::string FuncGraph::GetVariableArgName() {
691   if (!has_vararg_) {
692     return "";
693   }
694 
695   const auto &param_node = GetVariableArgParameter();
696   MS_EXCEPTION_IF_NULL(param_node);
697   auto parameter = param_node->cast_ptr<Parameter>();
698   MS_EXCEPTION_IF_NULL(parameter);
699   return parameter->name();
700 }
701 
GetVariableKwargParameter()702 AnfNodePtr FuncGraph::GetVariableKwargParameter() {
703   if (has_kwarg_) {
704     if (parameters_.size() < fv_param_count_ + 1) {
705       MS_LOG(INTERNAL_EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is "
706                                  << fv_param_count_ << ", parameters is less than 1 + fv_param_count";
707     }
708     return parameters_[(parameters_.size() - fv_param_count_) - 1];
709   }
710   return nullptr;
711 }
712 
GetVariableKwargName()713 std::string FuncGraph::GetVariableKwargName() {
714   auto kwarg_param = GetVariableKwargParameter();
715   if (kwarg_param != nullptr) {
716     auto parameter = kwarg_param->cast_ptr<Parameter>();
717     MS_EXCEPTION_IF_NULL(parameter);
718     return parameter->name();
719   }
720   return "";
721 }
722 
GetKwOnlyArgsParameters()723 AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
724   AnfNodePtrList kw_only_args;
725   if (kw_only_args_count_ == 0) {
726     return kw_only_args;
727   }
728 
729   size_t min_param_num = 0;
730   size_t varargs_kwargs_num = 0;
731   if (has_kwarg_) {
732     min_param_num += 1;
733     varargs_kwargs_num += 1;
734   }
735   min_param_num += IntToSize(kw_only_args_count_);
736   min_param_num += fv_param_count_;
737 
738   if (parameters_.size() < min_param_num) {
739     MS_LOG(INTERNAL_EXCEPTION) << "Length of parameters is " << parameters_.size()
740                                << " which less than the sum of following: fv_param_count: " << fv_param_count_
741                                << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
742                                << ", kw_only_args_count: " << kw_only_args_count_;
743   }
744   size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
745   std::copy(parameters_.cbegin() + kw_only_args_start_offset, parameters_.cend() - fv_param_count_ - varargs_kwargs_num,
746             std::back_inserter(kw_only_args));
747   return kw_only_args;
748 }
749 
GetPositionalArgsCount() const750 int FuncGraph::GetPositionalArgsCount() const {
751   int count = SizeToInt(parameters_.size());
752   if (has_kwarg_) {
753     count--;
754   }
755   if (has_vararg_) {
756     count--;
757   }
758   return (count - kw_only_args_count_) - SizeToInt(fv_param_count_);
759 }
760 
GetParameterByName(const std::string & name)761 AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
762   for (size_t i = 0; i < parameters_.size(); ++i) {
763     MS_EXCEPTION_IF_NULL(parameters_[i]);
764     auto param_cast = parameters_[i]->cast_ptr<Parameter>();
765     MS_EXCEPTION_IF_NULL(param_cast);
766     if (param_cast->name() == name) {
767       return parameters_[i];
768     }
769   }
770   return nullptr;
771 }
772 
GetOrderedCnodes()773 std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
774   auto this_ptr = shared_from_base<FuncGraph>();
775   auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1);
776   auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);
777 
778   std::list<CNodePtr> cnodes;
779   auto nodes = mindspore::TopoSort(return_node(), SuccDepends, BelongSameGraph);
780   for (const auto &node : nodes) {
781     auto cnode = dyn_cast<CNode>(node);
782     if (cnode != nullptr) {
783       (void)cnodes.emplace_back(std::move(cnode));
784     }
785   }
786   return cnodes;
787 }
788 
EraseUnusedNodeInOrder()789 void FuncGraph::EraseUnusedNodeInOrder() {
790   auto mng = manager_.lock();
791   if (mng != nullptr) {
792     auto &all_nodes = nodes();
793     // Erase unused cnode.
794     for (auto it = order_.begin(); it != order_.cend();) {
795       const auto &cnode = it->lock();
796       if (cnode == nullptr) {
797         it = order_.erase(it);
798         continue;
799       }
800       if (!all_nodes.contains(cnode)) {
801         MS_EXCEPTION_IF_NULL(cnode);
802         MS_LOG(DEBUG) << "Remove node: " << cnode->DebugString() << " in graph " << ToString() << " order.";
803         it = order_.erase(it);
804         continue;
805       }
806       (void)++it;
807     }
808   }
809 }
810 
EraseUnusedNodeInOrder(const AnfNodePtr & node)811 void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) {
812   if (node == nullptr) {
813     return;
814   }
815   auto cnode = node->cast<CNodePtr>();
816   if (cnode != nullptr) {
817     auto iter = std::find_if(order_.cbegin(), order_.cend(), [&cnode](const CNodeWeakPtr &node) {
818       return node.lock() != nullptr && node.lock() == cnode;
819     });
820     if (iter != order_.cend()) {
821       (void)order_.erase(iter);
822       MS_LOG(DEBUG) << "Remove node: " << node->DebugString() << " from order list.";
823     }
824   }
825 }
826 
827 // Maintain cnode order list when a cnode is replaced by a new one.
ReplaceInOrder(const AnfNodePtr & old_node,const AnfNodePtr & new_node)828 void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
829   MS_EXCEPTION_IF_NULL(old_node);
830   MS_EXCEPTION_IF_NULL(new_node);
831   if (order_.empty()) {
832     // Skip if order list is empty.
833     return;
834   }
835   auto old_cnode = old_node->cast<CNodePtr>();
836   if (old_cnode == nullptr) {
837     // Skip if old node is not cnode, since order list contains cnode only.
838     return;
839   }
840   // Search old node in order list.
841   auto iter = std::find_if(order_.cbegin(), order_.cend(), [&old_cnode](const CNodeWeakPtr &node) {
842     return node.lock() != nullptr && node.lock() == old_cnode;
843   });
844   if (iter == order_.cend()) {
845     // Skip if old node not found in order list.
846     return;
847   }
848   auto new_cnode = new_node->cast<CNodePtr>();
849   if (new_cnode != nullptr) {
850     // Insert new node just before the old node.
851     (void)order_.insert(iter, CNodeWeakPtr(new_cnode));
852   }
853   // Remove old node from order list.
854   // Unused children nodes can be cleared by EraseUnusedNodeInOrder().
855   (void)order_.erase(iter);
856 }
857 
MakeInputNodes(const PrimitivePtr & primitive,const AnfNodePtrList & inputs)858 static AnfNodePtrList MakeInputNodes(const PrimitivePtr &primitive, const AnfNodePtrList &inputs) {
859   AnfNodePtrList input_node_list;
860   input_node_list.reserve(inputs.size() + 1);
861   input_node_list.emplace_back(std::make_shared<ValueNode>(primitive));
862   input_node_list.insert(input_node_list.end(), inputs.begin(), inputs.end());
863   return input_node_list;
864 }
865 
NewCNode(const PrimitivePtr & primitive,const AnfNodePtrList & inputs)866 CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const AnfNodePtrList &inputs) {
867   auto input_node_list = MakeInputNodes(primitive, inputs);
868   return NewCNode(std::move(input_node_list));
869 }
870 
NewCNodeInOrder(const PrimitivePtr & primitive,const AnfNodePtrList & inputs)871 CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const AnfNodePtrList &inputs) {
872   auto input_node_list = MakeInputNodes(primitive, inputs);
873   return NewCNodeInOrder(std::move(input_node_list));
874 }
875 
SetMultiTarget() const876 void FuncGraph::SetMultiTarget() const {
877   auto graph_manager = manager();
878   MS_EXCEPTION_IF_NULL(graph_manager);
879   FuncGraphSet graphs = graph_manager->func_graphs();
880   AnfNodePtrList all_nodes;
881   for (auto &g : graphs) {
882     auto nodes = mindspore::TopoSort(g->get_return());
883     (void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(all_nodes));
884   }
885 
886   bool exist_multi_target = false;
887   if (mindspore::ContainMultiTarget(all_nodes)) {
888     exist_multi_target = true;
889     MS_LOG(INFO) << "The graph " << ToString() << " exists the multi target.";
890   }
891 
892   for (auto &g : graphs) {
893     g->set_exist_multi_target(exist_multi_target);
894   }
895 }
896 
set_used_forward_nodes(const AnfNodePtrList & used_forward_nodes)897 void FuncGraph::set_used_forward_nodes(const AnfNodePtrList &used_forward_nodes) {
898   (void)std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(), [this](const AnfNodePtr &node) {
899     MS_EXCEPTION_IF_NULL(node);
900     (void)used_forward_nodes_.insert(node);
901   });
902 }
903 
TopoSort(const AnfNodePtr & node)904 AnfNodePtrList FuncGraph::TopoSort(const AnfNodePtr &node) { return mindspore::TopoSort(node); }
905 
NewFgSeenGeneration()906 SeenNum NewFgSeenGeneration() {
907   static SeenNum fg_seen_generation = 0;
908   ++fg_seen_generation;
909   // 0 is invalid number.
910   if (fg_seen_generation == 0) {
911     ++fg_seen_generation;
912   }
913   return fg_seen_generation;
914 }
915 }  // namespace mindspore
916