• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/manager.h"
20 
21 #include <algorithm>
22 #include <list>
23 
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ir/func_graph.h"
27 #include "utils/convert_utils_base.h"
28 #include "utils/counter.h"
29 #include "utils/trace_base.h"
30 #include "utils/ms_context.h"
31 
32 namespace mindspore {
33 namespace change {
34 
35 struct Edge {
36   CNodePtr cnode;
37   int index;
38   AnfNodePtr input;
Edgemindspore::change::Edge39   Edge(const CNodePtr &cnode, int index, const AnfNodePtr &input) : cnode(cnode), index(index), input(input) {}
40   ~Edge() = default;
41 };
42 
43 struct EdgeHash {
operator ()mindspore::change::EdgeHash44   std::size_t operator()(const Edge &e) const noexcept {
45     const std::hash<AnfNodePtr> node_hash;
46     return hash_combine({node_hash(e.cnode), IntToSize(e.index), node_hash(e.input)});
47   }
48 };
49 
50 struct EdgeEqual {
operator ()mindspore::change::EdgeEqual51   bool operator()(const Edge &lhs, const Edge &rhs) const noexcept {
52     return lhs.cnode == rhs.cnode && lhs.index == rhs.index && lhs.input == rhs.input;
53   }
54 };
55 
56 using EdgeCounter = Counter<Edge, EdgeHash, EdgeEqual>;
57 using NodeCounter = Counter<AnfNodePtr>;
58 
59 struct ChangeCounter {
60   EdgeCounter new_edges;
61   EdgeCounter del_edges;
62   NodeCounter new_nodes;
63   NodeCounter del_nodes;
64 
65   template <typename Func>
ForEachAddedEdgesmindspore::change::ChangeCounter66   void ForEachAddedEdges(Func &&func) {
67     new_edges.subtract_by(del_edges, std::forward<Func>(func));
68   }
69 
70   template <typename Func>
ForEachRemovedEdgesmindspore::change::ChangeCounter71   void ForEachRemovedEdges(Func &&func) {
72     del_edges.subtract_by(new_edges, std::forward<Func>(func));
73   }
74 
GetAddedNodesmindspore::change::ChangeCounter75   std::vector<AnfNodePtr> GetAddedNodes() { return new_nodes.subtract(del_nodes); }
GetRemovedNodesmindspore::change::ChangeCounter76   std::vector<AnfNodePtr> GetRemovedNodes() { return del_nodes.subtract(new_nodes); }
77 };
78 
79 class SetEdge : public Change {
80  public:
SetEdge(const CNodePtr & cnode,int index,const AnfNodePtr & input)81   SetEdge(const CNodePtr &cnode, int index, const AnfNodePtr &input) : edge_{cnode, index, input} {}
82   ~SetEdge() override = default;
83 
Apply(ChangeCounter * counter)84   void Apply(ChangeCounter *counter) override {
85     auto &old_input = edge_.cnode->input(IntToSize(edge_.index));
86     counter->del_nodes.add(old_input);
87     counter->del_edges.add(edge_.cnode, edge_.index, old_input);
88     edge_.cnode->set_input(IntToSize(edge_.index), edge_.input);
89     counter->new_nodes.add(edge_.input);
90     counter->new_edges.add(std::move(edge_));
91   }
92 
93  private:
94   Edge edge_;
95 };
96 
97 class AddEdge : public Change {
98  public:
AddEdge(const CNodePtr & cnode,const AnfNodePtr & input)99   AddEdge(const CNodePtr &cnode, const AnfNodePtr &input) : cnode_{cnode}, input_{input} {}
100   ~AddEdge() override = default;
101 
Apply(ChangeCounter * counter)102   void Apply(ChangeCounter *counter) override {
103     int index = static_cast<int>(cnode_->size());
104     cnode_->add_input(input_);
105     counter->new_nodes.add(input_);
106     counter->new_edges.add(std::move(cnode_), index, std::move(input_));
107   }
108 
109  private:
110   CNodePtr cnode_;
111   AnfNodePtr input_;
112 };
113 
114 class SetParams : public Change {
115  public:
SetParams(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & params)116   SetParams(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &params)
117       : func_graph_{func_graph}, params_{params} {}
118   ~SetParams() override = default;
119 
Apply(ChangeCounter * counter)120   void Apply(ChangeCounter *counter) override {
121     auto &old_params = func_graph_->parameters();
122     for (auto &p : old_params) {
123       counter->del_nodes.add(p);
124     }
125     func_graph_->set_parameters(params_);
126     for (auto &p : params_) {
127       counter->new_nodes.add(std::move(p));
128     }
129   }
130 
131  private:
132   FuncGraphPtr func_graph_;
133   std::vector<AnfNodePtr> params_;
134 };
135 
136 class AddParam : public Change {
137  public:
AddParam(const FuncGraphPtr & func_graph,const ParameterPtr & param)138   AddParam(const FuncGraphPtr &func_graph, const ParameterPtr &param) : func_graph_{func_graph}, param_{param} {}
139   ~AddParam() override = default;
140 
Apply(ChangeCounter * counter)141   void Apply(ChangeCounter *counter) override {
142     func_graph_->append_parameter(param_);
143     counter->new_nodes.add(std::move(param_));
144   }
145 
146  private:
147   FuncGraphPtr func_graph_;
148   ParameterPtr param_;
149 };
150 
151 class InsertFrontParam : public Change {
152  public:
InsertFrontParam(const FuncGraphPtr & func_graph,const ParameterPtr & param)153   InsertFrontParam(const FuncGraphPtr &func_graph, const ParameterPtr &param)
154       : func_graph_{func_graph}, param_{param} {}
155   ~InsertFrontParam() override = default;
156 
Apply(ChangeCounter * counter)157   void Apply(ChangeCounter *counter) override {
158     func_graph_->PrependParameter(param_);
159     counter->new_nodes.add(std::move(param_));
160   }
161 
162  private:
163   FuncGraphPtr func_graph_;
164   ParameterPtr param_;
165 };
166 
167 }  // namespace change
168 
MakeManager(const std::vector<FuncGraphPtr> & func_graphs,bool manage,bool drop_unused_graph)169 FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage, bool drop_unused_graph) {
170   auto m = std::make_shared<FuncGraphManager>(func_graphs, manage, drop_unused_graph);
171   m->Init();
172   return m;
173 }
174 
Manage(const std::vector<FuncGraphPtr> & func_graphs,bool manage,bool drop_unused_graph)175 FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage, bool drop_unused_graph) {
176   FuncGraphManagerPtr m = nullptr;
177   bool root = false;
178 
179   for (auto &fg : func_graphs) {
180     if (fg == nullptr) {
181       continue;
182     }
183     if (fg->manager() != nullptr) {
184       m = fg->manager();
185       break;
186     }
187   }
188 
189   if (m == nullptr) {
190     std::vector<FuncGraphPtr> tmp;
191     m = MakeManager(tmp, manage, drop_unused_graph);
192     root = true;
193   }
194 
195   for (auto &fg : func_graphs) {
196     if (fg == nullptr) {
197       continue;
198     }
199     m->AddFuncGraph(fg, root);
200   }
201   return m;
202 }
203 
Manage(FuncGraphPtr func_graph,bool manage,bool drop_unused_graph)204 FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage, bool drop_unused_graph) {
205   std::vector<FuncGraphPtr> func_graphs = {func_graph};
206   return Manage(func_graphs, manage, drop_unused_graph);
207 }
208 
FuncGraphManager(const std::vector<FuncGraphPtr> & roots,bool manage,bool drop_unused_graph)209 FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage, bool drop_unused_graph)
210     : roots_(roots), is_manage_(manage), drop_unused_graph_(drop_unused_graph) {
211   Reset();
212 }
213 
~FuncGraphManager()214 FuncGraphManager::~FuncGraphManager() {
215   if (is_manage_) {
216     RemoveRoots();
217   }
218   Clear();
219 }
220 
Reset()221 void FuncGraphManager::Reset() {
222   func_graphs_ = FuncGraphSet();
223   all_nodes_ = AnfNodeSet();
224   node_users_ = NodeUsersMap();
225   signals_ = std::make_shared<Signals>();
226   func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
227   func_graph_parent_ = std::make_shared<ParentComputer>(this);
228   children_ = std::make_shared<ChildrenComputer>(this);
229   scopes_ = std::make_shared<ScopeComputer>(this);
230   free_variables_total_ = std::make_shared<FVTotalComputer>(this);
231   func_graphs_used_total_ = std::make_shared<FuncGraphsUsedTotalComputer>(this);
232   recursive_ = std::make_shared<RecursiveComputer>(this);
233   meta_fg_prim_total_ = std::make_shared<FuncGraphMetaFgPrimTotalComputer>(this);
234 }
235 
Init()236 void FuncGraphManager::Init() {
237   auto roots = roots_;
238   roots_ = FuncGraphSet();
239 
240   for (auto &fg : roots) {
241     AddFuncGraph(fg, true);
242   }
243 }
244 
func_graph_parents_total(const FuncGraphPtr & fg) const245 FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
246   if (fg == nullptr) {
247     MS_LOG(INTERNAL_EXCEPTION) << "The parameter 'fg' should not be null.";
248   }
249   MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
250   func_graph_parents_total_->Recompute(fg);
251   MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
252   return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
253 }
254 
parent(const FuncGraphPtr & fg) const255 FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
256   MS_EXCEPTION_IF_NULL(fg);
257   MS_EXCEPTION_IF_NULL(func_graph_parent_);
258   MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
259   func_graph_parent_->Recompute(fg);
260   if (func_graph_parent_->parent_analysis().count(fg) == 0) {
261     MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
262     return nullptr;
263   }
264   MS_LOG(DEBUG) << "End parents func graph " << fg->ToString();
265   return func_graph_parent_->parent_analysis()[fg];
266 }
267 
children(const FuncGraphPtr & fg) const268 FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
269   MS_EXCEPTION_IF_NULL(fg);
270   MS_EXCEPTION_IF_NULL(children_);
271   MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
272   children_->Recompute(fg);
273   return children_->children_analysis()[fg];
274 }
275 
scopes(const FuncGraphPtr & fg) const276 FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
277   MS_EXCEPTION_IF_NULL(fg);
278   MS_EXCEPTION_IF_NULL(scopes_);
279   MS_LOG(DEBUG) << "Start scopes func graph: " << fg->ToString();
280   scopes_->Recompute(fg);
281   MS_LOG(DEBUG) << "End scopes func graph: " << fg->ToString();
282   return scopes_->scope_analysis()[fg];
283 }
284 
free_variables_total() const285 FVTotalMap &FuncGraphManager::free_variables_total() const {
286   MS_EXCEPTION_IF_NULL(free_variables_total_);
287   free_variables_total_->Recompute();
288   return free_variables_total_->fv_total_analysis();
289 }
290 
func_graphs_used_total(const FuncGraphPtr & fg) const291 FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
292   MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
293   func_graphs_used_total_->Recompute(fg);
294   return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
295 }
296 
recursive(const FuncGraphPtr & fg) const297 bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
298   MS_EXCEPTION_IF_NULL(fg);
299   MS_EXCEPTION_IF_NULL(recursive_);
300   recursive_->Recompute(fg);
301   if (recursive_->recursive_analysis().count(fg) == 0) {
302     MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
303     return false;
304   }
305   return recursive_->recursive_analysis()[fg];
306 }
307 
recursive_graphs(const FuncGraphPtr & fg) const308 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
309   MS_EXCEPTION_IF_NULL(fg);
310   MS_EXCEPTION_IF_NULL(recursive_);
311   if (recursive(fg)) {
312     if (recursive_->recursive_map().count(fg) == 0) {
313       auto trace = std::list<FuncGraphPtr>();
314       recursive_->CheckRecursiveGraphs(fg, &trace);
315     }
316     if (recursive_->recursive_map().count(fg) == 0) {
317       MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
318       return nullptr;
319     }
320     return recursive_->recursive_map()[fg];
321   } else {
322     return nullptr;
323   }
324 }
325 
326 // Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap and kPrimTaylor.
func_graph_meta_fg_prim_total(const FuncGraphPtr & fg) const327 bool FuncGraphManager::func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const {
328   MS_EXCEPTION_IF_NULL(meta_fg_prim_total_);
329   MS_EXCEPTION_IF_NULL(fg);
330   meta_fg_prim_total_->Recompute(fg);
331   if (meta_fg_prim_total_->meta_fg_prim_total_analysis().count(fg) == 0) {
332     MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
333     return false;
334   }
335   return meta_fg_prim_total_->meta_fg_prim_total_analysis()[fg];
336 }
337 
338 // Add a func graph to this manager, optionally as a root func graph.
AddFuncGraph(const FuncGraphPtr & func_graph,bool is_root)339 void FuncGraphManager::AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root) {
340   MS_EXCEPTION_IF_NULL(func_graph);
341   if (is_root) {
342     roots_.add(func_graph);
343     return AddFuncGraphs(func_graph);
344   }
345 
346   if (func_graphs_.contains(func_graph)) {
347     return;
348   }
349 
350   // Add func_graph as a managed graph.
351   AddIntoManaged(func_graph);
352 
353   // New nodes to be acquired.
354   std::vector<AnfNodePtr> new_nodes = func_graph->parameters();
355   auto return_node = func_graph->get_return();
356   if (return_node != nullptr) {
357     (void)new_nodes.emplace_back(std::move(return_node));
358   } else {
359     MS_LOG(INFO) << "The func graph " << func_graph->ToString() << " has no return node.";
360   }
361 
362   // Acquire all nodes from func_graph.
363   AcquireNodes(std::move(new_nodes));
364 }
365 
366 // Add all func graphs from the root func graph.
AddFuncGraphs(const FuncGraphPtr & source_func_graph)367 void FuncGraphManager::AddFuncGraphs(const FuncGraphPtr &source_func_graph) {
368   MS_EXCEPTION_IF_NULL(source_func_graph);
369   todo_.clear();
370   todo_.emplace_back(source_func_graph);
371   while (!todo_.empty()) {
372     auto func_graph = todo_.front();
373     MS_EXCEPTION_IF_NULL(func_graph);
374     todo_.pop_front();
375     if (func_graphs_.contains(func_graph)) {
376       continue;
377     }
378 
379     // Add func_graph as a managed graph.
380     AddIntoManaged(func_graph);
381 
382     // New nodes to be acquired.
383     std::vector<AnfNodePtr> new_nodes = func_graph->parameters();
384     auto return_node = func_graph->get_return();
385     if (return_node != nullptr) {
386       (void)new_nodes.emplace_back(std::move(return_node));
387     } else {
388       MS_LOG(INFO) << "The func graph " << func_graph->ToString() << " has no return node.";
389     }
390 
391     // Acquire all nodes from func_graph.
392     AcquireNodes(std::move(new_nodes), false);
393   }
394 }
395 
396 // Drop the func graph
DropFuncGraph(const FuncGraphPtr & fg,bool force)397 void FuncGraphManager::DropFuncGraph(const FuncGraphPtr &fg, bool force) {
398   if (force || (is_manage_ && drop_unused_graph_ && !fg->reserved())) {
399     MS_LOG(INFO) << "Drop " << fg << "/" << fg->ToString() << ", use_count: " << fg.use_count()
400                  << ", type: " << fg->type_name();
401     fg->ResetReturnOwner();
402     fg->ResetOwnNodes();
403     fg->set_dropped(true);
404   }
405 }
406 
407 // Clear the all information in manager
Clear()408 void FuncGraphManager::Clear() noexcept {
409   roots_.clear();
410   for (auto &fg : func_graphs_) {
411     MS_EXCEPTION_IF_NULL(fg);
412     fg->DecAttachedMngCnt();
413     if (fg->attached_mng_cnt() == 0) {
414       fg->ClearAllResource();
415       DropFuncGraph(fg);
416     } else if (fg->attached_mng_cnt() < 0) {
417       MS_LOG(INTERNAL_EXCEPTION) << "The func graph '" << fg->ToString()
418                                  << "' attached cnt not right: " << fg->attached_mng_cnt();
419     }
420   }
421   func_graphs_.clear();
422 
423   all_nodes_.clear();
424   node_users_.clear();
425   todo_.clear();
426 
427   signals_->InvalidateComputer();
428 }
429 
KeepRoots(const std::vector<FuncGraphPtr> & func_graphs)430 void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
431   MS_LOG(DEBUG) << "Start keep roots";
432   bool root_exist = false;
433   for (auto &item : func_graphs) {
434     if (roots_.contains(item)) {
435       root_exist = true;
436       break;
437     }
438   }
439 
440   // If the new_root in roots_, we add new_root first, then calculate the func_graphs
441   // relation to new_root, remove the func_graphs not relation to new_root
442   // if the new_root not in roots_, we clear the all func_graphs in manager
443   // then add the new_root
444   if (root_exist || func_graphs.empty()) {
445     FuncGraphSet roots(func_graphs);
446     if (roots.empty()) {
447       roots = roots_;
448     } else {
449       roots_.clear();
450       for (auto &item : roots) {
451         AddFuncGraph(item, true);
452       }
453     }
454 
455     FuncGraphSet keep;
456     for (auto &item : roots) {
457       MS_LOG(DEBUG) << "roots: " << item->ToString();
458       keep.update(func_graphs_used_total(item));
459 #ifdef DEBUG
460       for (auto &k : keep) {
461         MS_LOG(DEBUG) << "keep: " << k->ToString();
462       }
463 #endif
464     }
465     MaybeDropFuncGraphs(func_graphs_ - keep, true);
466   } else {
467     Clear();
468     FuncGraphSet roots(func_graphs);
469     for (auto &item : roots) {
470       AddFuncGraph(item, true);
471     }
472   }
473 }
474 
RemoveRoots()475 void FuncGraphManager::RemoveRoots() {
476   MS_LOG(DEBUG) << "Start remove roots";
477   roots_.clear();
478   try {
479     MaybeDropFuncGraphs(func_graphs_, true);
480   } catch (std::exception &e) {
481     MS_LOG(DEBUG) << e.what();
482   }
483 }
484 
AddIntoManaged(const FuncGraphPtr & fg)485 void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
486   MS_EXCEPTION_IF_NULL(fg);
487   if (is_manage_) {
488     if (fg->manager() != nullptr && fg->manager().get() != this) {
489       MS_LOG(INFO) << "A func graph can only have one manager.";
490     }
491     fg->set_manager(shared_from_this());
492   }
493   func_graphs_.add(fg);
494   fg->IncAttachedMngCnt();
495 }
496 
MaybeDropFuncGraphs(const FuncGraphSet & func_graphs,bool ignore_users)497 void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
498   std::list<FuncGraphPtr> todo(func_graphs.begin(), func_graphs.end());
499   std::set<FuncGraphPtr> dropped;
500   while (!todo.empty()) {
501     FuncGraphPtr func_graph = std::move(todo.front());
502     MS_EXCEPTION_IF_NULL(func_graph);
503     todo.pop_front();
504     MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString();
505     if (roots_.contains(func_graph)) {
506       MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
507       continue;
508     }
509     auto &users_cnode_index = func_graph->func_graph_cnodes_index();
510     if (!users_cnode_index.empty() && !ignore_users) {
511       MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
512       continue;
513     }
514     if (dropped.find(func_graph) != dropped.end()) {
515       MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString();
516       continue;
517     }
518     (void)dropped.insert(func_graph);
519     MS_EXCEPTION_IF_NULL(func_graph->get_return());
520     std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
521     auto drop_graphs = MaybeDropNodes(std::move(return_vec));
522     (void)todo.insert(todo.end(), drop_graphs.begin(), drop_graphs.end());
523   }
524   for (auto &fg : dropped) {
525     MS_EXCEPTION_IF_NULL(fg);
526     all_nodes_.difference_update(fg->parameters());
527     EraseOneGraph(fg);
528     if (fg->manager().get() == this) {
529       fg->set_manager(nullptr);
530     }
531     MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString();
532   }
533 }
534 
ProcessEdgeAdd(const AnfNodePtr & node,int index,const AnfNodePtr & input)535 void FuncGraphManager::ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
536   if (IsValueNode<FuncGraph>(input)) {
537     AddFuncGraph(GetValueNode<FuncGraphPtr>(input));
538   }
539   auto &users_node = node_users_[input];
540   users_node.add(std::make_pair(node, index));
541   OnEdgeAdded(node, index, input);
542 }
543 
ProcessEdgeRemove(const AnfNodePtr & node,int index,const AnfNodePtr & input)544 void FuncGraphManager::ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
545   auto iter = node_users_.find(input);
546   if (iter == node_users_.end()) {
547     return;
548   }
549   bool removed = iter->second.erase(std::make_pair(node, index));
550   if (removed) {
551     OnEdgeRemoved(node, index, input);
552   }
553 }
554 
ProcessInputsEdgeAdd(const CNodePtr & cnode)555 void FuncGraphManager::ProcessInputsEdgeAdd(const CNodePtr &cnode) {
556   const size_t count = cnode->size();
557   for (size_t i = 0; i < count; ++i) {
558     ProcessEdgeAdd(cnode, static_cast<int>(i), cnode->input(i));
559   }
560 }
561 
ProcessInputsEdgeRemove(const CNodePtr & cnode)562 void FuncGraphManager::ProcessInputsEdgeRemove(const CNodePtr &cnode) {
563   const size_t count = cnode->size();
564   for (size_t i = 0; i < count; ++i) {
565     ProcessEdgeRemove(cnode, static_cast<int>(i), cnode->input(i));
566   }
567 }
568 
569 namespace {
FollowGraph(const FuncGraphPtr & fg,SeenNum seen,std::vector<AnfNodePtr> * nodes)570 inline void FollowGraph(const FuncGraphPtr &fg, SeenNum seen, std::vector<AnfNodePtr> *nodes) {
571   if (fg == nullptr) {
572     return;
573   }
574   if (auto res = fg->get_return(); res != nullptr && res->seen_ != seen) {
575     (void)nodes->emplace_back(std::move(res));
576   }
577 }
578 
FollowInputs(const CNodePtr & cnode,std::vector<AnfNodePtr> * nodes)579 inline void FollowInputs(const CNodePtr &cnode, std::vector<AnfNodePtr> *nodes) {
580   auto &weak_inputs = cnode->weak_inputs();
581   (void)std::transform(weak_inputs.cbegin(), weak_inputs.cend(), std::back_inserter(*nodes),
582                        [](const AnfNodeWeakPtr &weak_node) {
583                          auto node = weak_node.lock();
584                          MS_EXCEPTION_IF_NULL(node);
585                          return node;
586                        });
587 }
588 }  // namespace
589 
AcquireNodes(std::vector<AnfNodePtr> && nodes,bool recursive)590 void FuncGraphManager::AcquireNodes(std::vector<AnfNodePtr> &&nodes, bool recursive) {
591   auto seen = NewSeenGeneration();
592   while (!nodes.empty()) {
593     // Take the last one.
594     auto node = std::move(nodes.back());
595     nodes.pop_back();
596     MS_EXCEPTION_IF_NULL(node);
597     // Skip visited nodes.
598     if (node->seen_ == seen) {
599       continue;
600     }
601     node->seen_ = seen;
602     // Try add it to all_nodes_.
603     auto insert_result = all_nodes_.insert(node);
604     if (insert_result.second == false) {
605       // Skip acquired nodes.
606       continue;
607     }
608     // Add node to its func_graph.
609     auto fg = node->func_graph();
610     if (fg != nullptr) {
611       fg->AddNode(node);
612     }
613     // Follow graph for value node.
614     if (node->isa<ValueNode>()) {
615       auto graph = GetValueNode<FuncGraphPtr>(node);
616       FollowGraph(graph, seen, &nodes);
617       continue;
618     }
619     // Follow graph for cnode or parameter.
620     FollowGraph(fg, seen, &nodes);
621 
622     // Handle CNode.
623     auto cnode = node->cast<CNodePtr>();
624     if (cnode == nullptr) {
625       continue;
626     }
627 
628     // Handle input edges.
629     if (recursive) {
630       ProcessInputsEdgeAdd(cnode);
631       // Follow inputs.
632       FollowInputs(cnode, &nodes);
633       continue;
634     }
635     // The way not recursive.
636     for (size_t i = 0; i < cnode->size(); ++i) {
637       const auto &input = cnode->input(i);
638       if (input == nullptr) {
639         MS_LOG(INTERNAL_EXCEPTION) << "The input is null, " << cnode << "/" << cnode->DebugString() << "@" << fg << "/"
640                                    << fg->ToString();
641       }
642       if (IsValueNode<FuncGraph>(input)) {
643         todo_.emplace_back(GetValueNode<FuncGraphPtr>(input));
644       }
645 
646       auto &users_node = node_users_[input];
647       users_node.add(std::make_pair(node, i));
648       OnEdgeAdded(node, i, input);
649     }
650     // Follow inputs.
651     FollowInputs(cnode, &nodes);
652   }
653 }
654 
MaybeDropNodes(std::vector<AnfNodePtr> && nodes)655 FuncGraphSet FuncGraphManager::MaybeDropNodes(std::vector<AnfNodePtr> &&nodes) {
656   FuncGraphSet drop_func_graphs;
657   while (!nodes.empty()) {
658     AnfNodePtr node = std::move(nodes.back());
659     nodes.pop_back();
660     if (node == nullptr) {
661       // Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception,
662       // this method may be triggered by desctuctor.
663       MS_LOG(WARNING) << "Node to be dropped is nullptr";
664       continue;
665     }
666     if (!all_nodes_.contains(node)) {
667       // Node not existed.
668       continue;
669     }
670     auto &users = node_users_[node];
671     if (!users.empty()) {
672       // Node is in used.
673       continue;
674     }
675     if (node->isa<Parameter>() && node->func_graph() != nullptr) {
676       // Node is a used parameter.
677       auto &parameters = node->func_graph()->parameters();
678       if (std::find(parameters.begin(), parameters.end(), node) != parameters.end()) {
679         continue;
680       }
681     }
682     if (IsValueNode<FuncGraph>(node)) {
683       // The FuncGraph may need to be dropped.
684       auto fg = GetValueNode<FuncGraphPtr>(node);
685       drop_func_graphs.add(fg);
686     }
687     // Handle cnode.
688     if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
689       // Remove inputs edges.
690       ProcessInputsEdgeRemove(cnode);
691       // Handle inputs nodes.
692       FollowInputs(cnode, &nodes);
693     }
694     // Remove it from all_nodes_;
695     (void)all_nodes_.erase(node);
696     // Drop node from its func graph.
697     if (auto fg = node->func_graph(); fg != nullptr) {
698       fg->DropNode(node);
699     }
700     // Remove it from node_users.
701     (void)node_users_.erase(node);
702   }
703   return drop_func_graphs;
704 }
705 
SetParameters(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & parameters)706 void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters) {
707   auto tr = Transact();
708   tr.SetParameters(fg, parameters);
709   tr.Commit();
710 }
711 
AddParameter(const FuncGraphPtr & fg,const AnfNodePtr & parameter)712 void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter) {
713   auto tr = Transact();
714   tr.AddParameter(fg, parameter);
715   tr.Commit();
716 }
717 
InsertFrontParameter(const FuncGraphPtr & fg,const AnfNodePtr & parameter)718 void FuncGraphManager::InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter) {
719   auto tr = Transact();
720   tr.InsertFrontParameter(fg, parameter);
721   tr.Commit();
722 }
723 
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)724 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
725   MS_EXCEPTION_IF_NULL(old_node);
726   MS_EXCEPTION_IF_NULL(new_node);
727   auto tr = Transact();
728   bool success = tr.Replace(old_node, new_node);
729   if (success) {
730     tr.Commit();
731   }
732   return success;
733 }
734 
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node,const AnfNodePtr & mask_node)735 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const AnfNodePtr &mask_node) {
736   MS_EXCEPTION_IF_NULL(old_node);
737   MS_EXCEPTION_IF_NULL(new_node);
738   auto tr = Transact();
739   bool success = tr.Replace(old_node, new_node, mask_node);
740   if (success) {
741     tr.Commit();
742   }
743   return success;
744 }
745 
SetEdge(const AnfNodePtr & node,int index,const AnfNodePtr & value)746 void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
747   auto tr = Transact();
748   tr.SetEdge(node, index, value);
749   tr.Commit();
750 }
751 
AddEdge(const AnfNodePtr & node,const AnfNodePtr & value)752 void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
753   auto tr = Transact();
754   tr.AddEdge(node, value);
755   tr.Commit();
756 }
757 
MoveAllCNodeDropGraph(const FuncGraphPtr & source,const FuncGraphPtr & target,const AnfNodePtr & call_node,const ScopePtr & scope,bool update_debug_info)758 void FuncGraphManager::MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target,
759                                              const AnfNodePtr &call_node, const ScopePtr &scope,
760                                              bool update_debug_info) {
761   MS_EXCEPTION_IF_NULL(source);
762   CNodePtr source_return = source->get_return();
763   MS_EXCEPTION_IF_NULL(source_return);
764   AnfNodePtr source_output = source->output();
765   const auto &source_prim = source_return->input(0);
766 
767   int index = 0;
768   (void)node_users_[source_prim].erase(make_pair(source_return, index));
769   OnEdgeRemoved(source_return, index, source_prim);
770   index = 1;
771   (void)node_users_[source_output].erase(make_pair(source_return, index));
772   OnEdgeRemoved(source_return, index, source_output);
773   (void)all_nodes_.erase(source_return);
774   (void)node_users_.erase(source_return);
775   source->DropNode(source_return);
776   for (auto &node : source->nodes()) {
777     node->set_func_graph(target);
778     if (node->scope() == kDefaultScope) {
779       node->set_scope(scope);
780     }
781     if (update_debug_info && node->isa<CNode>()) {
782       MS_LOG(DEBUG) << "call_node: " << call_node << "/" << call_node->DebugString() << ", node: " << node << "/"
783                     << node->DebugString();
784       UpdateInlineCNodeDebugInfo(call_node, node);
785     }
786   }
787 
788   MoveAllNodes(source, target);
789   all_nodes_.difference_update(source->parameters());
790   EraseOneGraph(source);
791   source->set_dropped(true);
792   if (source->manager().get() == this) {
793     source->set_manager(nullptr);
794   }
795 
796   if (source->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
797     target->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
798   }
799   if (source->has_flag(kTraining)) {
800     target->set_flag(kTraining, true);
801   }
802 }
803 
OnEdgeAdded(const AnfNodePtr & node,int index,const AnfNodePtr & input)804 void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
805   auto fg = node->func_graph();
806   if (input->isa<ValueNode>()) {
807     fg->AddValueNode(input);
808     if (IsValueNode<FuncGraph>(input)) {
809       auto used = GetValueNode<FuncGraphPtr>(input);
810       used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
811       if (fg->AddFuncGraphUsed(used)) {
812         signals_->InvalidateComputer();
813       }
814     }
815     if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
816         IsPrimitiveCNode(node, prim::kPrimTaylor) || IsPrimitiveCNode(node, prim::kPrimShard)) {
817       fg->AddMetaFgPrimValueNode(input);
818     }
819   } else if (IsPrimitiveCNode(node, prim::kPrimVmap) && IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
820     // To handle the model ensembling scenario in vmap, whose input is a celllist, taking an arbitrary function graph
821     // is sufficient.
822     constexpr int64_t kIndex1 = 1;
823     auto func_union = dyn_cast<CNode>(input);
824     if (IsValueNode<FuncGraph>(func_union->input(kIndex1))) {
825       fg->AddMetaFgPrimValueNode(func_union->input(kIndex1));
826     }
827   } else if (fg != nullptr && fg != input->func_graph()) {
828     if (fg->AddFreeVariable(input)) {
829       signals_->InvalidateComputer();
830     }
831   }
832 }
833 
OnEdgeRemoved(const AnfNodePtr & node,int index,const AnfNodePtr & input)834 void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
835   auto fg = node->func_graph();
836   if (fg != nullptr && input->isa<ValueNode>()) {
837     fg->DropValueNode(input);
838     if (IsValueNode<FuncGraph>(input)) {
839       auto used = GetValueNode<FuncGraphPtr>(input);
840       used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
841       if (fg->DropFuncGraphUsed(used)) {
842         signals_->InvalidateComputer();
843       }
844     }
845     if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
846         IsPrimitiveCNode(node, prim::kPrimTaylor)) {
847       fg->DropMetaFgPrimValueNode(input);
848     }
849   } else if (fg != nullptr && fg != input->func_graph()) {
850     if (fg->DropFreeVariable(input)) {
851       signals_->InvalidateComputer();
852     }
853   }
854 }
855 
MoveAllNodes(const FuncGraphPtr & source,const FuncGraphPtr & target)856 void FuncGraphManager::MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target) {
857   target->CopyNodes(source);
858   target->CopyValueNodes(source);
859   target->CopyFuncGraphCNodesIndex(source);
860   target->CopyFreeVariables(source);
861   target->CopyFuncGraphsUsed(source);
862   target->CopyMetaFgPrimValueNodes(source);
863   source->ClearAllResource();
864   signals_->InvalidateComputer();
865 }
866 
CommitChanges(std::vector<change::ChangePtr> && changes)867 void FuncGraphManager::CommitChanges(std::vector<change::ChangePtr> &&changes) {
868   // Apply changes.
869   change::ChangeCounter counter;
870   for (auto &change : changes) {
871     change->Apply(&counter);
872   }
873   changes.clear();
874 
875   // Process added edges.
876   counter.ForEachAddedEdges([this](const change::Edge &edge) {  //
877     ProcessEdgeAdd(edge.cnode, edge.index, edge.input);
878   });
879 
880   // Process added nodes.
881   AcquireNodes(counter.GetAddedNodes());
882 
883   // Process removed edges.
884   counter.ForEachRemovedEdges([this](const change::Edge &edge) {  //
885     ProcessEdgeRemove(edge.cnode, edge.index, edge.input);
886   });
887 
888   // Process removed nodes.
889   auto drop_func_graphs = MaybeDropNodes(counter.GetRemovedNodes());
890   if (!drop_func_graphs.empty()) {
891     MaybeDropFuncGraphs(drop_func_graphs);
892   }
893 }
894 
EraseOneGraph(const FuncGraphPtr & fg)895 void FuncGraphManager::EraseOneGraph(const FuncGraphPtr &fg) {
896   MS_EXCEPTION_IF_NULL(fg);
897   bool erase_ret = func_graphs_.erase(fg->shared_from_base<FuncGraph>());
898   if (!erase_ret) {
899     return;
900   }
901   fg->DecAttachedMngCnt();
902   if (fg->attached_mng_cnt() == 0) {
903     fg->ClearAllResource();
904   }
905 }
906 
SetParameters(FuncGraphPtr fg,const std::vector<AnfNodePtr> & params)907 void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params) {
908   (void)changes_.emplace_back(std::make_unique<change::SetParams>(fg, params));
909 }
910 
AddParameter(FuncGraphPtr fg,const AnfNodePtr & param)911 void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr &param) {
912   (void)changes_.emplace_back(std::make_unique<change::AddParam>(fg, param->cast<ParameterPtr>()));
913 }
914 
InsertFrontParameter(FuncGraphPtr fg,const AnfNodePtr & param)915 void FuncGraphTransaction::InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr &param) {
916   (void)changes_.emplace_back(std::make_unique<change::InsertFrontParam>(fg, param->cast<ParameterPtr>()));
917 }
918 
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)919 bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
920   MS_EXCEPTION_IF_NULL(old_node);
921   MS_EXCEPTION_IF_NULL(new_node);
922   FuncGraphPtr old_func_graph = old_node->func_graph();
923   if (old_func_graph != nullptr && old_func_graph->get_return() != nullptr &&
924       old_func_graph->get_return() == old_node) {
925     MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
926     return false;
927   }
928   auto &users = manager_->node_users()[old_node];
929   for (auto &node : users) {
930     SetEdge(node.first, node.second, new_node);
931   }
932   return true;
933 }
934 
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node,const AnfNodePtr & mask_node)935 bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node,
936                                    const AnfNodePtr &mask_node) {
937   MS_EXCEPTION_IF_NULL(old_node);
938   MS_EXCEPTION_IF_NULL(new_node);
939   FuncGraphPtr old_func_graph = old_node->func_graph();
940   if (old_func_graph != nullptr && old_func_graph->get_return() != nullptr &&
941       old_func_graph->get_return() == old_node) {
942     MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
943     return false;
944   }
945   auto &users = manager_->node_users()[old_node];
946   for (auto &node : users) {
947     if (node.first == mask_node) {
948       SetEdge(node.first, node.second, new_node);
949     }
950   }
951   return true;
952 }
953 
SetEdge(const AnfNodePtr & src_node,int k,const AnfNodePtr & v)954 void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
955   if (k < 0) {
956     MS_LOG(INTERNAL_EXCEPTION) << "Invalid value k = " << k;
957   }
958   MS_EXCEPTION_IF_NULL(src_node);
959   auto cnode = src_node->cast<CNodePtr>();
960   if (cnode == nullptr) {
961     MS_LOG(INTERNAL_EXCEPTION) << "src_node should be a cnode, but cast failed.";
962   }
963   (void)changes_.emplace_back(std::make_unique<change::SetEdge>(cnode, k, v));
964 }
965 
AddEdge(const AnfNodePtr & src_node,const AnfNodePtr & v)966 void FuncGraphTransaction::AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v) {
967   MS_EXCEPTION_IF_NULL(src_node);
968   auto cnode = src_node->cast<CNodePtr>();
969   if (cnode == nullptr) {
970     MS_LOG(INTERNAL_EXCEPTION) << "src_node should be a cnode, but cast failed.";
971   }
972   (void)changes_.emplace_back(std::make_unique<change::AddEdge>(cnode, v));
973 }
974 
Commit()975 void FuncGraphTransaction::Commit() { manager_->CommitChanges(std::move(changes_)); }
976 
DepComputer(const FuncGraphManager * const manager)977 DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager), validate_(false) {
978   MS_EXCEPTION_IF_NULL(manager_);
979   manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
980 }
981 
Recompute()982 void DepComputer::Recompute() {
983   if (!validate_) {
984     RealRecompute();
985     validate_ = true;
986   }
987 }
988 
Recompute(const FuncGraphPtr & fg)989 void DepComputer::Recompute(const FuncGraphPtr &fg) {
990   if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
991     RealRecompute(fg);
992     func_graphs_validate_[fg] = true;
993   }
994 }
995 
SeekParents(const FuncGraphPtr & func_graph)996 FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &func_graph) {
997   MS_EXCEPTION_IF_NULL(func_graph);
998   constexpr auto out_call_stack = 0;
999   constexpr auto in_call_stack = 1;
1000   auto seen = NewFgSeenGeneration();
1001   func_graph->seen_ = seen;
1002   std::deque<FuncGraphPtr> todo;
1003   (void)todo.emplace_back(func_graph);
1004   FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
1005   while (!todo.empty()) {
1006     const auto &fg = todo.back();
1007     // For the func graph both in deque and in call stack, just pop it.
1008     if (fg->extra_seen_ == in_call_stack) {
1009       fg->extra_seen_ = out_call_stack;
1010       todo.pop_back();
1011       continue;
1012     }
1013 
1014     // Append all the fvs in fg.
1015     auto &fvs = fg->free_variables();
1016     for (const auto &fv : fvs) {
1017       const auto &fv_node = fv.first;
1018       MS_EXCEPTION_IF_NULL(fv_node);
1019       auto fv_func_graph = fv_node->func_graph();
1020       if (fv_func_graph == nullptr) {
1021         MS_LOG(INFO) << "Meet a FV '" << fv_node->DebugString() << "' whose func graph is null, during seeking for "
1022                      << fg->ToString() << "\nFV: " << trace::GetDebugInfoStr(fv_node->debug_info());
1023         continue;
1024       }
1025       // Found a parent if not in the call stack.
1026       if (fv_func_graph->extra_seen_ != in_call_stack) {
1027         parents->add(fv_func_graph);
1028       }
1029     }
1030 
1031     // Before push the used func graphs of 'fg', mark it as in call stack.
1032     fg->extra_seen_ = in_call_stack;
1033     // Add the fg's used func graph to search.
1034     auto &fgs = fg->func_graphs_used();
1035     for (auto &item : fgs) {
1036       auto &gt = item.first;
1037       if (gt->seen_ != seen) {
1038         gt->extra_seen_ = out_call_stack;
1039         (void)todo.emplace_back(gt);
1040         gt->seen_ = seen;
1041       }
1042     }
1043   }
1044   return parents;
1045 }
1046 
RealRecompute(FuncGraphPtr fg)1047 void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
1048   MS_EXCEPTION_IF_NULL(fg);
1049   auto parents = SeekParents(fg);
1050   func_graph_parents_total_analysis_[fg].update(parents);
1051 }
1052 
set_len_compare(const FuncGraphSetPair & lhs,const FuncGraphSetPair & rhs)1053 bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
1054   auto l1 = lhs.second.size();
1055   auto l2 = rhs.second.size();
1056   return l1 < l2;
1057 }
1058 
RealRecompute(FuncGraphPtr fg)1059 void ParentComputer::RealRecompute(FuncGraphPtr fg) {
1060   this->parent_analysis_[fg] = nullptr;
1061   // Note: must be a copy other than reference as it is modified thereafter.
1062   auto deps = this->manager_->func_graph_parents_total(fg);
1063   if (deps.empty()) {
1064     this->parent_analysis_[fg] = nullptr;
1065     return;
1066   } else if (deps.size() == 1) {
1067     this->parent_analysis_[fg] = deps.front();
1068     return;
1069   } else {
1070     // return nearest parent as parent
1071     FuncGraphSet deps_copy(deps);
1072     for (auto &dep : deps) {
1073       auto parent_deps = this->manager_->func_graph_parents_total(dep);
1074       for (auto &p_d : parent_deps) {
1075         if (deps_copy.count(p_d) > 0) {
1076           (void)deps_copy.erase(p_d);
1077         }
1078       }
1079       if (deps_copy.size() == 1) {
1080         this->parent_analysis_[fg] = deps_copy.front();
1081         return;
1082       }
1083     }
1084   }
1085 }
1086 
RealRecompute(FuncGraphPtr fg)1087 void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
1088   MS_EXCEPTION_IF_NULL(manager_);
1089   auto used_fg_total = manager_->func_graphs_used_total(fg);
1090   for (auto &used_fg : used_fg_total) {
1091     if (manager_->parent(used_fg) == fg) {
1092       children_analysis_[fg].add(used_fg);
1093     }
1094   }
1095 }
1096 
RealRecompute(FuncGraphPtr fg)1097 void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
1098   MS_EXCEPTION_IF_NULL(manager_);
1099   auto &children = manager_->children(fg);
1100 
1101   scope_analysis_[fg] = FuncGraphSet();
1102   scope_analysis_[fg].add(fg);
1103   for (auto &child : children) {
1104     scope_analysis_[fg].add(child);
1105   }
1106 }
1107 
RealRecompute()1108 void FVTotalComputer::RealRecompute() {
1109   auto manager = DepComputer::manager_;
1110   MS_EXCEPTION_IF_NULL(manager);
1111 
1112   for (auto &fg : manager->func_graphs()) {
1113     fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
1114   }
1115 
1116   for (auto &fg : manager->func_graphs()) {
1117     // add all free variable nodes
1118     AnfNodeCounterMap items = fg->free_variables();
1119     for (auto &iter : items) {
1120       auto curr = fg;
1121       while (curr != nullptr) {
1122         fv_total_analysis_[curr][iter.first] = iter.second;
1123         curr = manager->parent(curr);
1124         if (curr != nullptr) {
1125           const AnfNodeSet &all_nodes = curr->nodes();
1126           if (all_nodes.contains(iter.first)) {
1127             break;
1128           }
1129         }
1130       }
1131     }
1132 
1133     // add all FGs of free variables
1134     auto &used = fg->func_graphs_used();
1135     for (auto &iter : used) {
1136       auto p = manager->parent(iter.first);
1137       if (p == nullptr) {
1138         continue;
1139       }
1140       auto curr = fg;
1141       while (curr != nullptr && curr != p) {
1142         fv_total_analysis_[curr][iter.first] = iter.second;
1143         curr = manager->parent(curr);
1144       }
1145     }
1146   }
1147 }
1148 
RealRecompute(FuncGraphPtr fg)1149 void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
1150   MS_EXCEPTION_IF_NULL(manager_);
1151   std::vector<FuncGraphPtr> todo;
1152   std::vector<FuncGraphPtr> todo_new;
1153 
1154   todo.push_back(fg);
1155   while (!todo.empty()) {
1156     todo_new.clear();
1157     for (auto &gt : todo) {
1158       for (auto &item : gt->func_graphs_used()) {
1159         auto used_fg = item.first;
1160         if (used_fg == fg) {
1161           func_graph_used_total_analysis_[fg].add(used_fg);
1162           continue;
1163         }
1164         if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) {
1165           todo_new.push_back(used_fg);
1166         }
1167         MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString();
1168         func_graph_used_total_analysis_[fg].add(used_fg);
1169       }
1170     }
1171     todo = todo_new;
1172   }
1173 }
1174 
CheckRecursive(const FuncGraphManager * const manager,const FuncGraphPtr & fg)1175 bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
1176   MS_EXCEPTION_IF_NULL(manager);
1177   std::vector<FuncGraphPtr> todo;
1178   std::vector<FuncGraphPtr> todo_new;
1179   todo.push_back(fg);
1180   FuncGraphSet used_total;
1181   while (!todo.empty()) {
1182     todo_new.clear();
1183     for (auto &gt : todo) {
1184       for (auto &item : gt->func_graphs_used()) {
1185         auto used_g = item.first;
1186         if (used_g == fg) {
1187           return true;
1188         }
1189         if (used_total.count(used_g) == 0) {
1190           todo_new.push_back(used_g);
1191         }
1192         used_total.add(used_g);
1193       }
1194     }
1195     todo = todo_new;
1196   }
1197   return false;
1198 }
1199 
RealRecompute(FuncGraphPtr fg)1200 void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
1201   this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
1202 }
1203 
CheckRecursiveGraphs(const FuncGraphPtr & fg,std::list<FuncGraphPtr> * trace)1204 void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
1205   MS_EXCEPTION_IF_NULL(trace);
1206   auto res = std::find(trace->begin(), trace->end(), fg);
1207   // Find recursive
1208   if (res != trace->end()) {
1209     auto recur_ptr = std::make_shared<std::list<FuncGraphPtr>>(res, trace->end());
1210     for (auto iter = res; iter != trace->end(); (void)iter++) {
1211       MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString();
1212       recursive_map_[*iter] = recur_ptr;
1213     }
1214   } else {
1215     trace->push_back(fg);
1216     auto &items = fg->func_graphs_used();
1217     for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
1218       CheckRecursiveGraphs(iter->first, trace);
1219     }
1220     trace->pop_back();
1221     if (recursive_map_.count(fg) == 0) {
1222       recursive_map_[fg] = nullptr;
1223     }
1224   }
1225 }
1226 
SeekMetaFgPrim(const FuncGraphPtr & fg,SeenNum seen_num)1227 bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, SeenNum seen_num) {
1228   MS_EXCEPTION_IF_NULL(fg);
1229   if (fg->seen_ == seen_num) {
1230     MS_LOG(DEBUG) << fg->ToString() << " had been checked";
1231     return false;
1232   }
1233 
1234   // Check MetaFgPrim (J/Vmap/Taylor) FuncGraph input.
1235   const auto &meta_fg_prim_values = fg->meta_fg_prim_value_nodes();
1236   if (!meta_fg_prim_values.empty()) {
1237     auto contains_meta_fg_prim =
1238       std::find_if(meta_fg_prim_values.begin(), meta_fg_prim_values.end(), [seen_num](const auto &iter) {
1239         // Check g1->MetaFgPrim(fg)->g2->g cycle.
1240         if (IsValueNode<FuncGraph>(iter.first)) {
1241           auto func_graph = GetValuePtr<FuncGraph>(iter.first);
1242           return func_graph->seen_ != seen_num;
1243         }
1244         if (IsValueNode<Primitive>(iter.first)) {
1245           // Exclude the primitive of MetaFgPrim (J/Vmap/Taylor) itself.
1246           auto prim = GetValueNode<PrimitivePtr>(iter.first);
1247           return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name() &&
1248                   prim->name() != prim::kPrimTaylor->name());
1249         }
1250         return false;
1251       });
1252     if (contains_meta_fg_prim != meta_fg_prim_values.end()) {
1253       MS_EXCEPTION_IF_NULL(contains_meta_fg_prim->first);
1254       MS_LOG(DEBUG) << fg->ToString() << " contains MetaFgPrim(" << contains_meta_fg_prim->first->DebugString() << ")";
1255       return true;
1256     }
1257   }
1258 
1259   // Check MetaFgPrim (J/Vmap/Taylor) CNode as FV.
1260   const auto &fv_nodes = fg->free_variables();
1261   if (!fv_nodes.empty()) {
1262     auto contains_meta_fg_prim_cnode = std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const auto &iter) {
1263       // Check if the FV is a MetaFgPrim (J/Vmap/Taylor) call CNode.
1264       return IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap) ||
1265              IsPrimitiveCNode(iter.first, prim::kPrimTaylor);
1266     });
1267     if (contains_meta_fg_prim_cnode != fv_nodes.end()) {
1268       MS_EXCEPTION_IF_NULL(contains_meta_fg_prim_cnode->first);
1269       MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap/Taylor) ("
1270                     << contains_meta_fg_prim_cnode->first->DebugString() << ")";
1271       return true;
1272     }
1273   }
1274 
1275   // Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive),
1276   // Taylor(func_graph) or Taylor(Primitive).
1277   fg->seen_ = seen_num;
1278   for (auto &item : fg->func_graphs_used()) {
1279     auto used_g = item.first;
1280     MS_EXCEPTION_IF_NULL(used_g);
1281     if (SeekMetaFgPrim(used_g, seen_num)) {
1282       MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
1283                     << " which contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
1284                     << "Taylor(func_graph) or Taylor(Primitive)";
1285       return true;
1286     }
1287   }
1288   MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
1289                 << "Taylor(func_graph) or Taylor(Primitive)";
1290   return false;
1291 }
1292 
RealRecompute(FuncGraphPtr fg)1293 void FuncGraphMetaFgPrimTotalComputer::RealRecompute(FuncGraphPtr fg) {
1294   this->meta_fg_prim_total_analysis_[fg] = SeekMetaFgPrim(fg, NewFgSeenGeneration());
1295 }
1296 }  // namespace mindspore
1297