1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/func_graph.h"
20 #include <algorithm>
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "utils/trace_base.h"
23 #include "ir/manager.h"
24 #include "utils/ordered_set.h"
25 #include "utils/convert_utils_base.h"
26 #include "abstract/abstract_function.h"
27 #include "ir/func_graph_cloner.h"
28 #include "utils/phase.h"
29
30 namespace mindspore {
31 /*
32 * Methods of Graph
33 */
FuncGraph()34 FuncGraph::FuncGraph() : FuncGraph(std::make_shared<GraphDebugInfo>()) {}
35
FuncGraph(GraphDebugInfoPtr && debug_info)36 FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
37 : attrs_(),
38 transforms_(),
39 parameter_default_value_(),
40 seen_(0),
41 parameters_(),
42 has_vararg_(false),
43 has_kwarg_(false),
44 exist_multi_target_(false),
45 kw_only_args_count_(0),
46 fv_param_count_(0),
47 is_generated_(false),
48 manager_(),
49 debug_info_(std::move(debug_info)),
50 stub_(false),
51 stage_(-1),
52 segment_(1),
53 phase_(PhaseManager::GetInstance().phase()) {}
54
~FuncGraph()55 FuncGraph::~FuncGraph() { subclass_destruct_flag_ = true; }
56
DoBreakLoop()57 void FuncGraph::DoBreakLoop() {
58 if (attached_mng_cnt() > 0) {
59 MS_LOG(INFO) << "Current Graph is holding by FuncGraphManager, can't DoBreakLoop now.";
60 return;
61 }
62 ClearOrderList();
63 python_obj_ = nullptr;
64 used_forward_nodes_.clear();
65 func_graph_cache_.clear();
66 parameters_.clear();
67 parameter_obj_nodes_.clear();
68 set_dropped(true);
69 }
70
ToAbstract()71 abstract::AbstractBasePtr FuncGraph::ToAbstract() {
72 auto temp_context = abstract::AnalysisContext::DummyContext();
73 return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
74 }
75
output() const76 AnfNodePtr FuncGraph::output() const {
77 constexpr size_t return_input_num = 2;
78 // If return value is set, return should have two inputs.
79 if (return_node() != nullptr && return_node()->size() == return_input_num) {
80 return return_node()->input(1);
81 } else {
82 // If not set yet, return nullptr.
83 return nullptr;
84 }
85 }
86
get_inputs() const87 const AnfNodePtrList FuncGraph::get_inputs() const {
88 AnfNodePtrList input_params;
89 for (auto const &node : parameters_) {
90 MS_EXCEPTION_IF_NULL(node);
91 auto parameter = dyn_cast<Parameter>(node);
92 MS_EXCEPTION_IF_NULL(parameter);
93 if (!parameter->has_default()) {
94 input_params.push_back(parameter);
95 }
96 }
97 return input_params;
98 }
99
add_parameter()100 ParameterPtr FuncGraph::add_parameter() {
101 FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
102 ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
103 add_parameter(param);
104 return param;
105 }
106
add_parameter(NodeDebugInfoPtr && debug_info)107 ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) {
108 FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
109 ParameterPtr param = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
110 add_parameter(param);
111 return param;
112 }
113
add_parameter(const ParameterPtr & param)114 void FuncGraph::add_parameter(const ParameterPtr ¶m) {
115 if (manager_.lock()) {
116 manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
117 } else {
118 parameters_.push_back(param);
119 }
120 }
121
InsertFrontParameter()122 ParameterPtr FuncGraph::InsertFrontParameter() {
123 FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
124 ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
125 InsertFrontParameter(param);
126 return param;
127 }
128
InsertFrontParameter(const ParameterPtr & param)129 void FuncGraph::InsertFrontParameter(const ParameterPtr ¶m) {
130 if (manager_.lock()) {
131 manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), param);
132 } else {
133 PrependParameter(param);
134 }
135 }
136
AddFvParameter(const std::string & name,const ValuePtr & default_value)137 ParameterPtr FuncGraph::AddFvParameter(const std::string &name, const ValuePtr &default_value) {
138 FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
139 ParameterPtr param = std::make_shared<Parameter>(this_graph);
140 param->set_name(name);
141 MS_EXCEPTION_IF_NULL(param->debug_info());
142 param->debug_info()->set_name(name);
143 param->debug_info()->set_trace_info(nullptr);
144 MS_EXCEPTION_IF_NULL(default_value);
145 param->set_default_param(default_value);
146 param->set_abstract(default_value->ToAbstract());
147 if (manager_.lock()) {
148 manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
149 } else {
150 parameters_.push_back(param);
151 }
152 ++fv_param_count_;
153 return param;
154 }
155
has_flag(const std::string & key) const156 bool FuncGraph::has_flag(const std::string &key) const {
157 auto iter = attrs_.find(key);
158 if (iter != attrs_.cend()) {
159 MS_EXCEPTION_IF_NULL(iter->second);
160 if (iter->second->isa<BoolImm>()) {
161 return GetValue<bool>(iter->second);
162 }
163 MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
164 }
165 return false;
166 }
167
has_attr(const std::string & key) const168 bool FuncGraph::has_attr(const std::string &key) const {
169 auto iter = attrs_.find(key);
170 return !(iter == attrs_.cend());
171 }
172
get_attr(const std::string & key) const173 ValuePtr FuncGraph::get_attr(const std::string &key) const {
174 auto iter = attrs_.find(key);
175 return iter == attrs_.cend() ? nullptr : iter->second;
176 }
177
NewCNodeWeak(AnfNodeWeakPtrList && weak_inputs)178 CNodePtr FuncGraph::NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs) {
179 return std::make_shared<CNode>(std::move(weak_inputs), shared_from_base<FuncGraph>());
180 }
181
NewCNodeWeak(const AnfNodeWeakPtrList & weak_inputs)182 CNodePtr FuncGraph::NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs) {
183 return std::make_shared<CNode>(weak_inputs, shared_from_base<FuncGraph>());
184 }
185
NewCNode(AnfNodePtrList && inputs)186 CNodePtr FuncGraph::NewCNode(AnfNodePtrList &&inputs) {
187 std::vector<AnfNodeWeakPtr> weak_inputs;
188 weak_inputs.reserve(inputs.size());
189 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
190 [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
191 return std::make_shared<CNode>(std::move(weak_inputs), shared_from_base<FuncGraph>());
192 }
193
NewCNode(const AnfNodePtrList & inputs)194 CNodePtr FuncGraph::NewCNode(const AnfNodePtrList &inputs) {
195 std::vector<AnfNodeWeakPtr> weak_inputs;
196 weak_inputs.reserve(inputs.size());
197 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
198 [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
199 return std::make_shared<CNode>(std::move(weak_inputs), shared_from_base<FuncGraph>());
200 }
201
NewCNodeInOrderWeak(AnfNodeWeakPtrList && weak_inputs)202 CNodePtr FuncGraph::NewCNodeInOrderWeak(AnfNodeWeakPtrList &&weak_inputs) {
203 CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
204 (void)order_.emplace_back(CNodeWeakPtr(cnode));
205 return cnode;
206 }
207
NewCNodeInOrderWeak(const AnfNodeWeakPtrList & weak_inputs)208 CNodePtr FuncGraph::NewCNodeInOrderWeak(const AnfNodeWeakPtrList &weak_inputs) {
209 CNodePtr cnode = NewCNodeWeak(weak_inputs);
210 (void)order_.emplace_back(CNodeWeakPtr(cnode));
211 return cnode;
212 }
213
NewCNodeInOrder(AnfNodePtrList && inputs)214 CNodePtr FuncGraph::NewCNodeInOrder(AnfNodePtrList &&inputs) { return NewCNodeInOrder(inputs); }
215
NewCNodeInOrder(const AnfNodePtrList & inputs)216 CNodePtr FuncGraph::NewCNodeInOrder(const AnfNodePtrList &inputs) {
217 std::vector<AnfNodeWeakPtr> weak_inputs;
218 weak_inputs.reserve(inputs.size());
219 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
220 [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
221 CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
222 (void)order_.emplace_back(CNodeWeakPtr(cnode));
223 return cnode;
224 }
225
NewCNodeInFront(const AnfNodePtrList & inputs)226 CNodePtr FuncGraph::NewCNodeInFront(const AnfNodePtrList &inputs) {
227 std::vector<AnfNodeWeakPtr> weak_inputs;
228 weak_inputs.reserve(inputs.size());
229 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
230 [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
231 CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
232 (void)order_.emplace_front(CNodeWeakPtr(cnode));
233 return cnode;
234 }
235
NewCNodeBefore(const AnfNodePtr & position,const AnfNodePtrList & inputs)236 CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const AnfNodePtrList &inputs) {
237 std::vector<AnfNodeWeakPtr> weak_inputs;
238 weak_inputs.reserve(inputs.size());
239 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
240 [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
241 CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
242 CNodePtr pos_cnode = dyn_cast<CNode>(position);
243 auto iter = std::find_if(order_.cbegin(), order_.cend(), [&pos_cnode](const CNodeWeakPtr &node) {
244 return node.lock() != nullptr && node.lock() == pos_cnode;
245 });
246 (void)order_.insert(iter, CNodeWeakPtr(cnode));
247 return cnode;
248 }
249
NewCNodeAfter(const AnfNodePtr & position,const AnfNodePtrList & inputs)250 CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const AnfNodePtrList &inputs) {
251 std::vector<AnfNodeWeakPtr> weak_inputs;
252 weak_inputs.reserve(inputs.size());
253 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs),
254 [](const AnfNodePtr &node) { return AnfNodeWeakPtr(node); });
255 CNodePtr cnode = NewCNodeWeak(std::move(weak_inputs));
256 CNodePtr pos_cnode = dyn_cast<CNode>(position);
257 auto iter = std::find_if(order_.cbegin(), order_.cend(), [&pos_cnode](const CNodeWeakPtr &node) {
258 return node.lock() != nullptr && node.lock() == pos_cnode;
259 });
260 if (iter == order_.cend()) {
261 order_.push_front(CNodeWeakPtr(cnode));
262 } else {
263 (void)order_.insert(std::next(iter), CNodeWeakPtr(cnode));
264 }
265 return cnode;
266 }
267
own_nodes() const268 const std::list<AnfNodePtr> &FuncGraph::own_nodes() const { return own_nodes_; }
269
AddOwnNode(const AnfNodePtr & node)270 void FuncGraph::AddOwnNode(const AnfNodePtr &node) { (void)own_nodes_.emplace_back(node); }
271
AddOwnNode(const AnfNodePtrList & nodes)272 void FuncGraph::AddOwnNode(const AnfNodePtrList &nodes) {
273 (void)own_nodes_.insert(own_nodes_.end(), nodes.cbegin(), nodes.cend());
274 }
275
AddOwnNode(const AnfNodeWeakPtrList & weak_nodes)276 void FuncGraph::AddOwnNode(const AnfNodeWeakPtrList &weak_nodes) {
277 std::transform(weak_nodes.cbegin(), weak_nodes.cend(), std::back_inserter(own_nodes_),
278 [](const AnfNodeWeakPtr &weak_node) -> AnfNodePtr { return weak_node.lock(); });
279 }
280
RemoveOwnNode(const AnfNodePtr & node)281 void FuncGraph::RemoveOwnNode(const AnfNodePtr &node) {
282 auto iter = std::find(own_nodes_.cbegin(), own_nodes_.cend(), node);
283 if (iter != own_nodes_.cend()) {
284 own_nodes_.erase(iter);
285 }
286 }
287
ResetOwnNodes()288 void FuncGraph::ResetOwnNodes() { own_nodes_.clear(); }
289
DumpCNodeList()290 void FuncGraph::DumpCNodeList() {
291 MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
292 for (const auto &weak_cnode : order_) {
293 const auto &cnode = weak_cnode.lock();
294 if (cnode == nullptr) {
295 continue;
296 }
297 MS_LOG(INFO) << cnode->DebugString();
298 }
299 }
300
ToString() const301 std::string FuncGraph::ToString() const {
302 std::ostringstream buffer;
303 auto debug_info = const_cast<FuncGraph *>(this)->debug_info();
304 buffer << mindspore::trace::Label(debug_info);
305 buffer << "_" << debug_info->get_id();
306 return buffer.str();
307 }
308
debug_info()309 GraphDebugInfoPtr FuncGraph::debug_info() {
310 MS_EXCEPTION_IF_NULL(this->debug_info_);
311 if (this->debug_info_->get_graph() == nullptr) {
312 this->debug_info_->set_graph(shared_from_base<FuncGraph>());
313 }
314 return this->debug_info_;
315 }
316
nodes() const317 const AnfNodeSet &FuncGraph::nodes() const { return nodes_; }
318
switch_nodes() const319 const AnfNodeSet &FuncGraph::switch_nodes() const { return switch_nodes_; }
320
CopyNodes(const FuncGraphPtr & source)321 void FuncGraph::CopyNodes(const FuncGraphPtr &source) {
322 nodes_.update(source->nodes());
323 switch_nodes_.update(source->switch_nodes());
324 }
325
ClearNodes()326 void FuncGraph::ClearNodes() {
327 nodes_.clear();
328 switch_nodes_.clear();
329 }
330
AddNode(const AnfNodePtr & node)331 void FuncGraph::AddNode(const AnfNodePtr &node) {
332 nodes_.add(node);
333 if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
334 switch_nodes_.add(node);
335 }
336 }
337
DropNode(const AnfNodePtr & node)338 void FuncGraph::DropNode(const AnfNodePtr &node) {
339 if (node == nullptr) {
340 MS_LOG(ERROR) << "Node is nullptr";
341 return;
342 }
343 (void)nodes_.erase(node);
344 if (IsPrimitiveCNode(node, prim::kPrimSwitch)) {
345 switch_nodes_.erase(node);
346 }
347 auto graph = node->func_graph();
348 if (node->isa<Parameter>()) {
349 (void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
350 }
351 // Remove the node from order list.
352 if (graph != nullptr) {
353 graph->EraseUnusedNodeInOrder(node);
354 }
355 }
356
value_nodes() const357 const AnfNodeCounterMap &FuncGraph::value_nodes() const { return value_nodes_; }
358
CopyValueNodes(const FuncGraphPtr & source)359 void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
360 MS_EXCEPTION_IF_NULL(source);
361 auto &others = source->value_nodes();
362 for (auto it = others.begin(); it != others.end(); ++it) {
363 AddValueNode(it->first, it->second);
364 }
365 }
366
ClearValueNodes()367 void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
368
AddValueNode(const AnfNodePtr & node,int count)369 void FuncGraph::AddValueNode(const AnfNodePtr &node, int count) {
370 if (value_nodes_.count(node) == 0) {
371 value_nodes_[node] = count;
372 } else {
373 value_nodes_[node] += count;
374 }
375 }
376
DropValueNode(const AnfNodePtr & node)377 void FuncGraph::DropValueNode(const AnfNodePtr &node) {
378 if (value_nodes_.count(node) != 0) {
379 if (value_nodes_[node] == 1) {
380 (void)value_nodes_.erase(node);
381 } else {
382 value_nodes_[node]--;
383 if (value_nodes_[node] < 0) {
384 MS_LOG(INTERNAL_EXCEPTION) << "Count of ValueNode '" << node
385 << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
386 }
387 }
388 }
389 }
390
free_variables() const391 const AnfNodeCounterMap &FuncGraph::free_variables() const { return free_variables_; }
392
CopyFreeVariables(const FuncGraphPtr & source)393 void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
394 MS_EXCEPTION_IF_NULL(source);
395 auto &others = source->free_variables();
396 for (auto it = others.begin(); it != others.end(); ++it) {
397 const auto &free_var = it->first;
398 MS_EXCEPTION_IF_NULL(free_var);
399 if (free_var->func_graph().get() != this) {
400 (void)AddFreeVariable(free_var, it->second);
401 }
402 }
403 }
404
ClearFreeVariables()405 void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }
406
AddFreeVariable(const AnfNodePtr & node,int count)407 bool FuncGraph::AddFreeVariable(const AnfNodePtr &node, int count) {
408 if (free_variables_.count(node) == 0) {
409 free_variables_[node] = count;
410 return true;
411 } else {
412 free_variables_[node] += count;
413 return false;
414 }
415 }
416
DropFreeVariable(const AnfNodePtr & node)417 bool FuncGraph::DropFreeVariable(const AnfNodePtr &node) {
418 if (free_variables_.count(node) != 0) {
419 if (free_variables_[node] == 1) {
420 (void)free_variables_.erase(node);
421 return true;
422 } else {
423 free_variables_[node]--;
424 if (free_variables_[node] < 0) {
425 MS_LOG(INTERNAL_EXCEPTION) << "Count of free variable '" << node
426 << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
427 }
428 }
429 }
430 return false;
431 }
432
free_variables_total()433 const BaseRefCounterMap &FuncGraph::free_variables_total() {
434 auto mng = manager_.lock();
435 MS_EXCEPTION_IF_NULL(mng);
436 auto &fv_total = mng->free_variables_total();
437 return fv_total[shared_from_base<FuncGraph>()];
438 }
439
free_variables_nodes()440 AnfNodePtrList FuncGraph::free_variables_nodes() {
441 AnfNodePtrList nodes;
442 const auto &fv_total = this->free_variables_total();
443 for (auto &p : fv_total) {
444 auto key = p.first;
445 if (utils::isa<AnfNodePtr>(key)) {
446 nodes.push_back(utils::cast<AnfNodePtr>(key));
447 }
448 }
449 return nodes;
450 }
451
free_variables_func_graphs()452 std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
453 std::vector<FuncGraphPtr> func_graphs;
454 const auto &fv_total = this->free_variables_total();
455 for (auto &p : fv_total) {
456 auto key = p.first;
457 if (utils::isa<FuncGraphPtr>(key)) {
458 func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
459 }
460 }
461
462 return func_graphs;
463 }
464
func_graphs_used() const465 const FuncGraphCounterMap &FuncGraph::func_graphs_used() const { return func_graphs_used_; }
466
CopyFuncGraphsUsed(const FuncGraphPtr & source)467 void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
468 auto &others = source->func_graphs_used();
469 for (auto it = others.begin(); it != others.end(); ++it) {
470 (void)AddFuncGraphUsed(it->first, it->second);
471 }
472 (void)func_graphs_used_.erase(source);
473 }
474
ClearFuncGraphsUsed()475 void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
476
AddFuncGraphUsed(const FuncGraphPtr & fg,int count)477 bool FuncGraph::AddFuncGraphUsed(const FuncGraphPtr &fg, int count) {
478 if (func_graphs_used_.count(fg) == 0) {
479 func_graphs_used_[fg] = count;
480 return true;
481 } else {
482 func_graphs_used_[fg] += count;
483 return false;
484 }
485 }
486
DropFuncGraphUsed(const FuncGraphPtr & fg)487 bool FuncGraph::DropFuncGraphUsed(const FuncGraphPtr &fg) {
488 if (func_graphs_used_.count(fg) != 0) {
489 if (func_graphs_used_[fg] == 1) {
490 (void)func_graphs_used_.erase(fg);
491 return true;
492 } else {
493 func_graphs_used_[fg]--;
494 if (func_graphs_used_[fg] < 0) {
495 MS_LOG(INTERNAL_EXCEPTION) << "Count of FuncGraph '" << fg
496 << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
497 }
498 }
499 }
500 return false;
501 }
502
func_graphs_used_total()503 const FuncGraphSet &FuncGraph::func_graphs_used_total() {
504 auto mng = manager_.lock();
505 MS_EXCEPTION_IF_NULL(mng);
506 auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
507 return used;
508 }
509
func_graph_cnodes_index() const510 const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() const { return func_graph_cnodes_index_; }
511
CopyFuncGraphCNodesIndex(const FuncGraphPtr & source)512 void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
513 MS_EXCEPTION_IF_NULL(source);
514 auto &others = source->func_graph_cnodes_index();
515 for (auto it = others.begin(); it != others.end(); ++it) {
516 // Ignore the user graph who may own itself.
517 MS_EXCEPTION_IF_NULL(it->first);
518 MS_EXCEPTION_IF_NULL(it->first->first);
519 auto fg = it->first->first->func_graph();
520 MS_EXCEPTION_IF_NULL(fg);
521 if (fg.get() != this) {
522 AddFuncGraphCNodeIndex(it->first, it->second);
523 }
524 }
525 }
526
ClearFuncGraphCNodesIndex()527 void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }
528
AddFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair,int count)529 void FuncGraph::AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair, int count) {
530 if (func_graph_cnodes_index_.count(pair) == 0) {
531 func_graph_cnodes_index_[pair] = count;
532 } else {
533 func_graph_cnodes_index_[pair] += count;
534 }
535 }
536
DropFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair)537 void FuncGraph::DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair) {
538 if (func_graph_cnodes_index_.count(pair) != 0) {
539 if (func_graph_cnodes_index_[pair] == 1) {
540 (void)func_graph_cnodes_index_.erase(pair);
541 } else {
542 func_graph_cnodes_index_[pair]--;
543 if (func_graph_cnodes_index_[pair] < 0) {
544 MS_LOG(INTERNAL_EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
545 << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
546 }
547 }
548 }
549 }
550
meta_fg_prim_value_nodes() const551 const mindspore::HashMap<AnfNodePtr, int> &FuncGraph::meta_fg_prim_value_nodes() const {
552 return meta_fg_prim_value_nodes_;
553 }
554
CopyMetaFgPrimValueNodes(const FuncGraphPtr & source)555 void FuncGraph::CopyMetaFgPrimValueNodes(const FuncGraphPtr &source) {
556 MS_EXCEPTION_IF_NULL(source);
557 auto &others = source->meta_fg_prim_value_nodes();
558 for (const auto &other : others) {
559 AddMetaFgPrimValueNode(other.first, other.second);
560 }
561 }
562
ClearMetaFgPrimValueNodes()563 void FuncGraph::ClearMetaFgPrimValueNodes() { meta_fg_prim_value_nodes_.clear(); }
564
AddMetaFgPrimValueNode(const AnfNodePtr & value_node,int count)565 void FuncGraph::AddMetaFgPrimValueNode(const AnfNodePtr &value_node, int count) {
566 if (meta_fg_prim_value_nodes_.count(value_node) == 0) {
567 meta_fg_prim_value_nodes_[value_node] = count;
568 } else {
569 meta_fg_prim_value_nodes_[value_node] += count;
570 }
571 }
572
DropMetaFgPrimValueNode(const AnfNodePtr & value_node)573 void FuncGraph::DropMetaFgPrimValueNode(const AnfNodePtr &value_node) {
574 if (meta_fg_prim_value_nodes_.count(value_node) != 0) {
575 if (meta_fg_prim_value_nodes_[value_node] == 1) {
576 (void)meta_fg_prim_value_nodes_.erase(value_node);
577 } else {
578 meta_fg_prim_value_nodes_[value_node]--;
579 if (meta_fg_prim_value_nodes_[value_node] < 0) {
580 MS_LOG(INTERNAL_EXCEPTION) << "Count of MetaFgPrim ValueNode '" << value_node->DebugString()
581 << "' dec from 0. NodeInfo: " << trace::GetDebugInfoStr(debug_info());
582 }
583 }
584 }
585 }
586
parent()587 FuncGraphPtr FuncGraph::parent() {
588 // report the bug early.
589 if (manager_.lock() == nullptr) {
590 MS_LOG(INTERNAL_EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
591 << " NodeInfo: " << trace::GetDebugInfoStr(debug_info());
592 }
593 auto mng = manager_.lock();
594 MS_EXCEPTION_IF_NULL(mng);
595 return mng->parent(shared_from_base<FuncGraph>());
596 }
597
children()598 const FuncGraphSet &FuncGraph::children() {
599 auto mng = manager_.lock();
600 MS_EXCEPTION_IF_NULL(mng);
601 return mng->children(shared_from_base<FuncGraph>());
602 }
603
scope()604 const FuncGraphSet &FuncGraph::scope() {
605 auto mng = manager_.lock();
606 MS_EXCEPTION_IF_NULL(mng);
607 return mng->scopes(shared_from_base<FuncGraph>());
608 }
609
recursive()610 bool FuncGraph::recursive() {
611 auto mng = manager_.lock();
612 MS_EXCEPTION_IF_NULL(mng);
613 return mng->recursive(shared_from_base<FuncGraph>());
614 }
615
recursive_graphs()616 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
617 auto mng = manager_.lock();
618 MS_EXCEPTION_IF_NULL(mng);
619 return mng->recursive_graphs(shared_from_base<FuncGraph>());
620 }
621
ClearAllResource()622 void FuncGraph::ClearAllResource() {
623 ClearNodes();
624 ClearValueNodes();
625 ClearFuncGraphCNodesIndex();
626 ClearFreeVariables();
627 ClearFuncGraphsUsed();
628 ClearMetaFgPrimValueNodes();
629 }
630
GetDefaultValueByName(const std::string & name)631 AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
632 auto itr = this->parameter_default_value_.find(name);
633 if (itr == parameter_default_value_.end()) {
634 return nullptr;
635 }
636 auto default_value = itr->second;
637 if (default_value == nullptr) {
638 MS_LOG(INTERNAL_EXCEPTION) << "Graph parameter " << name << " not exist";
639 }
640 if (IsValueNode<Null>(default_value)) {
641 return nullptr;
642 }
643 return default_value;
644 }
645
646 // set the default values
SetDefaultValues(const std::vector<std::string> & name_list,const AnfNodePtrList & value_list)647 void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const AnfNodePtrList &value_list) {
648 auto all_is_null =
649 std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode<Null>(node); });
650 if (value_list.empty()) {
651 all_is_null = true;
652 }
653 for (size_t i = 0; i < name_list.size(); ++i) {
654 if (!all_is_null) {
655 this->parameter_default_value_[name_list[i]] = value_list[i];
656 }
657 }
658 }
659
ClearDefaultValues()660 void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
661
GetDefaultValueCount()662 size_t FuncGraph::GetDefaultValueCount() {
663 int64_t null_count =
664 std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
665 [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<Null>(pair.second); });
666 return parameter_default_value_.size() - LongToSize(null_count);
667 }
668
GetVariableArgParameter()669 AnfNodePtr FuncGraph::GetVariableArgParameter() {
670 if (!has_vararg_) {
671 return nullptr;
672 }
673
674 size_t min_param_num = 1;
675 if (has_kwarg_) {
676 min_param_num += 1;
677 }
678 min_param_num += IntToSize(kw_only_args_count_);
679 min_param_num += fv_param_count_;
680
681 if (parameters_.size() < min_param_num) {
682 MS_LOG(INTERNAL_EXCEPTION) << "Length of parameters is " << parameters_.size()
683 << " which less than the sum of following: fv_param_count: " << fv_param_count_
684 << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
685 << ", kw_only_args_count_: " << kw_only_args_count_;
686 }
687 return parameters_[parameters_.size() - min_param_num];
688 }
689
GetVariableArgName()690 std::string FuncGraph::GetVariableArgName() {
691 if (!has_vararg_) {
692 return "";
693 }
694
695 const auto ¶m_node = GetVariableArgParameter();
696 MS_EXCEPTION_IF_NULL(param_node);
697 auto parameter = param_node->cast_ptr<Parameter>();
698 MS_EXCEPTION_IF_NULL(parameter);
699 return parameter->name();
700 }
701
GetVariableKwargParameter()702 AnfNodePtr FuncGraph::GetVariableKwargParameter() {
703 if (has_kwarg_) {
704 if (parameters_.size() < fv_param_count_ + 1) {
705 MS_LOG(INTERNAL_EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is "
706 << fv_param_count_ << ", parameters is less than 1 + fv_param_count";
707 }
708 return parameters_[(parameters_.size() - fv_param_count_) - 1];
709 }
710 return nullptr;
711 }
712
GetVariableKwargName()713 std::string FuncGraph::GetVariableKwargName() {
714 auto kwarg_param = GetVariableKwargParameter();
715 if (kwarg_param != nullptr) {
716 auto parameter = kwarg_param->cast_ptr<Parameter>();
717 MS_EXCEPTION_IF_NULL(parameter);
718 return parameter->name();
719 }
720 return "";
721 }
722
GetKwOnlyArgsParameters()723 AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
724 AnfNodePtrList kw_only_args;
725 if (kw_only_args_count_ == 0) {
726 return kw_only_args;
727 }
728
729 size_t min_param_num = 0;
730 size_t varargs_kwargs_num = 0;
731 if (has_kwarg_) {
732 min_param_num += 1;
733 varargs_kwargs_num += 1;
734 }
735 min_param_num += IntToSize(kw_only_args_count_);
736 min_param_num += fv_param_count_;
737
738 if (parameters_.size() < min_param_num) {
739 MS_LOG(INTERNAL_EXCEPTION) << "Length of parameters is " << parameters_.size()
740 << " which less than the sum of following: fv_param_count: " << fv_param_count_
741 << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
742 << ", kw_only_args_count: " << kw_only_args_count_;
743 }
744 size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
745 std::copy(parameters_.cbegin() + kw_only_args_start_offset, parameters_.cend() - fv_param_count_ - varargs_kwargs_num,
746 std::back_inserter(kw_only_args));
747 return kw_only_args;
748 }
749
GetPositionalArgsCount() const750 int FuncGraph::GetPositionalArgsCount() const {
751 int count = SizeToInt(parameters_.size());
752 if (has_kwarg_) {
753 count--;
754 }
755 if (has_vararg_) {
756 count--;
757 }
758 return (count - kw_only_args_count_) - SizeToInt(fv_param_count_);
759 }
760
GetParameterByName(const std::string & name)761 AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
762 for (size_t i = 0; i < parameters_.size(); ++i) {
763 MS_EXCEPTION_IF_NULL(parameters_[i]);
764 auto param_cast = parameters_[i]->cast_ptr<Parameter>();
765 MS_EXCEPTION_IF_NULL(param_cast);
766 if (param_cast->name() == name) {
767 return parameters_[i];
768 }
769 }
770 return nullptr;
771 }
772
GetOrderedCnodes()773 std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
774 auto this_ptr = shared_from_base<FuncGraph>();
775 auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1);
776 auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);
777
778 std::list<CNodePtr> cnodes;
779 auto nodes = mindspore::TopoSort(return_node(), SuccDepends, BelongSameGraph);
780 for (const auto &node : nodes) {
781 auto cnode = dyn_cast<CNode>(node);
782 if (cnode != nullptr) {
783 (void)cnodes.emplace_back(std::move(cnode));
784 }
785 }
786 return cnodes;
787 }
788
EraseUnusedNodeInOrder()789 void FuncGraph::EraseUnusedNodeInOrder() {
790 auto mng = manager_.lock();
791 if (mng != nullptr) {
792 auto &all_nodes = nodes();
793 // Erase unused cnode.
794 for (auto it = order_.begin(); it != order_.cend();) {
795 const auto &cnode = it->lock();
796 if (cnode == nullptr) {
797 it = order_.erase(it);
798 continue;
799 }
800 if (!all_nodes.contains(cnode)) {
801 MS_EXCEPTION_IF_NULL(cnode);
802 MS_LOG(DEBUG) << "Remove node: " << cnode->DebugString() << " in graph " << ToString() << " order.";
803 it = order_.erase(it);
804 continue;
805 }
806 (void)++it;
807 }
808 }
809 }
810
EraseUnusedNodeInOrder(const AnfNodePtr & node)811 void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) {
812 if (node == nullptr) {
813 return;
814 }
815 auto cnode = node->cast<CNodePtr>();
816 if (cnode != nullptr) {
817 auto iter = std::find_if(order_.cbegin(), order_.cend(), [&cnode](const CNodeWeakPtr &node) {
818 return node.lock() != nullptr && node.lock() == cnode;
819 });
820 if (iter != order_.cend()) {
821 (void)order_.erase(iter);
822 MS_LOG(DEBUG) << "Remove node: " << node->DebugString() << " from order list.";
823 }
824 }
825 }
826
827 // Maintain cnode order list when a cnode is replaced by a new one.
ReplaceInOrder(const AnfNodePtr & old_node,const AnfNodePtr & new_node)828 void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
829 MS_EXCEPTION_IF_NULL(old_node);
830 MS_EXCEPTION_IF_NULL(new_node);
831 if (order_.empty()) {
832 // Skip if order list is empty.
833 return;
834 }
835 auto old_cnode = old_node->cast<CNodePtr>();
836 if (old_cnode == nullptr) {
837 // Skip if old node is not cnode, since order list contains cnode only.
838 return;
839 }
840 // Search old node in order list.
841 auto iter = std::find_if(order_.cbegin(), order_.cend(), [&old_cnode](const CNodeWeakPtr &node) {
842 return node.lock() != nullptr && node.lock() == old_cnode;
843 });
844 if (iter == order_.cend()) {
845 // Skip if old node not found in order list.
846 return;
847 }
848 auto new_cnode = new_node->cast<CNodePtr>();
849 if (new_cnode != nullptr) {
850 // Insert new node just before the old node.
851 (void)order_.insert(iter, CNodeWeakPtr(new_cnode));
852 }
853 // Remove old node from order list.
854 // Unused children nodes can be cleared by EraseUnusedNodeInOrder().
855 (void)order_.erase(iter);
856 }
857
MakeInputNodes(const PrimitivePtr & primitive,const AnfNodePtrList & inputs)858 static AnfNodePtrList MakeInputNodes(const PrimitivePtr &primitive, const AnfNodePtrList &inputs) {
859 AnfNodePtrList input_node_list;
860 input_node_list.reserve(inputs.size() + 1);
861 input_node_list.emplace_back(std::make_shared<ValueNode>(primitive));
862 input_node_list.insert(input_node_list.end(), inputs.begin(), inputs.end());
863 return input_node_list;
864 }
865
NewCNode(const PrimitivePtr & primitive,const AnfNodePtrList & inputs)866 CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const AnfNodePtrList &inputs) {
867 auto input_node_list = MakeInputNodes(primitive, inputs);
868 return NewCNode(std::move(input_node_list));
869 }
870
NewCNodeInOrder(const PrimitivePtr & primitive,const AnfNodePtrList & inputs)871 CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const AnfNodePtrList &inputs) {
872 auto input_node_list = MakeInputNodes(primitive, inputs);
873 return NewCNodeInOrder(std::move(input_node_list));
874 }
875
SetMultiTarget() const876 void FuncGraph::SetMultiTarget() const {
877 auto graph_manager = manager();
878 MS_EXCEPTION_IF_NULL(graph_manager);
879 FuncGraphSet graphs = graph_manager->func_graphs();
880 AnfNodePtrList all_nodes;
881 for (auto &g : graphs) {
882 auto nodes = mindspore::TopoSort(g->get_return());
883 (void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(all_nodes));
884 }
885
886 bool exist_multi_target = false;
887 if (mindspore::ContainMultiTarget(all_nodes)) {
888 exist_multi_target = true;
889 MS_LOG(INFO) << "The graph " << ToString() << " exists the multi target.";
890 }
891
892 for (auto &g : graphs) {
893 g->set_exist_multi_target(exist_multi_target);
894 }
895 }
896
set_used_forward_nodes(const AnfNodePtrList & used_forward_nodes)897 void FuncGraph::set_used_forward_nodes(const AnfNodePtrList &used_forward_nodes) {
898 (void)std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(), [this](const AnfNodePtr &node) {
899 MS_EXCEPTION_IF_NULL(node);
900 (void)used_forward_nodes_.insert(node);
901 });
902 }
903
TopoSort(const AnfNodePtr & node)904 AnfNodePtrList FuncGraph::TopoSort(const AnfNodePtr &node) { return mindspore::TopoSort(node); }
905
NewFgSeenGeneration()906 SeenNum NewFgSeenGeneration() {
907 static SeenNum fg_seen_generation = 0;
908 ++fg_seen_generation;
909 // 0 is invalid number.
910 if (fg_seen_generation == 0) {
911 ++fg_seen_generation;
912 }
913 return fg_seen_generation;
914 }
915 } // namespace mindspore
916