1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/manager.h"
20
21 #include <algorithm>
22 #include <list>
23
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ir/func_graph.h"
27 #include "utils/convert_utils_base.h"
28 #include "utils/counter.h"
29 #include "utils/trace_base.h"
30 #include "utils/ms_context.h"
31
32 namespace mindspore {
33 namespace change {
34
35 struct Edge {
36 CNodePtr cnode;
37 int index;
38 AnfNodePtr input;
Edgemindspore::change::Edge39 Edge(const CNodePtr &cnode, int index, const AnfNodePtr &input) : cnode(cnode), index(index), input(input) {}
40 ~Edge() = default;
41 };
42
43 struct EdgeHash {
operator ()mindspore::change::EdgeHash44 std::size_t operator()(const Edge &e) const noexcept {
45 const std::hash<AnfNodePtr> node_hash;
46 return hash_combine({node_hash(e.cnode), IntToSize(e.index), node_hash(e.input)});
47 }
48 };
49
50 struct EdgeEqual {
operator ()mindspore::change::EdgeEqual51 bool operator()(const Edge &lhs, const Edge &rhs) const noexcept {
52 return lhs.cnode == rhs.cnode && lhs.index == rhs.index && lhs.input == rhs.input;
53 }
54 };
55
56 using EdgeCounter = Counter<Edge, EdgeHash, EdgeEqual>;
57 using NodeCounter = Counter<AnfNodePtr>;
58
59 struct ChangeCounter {
60 EdgeCounter new_edges;
61 EdgeCounter del_edges;
62 NodeCounter new_nodes;
63 NodeCounter del_nodes;
64
65 template <typename Func>
ForEachAddedEdgesmindspore::change::ChangeCounter66 void ForEachAddedEdges(Func &&func) {
67 new_edges.subtract_by(del_edges, std::forward<Func>(func));
68 }
69
70 template <typename Func>
ForEachRemovedEdgesmindspore::change::ChangeCounter71 void ForEachRemovedEdges(Func &&func) {
72 del_edges.subtract_by(new_edges, std::forward<Func>(func));
73 }
74
GetAddedNodesmindspore::change::ChangeCounter75 std::vector<AnfNodePtr> GetAddedNodes() { return new_nodes.subtract(del_nodes); }
GetRemovedNodesmindspore::change::ChangeCounter76 std::vector<AnfNodePtr> GetRemovedNodes() { return del_nodes.subtract(new_nodes); }
77 };
78
79 class SetEdge : public Change {
80 public:
SetEdge(const CNodePtr & cnode,int index,const AnfNodePtr & input)81 SetEdge(const CNodePtr &cnode, int index, const AnfNodePtr &input) : edge_{cnode, index, input} {}
82 ~SetEdge() override = default;
83
Apply(ChangeCounter * counter)84 void Apply(ChangeCounter *counter) override {
85 auto &old_input = edge_.cnode->input(IntToSize(edge_.index));
86 counter->del_nodes.add(old_input);
87 counter->del_edges.add(edge_.cnode, edge_.index, old_input);
88 edge_.cnode->set_input(IntToSize(edge_.index), edge_.input);
89 counter->new_nodes.add(edge_.input);
90 counter->new_edges.add(std::move(edge_));
91 }
92
93 private:
94 Edge edge_;
95 };
96
97 class AddEdge : public Change {
98 public:
AddEdge(const CNodePtr & cnode,const AnfNodePtr & input)99 AddEdge(const CNodePtr &cnode, const AnfNodePtr &input) : cnode_{cnode}, input_{input} {}
100 ~AddEdge() override = default;
101
Apply(ChangeCounter * counter)102 void Apply(ChangeCounter *counter) override {
103 int index = static_cast<int>(cnode_->size());
104 cnode_->add_input(input_);
105 counter->new_nodes.add(input_);
106 counter->new_edges.add(std::move(cnode_), index, std::move(input_));
107 }
108
109 private:
110 CNodePtr cnode_;
111 AnfNodePtr input_;
112 };
113
114 class SetParams : public Change {
115 public:
SetParams(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & params)116 SetParams(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> ¶ms)
117 : func_graph_{func_graph}, params_{params} {}
118 ~SetParams() override = default;
119
Apply(ChangeCounter * counter)120 void Apply(ChangeCounter *counter) override {
121 auto &old_params = func_graph_->parameters();
122 for (auto &p : old_params) {
123 counter->del_nodes.add(p);
124 }
125 func_graph_->set_parameters(params_);
126 for (auto &p : params_) {
127 counter->new_nodes.add(std::move(p));
128 }
129 }
130
131 private:
132 FuncGraphPtr func_graph_;
133 std::vector<AnfNodePtr> params_;
134 };
135
136 class AddParam : public Change {
137 public:
AddParam(const FuncGraphPtr & func_graph,const ParameterPtr & param)138 AddParam(const FuncGraphPtr &func_graph, const ParameterPtr ¶m) : func_graph_{func_graph}, param_{param} {}
139 ~AddParam() override = default;
140
Apply(ChangeCounter * counter)141 void Apply(ChangeCounter *counter) override {
142 func_graph_->append_parameter(param_);
143 counter->new_nodes.add(std::move(param_));
144 }
145
146 private:
147 FuncGraphPtr func_graph_;
148 ParameterPtr param_;
149 };
150
151 class InsertFrontParam : public Change {
152 public:
InsertFrontParam(const FuncGraphPtr & func_graph,const ParameterPtr & param)153 InsertFrontParam(const FuncGraphPtr &func_graph, const ParameterPtr ¶m)
154 : func_graph_{func_graph}, param_{param} {}
155 ~InsertFrontParam() override = default;
156
Apply(ChangeCounter * counter)157 void Apply(ChangeCounter *counter) override {
158 func_graph_->PrependParameter(param_);
159 counter->new_nodes.add(std::move(param_));
160 }
161
162 private:
163 FuncGraphPtr func_graph_;
164 ParameterPtr param_;
165 };
166
167 } // namespace change
168
MakeManager(const std::vector<FuncGraphPtr> & func_graphs,bool manage,bool drop_unused_graph)169 FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs, bool manage, bool drop_unused_graph) {
170 auto m = std::make_shared<FuncGraphManager>(func_graphs, manage, drop_unused_graph);
171 m->Init();
172 return m;
173 }
174
Manage(const std::vector<FuncGraphPtr> & func_graphs,bool manage,bool drop_unused_graph)175 FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage, bool drop_unused_graph) {
176 FuncGraphManagerPtr m = nullptr;
177 bool root = false;
178
179 for (auto &fg : func_graphs) {
180 if (fg == nullptr) {
181 continue;
182 }
183 if (fg->manager() != nullptr) {
184 m = fg->manager();
185 break;
186 }
187 }
188
189 if (m == nullptr) {
190 std::vector<FuncGraphPtr> tmp;
191 m = MakeManager(tmp, manage, drop_unused_graph);
192 root = true;
193 }
194
195 for (auto &fg : func_graphs) {
196 if (fg == nullptr) {
197 continue;
198 }
199 m->AddFuncGraph(fg, root);
200 }
201 return m;
202 }
203
Manage(FuncGraphPtr func_graph,bool manage,bool drop_unused_graph)204 FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage, bool drop_unused_graph) {
205 std::vector<FuncGraphPtr> func_graphs = {func_graph};
206 return Manage(func_graphs, manage, drop_unused_graph);
207 }
208
FuncGraphManager(const std::vector<FuncGraphPtr> & roots,bool manage,bool drop_unused_graph)209 FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage, bool drop_unused_graph)
210 : roots_(roots), is_manage_(manage), drop_unused_graph_(drop_unused_graph) {
211 Reset();
212 }
213
~FuncGraphManager()214 FuncGraphManager::~FuncGraphManager() {
215 if (is_manage_) {
216 RemoveRoots();
217 }
218 Clear();
219 }
220
Reset()221 void FuncGraphManager::Reset() {
222 func_graphs_ = FuncGraphSet();
223 all_nodes_ = AnfNodeSet();
224 node_users_ = NodeUsersMap();
225 signals_ = std::make_shared<Signals>();
226 func_graph_parents_total_ = std::make_shared<FuncGraphParentsTotalComputer>(this);
227 func_graph_parent_ = std::make_shared<ParentComputer>(this);
228 children_ = std::make_shared<ChildrenComputer>(this);
229 scopes_ = std::make_shared<ScopeComputer>(this);
230 free_variables_total_ = std::make_shared<FVTotalComputer>(this);
231 func_graphs_used_total_ = std::make_shared<FuncGraphsUsedTotalComputer>(this);
232 recursive_ = std::make_shared<RecursiveComputer>(this);
233 meta_fg_prim_total_ = std::make_shared<FuncGraphMetaFgPrimTotalComputer>(this);
234 }
235
Init()236 void FuncGraphManager::Init() {
237 auto roots = roots_;
238 roots_ = FuncGraphSet();
239
240 for (auto &fg : roots) {
241 AddFuncGraph(fg, true);
242 }
243 }
244
func_graph_parents_total(const FuncGraphPtr & fg) const245 FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
246 if (fg == nullptr) {
247 MS_LOG(INTERNAL_EXCEPTION) << "The parameter 'fg' should not be null.";
248 }
249 MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
250 func_graph_parents_total_->Recompute(fg);
251 MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
252 return func_graph_parents_total_->func_graph_parents_total_analysis()[fg];
253 }
254
parent(const FuncGraphPtr & fg) const255 FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const {
256 MS_EXCEPTION_IF_NULL(fg);
257 MS_EXCEPTION_IF_NULL(func_graph_parent_);
258 MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString();
259 func_graph_parent_->Recompute(fg);
260 if (func_graph_parent_->parent_analysis().count(fg) == 0) {
261 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
262 return nullptr;
263 }
264 MS_LOG(DEBUG) << "End parents func graph " << fg->ToString();
265 return func_graph_parent_->parent_analysis()[fg];
266 }
267
children(const FuncGraphPtr & fg) const268 FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const {
269 MS_EXCEPTION_IF_NULL(fg);
270 MS_EXCEPTION_IF_NULL(children_);
271 MS_LOG(DEBUG) << "Start child func graph " << fg->ToString();
272 children_->Recompute(fg);
273 return children_->children_analysis()[fg];
274 }
275
scopes(const FuncGraphPtr & fg) const276 FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const {
277 MS_EXCEPTION_IF_NULL(fg);
278 MS_EXCEPTION_IF_NULL(scopes_);
279 MS_LOG(DEBUG) << "Start scopes func graph: " << fg->ToString();
280 scopes_->Recompute(fg);
281 MS_LOG(DEBUG) << "End scopes func graph: " << fg->ToString();
282 return scopes_->scope_analysis()[fg];
283 }
284
free_variables_total() const285 FVTotalMap &FuncGraphManager::free_variables_total() const {
286 MS_EXCEPTION_IF_NULL(free_variables_total_);
287 free_variables_total_->Recompute();
288 return free_variables_total_->fv_total_analysis();
289 }
290
func_graphs_used_total(const FuncGraphPtr & fg) const291 FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const {
292 MS_EXCEPTION_IF_NULL(func_graphs_used_total_);
293 func_graphs_used_total_->Recompute(fg);
294 return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
295 }
296
recursive(const FuncGraphPtr & fg) const297 bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
298 MS_EXCEPTION_IF_NULL(fg);
299 MS_EXCEPTION_IF_NULL(recursive_);
300 recursive_->Recompute(fg);
301 if (recursive_->recursive_analysis().count(fg) == 0) {
302 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
303 return false;
304 }
305 return recursive_->recursive_analysis()[fg];
306 }
307
recursive_graphs(const FuncGraphPtr & fg) const308 std::shared_ptr<std::list<FuncGraphPtr>> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const {
309 MS_EXCEPTION_IF_NULL(fg);
310 MS_EXCEPTION_IF_NULL(recursive_);
311 if (recursive(fg)) {
312 if (recursive_->recursive_map().count(fg) == 0) {
313 auto trace = std::list<FuncGraphPtr>();
314 recursive_->CheckRecursiveGraphs(fg, &trace);
315 }
316 if (recursive_->recursive_map().count(fg) == 0) {
317 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
318 return nullptr;
319 }
320 return recursive_->recursive_map()[fg];
321 } else {
322 return nullptr;
323 }
324 }
325
326 // Check if the function graph embed with `MetaFGPrim`, which currently covers kPrimJ and kPrimVmap and kPrimTaylor.
func_graph_meta_fg_prim_total(const FuncGraphPtr & fg) const327 bool FuncGraphManager::func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const {
328 MS_EXCEPTION_IF_NULL(meta_fg_prim_total_);
329 MS_EXCEPTION_IF_NULL(fg);
330 meta_fg_prim_total_->Recompute(fg);
331 if (meta_fg_prim_total_->meta_fg_prim_total_analysis().count(fg) == 0) {
332 MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString();
333 return false;
334 }
335 return meta_fg_prim_total_->meta_fg_prim_total_analysis()[fg];
336 }
337
338 // Add a func graph to this manager, optionally as a root func graph.
AddFuncGraph(const FuncGraphPtr & func_graph,bool is_root)339 void FuncGraphManager::AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root) {
340 MS_EXCEPTION_IF_NULL(func_graph);
341 if (is_root) {
342 roots_.add(func_graph);
343 return AddFuncGraphs(func_graph);
344 }
345
346 if (func_graphs_.contains(func_graph)) {
347 return;
348 }
349
350 // Add func_graph as a managed graph.
351 AddIntoManaged(func_graph);
352
353 // New nodes to be acquired.
354 std::vector<AnfNodePtr> new_nodes = func_graph->parameters();
355 auto return_node = func_graph->get_return();
356 if (return_node != nullptr) {
357 (void)new_nodes.emplace_back(std::move(return_node));
358 } else {
359 MS_LOG(INFO) << "The func graph " << func_graph->ToString() << " has no return node.";
360 }
361
362 // Acquire all nodes from func_graph.
363 AcquireNodes(std::move(new_nodes));
364 }
365
366 // Add all func graphs from the root func graph.
AddFuncGraphs(const FuncGraphPtr & source_func_graph)367 void FuncGraphManager::AddFuncGraphs(const FuncGraphPtr &source_func_graph) {
368 MS_EXCEPTION_IF_NULL(source_func_graph);
369 todo_.clear();
370 todo_.emplace_back(source_func_graph);
371 while (!todo_.empty()) {
372 auto func_graph = todo_.front();
373 MS_EXCEPTION_IF_NULL(func_graph);
374 todo_.pop_front();
375 if (func_graphs_.contains(func_graph)) {
376 continue;
377 }
378
379 // Add func_graph as a managed graph.
380 AddIntoManaged(func_graph);
381
382 // New nodes to be acquired.
383 std::vector<AnfNodePtr> new_nodes = func_graph->parameters();
384 auto return_node = func_graph->get_return();
385 if (return_node != nullptr) {
386 (void)new_nodes.emplace_back(std::move(return_node));
387 } else {
388 MS_LOG(INFO) << "The func graph " << func_graph->ToString() << " has no return node.";
389 }
390
391 // Acquire all nodes from func_graph.
392 AcquireNodes(std::move(new_nodes), false);
393 }
394 }
395
396 // Drop the func graph
DropFuncGraph(const FuncGraphPtr & fg,bool force)397 void FuncGraphManager::DropFuncGraph(const FuncGraphPtr &fg, bool force) {
398 if (force || (is_manage_ && drop_unused_graph_ && !fg->reserved())) {
399 MS_LOG(INFO) << "Drop " << fg << "/" << fg->ToString() << ", use_count: " << fg.use_count()
400 << ", type: " << fg->type_name();
401 fg->ResetReturnOwner();
402 fg->ResetOwnNodes();
403 fg->set_dropped(true);
404 }
405 }
406
407 // Clear the all information in manager
Clear()408 void FuncGraphManager::Clear() noexcept {
409 roots_.clear();
410 for (auto &fg : func_graphs_) {
411 MS_EXCEPTION_IF_NULL(fg);
412 fg->DecAttachedMngCnt();
413 if (fg->attached_mng_cnt() == 0) {
414 fg->ClearAllResource();
415 DropFuncGraph(fg);
416 } else if (fg->attached_mng_cnt() < 0) {
417 MS_LOG(INTERNAL_EXCEPTION) << "The func graph '" << fg->ToString()
418 << "' attached cnt not right: " << fg->attached_mng_cnt();
419 }
420 }
421 func_graphs_.clear();
422
423 all_nodes_.clear();
424 node_users_.clear();
425 todo_.clear();
426
427 signals_->InvalidateComputer();
428 }
429
KeepRoots(const std::vector<FuncGraphPtr> & func_graphs)430 void FuncGraphManager::KeepRoots(const std::vector<FuncGraphPtr> &func_graphs) {
431 MS_LOG(DEBUG) << "Start keep roots";
432 bool root_exist = false;
433 for (auto &item : func_graphs) {
434 if (roots_.contains(item)) {
435 root_exist = true;
436 break;
437 }
438 }
439
440 // If the new_root in roots_, we add new_root first, then calculate the func_graphs
441 // relation to new_root, remove the func_graphs not relation to new_root
442 // if the new_root not in roots_, we clear the all func_graphs in manager
443 // then add the new_root
444 if (root_exist || func_graphs.empty()) {
445 FuncGraphSet roots(func_graphs);
446 if (roots.empty()) {
447 roots = roots_;
448 } else {
449 roots_.clear();
450 for (auto &item : roots) {
451 AddFuncGraph(item, true);
452 }
453 }
454
455 FuncGraphSet keep;
456 for (auto &item : roots) {
457 MS_LOG(DEBUG) << "roots: " << item->ToString();
458 keep.update(func_graphs_used_total(item));
459 #ifdef DEBUG
460 for (auto &k : keep) {
461 MS_LOG(DEBUG) << "keep: " << k->ToString();
462 }
463 #endif
464 }
465 MaybeDropFuncGraphs(func_graphs_ - keep, true);
466 } else {
467 Clear();
468 FuncGraphSet roots(func_graphs);
469 for (auto &item : roots) {
470 AddFuncGraph(item, true);
471 }
472 }
473 }
474
RemoveRoots()475 void FuncGraphManager::RemoveRoots() {
476 MS_LOG(DEBUG) << "Start remove roots";
477 roots_.clear();
478 try {
479 MaybeDropFuncGraphs(func_graphs_, true);
480 } catch (std::exception &e) {
481 MS_LOG(DEBUG) << e.what();
482 }
483 }
484
AddIntoManaged(const FuncGraphPtr & fg)485 void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) {
486 MS_EXCEPTION_IF_NULL(fg);
487 if (is_manage_) {
488 if (fg->manager() != nullptr && fg->manager().get() != this) {
489 MS_LOG(INFO) << "A func graph can only have one manager.";
490 }
491 fg->set_manager(shared_from_this());
492 }
493 func_graphs_.add(fg);
494 fg->IncAttachedMngCnt();
495 }
496
MaybeDropFuncGraphs(const FuncGraphSet & func_graphs,bool ignore_users)497 void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) {
498 std::list<FuncGraphPtr> todo(func_graphs.begin(), func_graphs.end());
499 std::set<FuncGraphPtr> dropped;
500 while (!todo.empty()) {
501 FuncGraphPtr func_graph = std::move(todo.front());
502 MS_EXCEPTION_IF_NULL(func_graph);
503 todo.pop_front();
504 MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString();
505 if (roots_.contains(func_graph)) {
506 MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
507 continue;
508 }
509 auto &users_cnode_index = func_graph->func_graph_cnodes_index();
510 if (!users_cnode_index.empty() && !ignore_users) {
511 MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
512 continue;
513 }
514 if (dropped.find(func_graph) != dropped.end()) {
515 MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString();
516 continue;
517 }
518 (void)dropped.insert(func_graph);
519 MS_EXCEPTION_IF_NULL(func_graph->get_return());
520 std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
521 auto drop_graphs = MaybeDropNodes(std::move(return_vec));
522 (void)todo.insert(todo.end(), drop_graphs.begin(), drop_graphs.end());
523 }
524 for (auto &fg : dropped) {
525 MS_EXCEPTION_IF_NULL(fg);
526 all_nodes_.difference_update(fg->parameters());
527 EraseOneGraph(fg);
528 if (fg->manager().get() == this) {
529 fg->set_manager(nullptr);
530 }
531 MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString();
532 }
533 }
534
ProcessEdgeAdd(const AnfNodePtr & node,int index,const AnfNodePtr & input)535 void FuncGraphManager::ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
536 if (IsValueNode<FuncGraph>(input)) {
537 AddFuncGraph(GetValueNode<FuncGraphPtr>(input));
538 }
539 auto &users_node = node_users_[input];
540 users_node.add(std::make_pair(node, index));
541 OnEdgeAdded(node, index, input);
542 }
543
ProcessEdgeRemove(const AnfNodePtr & node,int index,const AnfNodePtr & input)544 void FuncGraphManager::ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
545 auto iter = node_users_.find(input);
546 if (iter == node_users_.end()) {
547 return;
548 }
549 bool removed = iter->second.erase(std::make_pair(node, index));
550 if (removed) {
551 OnEdgeRemoved(node, index, input);
552 }
553 }
554
ProcessInputsEdgeAdd(const CNodePtr & cnode)555 void FuncGraphManager::ProcessInputsEdgeAdd(const CNodePtr &cnode) {
556 const size_t count = cnode->size();
557 for (size_t i = 0; i < count; ++i) {
558 ProcessEdgeAdd(cnode, static_cast<int>(i), cnode->input(i));
559 }
560 }
561
ProcessInputsEdgeRemove(const CNodePtr & cnode)562 void FuncGraphManager::ProcessInputsEdgeRemove(const CNodePtr &cnode) {
563 const size_t count = cnode->size();
564 for (size_t i = 0; i < count; ++i) {
565 ProcessEdgeRemove(cnode, static_cast<int>(i), cnode->input(i));
566 }
567 }
568
569 namespace {
FollowGraph(const FuncGraphPtr & fg,SeenNum seen,std::vector<AnfNodePtr> * nodes)570 inline void FollowGraph(const FuncGraphPtr &fg, SeenNum seen, std::vector<AnfNodePtr> *nodes) {
571 if (fg == nullptr) {
572 return;
573 }
574 if (auto res = fg->get_return(); res != nullptr && res->seen_ != seen) {
575 (void)nodes->emplace_back(std::move(res));
576 }
577 }
578
FollowInputs(const CNodePtr & cnode,std::vector<AnfNodePtr> * nodes)579 inline void FollowInputs(const CNodePtr &cnode, std::vector<AnfNodePtr> *nodes) {
580 auto &weak_inputs = cnode->weak_inputs();
581 (void)std::transform(weak_inputs.cbegin(), weak_inputs.cend(), std::back_inserter(*nodes),
582 [](const AnfNodeWeakPtr &weak_node) {
583 auto node = weak_node.lock();
584 MS_EXCEPTION_IF_NULL(node);
585 return node;
586 });
587 }
588 } // namespace
589
AcquireNodes(std::vector<AnfNodePtr> && nodes,bool recursive)590 void FuncGraphManager::AcquireNodes(std::vector<AnfNodePtr> &&nodes, bool recursive) {
591 auto seen = NewSeenGeneration();
592 while (!nodes.empty()) {
593 // Take the last one.
594 auto node = std::move(nodes.back());
595 nodes.pop_back();
596 MS_EXCEPTION_IF_NULL(node);
597 // Skip visited nodes.
598 if (node->seen_ == seen) {
599 continue;
600 }
601 node->seen_ = seen;
602 // Try add it to all_nodes_.
603 auto insert_result = all_nodes_.insert(node);
604 if (insert_result.second == false) {
605 // Skip acquired nodes.
606 continue;
607 }
608 // Add node to its func_graph.
609 auto fg = node->func_graph();
610 if (fg != nullptr) {
611 fg->AddNode(node);
612 }
613 // Follow graph for value node.
614 if (node->isa<ValueNode>()) {
615 auto graph = GetValueNode<FuncGraphPtr>(node);
616 FollowGraph(graph, seen, &nodes);
617 continue;
618 }
619 // Follow graph for cnode or parameter.
620 FollowGraph(fg, seen, &nodes);
621
622 // Handle CNode.
623 auto cnode = node->cast<CNodePtr>();
624 if (cnode == nullptr) {
625 continue;
626 }
627
628 // Handle input edges.
629 if (recursive) {
630 ProcessInputsEdgeAdd(cnode);
631 // Follow inputs.
632 FollowInputs(cnode, &nodes);
633 continue;
634 }
635 // The way not recursive.
636 for (size_t i = 0; i < cnode->size(); ++i) {
637 const auto &input = cnode->input(i);
638 if (input == nullptr) {
639 MS_LOG(INTERNAL_EXCEPTION) << "The input is null, " << cnode << "/" << cnode->DebugString() << "@" << fg << "/"
640 << fg->ToString();
641 }
642 if (IsValueNode<FuncGraph>(input)) {
643 todo_.emplace_back(GetValueNode<FuncGraphPtr>(input));
644 }
645
646 auto &users_node = node_users_[input];
647 users_node.add(std::make_pair(node, i));
648 OnEdgeAdded(node, i, input);
649 }
650 // Follow inputs.
651 FollowInputs(cnode, &nodes);
652 }
653 }
654
MaybeDropNodes(std::vector<AnfNodePtr> && nodes)655 FuncGraphSet FuncGraphManager::MaybeDropNodes(std::vector<AnfNodePtr> &&nodes) {
656 FuncGraphSet drop_func_graphs;
657 while (!nodes.empty()) {
658 AnfNodePtr node = std::move(nodes.back());
659 nodes.pop_back();
660 if (node == nullptr) {
661 // Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception,
662 // this method may be triggered by desctuctor.
663 MS_LOG(WARNING) << "Node to be dropped is nullptr";
664 continue;
665 }
666 if (!all_nodes_.contains(node)) {
667 // Node not existed.
668 continue;
669 }
670 auto &users = node_users_[node];
671 if (!users.empty()) {
672 // Node is in used.
673 continue;
674 }
675 if (node->isa<Parameter>() && node->func_graph() != nullptr) {
676 // Node is a used parameter.
677 auto ¶meters = node->func_graph()->parameters();
678 if (std::find(parameters.begin(), parameters.end(), node) != parameters.end()) {
679 continue;
680 }
681 }
682 if (IsValueNode<FuncGraph>(node)) {
683 // The FuncGraph may need to be dropped.
684 auto fg = GetValueNode<FuncGraphPtr>(node);
685 drop_func_graphs.add(fg);
686 }
687 // Handle cnode.
688 if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
689 // Remove inputs edges.
690 ProcessInputsEdgeRemove(cnode);
691 // Handle inputs nodes.
692 FollowInputs(cnode, &nodes);
693 }
694 // Remove it from all_nodes_;
695 (void)all_nodes_.erase(node);
696 // Drop node from its func graph.
697 if (auto fg = node->func_graph(); fg != nullptr) {
698 fg->DropNode(node);
699 }
700 // Remove it from node_users.
701 (void)node_users_.erase(node);
702 }
703 return drop_func_graphs;
704 }
705
SetParameters(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & parameters)706 void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters) {
707 auto tr = Transact();
708 tr.SetParameters(fg, parameters);
709 tr.Commit();
710 }
711
AddParameter(const FuncGraphPtr & fg,const AnfNodePtr & parameter)712 void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) {
713 auto tr = Transact();
714 tr.AddParameter(fg, parameter);
715 tr.Commit();
716 }
717
InsertFrontParameter(const FuncGraphPtr & fg,const AnfNodePtr & parameter)718 void FuncGraphManager::InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) {
719 auto tr = Transact();
720 tr.InsertFrontParameter(fg, parameter);
721 tr.Commit();
722 }
723
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)724 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
725 MS_EXCEPTION_IF_NULL(old_node);
726 MS_EXCEPTION_IF_NULL(new_node);
727 auto tr = Transact();
728 bool success = tr.Replace(old_node, new_node);
729 if (success) {
730 tr.Commit();
731 }
732 return success;
733 }
734
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node,const AnfNodePtr & mask_node)735 bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const AnfNodePtr &mask_node) {
736 MS_EXCEPTION_IF_NULL(old_node);
737 MS_EXCEPTION_IF_NULL(new_node);
738 auto tr = Transact();
739 bool success = tr.Replace(old_node, new_node, mask_node);
740 if (success) {
741 tr.Commit();
742 }
743 return success;
744 }
745
SetEdge(const AnfNodePtr & node,int index,const AnfNodePtr & value)746 void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) {
747 auto tr = Transact();
748 tr.SetEdge(node, index, value);
749 tr.Commit();
750 }
751
AddEdge(const AnfNodePtr & node,const AnfNodePtr & value)752 void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) {
753 auto tr = Transact();
754 tr.AddEdge(node, value);
755 tr.Commit();
756 }
757
MoveAllCNodeDropGraph(const FuncGraphPtr & source,const FuncGraphPtr & target,const AnfNodePtr & call_node,const ScopePtr & scope,bool update_debug_info)758 void FuncGraphManager::MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target,
759 const AnfNodePtr &call_node, const ScopePtr &scope,
760 bool update_debug_info) {
761 MS_EXCEPTION_IF_NULL(source);
762 CNodePtr source_return = source->get_return();
763 MS_EXCEPTION_IF_NULL(source_return);
764 AnfNodePtr source_output = source->output();
765 const auto &source_prim = source_return->input(0);
766
767 int index = 0;
768 (void)node_users_[source_prim].erase(make_pair(source_return, index));
769 OnEdgeRemoved(source_return, index, source_prim);
770 index = 1;
771 (void)node_users_[source_output].erase(make_pair(source_return, index));
772 OnEdgeRemoved(source_return, index, source_output);
773 (void)all_nodes_.erase(source_return);
774 (void)node_users_.erase(source_return);
775 source->DropNode(source_return);
776 for (auto &node : source->nodes()) {
777 node->set_func_graph(target);
778 if (node->scope() == kDefaultScope) {
779 node->set_scope(scope);
780 }
781 if (update_debug_info && node->isa<CNode>()) {
782 MS_LOG(DEBUG) << "call_node: " << call_node << "/" << call_node->DebugString() << ", node: " << node << "/"
783 << node->DebugString();
784 UpdateInlineCNodeDebugInfo(call_node, node);
785 }
786 }
787
788 MoveAllNodes(source, target);
789 all_nodes_.difference_update(source->parameters());
790 EraseOneGraph(source);
791 source->set_dropped(true);
792 if (source->manager().get() == this) {
793 source->set_manager(nullptr);
794 }
795
796 if (source->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
797 target->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
798 }
799 if (source->has_flag(kTraining)) {
800 target->set_flag(kTraining, true);
801 }
802 }
803
OnEdgeAdded(const AnfNodePtr & node,int index,const AnfNodePtr & input)804 void FuncGraphManager::OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
805 auto fg = node->func_graph();
806 if (input->isa<ValueNode>()) {
807 fg->AddValueNode(input);
808 if (IsValueNode<FuncGraph>(input)) {
809 auto used = GetValueNode<FuncGraphPtr>(input);
810 used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
811 if (fg->AddFuncGraphUsed(used)) {
812 signals_->InvalidateComputer();
813 }
814 }
815 if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
816 IsPrimitiveCNode(node, prim::kPrimTaylor) || IsPrimitiveCNode(node, prim::kPrimShard)) {
817 fg->AddMetaFgPrimValueNode(input);
818 }
819 } else if (IsPrimitiveCNode(node, prim::kPrimVmap) && IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
820 // To handle the model ensembling scenario in vmap, whose input is a celllist, taking an arbitrary function graph
821 // is sufficient.
822 constexpr int64_t kIndex1 = 1;
823 auto func_union = dyn_cast<CNode>(input);
824 if (IsValueNode<FuncGraph>(func_union->input(kIndex1))) {
825 fg->AddMetaFgPrimValueNode(func_union->input(kIndex1));
826 }
827 } else if (fg != nullptr && fg != input->func_graph()) {
828 if (fg->AddFreeVariable(input)) {
829 signals_->InvalidateComputer();
830 }
831 }
832 }
833
OnEdgeRemoved(const AnfNodePtr & node,int index,const AnfNodePtr & input)834 void FuncGraphManager::OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input) {
835 auto fg = node->func_graph();
836 if (fg != nullptr && input->isa<ValueNode>()) {
837 fg->DropValueNode(input);
838 if (IsValueNode<FuncGraph>(input)) {
839 auto used = GetValueNode<FuncGraphPtr>(input);
840 used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
841 if (fg->DropFuncGraphUsed(used)) {
842 signals_->InvalidateComputer();
843 }
844 }
845 if (IsPrimitiveCNode(node, prim::kPrimJ) || IsPrimitiveCNode(node, prim::kPrimVmap) ||
846 IsPrimitiveCNode(node, prim::kPrimTaylor)) {
847 fg->DropMetaFgPrimValueNode(input);
848 }
849 } else if (fg != nullptr && fg != input->func_graph()) {
850 if (fg->DropFreeVariable(input)) {
851 signals_->InvalidateComputer();
852 }
853 }
854 }
855
MoveAllNodes(const FuncGraphPtr & source,const FuncGraphPtr & target)856 void FuncGraphManager::MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target) {
857 target->CopyNodes(source);
858 target->CopyValueNodes(source);
859 target->CopyFuncGraphCNodesIndex(source);
860 target->CopyFreeVariables(source);
861 target->CopyFuncGraphsUsed(source);
862 target->CopyMetaFgPrimValueNodes(source);
863 source->ClearAllResource();
864 signals_->InvalidateComputer();
865 }
866
CommitChanges(std::vector<change::ChangePtr> && changes)867 void FuncGraphManager::CommitChanges(std::vector<change::ChangePtr> &&changes) {
868 // Apply changes.
869 change::ChangeCounter counter;
870 for (auto &change : changes) {
871 change->Apply(&counter);
872 }
873 changes.clear();
874
875 // Process added edges.
876 counter.ForEachAddedEdges([this](const change::Edge &edge) { //
877 ProcessEdgeAdd(edge.cnode, edge.index, edge.input);
878 });
879
880 // Process added nodes.
881 AcquireNodes(counter.GetAddedNodes());
882
883 // Process removed edges.
884 counter.ForEachRemovedEdges([this](const change::Edge &edge) { //
885 ProcessEdgeRemove(edge.cnode, edge.index, edge.input);
886 });
887
888 // Process removed nodes.
889 auto drop_func_graphs = MaybeDropNodes(counter.GetRemovedNodes());
890 if (!drop_func_graphs.empty()) {
891 MaybeDropFuncGraphs(drop_func_graphs);
892 }
893 }
894
EraseOneGraph(const FuncGraphPtr & fg)895 void FuncGraphManager::EraseOneGraph(const FuncGraphPtr &fg) {
896 MS_EXCEPTION_IF_NULL(fg);
897 bool erase_ret = func_graphs_.erase(fg->shared_from_base<FuncGraph>());
898 if (!erase_ret) {
899 return;
900 }
901 fg->DecAttachedMngCnt();
902 if (fg->attached_mng_cnt() == 0) {
903 fg->ClearAllResource();
904 }
905 }
906
SetParameters(FuncGraphPtr fg,const std::vector<AnfNodePtr> & params)907 void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms) {
908 (void)changes_.emplace_back(std::make_unique<change::SetParams>(fg, params));
909 }
910
AddParameter(FuncGraphPtr fg,const AnfNodePtr & param)911 void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) {
912 (void)changes_.emplace_back(std::make_unique<change::AddParam>(fg, param->cast<ParameterPtr>()));
913 }
914
InsertFrontParameter(FuncGraphPtr fg,const AnfNodePtr & param)915 void FuncGraphTransaction::InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) {
916 (void)changes_.emplace_back(std::make_unique<change::InsertFrontParam>(fg, param->cast<ParameterPtr>()));
917 }
918
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node)919 bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
920 MS_EXCEPTION_IF_NULL(old_node);
921 MS_EXCEPTION_IF_NULL(new_node);
922 FuncGraphPtr old_func_graph = old_node->func_graph();
923 if (old_func_graph != nullptr && old_func_graph->get_return() != nullptr &&
924 old_func_graph->get_return() == old_node) {
925 MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
926 return false;
927 }
928 auto &users = manager_->node_users()[old_node];
929 for (auto &node : users) {
930 SetEdge(node.first, node.second, new_node);
931 }
932 return true;
933 }
934
Replace(const AnfNodePtr & old_node,const AnfNodePtr & new_node,const AnfNodePtr & mask_node)935 bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node,
936 const AnfNodePtr &mask_node) {
937 MS_EXCEPTION_IF_NULL(old_node);
938 MS_EXCEPTION_IF_NULL(new_node);
939 FuncGraphPtr old_func_graph = old_node->func_graph();
940 if (old_func_graph != nullptr && old_func_graph->get_return() != nullptr &&
941 old_func_graph->get_return() == old_node) {
942 MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString();
943 return false;
944 }
945 auto &users = manager_->node_users()[old_node];
946 for (auto &node : users) {
947 if (node.first == mask_node) {
948 SetEdge(node.first, node.second, new_node);
949 }
950 }
951 return true;
952 }
953
SetEdge(const AnfNodePtr & src_node,int k,const AnfNodePtr & v)954 void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) {
955 if (k < 0) {
956 MS_LOG(INTERNAL_EXCEPTION) << "Invalid value k = " << k;
957 }
958 MS_EXCEPTION_IF_NULL(src_node);
959 auto cnode = src_node->cast<CNodePtr>();
960 if (cnode == nullptr) {
961 MS_LOG(INTERNAL_EXCEPTION) << "src_node should be a cnode, but cast failed.";
962 }
963 (void)changes_.emplace_back(std::make_unique<change::SetEdge>(cnode, k, v));
964 }
965
AddEdge(const AnfNodePtr & src_node,const AnfNodePtr & v)966 void FuncGraphTransaction::AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v) {
967 MS_EXCEPTION_IF_NULL(src_node);
968 auto cnode = src_node->cast<CNodePtr>();
969 if (cnode == nullptr) {
970 MS_LOG(INTERNAL_EXCEPTION) << "src_node should be a cnode, but cast failed.";
971 }
972 (void)changes_.emplace_back(std::make_unique<change::AddEdge>(cnode, v));
973 }
974
Commit()975 void FuncGraphTransaction::Commit() { manager_->CommitChanges(std::move(changes_)); }
976
DepComputer(const FuncGraphManager * const manager)977 DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager), validate_(false) {
978 MS_EXCEPTION_IF_NULL(manager_);
979 manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
980 }
981
Recompute()982 void DepComputer::Recompute() {
983 if (!validate_) {
984 RealRecompute();
985 validate_ = true;
986 }
987 }
988
Recompute(const FuncGraphPtr & fg)989 void DepComputer::Recompute(const FuncGraphPtr &fg) {
990 if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) {
991 RealRecompute(fg);
992 func_graphs_validate_[fg] = true;
993 }
994 }
995
SeekParents(const FuncGraphPtr & func_graph)996 FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &func_graph) {
997 MS_EXCEPTION_IF_NULL(func_graph);
998 constexpr auto out_call_stack = 0;
999 constexpr auto in_call_stack = 1;
1000 auto seen = NewFgSeenGeneration();
1001 func_graph->seen_ = seen;
1002 std::deque<FuncGraphPtr> todo;
1003 (void)todo.emplace_back(func_graph);
1004 FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
1005 while (!todo.empty()) {
1006 const auto &fg = todo.back();
1007 // For the func graph both in deque and in call stack, just pop it.
1008 if (fg->extra_seen_ == in_call_stack) {
1009 fg->extra_seen_ = out_call_stack;
1010 todo.pop_back();
1011 continue;
1012 }
1013
1014 // Append all the fvs in fg.
1015 auto &fvs = fg->free_variables();
1016 for (const auto &fv : fvs) {
1017 const auto &fv_node = fv.first;
1018 MS_EXCEPTION_IF_NULL(fv_node);
1019 auto fv_func_graph = fv_node->func_graph();
1020 if (fv_func_graph == nullptr) {
1021 MS_LOG(INFO) << "Meet a FV '" << fv_node->DebugString() << "' whose func graph is null, during seeking for "
1022 << fg->ToString() << "\nFV: " << trace::GetDebugInfoStr(fv_node->debug_info());
1023 continue;
1024 }
1025 // Found a parent if not in the call stack.
1026 if (fv_func_graph->extra_seen_ != in_call_stack) {
1027 parents->add(fv_func_graph);
1028 }
1029 }
1030
1031 // Before push the used func graphs of 'fg', mark it as in call stack.
1032 fg->extra_seen_ = in_call_stack;
1033 // Add the fg's used func graph to search.
1034 auto &fgs = fg->func_graphs_used();
1035 for (auto &item : fgs) {
1036 auto > = item.first;
1037 if (gt->seen_ != seen) {
1038 gt->extra_seen_ = out_call_stack;
1039 (void)todo.emplace_back(gt);
1040 gt->seen_ = seen;
1041 }
1042 }
1043 }
1044 return parents;
1045 }
1046
RealRecompute(FuncGraphPtr fg)1047 void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
1048 MS_EXCEPTION_IF_NULL(fg);
1049 auto parents = SeekParents(fg);
1050 func_graph_parents_total_analysis_[fg].update(parents);
1051 }
1052
set_len_compare(const FuncGraphSetPair & lhs,const FuncGraphSetPair & rhs)1053 bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
1054 auto l1 = lhs.second.size();
1055 auto l2 = rhs.second.size();
1056 return l1 < l2;
1057 }
1058
RealRecompute(FuncGraphPtr fg)1059 void ParentComputer::RealRecompute(FuncGraphPtr fg) {
1060 this->parent_analysis_[fg] = nullptr;
1061 // Note: must be a copy other than reference as it is modified thereafter.
1062 auto deps = this->manager_->func_graph_parents_total(fg);
1063 if (deps.empty()) {
1064 this->parent_analysis_[fg] = nullptr;
1065 return;
1066 } else if (deps.size() == 1) {
1067 this->parent_analysis_[fg] = deps.front();
1068 return;
1069 } else {
1070 // return nearest parent as parent
1071 FuncGraphSet deps_copy(deps);
1072 for (auto &dep : deps) {
1073 auto parent_deps = this->manager_->func_graph_parents_total(dep);
1074 for (auto &p_d : parent_deps) {
1075 if (deps_copy.count(p_d) > 0) {
1076 (void)deps_copy.erase(p_d);
1077 }
1078 }
1079 if (deps_copy.size() == 1) {
1080 this->parent_analysis_[fg] = deps_copy.front();
1081 return;
1082 }
1083 }
1084 }
1085 }
1086
RealRecompute(FuncGraphPtr fg)1087 void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
1088 MS_EXCEPTION_IF_NULL(manager_);
1089 auto used_fg_total = manager_->func_graphs_used_total(fg);
1090 for (auto &used_fg : used_fg_total) {
1091 if (manager_->parent(used_fg) == fg) {
1092 children_analysis_[fg].add(used_fg);
1093 }
1094 }
1095 }
1096
RealRecompute(FuncGraphPtr fg)1097 void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
1098 MS_EXCEPTION_IF_NULL(manager_);
1099 auto &children = manager_->children(fg);
1100
1101 scope_analysis_[fg] = FuncGraphSet();
1102 scope_analysis_[fg].add(fg);
1103 for (auto &child : children) {
1104 scope_analysis_[fg].add(child);
1105 }
1106 }
1107
RealRecompute()1108 void FVTotalComputer::RealRecompute() {
1109 auto manager = DepComputer::manager_;
1110 MS_EXCEPTION_IF_NULL(manager);
1111
1112 for (auto &fg : manager->func_graphs()) {
1113 fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
1114 }
1115
1116 for (auto &fg : manager->func_graphs()) {
1117 // add all free variable nodes
1118 AnfNodeCounterMap items = fg->free_variables();
1119 for (auto &iter : items) {
1120 auto curr = fg;
1121 while (curr != nullptr) {
1122 fv_total_analysis_[curr][iter.first] = iter.second;
1123 curr = manager->parent(curr);
1124 if (curr != nullptr) {
1125 const AnfNodeSet &all_nodes = curr->nodes();
1126 if (all_nodes.contains(iter.first)) {
1127 break;
1128 }
1129 }
1130 }
1131 }
1132
1133 // add all FGs of free variables
1134 auto &used = fg->func_graphs_used();
1135 for (auto &iter : used) {
1136 auto p = manager->parent(iter.first);
1137 if (p == nullptr) {
1138 continue;
1139 }
1140 auto curr = fg;
1141 while (curr != nullptr && curr != p) {
1142 fv_total_analysis_[curr][iter.first] = iter.second;
1143 curr = manager->parent(curr);
1144 }
1145 }
1146 }
1147 }
1148
RealRecompute(FuncGraphPtr fg)1149 void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
1150 MS_EXCEPTION_IF_NULL(manager_);
1151 std::vector<FuncGraphPtr> todo;
1152 std::vector<FuncGraphPtr> todo_new;
1153
1154 todo.push_back(fg);
1155 while (!todo.empty()) {
1156 todo_new.clear();
1157 for (auto > : todo) {
1158 for (auto &item : gt->func_graphs_used()) {
1159 auto used_fg = item.first;
1160 if (used_fg == fg) {
1161 func_graph_used_total_analysis_[fg].add(used_fg);
1162 continue;
1163 }
1164 if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) {
1165 todo_new.push_back(used_fg);
1166 }
1167 MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString();
1168 func_graph_used_total_analysis_[fg].add(used_fg);
1169 }
1170 }
1171 todo = todo_new;
1172 }
1173 }
1174
CheckRecursive(const FuncGraphManager * const manager,const FuncGraphPtr & fg)1175 bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) {
1176 MS_EXCEPTION_IF_NULL(manager);
1177 std::vector<FuncGraphPtr> todo;
1178 std::vector<FuncGraphPtr> todo_new;
1179 todo.push_back(fg);
1180 FuncGraphSet used_total;
1181 while (!todo.empty()) {
1182 todo_new.clear();
1183 for (auto > : todo) {
1184 for (auto &item : gt->func_graphs_used()) {
1185 auto used_g = item.first;
1186 if (used_g == fg) {
1187 return true;
1188 }
1189 if (used_total.count(used_g) == 0) {
1190 todo_new.push_back(used_g);
1191 }
1192 used_total.add(used_g);
1193 }
1194 }
1195 todo = todo_new;
1196 }
1197 return false;
1198 }
1199
RealRecompute(FuncGraphPtr fg)1200 void RecursiveComputer::RealRecompute(FuncGraphPtr fg) {
1201 this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg);
1202 }
1203
CheckRecursiveGraphs(const FuncGraphPtr & fg,std::list<FuncGraphPtr> * trace)1204 void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace) {
1205 MS_EXCEPTION_IF_NULL(trace);
1206 auto res = std::find(trace->begin(), trace->end(), fg);
1207 // Find recursive
1208 if (res != trace->end()) {
1209 auto recur_ptr = std::make_shared<std::list<FuncGraphPtr>>(res, trace->end());
1210 for (auto iter = res; iter != trace->end(); (void)iter++) {
1211 MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString();
1212 recursive_map_[*iter] = recur_ptr;
1213 }
1214 } else {
1215 trace->push_back(fg);
1216 auto &items = fg->func_graphs_used();
1217 for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
1218 CheckRecursiveGraphs(iter->first, trace);
1219 }
1220 trace->pop_back();
1221 if (recursive_map_.count(fg) == 0) {
1222 recursive_map_[fg] = nullptr;
1223 }
1224 }
1225 }
1226
SeekMetaFgPrim(const FuncGraphPtr & fg,SeenNum seen_num)1227 bool FuncGraphMetaFgPrimTotalComputer::SeekMetaFgPrim(const FuncGraphPtr &fg, SeenNum seen_num) {
1228 MS_EXCEPTION_IF_NULL(fg);
1229 if (fg->seen_ == seen_num) {
1230 MS_LOG(DEBUG) << fg->ToString() << " had been checked";
1231 return false;
1232 }
1233
1234 // Check MetaFgPrim (J/Vmap/Taylor) FuncGraph input.
1235 const auto &meta_fg_prim_values = fg->meta_fg_prim_value_nodes();
1236 if (!meta_fg_prim_values.empty()) {
1237 auto contains_meta_fg_prim =
1238 std::find_if(meta_fg_prim_values.begin(), meta_fg_prim_values.end(), [seen_num](const auto &iter) {
1239 // Check g1->MetaFgPrim(fg)->g2->g cycle.
1240 if (IsValueNode<FuncGraph>(iter.first)) {
1241 auto func_graph = GetValuePtr<FuncGraph>(iter.first);
1242 return func_graph->seen_ != seen_num;
1243 }
1244 if (IsValueNode<Primitive>(iter.first)) {
1245 // Exclude the primitive of MetaFgPrim (J/Vmap/Taylor) itself.
1246 auto prim = GetValueNode<PrimitivePtr>(iter.first);
1247 return (prim->name() != prim::kPrimJ->name() && prim->name() != prim::kPrimVmap->name() &&
1248 prim->name() != prim::kPrimTaylor->name());
1249 }
1250 return false;
1251 });
1252 if (contains_meta_fg_prim != meta_fg_prim_values.end()) {
1253 MS_EXCEPTION_IF_NULL(contains_meta_fg_prim->first);
1254 MS_LOG(DEBUG) << fg->ToString() << " contains MetaFgPrim(" << contains_meta_fg_prim->first->DebugString() << ")";
1255 return true;
1256 }
1257 }
1258
1259 // Check MetaFgPrim (J/Vmap/Taylor) CNode as FV.
1260 const auto &fv_nodes = fg->free_variables();
1261 if (!fv_nodes.empty()) {
1262 auto contains_meta_fg_prim_cnode = std::find_if(fv_nodes.begin(), fv_nodes.end(), [seen_num](const auto &iter) {
1263 // Check if the FV is a MetaFgPrim (J/Vmap/Taylor) call CNode.
1264 return IsPrimitiveCNode(iter.first, prim::kPrimJ) || IsPrimitiveCNode(iter.first, prim::kPrimVmap) ||
1265 IsPrimitiveCNode(iter.first, prim::kPrimTaylor);
1266 });
1267 if (contains_meta_fg_prim_cnode != fv_nodes.end()) {
1268 MS_EXCEPTION_IF_NULL(contains_meta_fg_prim_cnode->first);
1269 MS_LOG(DEBUG) << fg->ToString() << " contains FV MetaFgPrim (J/Vmap/Taylor) ("
1270 << contains_meta_fg_prim_cnode->first->DebugString() << ")";
1271 return true;
1272 }
1273 }
1274
1275 // Check if func graphs used contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive),
1276 // Taylor(func_graph) or Taylor(Primitive).
1277 fg->seen_ = seen_num;
1278 for (auto &item : fg->func_graphs_used()) {
1279 auto used_g = item.first;
1280 MS_EXCEPTION_IF_NULL(used_g);
1281 if (SeekMetaFgPrim(used_g, seen_num)) {
1282 MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString()
1283 << " which contains J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
1284 << "Taylor(func_graph) or Taylor(Primitive)";
1285 return true;
1286 }
1287 }
1288 MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph), J(Primitive), Vmap(func_graph), Vmap(Primitive), "
1289 << "Taylor(func_graph) or Taylor(Primitive)";
1290 return false;
1291 }
1292
RealRecompute(FuncGraphPtr fg)1293 void FuncGraphMetaFgPrimTotalComputer::RealRecompute(FuncGraphPtr fg) {
1294 this->meta_fg_prim_total_analysis_[fg] = SeekMetaFgPrim(fg, NewFgSeenGeneration());
1295 }
1296 } // namespace mindspore
1297