1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/manager.h"
20
21 #include <algorithm>
22 #include <list>
23
24 #include "ir/func_graph.h"
25 #include "utils/convert_utils_base.h"
26 #include "utils/counter.h"
27 #include "base/core_ops.h"
28
29 namespace mindspore {
30 namespace change {
31
32 struct Edge {
33 CNodePtr cnode;
34 int index;
35 AnfNodePtr input;
Edgemindspore::change::Edge36 Edge(const CNodePtr &cnode, int index, const AnfNodePtr &input) : cnode(cnode), index(index), input(input) {}
37 ~Edge() = default;
38 };
39
40 struct EdgeHash {
operator ()mindspore::change::EdgeHash41 std::size_t operator()(const Edge &e) const noexcept {
42 const std::hash<AnfNodePtr> node_hash;
43 return hash_combine({node_hash(e.cnode), IntToSize(e.index), node_hash(e.input)});
44 }
45 };
46
47 struct EdgeEqual {
operator ()mindspore::change::EdgeEqual48 bool operator()(const Edge &lhs, const Edge &rhs) const noexcept {
49 return lhs.cnode == rhs.cnode && lhs.index == rhs.index && lhs.input == rhs.input;
50 }
51 };
52
53 using EdgeCounter = Counter<Edge, EdgeHash, EdgeEqual>;
54 using NodeCounter = Counter<AnfNodePtr>;
55
56 struct ChangeCounter {
57 EdgeCounter new_edges;
58 EdgeCounter del_edges;
59 NodeCounter new_nodes;
60 NodeCounter del_nodes;
61
62 template <typename Func>
ForEachAddedEdgesmindspore::change::ChangeCounter63 void ForEachAddedEdges(Func &&func) {
64 new_edges.subtract_by(del_edges, std::forward<Func>(func));
65 }
66
67 template <typename Func>
ForEachRemovedEdgesmindspore::change::ChangeCounter68 void ForEachRemovedEdges(Func &&func) {
69 del_edges.subtract_by(new_edges, std::forward<Func>(func));
70 }
71
GetAddedNodesmindspore::change::ChangeCounter72 std::vector<AnfNodePtr> GetAddedNodes() { return new_nodes.subtract(del_nodes); }
GetRemovedNodesmindspore::change::ChangeCounter73 std::vector<AnfNodePtr> GetRemovedNodes() { return del_nodes.subtract(new_nodes); }
74 };
75
76 class SetEdge : public Change {
77 public:
SetEdge(const CNodePtr & cnode,int index,const AnfNodePtr & input)78 SetEdge(const CNodePtr &cnode, int index, const AnfNodePtr &input) : edge_{cnode, index, input} {}
79 ~SetEdge() override = default;
80
Apply(ChangeCounter * counter)81 void Apply(ChangeCounter *counter) override {
82 auto &old_input = edge_.cnode->input(IntToSize(edge_.index));
83 counter->del_nodes.add(old_input);
84 counter->del_edges.add(edge_.cnode, edge_.index, old_input);
85 edge_.cnode->set_input(IntToSize(edge_.index), edge_.input);
86 counter->new_nodes.add(edge_.input);
87 counter->new_edges.add(std::move(edge_));
88 }
89
90 private:
91 Edge edge_;
92 };
93
94 class AddEdge : public Change {
95 public:
AddEdge(const CNodePtr & cnode,const AnfNodePtr & input)96 AddEdge(const CNodePtr &cnode, const AnfNodePtr &input) : cnode_{cnode}, input_{input} {}
97 ~AddEdge() override = default;
98
Apply(ChangeCounter * counter)99 void Apply(ChangeCounter *counter) override {
100 int index = static_cast<int>(cnode_->size());
101 cnode_->add_input(input_);
102 counter->new_nodes.add(input_);
103 counter->new_edges.add(std::move(cnode_), index, std::move(input_));
104 }
105
106 private:
107 CNodePtr cnode_;
108 AnfNodePtr input_;
109 };
110
111 class SetParams : public Change {
112 public:
SetParams(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & params)113 SetParams(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> ¶ms)
114 : func_graph_{func_graph}, params_{params} {}
115 ~SetParams() override = default;
116
Apply(ChangeCounter * counter)117 void Apply(ChangeCounter *counter) override {
118 auto &old_params = func_graph_->parameters();
119 for (auto &p : old_params) {
120 counter->del_nodes.add(p);
121 }
122 func_graph_->set_parameters(params_);
123 for (auto &p : params_) {
124 counter->new_nodes.add(std::move(p));
125 }
126 }
127
128 private:
129 FuncGraphPtr func_graph_;
130 std::vector<AnfNodePtr> params_;
131 };
132
133 class AddParam : public Change {
134 public:
AddParam(const FuncGraphPtr & func_graph,const ParameterPtr & param)135 AddParam(const FuncGraphPtr &func_graph, const ParameterPtr ¶m) : func_graph_{func_graph}, param_{param} {}
136 ~AddParam() override = default;
137
Apply(ChangeCounter * counter)138 void Apply(ChangeCounter *counter) override {
139 func_graph_->append_parameter(param_);
140 counter->new_nodes.add(std::move(param_));
141 }
142
143 private:
144 FuncGraphPtr func_graph_;
145 ParameterPtr param_;
146 };
147
148 class InsertFrontParam : public Change {
149 public:
InsertFrontParam(const FuncGraphPtr & func_graph,const ParameterPtr & param)150 InsertFrontParam(const FuncGraphPtr &func_graph, const ParameterPtr ¶m)
151 : func_graph_{func_graph}, param_{param} {}
152 ~InsertFrontParam() override = default;
153
Apply(ChangeCounter * counter)154 void Apply(ChangeCounter *counter) override {
155 func_graph_->PrependParameter(param_);
156 counter->new_nodes.add(std::move(param_));
157 }
158
159 private:
160 FuncGraphPtr func_graph_;
161 ParameterPtr param_;
162 };
163
164 } // namespace change
165
MakeManager(const std::vector<FuncGraphPtr> & func_graphs,bool manage)166 FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
167 auto m = std::make_shared<FuncGraphManager>(func_graphs, manage);
168 m->Init();
169 return m;
170 }
171
Manage(const std::vector<FuncGraphPtr> & func_graphs,bool manage)172 FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage) {
173 FuncGraphManagerPtr m = nullptr;
174 bool root = false;
175
176 for (auto &fg : func_graphs) {
177 if (fg == nullptr) {
178 continue;
179 }
180 if (fg->manager() != nullptr) {
181 m = fg->manager();
182 break;
183 }
184 }
185
186 if (m == nullptr) {
187 std::vector<FuncGraphPtr> tmp;
188 m = MakeManager(tmp, manage);
189 root = true;
190 }
191
192 for (auto &fg : func_graphs) {
193 if (fg == nullptr) {
194 continue;
195 }
196 m->AddFuncGraph(fg, root);
197 }
198 return m;
199 }
200
Manage(FuncGraphPtr func_graph,bool manage)201 FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
202 std::vector<FuncGraphPtr> func_graphs = {func_graph};
203 return Manage(func_graphs, manage);
204 }
205
Manage(const api::FuncGraphPtr & func_graph,bool manage)206 api::FuncGraphManagerPtr api::FuncGraphManager::Manage(const api::FuncGraphPtr &func_graph, bool manage) {
207 return mindspore::Manage(std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph), manage);
208 }
209
FuncGraphManager(const std::vector<FuncGraphPtr> & roots,bool manage)210 FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
211 : roots_(roots), is_manage_(manage) {
212 Reset();
213 }
214
Reset()215 void FuncGraphManager::Reset() {
216 func_graphs_ = FuncGraphSet();
217 all_nodes_ = AnfNodeSet();
218 node_users_ = NodeUsersMap();
219 signals_ = std::make_shared<Signals>();
220 func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
221 func_graph_parent_ = std::make_shared<ParentComputer>(this);
222 children_ = std::make_shared<ChildrenComputer>(this);
223 scopes_ = std::make_shared<ScopeComputer>(this);
224 free_variables_total_ = std::make_shared<FVTotalComputer>(this);
225 func_graphs_used_total_ = std::make_shared<FuncGraphsUsedTotalComputer>(this);
226 recursive_ = std::make_shared<RecursiveComputer>(this);
227 j_total_ = std::make_shared<FuncGraphJTotalComputer>(this);
228 }
229
Init()230 void FuncGraphManager::Init() {
231 auto roots = roots_;
232 roots_ = FuncGraphSet();
233
234 for (auto &fg : roots) {
235 AddFuncGraph(fg, true);
236 }
237 }
238
func_graph_parents_total(const FuncGraphPtr & fg) const239 FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
240 MS_EXCEPTION_IF_NULL(fg);
241 MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
242 func_graph_parents_total_->Recompute(fg);
243 MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
244 return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
245 }
246
parent(const FuncGraphPtr & fg) const247 FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
248 MS_EXCEPTION_IF_NULL(fg);
249 MS_EXCEPTION_IF_NULL(func_graph_parent_);
250 MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
251 func_graph_parent_->Recompute(fg);
252 if (func_graph_parent_->parent_analysis().count(fg) == 0) {
253 MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString();
254 return nullptr;
255 }
256 MS_LOG(DEBUG) << "End parents func graph " << fg->ToString();
257 return func_graph_parent_->parent_analysis()[fg];
258 }
259
children(const FuncGraphPtr & fg) const260 FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
261 MS_EXCEPTION_IF_NULL(fg);
262 MS_EXCEPTION_IF_NULL(children_);
263 MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
264 children_->Recompute(fg);
265 return children_->children_analysis()[fg];
266 }
267
scopes(const FuncGraphPtr & fg) const268 FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
269 MS_EXCEPTION_IF_NULL(fg);
270 MS_EXCEPTION_IF_NULL(scopes_);
271 MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString();
272 scopes_->Recompute(fg);
273 MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString();
274 return scopes_->scope_analysis()[fg];
275 }
276
free_variables_total() const277 FVTotalMap &FuncGraphManager::free_variables_total() const {
278 MS_EXCEPTION_IF_NULL(free_variables_total_);
279 free_variables_total_->Recompute();
280 return free_variables_total_->fv_total_analysis();
281 }
282
func_graphs_used_total(const FuncGraphPtr & fg) const283 FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
284 MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
285 func_graphs_used_total_->Recompute(fg);
286 return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
287 }
288
recursive(const FuncGraphPtr & fg) const289 bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
290 MS_EXCEPTION_IF_NULL(fg);
291 recursive_->Recompute(fg);
292 if (recursive_->recursive_analysis().count(fg) == 0) {
293 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
294 return false;
295 }
296 return recursive_->recursive_analysis()[fg];
297 }
298
recursive_graphs(const FuncGraphPtr & fg) const299 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
300 MS_EXCEPTION_IF_NULL(fg);
301 if (recursive(fg)) {
302 if (!recursive_->recursive_map().count(fg)) {
303 auto trace = std::list<FuncGraphPtr>();
304 recursive_->CheckRecursiveGraphs(fg, &trace);
305 }
306 if (recursive_->recursive_map().count(fg) == 0) {
307 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
308 return nullptr;
309 }
310 return recursive_->recursive_map()[fg];
311 } else {
312 return nullptr;
313 }
314 }
315
func_graph_j_total(const FuncGraphPtr & fg) const316 bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const {
317 MS_EXCEPTION_IF_NULL(j_total_);
318 MS_EXCEPTION_IF_NULL(fg);
319 j_total_->Recompute(fg);
320 if (j_total_->j_total_analysis().count(fg) == 0) {
321 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
322 return false;
323 }
324 return j_total_->j_total_analysis()[fg];
325 }
326
327 // Add a func graph to this manager, optionally as a root func graph.
AddFuncGraph(const FuncGraphPtr & func_graph,bool is_root)328 void FuncGraphManager::AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root) {
329 MS_EXCEPTION_IF_NULL(func_graph);
330 if (is_root) {
331 roots_.add(func_graph);
332 }
333 if (func_graphs_.contains(func_graph)) {
334 return;
335 }
336
337 // Add func_graph as a managed graph.
338 AddIntoManaged(func_graph);
339
340 // New nodes to be acquired.
341 std::vector<AnfNodePtr> new_nodes = func_graph->parameters();
342 auto return_node = func_graph->get_return();
343 if (return_node != nullptr) {
344 (void)new_nodes.emplace_back(std::move(return_node));
345 }
346
347 // Acquire all nodes from func_graph.
348 AcquireNodes(std::move(new_nodes));
349 }
350
351 // Clear the all information in manager
Clear()352 void FuncGraphManager::Clear() {
353 for (auto graph : func_graphs_) {
354 graph->DecAttachedMngCnt();
355 if (graph->attached_mng_cnt() == 0) {
356 graph->ClearAllManagerInfo();
357 } else if (graph->attached_mng_cnt() < 0) {
358 MS_LOG(EXCEPTION) << "graph:" << graph->ToString() << " attached cnt not right:" << graph->attached_mng_cnt();
359 }
360 }
361
362 func_graphs_.clear();
363 all_nodes_.clear();
364 node_users_.clear();
365 roots_.clear();
366
367 signals_->InvalidateComputer();
368 }
369
KeepRoots(const std::vector<FuncGraphPtr> & func_graphs)370 void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
371 MS_LOG(DEBUG) << "Start keep roots";
372 bool root_exist = false;
373 for (auto &item : func_graphs) {
374 if (roots_.contains(item)) {
375 root_exist = true;
376 break;
377 }
378 }
379
380 // if the new_root in roots_, we add new_root first, then calculate the func_graphs
381 // relation to new_root, remove the func_graphs not relation to new_root
382 // if the new_root not in roots_, we clear the all func_graphs in manager
383 // then add the new_root
384 if (root_exist || func_graphs.empty()) {
385 FuncGraphSet roots(func_graphs);
386 if (roots.empty()) {
387 roots = roots_;
388 } else {
389 roots_.clear();
390 for (auto &item : roots) {
391 AddFuncGraph(item, true);
392 }
393 }
394
395 FuncGraphSet keep;
396 for (auto &item : roots) {
397 MS_LOG(DEBUG) << "roots: " << item->ToString();
398 keep.update(func_graphs_used_total(item));
399 #ifdef DEBUG
400 for (auto &k : keep) {
401 MS_LOG(DEBUG) << "keep: " << k->ToString();
402 }
403 #endif
404 }
405 MaybeDropFuncGraphs(func_graphs_ - keep, true);
406 } else {
407 Clear();
408 FuncGraphSet roots(func_graphs);
409 for (auto &item : roots) {
410 AddFuncGraph(item, true);
411 }
412 }
413 }
414
RemoveRoots()415 void FuncGraphManager::RemoveRoots() {
416 MS_LOG(DEBUG) << "Start remove roots";
417 roots_.clear();
418 MaybeDropFuncGraphs(func_graphs_, true);
419 }
420
AddIntoManaged(const FuncGraphPtr & fg)421 void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
422 MS_EXCEPTION_IF_NULL(fg);
423 if (is_manage_) {
424 if (fg->manager() != nullptr && fg->manager().get() != this) {
425 MS_LOG(INFO) << "A func graph can only have one manager.";
426 }
427 fg->set_manager(shared_from_this());
428 }
429 func_graphs_.add(fg);
430 fg->IncAttachedMngCnt();
431 }
432
MaybeDropFuncGraphs(const FuncGraphSet & func_graphs,bool ignore_users)433 void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
434 std::list<FuncGraphPtr> todo(func_graphs.begin(), func_graphs.end());
435 std::set<FuncGraphPtr> dropped;
436 while (!todo.empty()) {
437 FuncGraphPtr func_graph = std::move(todo.front());
438 MS_EXCEPTION_IF_NULL(func_graph);
439 todo.pop_front();
440 MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString();
441 if (roots_.contains(func_graph)) {
442 MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
443 continue;
444 }
445 auto &users_cnode_index = func_graph->func_graph_cnodes_index();
446 if (!users_cnode_index.empty() && !ignore_users) {
447 MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
448 continue;
449 }
450 if (dropped.find(func_graph) != dropped.end()) {
451 MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString();
452 continue;
453 }
454 (void)dropped.insert(func_graph);
455 std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
456 auto drop_graphs = MaybeDropNodes(std::move(return_vec));
457 (void)todo.insert(todo.end(), drop_graphs.begin(), drop_graphs.end());
458 }
459 for (auto &fg : dropped) {
460 MS_EXCEPTION_IF_NULL(fg);
461 all_nodes_.difference_update(fg->parameters());
462 EraseOneGraph(fg);
463 if (fg->manager().get() == this) {
464 fg->set_manager(nullptr);
465 }
466 MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString();
467 }
468 }
469
ProcessEdgeAdd(const AnfNodePtr & node,int index,const AnfNodePtr & input)470 void FuncGraphManager::ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
471 if (IsValueNode<FuncGraph>(input)) {
472 AddFuncGraph(GetValueNode<FuncGraphPtr>(input));
473 }
474 auto &users_node = node_users_[input];
475 users_node.add(std::make_pair(node, index));
476 OnEdgeAdded(node, index, input);
477 }
478
ProcessEdgeRemove(const AnfNodePtr & node,int index,const AnfNodePtr & input)479 void FuncGraphManager::ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
480 auto iter = node_users_.find(input);
481 if (iter == node_users_.end()) {
482 return;
483 }
484 bool removed = iter->second.erase(std::make_pair(node, index));
485 if (removed) {
486 OnEdgeRemoved(node, index, input);
487 }
488 }
489
ProcessInputsEdgeAdd(const CNodePtr & cnode)490 void FuncGraphManager::ProcessInputsEdgeAdd(const CNodePtr &cnode) {
491 const size_t count = cnode->size();
492 for (size_t i = 0; i < count; ++i) {
493 ProcessEdgeAdd(cnode, static_cast<int>(i), cnode->input(i));
494 }
495 }
496
ProcessInputsEdgeRemove(const CNodePtr & cnode)497 void FuncGraphManager::ProcessInputsEdgeRemove(const CNodePtr &cnode) {
498 const size_t count = cnode->size();
499 for (size_t i = 0; i < count; ++i) {
500 ProcessEdgeRemove(cnode, static_cast<int>(i), cnode->input(i));
501 }
502 }
503
FollowGraph(const FuncGraphPtr & fg,size_t seen,std::vector<AnfNodePtr> * nodes)504 static inline void FollowGraph(const FuncGraphPtr &fg, size_t seen, std::vector<AnfNodePtr> *nodes) {
505 if (fg == nullptr) {
506 return;
507 }
508 if (auto ret = fg->get_return(); ret != nullptr && ret->seen_ != seen) {
509 (void)nodes->emplace_back(std::move(ret));
510 }
511 }
512
AcquireNodes(std::vector<AnfNodePtr> && nodes)513 void FuncGraphManager::AcquireNodes(std::vector<AnfNodePtr> &&nodes) {
514 auto seen = NewSeenGeneration();
515 while (!nodes.empty()) {
516 // Take the last one.
517 auto node = std::move(nodes.back());
518 nodes.pop_back();
519 MS_EXCEPTION_IF_NULL(node);
520 // Skip visited nodes.
521 if (node->seen_ == seen) {
522 continue;
523 }
524 node->seen_ = seen;
525 // Try add it to all_nodes_.
526 auto insert_result = all_nodes_.insert(node);
527 if (insert_result.second == false) {
528 // Skip acquired nodes.
529 continue;
530 }
531 // Add node to its func_graph.
532 auto fg = node->func_graph();
533 if (fg != nullptr) {
534 fg->AddNode(node);
535 }
536 // Follow graph for value node.
537 if (node->isa<ValueNode>()) {
538 auto graph = GetValueNode<FuncGraphPtr>(node);
539 FollowGraph(graph, seen, &nodes);
540 continue;
541 }
542 // Follow graph for cnode or parameter.
543 FollowGraph(fg, seen, &nodes);
544 // Handle cnode.
545 auto cnode = node->cast<CNodePtr>();
546 if (cnode != nullptr) {
547 // Handle input edges.
548 ProcessInputsEdgeAdd(cnode);
549 // Follow inputs.
550 auto &inputs = cnode->inputs();
551 (void)nodes.insert(nodes.end(), inputs.begin(), inputs.end());
552 }
553 }
554 }
555
MaybeDropNodes(std::vector<AnfNodePtr> && nodes)556 FuncGraphSet FuncGraphManager::MaybeDropNodes(std::vector<AnfNodePtr> &&nodes) {
557 FuncGraphSet drop_func_graphs;
558 while (!nodes.empty()) {
559 AnfNodePtr node = std::move(nodes.back());
560 nodes.pop_back();
561 if (node == nullptr) {
562 // Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception,
563 // this method may be triggered by desctuctor.
564 MS_LOG(WARNING) << "Node to be dropped is nullptr";
565 continue;
566 }
567 if (!all_nodes_.contains(node)) {
568 // Node not existed.
569 continue;
570 }
571 auto &users = node_users_[node];
572 if (!users.empty()) {
573 // Node is in used.
574 continue;
575 }
576 if (node->isa<Parameter>() && node->func_graph() != nullptr) {
577 // Node is a used parameter.
578 auto ¶meters = node->func_graph()->parameters();
579 if (std::find(parameters.begin(), parameters.end(), node) != parameters.end()) {
580 continue;
581 }
582 }
583 if (IsValueNode<FuncGraph>(node)) {
584 // The FuncGraph may need to be dropped.
585 auto fg = GetValueNode<FuncGraphPtr>(node);
586 drop_func_graphs.add(fg);
587 }
588 // Handle cnode.
589 if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
590 // Remove inputs edges.
591 ProcessInputsEdgeRemove(cnode);
592 // Handle inputs nodes.
593 auto &inputs = cnode->inputs();
594 (void)nodes.insert(nodes.end(), inputs.begin(), inputs.end());
595 }
596 // Remove it from all_nodes_;
597 (void)all_nodes_.erase(node);
598 // Drop node from its func graph.
599 if (auto fg = node->func_graph(); fg != nullptr) {
600 fg->DropNode(node);
601 }
602 // Remove it from node_users.
603 (void)node_users_.erase(node);
604 }
605 return drop_func_graphs;
606 }
607
SetParameters(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & parameters)608 void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters) {
609 auto tr = Transact();
610 tr.SetParameters(fg, parameters);
611 tr.Commit();
612 }
613
AddParameter(const FuncGraphPtr & fg,const AnfNodePtr & parameter)614 void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) {
615 auto tr = Transact();
616 tr.AddParameter(fg, parameter);
617 tr.Commit();
618 }
619
InsertFrontParameter(const FuncGraphPtr & fg,const AnfNodePtr & parameter)620 void FuncGraphManager::InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) {
621 auto tr = Transact();
622 tr.InsertFrontParameter(fg, parameter);
623 tr.Commit();
624 }
625
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)626 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
627 auto func_graph = old_node->func_graph();
628 auto tr = Transact();
629 bool success = tr.Replace(old_node, new_node);
630 if (success) {
631 tr.Commit();
632 if (func_graph != nullptr) {
633 func_graph->ReplaceInOrder(old_node, new_node);
634 }
635 }
636 return success;
637 }
638
SetEdge(const AnfNodePtr & node,int index,const AnfNodePtr & value)639 void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
640 auto tr = Transact();
641 tr.SetEdge(node, index, value);
642 tr.Commit();
643 }
644
AddEdge(const AnfNodePtr & node,const AnfNodePtr & value)645 void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
646 auto tr = Transact();
647 tr.AddEdge(node, value);
648 tr.Commit();
649 }
650
MoveAllCNodeDropGraph(const FuncGraphPtr & source,const FuncGraphPtr & target,const ScopePtr & scope)651 void FuncGraphManager::MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target,
652 const ScopePtr &scope) {
653 AnfNodePtr source_return = source->get_return();
654 AnfNodePtr source_output = source->output();
655 AnfNodePtr source_prim = source_return->cast<CNodePtr>()->input(0);
656
657 int index = 0;
658 (void)node_users_[source_prim].erase(make_pair(source_return, index));
659 OnEdgeRemoved(source_return, index, source_prim);
660 index = 1;
661 (void)node_users_[source_output].erase(make_pair(source_return, index));
662 OnEdgeRemoved(source_return, index, source_output);
663 (void)all_nodes_.erase(source_return);
664 (void)node_users_.erase(source_return);
665 source->DropNode(source_return);
666 for (auto &node : source->nodes()) {
667 node->set_func_graph(target);
668 if (node->scope() == kDefaultScope) {
669 node->set_scope(scope);
670 }
671 }
672
673 MoveAllNodes(source, target);
674 all_nodes_.difference_update(source->parameters());
675 EraseOneGraph(source);
676 source->set_dropped(true);
677 if (source->manager().get() == this) {
678 source->set_manager(nullptr);
679 }
680 }
681
OnEdgeAdded(const AnfNodePtr & node,int index,const AnfNodePtr & input)682 void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
683 auto fg = node->func_graph();
684 if (input->isa<ValueNode>()) {
685 fg->AddValueNode(input);
686 if (IsValueNode<FuncGraph>(input)) {
687 auto used = GetValueNode<FuncGraphPtr>(input);
688 used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
689 if (fg->AddFuncGraphUsed(used)) {
690 signals_->InvalidateComputer();
691 }
692 }
693 if (IsPrimitiveCNode(node, prim::kPrimJ)) {
694 fg->AddJValueNode(input);
695 }
696 } else if (fg != nullptr && fg != input->func_graph()) {
697 if (fg->AddFreeVariable(input)) {
698 signals_->InvalidateComputer();
699 }
700 }
701 }
702
OnEdgeRemoved(const AnfNodePtr & node,int index,const AnfNodePtr & input)703 void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
704 auto fg = node->func_graph();
705 if (fg != nullptr && input->isa<ValueNode>()) {
706 fg->DropValueNode(input);
707 if (IsValueNode<FuncGraph>(input)) {
708 auto used = GetValueNode<FuncGraphPtr>(input);
709 used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
710 if (fg->DropFuncGraphUsed(used)) {
711 signals_->InvalidateComputer();
712 }
713 }
714 if (IsPrimitiveCNode(node, prim::kPrimJ)) {
715 fg->DropJValueNode(input);
716 }
717 } else if (fg != nullptr && fg != input->func_graph()) {
718 if (fg->DropFreeVariable(input)) {
719 signals_->InvalidateComputer();
720 }
721 }
722 }
723
MoveAllNodes(const FuncGraphPtr & source,const FuncGraphPtr & target)724 void FuncGraphManager::MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target) {
725 target->CopyNodes(source);
726 target->CopyValueNodes(source);
727 target->CopyFuncGraphCNodesIndex(source);
728 target->CopyFreeVariables(source);
729 target->CopyFuncGraphsUsed(source);
730 target->CopyJValueNodes(source);
731 source->ClearAllManagerInfo();
732 signals_->InvalidateComputer();
733 }
734
CommitChanges(std::vector<change::ChangePtr> && changes)735 void FuncGraphManager::CommitChanges(std::vector<change::ChangePtr> &&changes) {
736 // Apply changes.
737 change::ChangeCounter counter;
738 for (auto &change : changes) {
739 change->Apply(&counter);
740 }
741 changes.clear();
742
743 // Process added edges.
744 counter.ForEachAddedEdges([this](const change::Edge &edge) { //
745 ProcessEdgeAdd(edge.cnode, edge.index, edge.input);
746 });
747
748 // Process added nodes.
749 AcquireNodes(counter.GetAddedNodes());
750
751 // Process removed edges.
752 counter.ForEachRemovedEdges([this](const change::Edge &edge) { //
753 ProcessEdgeRemove(edge.cnode, edge.index, edge.input);
754 });
755
756 // Process removed nodes.
757 auto drop_func_graphs = MaybeDropNodes(counter.GetRemovedNodes());
758 if (!drop_func_graphs.empty()) {
759 MaybeDropFuncGraphs(drop_func_graphs);
760 }
761 }
762
EraseOneGraph(const FuncGraphPtr & fg)763 void FuncGraphManager::EraseOneGraph(const FuncGraphPtr &fg) {
764 MS_EXCEPTION_IF_NULL(fg);
765 size_t erase_cnt = func_graphs_.erase(fg->shared_from_base<FuncGraph>());
766 if (!erase_cnt) {
767 return;
768 }
769 fg->DecAttachedMngCnt();
770 if (fg->attached_mng_cnt() == 0) {
771 fg->ClearAllManagerInfo();
772 }
773 }
774
SetParameters(FuncGraphPtr fg,const std::vector<AnfNodePtr> & params)775 void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms) {
776 (void)changes_.emplace_back(std::make_unique<change::SetParams>(fg, params));
777 }
778
AddParameter(FuncGraphPtr fg,const AnfNodePtr & param)779 void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) {
780 (void)changes_.emplace_back(std::make_unique<change::AddParam>(fg, param->cast<ParameterPtr>()));
781 }
782
InsertFrontParameter(FuncGraphPtr fg,const AnfNodePtr & param)783 void FuncGraphTransaction::InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) {
784 (void)changes_.emplace_back(std::make_unique<change::InsertFrontParam>(fg, param->cast<ParameterPtr>()));
785 }
786
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)787 bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
788 MS_EXCEPTION_IF_NULL(old_node);
789 MS_EXCEPTION_IF_NULL(new_node);
790 FuncGraphPtr old_func_graph = old_node->func_graph();
791 if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) {
792 MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
793 return false;
794 }
795 auto &users = manager_->node_users()[old_node];
796 for (auto &node : users) {
797 SetEdge(node.first, node.second, new_node);
798 }
799 return true;
800 }
801
SetEdge(const AnfNodePtr & src_node,int k,const AnfNodePtr & v)802 void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
803 if (k < 0) {
804 MS_LOG(EXCEPTION) << "Invalid value k = " << k;
805 }
806 MS_EXCEPTION_IF_NULL(src_node);
807 auto cnode = src_node->cast<CNodePtr>();
808 if (cnode == nullptr) {
809 MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed.";
810 }
811 (void)changes_.emplace_back(std::make_unique<change::SetEdge>(cnode, k, v));
812 }
813
AddEdge(const AnfNodePtr & src_node,const AnfNodePtr & v)814 void FuncGraphTransaction::AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v) {
815 MS_EXCEPTION_IF_NULL(src_node);
816 auto cnode = src_node->cast<CNodePtr>();
817 if (cnode == nullptr) {
818 MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed.";
819 }
820 (void)changes_.emplace_back(std::make_unique<change::AddEdge>(cnode, v));
821 }
822
Commit()823 void FuncGraphTransaction::Commit() { manager_->CommitChanges(std::move(changes_)); }
824
DepComputer(const FuncGraphManager * const manager)825 DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
826 MS_EXCEPTION_IF_NULL(manager_);
827 manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
828 validate_ = false;
829 }
830
Recompute()831 void DepComputer::Recompute() {
832 if (!validate_) {
833 RealRecompute();
834 validate_ = true;
835 }
836 }
837
Recompute(const FuncGraphPtr & fg)838 void DepComputer::Recompute(const FuncGraphPtr &fg) {
839 if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
840 RealRecompute(fg);
841 func_graphs_validate_[fg] = true;
842 }
843 }
844
SeekParents(const FuncGraphPtr & fg,std::unordered_map<FuncGraphPtr,FuncGraphSetPtr> * seen_fgs)845 FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(
846 const FuncGraphPtr &fg, std::unordered_map<FuncGraphPtr, FuncGraphSetPtr> *seen_fgs) {
847 auto iter = seen_fgs->find(fg);
848 if (iter != seen_fgs->end()) {
849 return iter->second;
850 }
851 FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
852
853 // Append all the fvs in fg.
854 auto &fvs = fg->free_variables();
855 for (auto fv : fvs) {
856 parents->add(fv.first->func_graph());
857 }
858
859 // Search the fv in fg's child func graph.
860 auto &fgs = fg->func_graphs_used();
861 for (auto &item : fgs) {
862 auto gt = item.first;
863 if (gt->seen_ == 1) {
864 continue;
865 }
866 gt->seen_ = 1;
867 parents->update(SeekParents(gt, seen_fgs));
868 gt->seen_ = 0;
869 }
870 (void)parents->erase(fg);
871 (*seen_fgs)[fg] = parents;
872 return parents;
873 }
874
RealRecompute(FuncGraphPtr fg)875 void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
876 MS_EXCEPTION_IF_NULL(fg);
877 std::unordered_map<FuncGraphPtr, FuncGraphSetPtr> seen_fgs;
878 fg->seen_ = 1;
879 func_graph_parents_total_analysis_[fg].update(SeekParents(fg, &seen_fgs));
880 fg->seen_ = 0;
881 }
882
set_len_compare(const FuncGraphSetPair & lhs,const FuncGraphSetPair & rhs)883 bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
884 auto l1 = lhs.second.size();
885 auto l2 = rhs.second.size();
886 return l1 < l2;
887 }
888
RealRecompute(FuncGraphPtr fg)889 void ParentComputer::RealRecompute(FuncGraphPtr fg) {
890 this->parent_analysis_[fg] = nullptr;
891 // Note: must be a copy other than reference as it is modified thereafter.
892 auto deps = this->manager_->func_graph_parents_total(fg);
893
894 if (deps.empty()) {
895 this->parent_analysis_[fg] = nullptr;
896 return;
897 } else if (deps.size() == 1) {
898 this->parent_analysis_[fg] = deps.front();
899 return;
900 } else {
901 // return nearest parent as parent
902 FuncGraphSet deps_copy(deps);
903 for (auto &dep : deps) {
904 auto parent_deps = this->manager_->func_graph_parents_total(dep);
905 for (auto &p_d : parent_deps) {
906 if (deps_copy.count(p_d)) {
907 (void)deps_copy.erase(p_d);
908 }
909 }
910 if (deps_copy.size() == 1) {
911 this->parent_analysis_[fg] = deps_copy.front();
912 return;
913 }
914 }
915 }
916 }
917
RealRecompute(FuncGraphPtr fg)918 void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
919 MS_EXCEPTION_IF_NULL(manager_);
920 auto used_fg_total = manager_->func_graphs_used_total(fg);
921 for (auto &used_fg : used_fg_total) {
922 if (manager_->parent(used_fg) == fg) {
923 children_analysis_[fg].add(used_fg);
924 }
925 }
926 }
927
RealRecompute(FuncGraphPtr fg)928 void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
929 MS_EXCEPTION_IF_NULL(manager_);
930 auto &children = manager_->children(fg);
931
932 scope_analysis_[fg] = FuncGraphSet();
933 scope_analysis_[fg].add(fg);
934 for (auto &child : children) {
935 scope_analysis_[fg].add(child);
936 }
937 }
938
RealRecompute()939 void FVTotalComputer::RealRecompute() {
940 auto manager = DepComputer::manager_;
941 MS_EXCEPTION_IF_NULL(manager);
942
943 for (auto &fg : manager->func_graphs()) {
944 fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
945 }
946
947 for (auto &fg : manager->func_graphs()) {
948 // add all free variable nodes
949 AnfNodeCounterMap items = fg->free_variables();
950 for (auto &iter : items) {
951 auto curr = fg;
952 while (curr != nullptr) {
953 fv_total_analysis_[curr][iter.first] = iter.second;
954 curr = manager->parent(curr);
955 if (curr != nullptr) {
956 const AnfNodeSet &all_nodes = curr->nodes();
957 if (all_nodes.contains(iter.first)) {
958 break;
959 }
960 }
961 }
962 }
963
964 // add all FGs of free variables
965 auto &used = fg->func_graphs_used();
966 for (auto &iter : used) {
967 auto p = manager->parent(iter.first);
968 if (p == nullptr) {
969 continue;
970 }
971 auto curr = fg;
972 while (curr != nullptr && curr != p) {
973 fv_total_analysis_[curr][iter.first] = iter.second;
974 curr = manager->parent(curr);
975 }
976 }
977 }
978 }
979
RealRecompute(FuncGraphPtr fg)980 void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
981 MS_EXCEPTION_IF_NULL(manager_);
982 std::vector<FuncGraphPtr> todo;
983 std::vector<FuncGraphPtr> todo_new;
984
985 todo.push_back(fg);
986 while (!todo.empty()) {
987 todo_new.clear();
988 for (auto > : todo) {
989 for (auto &item : gt->func_graphs_used()) {
990 auto used_fg = item.first;
991 if (used_fg == fg) {
992 func_graph_used_total_analysis_[fg].add(used_fg);
993 continue;
994 }
995 if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) {
996 todo_new.push_back(used_fg);
997 }
998 MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString();
999 func_graph_used_total_analysis_[fg].add(used_fg);
1000 }
1001 }
1002 todo = todo_new;
1003 }
1004 }
1005
CheckRecursive(const FuncGraphManager * const manager,const FuncGraphPtr & fg)1006 bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
1007 MS_EXCEPTION_IF_NULL(manager);
1008 std::vector<FuncGraphPtr> todo;
1009 std::vector<FuncGraphPtr> todo_new;
1010 todo.push_back(fg);
1011 FuncGraphSet used_total;
1012 while (!todo.empty()) {
1013 todo_new.clear();
1014 for (auto > : todo) {
1015 for (auto &item : gt->func_graphs_used()) {
1016 auto used_g = item.first;
1017 if (used_g == fg) {
1018 return true;
1019 }
1020 if (used_total.count(used_g) == 0) {
1021 todo_new.push_back(used_g);
1022 }
1023 used_total.add(used_g);
1024 }
1025 }
1026 todo = todo_new;
1027 }
1028 return false;
1029 }
1030
RealRecompute(FuncGraphPtr fg)1031 void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
1032 this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
1033 }
1034
CheckRecursiveGraphs(const FuncGraphPtr & fg,std::list<FuncGraphPtr> * trace)1035 void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
1036 MS_EXCEPTION_IF_NULL(trace);
1037 auto res = std::find(trace->begin(), trace->end(), fg);
1038 // find recursive
1039 if (res != trace->end()) {
1040 auto recur_ptr = std::make_shared<std::list<FuncGraphPtr>>(res, trace->end());
1041 for (auto iter = res; iter != trace->end(); (void)iter++) {
1042 MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString();
1043 recursive_map_[*iter] = recur_ptr;
1044 }
1045 } else {
1046 trace->push_back(fg);
1047 auto &items = fg->func_graphs_used();
1048 for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
1049 CheckRecursiveGraphs(iter->first, trace);
1050 }
1051 trace->pop_back();
1052 if (!recursive_map_.count(fg)) {
1053 recursive_map_[fg] = nullptr;
1054 }
1055 }
1056 }
1057
SeekJ(const FuncGraphPtr & fg,size_t seen_num)1058 bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
1059 MS_EXCEPTION_IF_NULL(fg);
1060 if (fg->seen_ == seen_num) {
1061 MS_LOG(DEBUG) << fg->ToString() << " had been checked";
1062 return false;
1063 }
1064
1065 // Check J FuncGraph input.
1066 const auto &j_values = fg->j_value_nodes();
1067 if (!j_values.empty()) {
1068 auto contains_j =
1069 std::find_if(j_values.begin(), j_values.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
1070 // Check g1->J(fg)->g2->g cycle.
1071 if (IsValueNode<FuncGraph>(iter.first)) {
1072 auto func_graph = GetValueNode<FuncGraphPtr>(iter.first);
1073 return func_graph->seen_ != seen_num;
1074 }
1075 if (IsValueNode<Primitive>(iter.first)) {
1076 // Exclude the primitive of J itself.
1077 auto prim = GetValueNode<PrimitivePtr>(iter.first);
1078 return prim->name() != prim::kPrimJ->name();
1079 }
1080 return false;
1081 });
1082 if (contains_j != j_values.end()) {
1083 MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->DebugString() << ")";
1084 return true;
1085 }
1086 }
1087
1088 // Check J CNode as FV.
1089 const auto &fv_nodes = fg->free_variables();
1090 if (!fv_nodes.empty()) {
1091 auto contains_j_cnode =
1092 std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> &iter) {
1093 // Check if the FV is a J call CNode.
1094 if (IsPrimitiveCNode(iter.first, prim::kPrimJ)) {
1095 return true;
1096 }
1097 return false;
1098 });
1099 if (contains_j_cnode != fv_nodes.end()) {
1100 MS_LOG(DEBUG) << fg->ToString() << " contains FV J(" << contains_j_cnode->first->DebugString() << ")";
1101 return true;
1102 }
1103 }
1104
1105 // Check if func graphs used contains J(func_graph) or J(Primitive)
1106 fg->seen_ = seen_num;
1107 for (auto &item : fg->func_graphs_used()) {
1108 auto used_g = item.first;
1109 if (SeekJ(used_g, seen_num)) {
1110 MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
1111 << " which contains J(func_graph) or J(Primitive)";
1112 return true;
1113 }
1114 }
1115 MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph) or J(Primitive)";
1116 return false;
1117 }
1118
RealRecompute(FuncGraphPtr fg)1119 void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
1120 this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration());
1121 }
1122 } // namespace mindspore
1123