1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/func_graph.h"
20
21 #include <algorithm>
22 #include <sstream>
23 #include <utility>
24
25 #include "utils/trace_base.h"
26 #include "ir/manager.h"
27 #include "utils/flags.h"
28 #include "utils/ordered_set.h"
29 #include "utils/convert_utils_base.h"
30 #include "abstract/abstract_function.h"
31
32 namespace mindspore {
33 /*
34 * Methods of Graph
35 */
FuncGraph()36 FuncGraph::FuncGraph()
37 : attrs_(),
38 transforms_(),
39 parameter_default_value_(),
40 seen_(0),
41 parameters_(),
42 has_vararg_(false),
43 has_kwarg_(false),
44 kwonlyargs_count_(0),
45 hyper_param_count_(0),
46 is_generated_(false),
47 is_bprop_(false),
48 return_(nullptr),
49 manager_(std::weak_ptr<FuncGraphManager>()),
50 stub_(false),
51 stage_(-1) {
52 debug_info_ = std::make_shared<GraphDebugInfo>();
53 switch_input_ = std::make_shared<bool>(false);
54 switch_layer_input_ = std::make_shared<bool>(false);
55 }
56
ToAbstract()57 abstract::AbstractBasePtr FuncGraph::ToAbstract() {
58 auto temp_context = abstract::AnalysisContext::DummyContext();
59 return std::make_shared<abstract::FuncGraphAbstractClosure>(shared_from_base<FuncGraph>(), temp_context);
60 }
61
output() const62 AnfNodePtr FuncGraph::output() const {
63 constexpr size_t return_input_num = 2;
64 // If return value is set, return should have two inputs.
65 if (return_ != nullptr && return_->inputs().size() == return_input_num) {
66 return return_->input(1);
67 } else {
68 // If not set yet, return nullptr.
69 return nullptr;
70 }
71 }
72
get_inputs() const73 const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
74 std::vector<AnfNodePtr> input_params;
75 for (auto const &node : parameters_) {
76 MS_EXCEPTION_IF_NULL(node);
77 auto parameter = dyn_cast<Parameter>(node);
78 MS_EXCEPTION_IF_NULL(parameter);
79 if (!parameter->has_default()) {
80 input_params.push_back(parameter);
81 }
82 }
83 return input_params;
84 }
85
add_parameter()86 ParameterPtr FuncGraph::add_parameter() {
87 FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
88 ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
89 add_parameter(p);
90 return p;
91 }
92
add_parameter(const ParameterPtr & p)93 void FuncGraph::add_parameter(const ParameterPtr &p) {
94 if (manager_.lock()) {
95 manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
96 } else {
97 parameters_.push_back(p);
98 }
99 }
100
InsertFrontParameter()101 ParameterPtr FuncGraph::InsertFrontParameter() {
102 FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
103 ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
104 InsertFrontParameter(p);
105 return p;
106 }
107
InsertFrontParameter(const ParameterPtr & p)108 void FuncGraph::InsertFrontParameter(const ParameterPtr &p) {
109 if (manager_.lock()) {
110 manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p);
111 } else {
112 PrependParameter(p);
113 }
114 }
115
AddWeightParameter(const std::string & name)116 ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
117 FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
118 ParameterPtr p = std::make_shared<Parameter>(this_graph);
119 p->set_name(name);
120 p->debug_info()->set_name(name);
121
122 if (manager_.lock()) {
123 manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
124 } else {
125 parameters_.push_back(p);
126 }
127 hyper_param_count_++;
128 return p;
129 }
130
has_flag(const std::string & key)131 bool FuncGraph::has_flag(const std::string &key) {
132 auto iter = attrs_.find(key);
133 if (iter != attrs_.cend()) {
134 MS_EXCEPTION_IF_NULL(iter->second);
135 if (iter->second->isa<BoolImm>()) {
136 return GetValue<bool>(iter->second);
137 }
138 MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
139 }
140 return false;
141 }
142
has_attr(const std::string & key) const143 bool FuncGraph::has_attr(const std::string &key) const {
144 auto iter = attrs_.find(key);
145 return !(iter == attrs_.cend());
146 }
147
get_attr(const std::string & key) const148 ValuePtr FuncGraph::get_attr(const std::string &key) const {
149 auto iter = attrs_.find(key);
150 return iter == attrs_.cend() ? nullptr : iter->second;
151 }
152
NewCNode(const std::vector<AnfNodePtr> & inputs)153 CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
154 return std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
155 }
156
NewCNodeInOrder(const std::vector<AnfNodePtr> & inputs)157 CNodePtr FuncGraph::NewCNodeInOrder(const std::vector<AnfNodePtr> &inputs) {
158 CNodePtr cnode = NewCNode(inputs);
159 order_.push_back(cnode);
160 return cnode;
161 }
162
NewCNodeInFront(const std::vector<AnfNodePtr> & inputs)163 CNodePtr FuncGraph::NewCNodeInFront(const std::vector<AnfNodePtr> &inputs) {
164 CNodePtr cnode = NewCNode(inputs);
165 order_.push_front(cnode);
166 return cnode;
167 }
168
NewCNodeBefore(const AnfNodePtr & position,const std::vector<AnfNodePtr> & inputs)169 CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
170 CNodePtr cnode = NewCNode(inputs);
171 CNodePtr pos_cnode = dyn_cast<CNode>(position);
172 auto iter = order_.find(pos_cnode);
173 order_.insert(iter, cnode);
174 return cnode;
175 }
176
NewCNodeAfter(const AnfNodePtr & position,const std::vector<AnfNodePtr> & inputs)177 CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
178 CNodePtr cnode = NewCNode(inputs);
179 CNodePtr pos_cnode = dyn_cast<CNode>(position);
180 auto iter = order_.find(pos_cnode);
181 if (iter == order_.end()) {
182 order_.push_front(cnode);
183 } else {
184 order_.insert(std::next(iter), cnode);
185 }
186 return cnode;
187 }
188
DumpCNodeList()189 void FuncGraph::DumpCNodeList() {
190 MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:";
191 for (const auto &cnode : order_) {
192 MS_LOG(INFO) << cnode->DebugString();
193 }
194 }
195
ToString() const196 std::string FuncGraph::ToString() const {
197 std::ostringstream buffer;
198 auto debug_info = const_cast<FuncGraph *>(this)->shared_from_base<FuncGraph>()->debug_info();
199 buffer << mindspore::label_manage::Label(debug_info);
200 buffer << "." << debug_info->get_id();
201 return buffer.str();
202 }
203
debug_info()204 GraphDebugInfoPtr FuncGraph::debug_info() {
205 MS_EXCEPTION_IF_NULL(this->debug_info_);
206 if (this->debug_info_->get_graph() == nullptr) {
207 this->debug_info_->set_graph(shared_from_base<FuncGraph>());
208 }
209 return this->debug_info_;
210 }
211
nodes() const212 const AnfNodeSet &FuncGraph::nodes() const { return nodes_; }
213
CopyNodes(const FuncGraphPtr & source)214 void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_.update(source->nodes()); }
215
ClearNodes()216 void FuncGraph::ClearNodes() { nodes_.clear(); }
217
AddNode(const AnfNodePtr & node)218 void FuncGraph::AddNode(const AnfNodePtr &node) { nodes_.add(node); }
219
DropNode(const AnfNodePtr & node)220 void FuncGraph::DropNode(const AnfNodePtr &node) {
221 nodes_.erase(node);
222 if (node == nullptr) {
223 MS_LOG(ERROR) << "Node is nullptr";
224 return;
225 }
226 auto graph = node->func_graph();
227 if (node->isa<Parameter>()) {
228 (void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
229 }
230 // Remove the node from order list.
231 if (graph) {
232 graph->EraseUnusedNodeInOrder(node);
233 }
234 }
235
value_nodes() const236 const AnfNodeCounterMap &FuncGraph::value_nodes() const { return value_nodes_; }
237
CopyValueNodes(const FuncGraphPtr & source)238 void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
239 auto &others = source->value_nodes();
240 for (auto it = others.begin(); it != others.end(); ++it) {
241 AddValueNode(it->first, it->second);
242 }
243 }
244
ClearValueNodes()245 void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
246
AddValueNode(const AnfNodePtr & node,int count)247 void FuncGraph::AddValueNode(const AnfNodePtr &node, int count) {
248 if (value_nodes_.count(node) == 0) {
249 value_nodes_[node] = count;
250 } else {
251 value_nodes_[node] += count;
252 }
253 }
254
DropValueNode(const AnfNodePtr & node)255 void FuncGraph::DropValueNode(const AnfNodePtr &node) {
256 if (value_nodes_.count(node) != 0) {
257 if (value_nodes_[node] == 1) {
258 (void)value_nodes_.erase(node);
259 } else {
260 value_nodes_[node]--;
261 if (value_nodes_[node] < 0) {
262 MS_LOG(EXCEPTION) << "Count of ValueNode '" << node
263 << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
264 }
265 }
266 }
267 }
268
free_variables() const269 const AnfNodeCounterMap &FuncGraph::free_variables() const { return free_variables_; }
270
CopyFreeVariables(const FuncGraphPtr & source)271 void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
272 auto &others = source->free_variables();
273 for (auto it = others.begin(); it != others.end(); ++it) {
274 const auto &free_var = it->first;
275 MS_EXCEPTION_IF_NULL(free_var);
276 if (free_var->func_graph().get() != this) {
277 (void)AddFreeVariable(free_var, it->second);
278 }
279 }
280 }
281
ClearFreeVariables()282 void FuncGraph::ClearFreeVariables() { free_variables_.clear(); }
283
AddFreeVariable(const AnfNodePtr & node,int count)284 bool FuncGraph::AddFreeVariable(const AnfNodePtr &node, int count) {
285 if (free_variables_.count(node) == 0) {
286 free_variables_[node] = count;
287 return true;
288 } else {
289 free_variables_[node] += count;
290 return false;
291 }
292 }
293
DropFreeVariable(const AnfNodePtr & node)294 bool FuncGraph::DropFreeVariable(const AnfNodePtr &node) {
295 if (free_variables_.count(node) != 0) {
296 if (free_variables_[node] == 1) {
297 (void)free_variables_.erase(node);
298 return true;
299 } else {
300 free_variables_[node]--;
301 if (free_variables_[node] < 0) {
302 MS_LOG(EXCEPTION) << "Count of free variable '" << node
303 << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
304 }
305 }
306 }
307 return false;
308 }
309
free_variables_total()310 const BaseRefCounterMap &FuncGraph::free_variables_total() {
311 auto mng = manager_.lock();
312 MS_EXCEPTION_IF_NULL(mng);
313 auto &fv_total = mng->free_variables_total();
314 return fv_total[shared_from_base<FuncGraph>()];
315 }
316
free_variables_nodes()317 std::vector<AnfNodePtr> FuncGraph::free_variables_nodes() {
318 std::vector<AnfNodePtr> nodes;
319 const auto &fv_total = this->free_variables_total();
320 for (auto &p : fv_total) {
321 auto key = p.first;
322 if (utils::isa<AnfNodePtr>(key)) {
323 nodes.push_back(utils::cast<AnfNodePtr>(key));
324 }
325 }
326 return nodes;
327 }
328
free_variables_func_graphs()329 std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
330 std::vector<FuncGraphPtr> func_graphs;
331 const auto &fv_total = this->free_variables_total();
332 for (auto &p : fv_total) {
333 auto key = p.first;
334 if (utils::isa<FuncGraphPtr>(key)) {
335 func_graphs.push_back(utils::cast<FuncGraphPtr>(key));
336 }
337 }
338
339 return func_graphs;
340 }
341
func_graphs_used() const342 const FuncGraphCounterMap &FuncGraph::func_graphs_used() const { return func_graphs_used_; }
343
CopyFuncGraphsUsed(const FuncGraphPtr & source)344 void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
345 auto &others = source->func_graphs_used();
346 for (auto it = others.begin(); it != others.end(); ++it) {
347 (void)AddFuncGraphUsed(it->first, it->second);
348 }
349 func_graphs_used_.erase(source);
350 }
351
ClearFuncGraphsUsed()352 void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
353
AddFuncGraphUsed(const FuncGraphPtr & fg,int count)354 bool FuncGraph::AddFuncGraphUsed(const FuncGraphPtr &fg, int count) {
355 if (func_graphs_used_.count(fg) == 0) {
356 func_graphs_used_[fg] = count;
357 return true;
358 } else {
359 func_graphs_used_[fg] += count;
360 return false;
361 }
362 }
363
DropFuncGraphUsed(const FuncGraphPtr & fg)364 bool FuncGraph::DropFuncGraphUsed(const FuncGraphPtr &fg) {
365 if (func_graphs_used_.count(fg) != 0) {
366 if (func_graphs_used_[fg] == 1) {
367 (void)func_graphs_used_.erase(fg);
368 return true;
369 } else {
370 func_graphs_used_[fg]--;
371 if (func_graphs_used_[fg] < 0) {
372 MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
373 << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
374 }
375 }
376 }
377 return false;
378 }
379
func_graphs_used_total()380 const FuncGraphSet &FuncGraph::func_graphs_used_total() {
381 auto mng = manager_.lock();
382 MS_EXCEPTION_IF_NULL(mng);
383 auto &used = mng->func_graphs_used_total(shared_from_base<FuncGraph>());
384 return used;
385 }
386
func_graph_cnodes_index() const387 const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() const { return func_graph_cnodes_index_; }
388
CopyFuncGraphCNodesIndex(const FuncGraphPtr & source)389 void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
390 auto &others = source->func_graph_cnodes_index();
391 for (auto it = others.begin(); it != others.end(); ++it) {
392 // Ignore the user graph who may own itself.
393 auto fg = it->first->first->func_graph();
394 MS_EXCEPTION_IF_NULL(fg);
395 if (fg.get() != this) {
396 AddFuncGraphCNodeIndex(it->first, it->second);
397 }
398 }
399 }
400
ClearFuncGraphCNodesIndex()401 void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); }
402
AddFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair,int count)403 void FuncGraph::AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair, int count) {
404 if (func_graph_cnodes_index_.count(pair) == 0) {
405 func_graph_cnodes_index_[pair] = count;
406 } else {
407 func_graph_cnodes_index_[pair] += count;
408 }
409 }
410
DropFuncGraphCNodeIndex(const CNodeIndexPairPtr & pair)411 void FuncGraph::DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair) {
412 if (func_graph_cnodes_index_.count(pair) != 0) {
413 if (func_graph_cnodes_index_[pair] == 1) {
414 (void)func_graph_cnodes_index_.erase(pair);
415 } else {
416 func_graph_cnodes_index_[pair]--;
417 if (func_graph_cnodes_index_[pair] < 0) {
418 MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
419 << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
420 }
421 }
422 }
423 }
424
j_value_nodes() const425 const std::unordered_map<AnfNodePtr, int> &FuncGraph::j_value_nodes() const { return j_value_nodes_; }
426
CopyJValueNodes(const FuncGraphPtr & source)427 void FuncGraph::CopyJValueNodes(const FuncGraphPtr &source) {
428 MS_EXCEPTION_IF_NULL(source);
429 auto &others = source->j_value_nodes();
430 for (const auto &other : others) {
431 AddJValueNode(other.first, other.second);
432 }
433 }
434
ClearJValueNodes()435 void FuncGraph::ClearJValueNodes() { j_value_nodes_.clear(); }
436
AddJValueNode(const AnfNodePtr & value_node,int count)437 void FuncGraph::AddJValueNode(const AnfNodePtr &value_node, int count) {
438 if (j_value_nodes_.count(value_node) == 0) {
439 j_value_nodes_[value_node] = count;
440 } else {
441 j_value_nodes_[value_node] += count;
442 }
443 }
444
DropJValueNode(const AnfNodePtr & value_node)445 void FuncGraph::DropJValueNode(const AnfNodePtr &value_node) {
446 if (j_value_nodes_.count(value_node) != 0) {
447 if (j_value_nodes_[value_node] == 1) {
448 (void)j_value_nodes_.erase(value_node);
449 } else {
450 j_value_nodes_[value_node]--;
451 if (j_value_nodes_[value_node] < 0) {
452 MS_LOG(EXCEPTION) << "Count of J ValueNode '" << value_node->DebugString()
453 << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
454 }
455 }
456 }
457 }
458
parent()459 FuncGraphPtr FuncGraph::parent() {
460 // report the bug early.
461 if (manager_.lock() == nullptr) {
462 MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
463 << " NodeInfo: " << trace::GetDebugInfo(debug_info());
464 }
465 auto mng = manager_.lock();
466 MS_EXCEPTION_IF_NULL(mng);
467 return mng->parent(shared_from_base<FuncGraph>());
468 }
469
children()470 const FuncGraphSet &FuncGraph::children() {
471 auto mng = manager_.lock();
472 MS_EXCEPTION_IF_NULL(mng);
473 return mng->children(shared_from_base<FuncGraph>());
474 }
475
scope()476 const FuncGraphSet &FuncGraph::scope() {
477 auto mng = manager_.lock();
478 MS_EXCEPTION_IF_NULL(mng);
479 return mng->scopes(shared_from_base<FuncGraph>());
480 }
481
recursive()482 bool FuncGraph::recursive() {
483 auto mng = manager_.lock();
484 MS_EXCEPTION_IF_NULL(mng);
485 return mng->recursive(shared_from_base<FuncGraph>());
486 }
487
recursive_graphs()488 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
489 auto mng = manager_.lock();
490 MS_EXCEPTION_IF_NULL(mng);
491 return mng->recursive_graphs(shared_from_base<FuncGraph>());
492 }
493
ClearAllManagerInfo()494 void FuncGraph::ClearAllManagerInfo() {
495 ClearNodes();
496 ClearValueNodes();
497 ClearFuncGraphCNodesIndex();
498 ClearFreeVariables();
499 ClearFuncGraphsUsed();
500 ClearJValueNodes();
501 }
502
GetDefaultValueByName(const std::string & name)503 AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
504 auto itr = this->parameter_default_value_.find(name);
505 if (itr == parameter_default_value_.end()) {
506 return nullptr;
507 }
508 auto default_value = itr->second;
509 if (default_value == nullptr) {
510 MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist";
511 }
512 if (IsValueNode<Null>(default_value)) {
513 return nullptr;
514 }
515 return default_value;
516 }
517
518 // set the default values
SetDefaultValues(const std::vector<std::string> & name_list,const std::vector<AnfNodePtr> & value_list)519 void FuncGraph::SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list) {
520 auto all_is_null =
521 std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode<Null>(node); });
522 if (value_list.empty()) {
523 all_is_null = true;
524 }
525 for (size_t i = 0; i < name_list.size(); ++i) {
526 if (!all_is_null) {
527 this->parameter_default_value_[name_list[i]] = value_list[i];
528 }
529 }
530 }
531
ClearDefaultValues()532 void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); }
533
GetDefaultValueCount()534 size_t FuncGraph::GetDefaultValueCount() {
535 int64_t null_count =
536 std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(),
537 [](const std::pair<std::string, AnfNodePtr> &pair) { return IsValueNode<Null>(pair.second); });
538 return parameter_default_value_.size() - LongToSize(null_count);
539 }
540
GetVariableArgParameter()541 AnfNodePtr FuncGraph::GetVariableArgParameter() {
542 if (!has_vararg_) {
543 return nullptr;
544 }
545
546 // one vararg + kwarg so the min param num is 2;
547 constexpr size_t min_param_num = 2;
548 if (has_kwarg_) {
549 if (parameters_.size() < hyper_param_count_ + min_param_num) {
550 MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
551 << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
552 }
553 return parameters_[parameters_.size() - hyper_param_count_ - min_param_num];
554 }
555
556 if (parameters_.size() < hyper_param_count_ + 1) {
557 MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
558 << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
559 }
560 return parameters_[parameters_.size() - hyper_param_count_ - 1];
561 }
562
GetVariableArgName()563 std::string FuncGraph::GetVariableArgName() {
564 if (!has_vararg_) {
565 return "";
566 }
567
568 // one vararg + kwarg so the min param num is 2;
569 constexpr size_t min_param_num = 2;
570 if (has_kwarg_) {
571 if (parameters_.size() < hyper_param_count_ + min_param_num) {
572 MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
573 << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
574 }
575 const auto ¶meter = parameters_[parameters_.size() - hyper_param_count_ - min_param_num]->cast<ParameterPtr>();
576 MS_EXCEPTION_IF_NULL(parameter);
577 return parameter->name();
578 }
579
580 if (parameters_.size() < hyper_param_count_ + 1) {
581 MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
582 << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
583 }
584 const auto ¶meter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>();
585 MS_EXCEPTION_IF_NULL(parameter);
586 return parameter->name();
587 }
588
GetVariableKwargParameter()589 AnfNodePtr FuncGraph::GetVariableKwargParameter() {
590 if (has_kwarg_) {
591 if (parameters_.size() < hyper_param_count_ + 1) {
592 MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
593 << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
594 }
595 return parameters_[parameters_.size() - hyper_param_count_ - 1];
596 }
597 return nullptr;
598 }
599
GetVariableKwargName()600 std::string FuncGraph::GetVariableKwargName() {
601 if (has_kwarg_) {
602 if (parameters_.size() < hyper_param_count_ + 1) {
603 MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
604 << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
605 }
606 const auto ¶meter = parameters_[parameters_.size() - hyper_param_count_ - 1]->cast<ParameterPtr>();
607 MS_EXCEPTION_IF_NULL(parameter);
608 return parameter->name();
609 }
610 return "";
611 }
612
GetPositionalArgsCount() const613 int FuncGraph::GetPositionalArgsCount() const {
614 int count = SizeToInt(parameters_.size());
615 if (has_kwarg_) {
616 count--;
617 }
618 if (has_vararg_) {
619 count--;
620 }
621 return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_);
622 }
623
GetParameterByName(const std::string & name)624 AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
625 for (size_t i = 0; i < parameters_.size(); ++i) {
626 MS_EXCEPTION_IF_NULL(parameters_[i]);
627 auto param_cast = parameters_[i]->cast<ParameterPtr>();
628 MS_EXCEPTION_IF_NULL(param_cast);
629 if (param_cast->name() == name) {
630 return parameters_[i];
631 }
632 }
633 return nullptr;
634 }
635
GetOrderedCnodes()636 std::list<CNodePtr> FuncGraph::GetOrderedCnodes() {
637 auto this_ptr = shared_from_base<FuncGraph>();
638 auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1);
639 auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1);
640
641 std::list<CNodePtr> cnodes;
642 auto nodes = mindspore::TopoSort(get_return(), SuccDepends, BelongSameGraph);
643 for (const auto &node : nodes) {
644 auto cnode = dyn_cast<CNode>(node);
645 if (cnode) {
646 cnodes.push_back(cnode);
647 }
648 }
649 return cnodes;
650 }
651
EraseUnusedNodeInOrder()652 void FuncGraph::EraseUnusedNodeInOrder() {
653 auto mng = manager_.lock();
654 if (mng) {
655 auto &all_nodes = nodes();
656 // Erase unused cnode.
657 for (auto it = order_.begin(); it != order_.end();) {
658 if (!all_nodes.contains(*it)) {
659 MS_LOG(DEBUG) << "Remove node: " << (*it)->ToString() << " in graph " << ToString() << " order.";
660 it = order_.erase(it);
661 continue;
662 }
663 (void)it++;
664 }
665 }
666 }
667
EraseUnusedNodeInOrder(const AnfNodePtr & node)668 void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) {
669 if (node) {
670 auto cnode = node->cast<CNodePtr>();
671 if (cnode) {
672 order_.erase(cnode);
673 MS_LOG(DEBUG) << "Remove node: " << node->DebugString() << " from order list.";
674 }
675 }
676 }
677
678 // Maintain cnode order list when a cnode is replaced by a new one.
ReplaceInOrder(const AnfNodePtr & old_node,const AnfNodePtr & new_node)679 void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
680 MS_EXCEPTION_IF_NULL(old_node);
681 MS_EXCEPTION_IF_NULL(new_node);
682 if (order_.empty()) {
683 // Skip if order list is empty.
684 return;
685 }
686 auto old_cnode = old_node->cast<CNodePtr>();
687 if (old_cnode == nullptr) {
688 // Skip if old node is not cnode, since order list contains cnode only.
689 return;
690 }
691 // Search old node in order list.
692 auto iter = order_.find(old_cnode);
693 if (iter == order_.end()) {
694 // Skip if old node not found in order list.
695 return;
696 }
697 auto new_cnode = new_node->cast<CNodePtr>();
698 if (new_cnode != nullptr) {
699 // Insert new node just before the old node.
700 order_.insert(iter, new_cnode);
701 }
702 // Remove old node from order list.
703 // Unused children nodes can be cleared by EraseUnusedNodeInOrder().
704 order_.erase(iter);
705 }
706
MakeInputNodes(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs)707 static std::vector<AnfNodePtr> MakeInputNodes(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
708 std::vector<AnfNodePtr> input_node_list;
709 input_node_list.reserve(inputs.size() + 1);
710 input_node_list.emplace_back(std::make_shared<ValueNode>(primitive));
711 input_node_list.insert(input_node_list.end(), inputs.begin(), inputs.end());
712 return input_node_list;
713 }
714
NewCNode(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs)715 CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
716 auto input_node_list = MakeInputNodes(primitive, inputs);
717 return NewCNode(input_node_list);
718 }
719
NewCNodeInOrder(const PrimitivePtr & primitive,const std::vector<AnfNodePtr> & inputs)720 CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &inputs) {
721 auto input_node_list = MakeInputNodes(primitive, inputs);
722 return NewCNodeInOrder(input_node_list);
723 }
724
add_weight(const tensor::MetaTensorPtr & meta_tensor)725 ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
726 auto parameter = add_parameter();
727 parameter->set_default_param(MakeValue(meta_tensor));
728 parameter->set_abstract(meta_tensor->ToAbstract());
729 return parameter;
730 }
731
ContainMultiTarget() const732 bool FuncGraph::ContainMultiTarget() const {
733 auto graph_manager = manager();
734 MS_EXCEPTION_IF_NULL(graph_manager);
735 FuncGraphSet graphs = graph_manager->func_graphs();
736 for (auto &g : graphs) {
737 auto nodes = mindspore::TopoSort(g->get_return());
738 if (mindspore::ContainMultiTarget(nodes)) {
739 return true;
740 }
741 }
742 return false;
743 }
744
set_used_forward_nodes(const std::vector<AnfNodePtr> & used_forward_nodes)745 void FuncGraph::set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes) {
746 (void)std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(), [this](const AnfNodePtr &node) {
747 MS_EXCEPTION_IF_NULL(node);
748 (void)used_forward_nodes_.emplace(node);
749 });
750 }
751
NewFgSeenGeneration()752 size_t NewFgSeenGeneration() {
753 static size_t fg_seen_generation = 0;
754 return ++fg_seen_generation;
755 }
756
757 // Implement TopoSort api.
TopoSort(const AnfNodePtr & node)758 std::vector<AnfNodePtr> api::FuncGraph::TopoSort(const AnfNodePtr &node) { return mindspore::TopoSort(node); }
759
760 // Create an api::FuncGraph instance.
Create()761 api::FuncGraphPtr api::FuncGraph::Create() { return std::make_shared<mindspore::FuncGraph>(); }
762
MakeValueNode(const api::FuncGraphPtr & func_graph)763 AnfNodePtr api::FuncGraph::MakeValueNode(const api::FuncGraphPtr &func_graph) {
764 auto fg = std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph);
765 return NewValueNode(fg);
766 }
767
GetFuncGraphFromAnfNode(const AnfNodePtr & input)768 api::FuncGraphPtr api::FuncGraph::GetFuncGraphFromAnfNode(const AnfNodePtr &input) {
769 auto fg = GetValueNode<mindspore::FuncGraphPtr>(input);
770 return fg;
771 }
772
773 const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
774 } // namespace mindspore
775