• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common_test.h"
17 #include "common/py_func_graph_fetcher.h"
18 #include "ir/dtype.h"
19 #include "ir/manager.h"
20 #include "ir/func_graph_cloner.h"
21 #include "pipeline/jit/parse/parse.h"
22 #include "frontend/operator/ops.h"
23 #include "utils/log_adapter.h"
24 #include "debug/draw.h"
25 #include "utils/label.h"
26 
27 namespace mindspore {
28 
29 namespace {
30 std::vector<std::string> SplitString(std::string str, std::string pattern) {
31   std::string::size_type pos;
32   std::vector<std::string> result;
33   str += pattern;
34   std::string::size_type size = str.size();
35 
36   for (std::string::size_type i = 0; i < size; ++i) {
37     pos = str.find(pattern, i);
38     if (pos < size) {
39       std::string s = str.substr(i, pos - i);
40       result.push_back(s);
41       i = pos + pattern.size() - 1;
42     }
43   }
44 
45   return result;
46 }
47 }  // namespace
48 using std::dynamic_pointer_cast;
49 
50 using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
51 using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
52 
53 class NestingSpecs;
54 
55 class Stage {
56  public:
57   explicit Stage(std::vector<std::string> specs) {
58     for (auto arg : specs) {
59       auto spec = SplitString(arg, "=");
60       if (spec.size() <= 1) {
61         continue;
62       }
63       std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
64       specs_[ToFullString(spec[0])] = nesting;
65     }
66   }
67 
68   ~Stage() {}
69 
70   std::map<std::string, std::string> &subs() { return subs_; }
71 
72   void set_subs(const std::map<std::string, std::string> &subs) { subs_ = subs; }
73 
74  private:
75   std::string ToFullString(std::string s) {
76     if (s.find("fv") != std::string::npos) {
77       s = s.replace(s.find("fv"), 2, "free_variable");
78     }
79 
80     if (s.find("deps") != std::string::npos) {
81       s = s.replace(s.find("deps"), 4, "dependencies");
82     }
83 
84     return s;
85   }
86 
87   std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
88   std::map<std::string, std::string> subs_;
89 };
90 
91 class NestingSpecs {
92  public:
93   NestingSpecs(Stage *stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
94 
95   ~NestingSpecs() {}
96 
97   std::string Name(Any node) {
98     std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
99     if (stage_->subs().find(name) != stage_->subs().end()) {
100       return stage_->subs()[name];
101     }
102 
103     return name;
104   }
105 
106   void Check(std::shared_ptr<DepComputer> results) {
107     if (expected_.empty() && expected_recursive_.empty()) {
108       return;
109     }
110 
111     auto parent = dynamic_pointer_cast<ParentComputer>(results);
112     if (parent != nullptr) {
113       CheckParent(parent);
114       return;
115     }
116 
117     auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
118     if (recursive != nullptr) {
119       CheckRecursive(recursive);
120       return;
121     }
122   }
123 
124  private:
125   void ParseSpecs(std::string specs) {
126     if (specs.empty()) {
127       return;
128     }
129 
130     std::vector<std::string> str_list = SplitString(specs, ";");
131     for (auto spec : str_list) {
132       spec.erase(0, spec.find_first_not_of(" "));
133       spec.erase(spec.find_last_not_of(" ") + 1);
134       if (spec.empty()) {
135         continue;
136       }
137       if (spec.find("->") != std::string::npos) {
138         auto substr = SplitString(spec, "->");
139         ASSERT_GT(substr.size(), 1);
140         auto key = substr[0];
141         auto value = substr[1];
142         if (!value.empty()) {
143           expected_[key] = {value};
144         }
145       } else if (spec.find(":") != std::string::npos) {
146         auto substr = SplitString(spec, ":");
147         ASSERT_GT(substr.size(), 1);
148         auto key = substr[0];
149         auto values = SplitString(substr[1], ",");
150         std::set<std::string> values_set(values.begin(), values.end());
151         if (!values_set.empty()) {
152           expected_[key] = values_set;
153         }
154       } else {
155         expected_recursive_[spec] = true;
156       }
157     }
158   }
159 
160   void CheckParent(std::shared_ptr<ParentComputer> results) {
161     std::map<std::string, std::set<std::string>> clean_results;
162     for (auto &iter : results->parent_analysis()) {
163       auto key = iter.first;
164       auto value = iter.second;
165       if (key == nullptr) {
166         continue;
167       }
168       std::string k = Name(key);
169 
170       std::set<std::string> v;
171       if (value != nullptr && !Name(value).empty()) {
172         v.insert(Name(value));
173       }
174 
175       if (!v.empty()) {
176         clean_results[k] = v;
177       }
178     }
179 
180     ASSERT_EQ(clean_results, expected_);
181   }
182 
183   void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
184     std::map<std::string, bool> clean_results;
185     for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
186       auto key = iter->first;
187       auto value = iter->second;
188       if (key == nullptr) {
189         continue;
190       }
191       std::string k = Name(key);
192 
193       clean_results[k] = value;
194     }
195 
196     ASSERT_EQ(clean_results, expected_recursive_);
197   }
198 
199  private:
200   Stage *stage_;
201   std::map<std::string, std::set<std::string>> expected_;
202   std::map<std::string, bool> expected_recursive_;
203 };
204 
205 bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
206   for (auto node : manager->all_nodes()) {
207     if (node->isa<CNode>()) {
208       auto &inputs = node->cast<CNodePtr>()->inputs();
209       for (size_t i = 0; i < inputs.size(); ++i) {
210         auto inp = inputs[i];
211         if (!manager->all_nodes().contains(inp)) {
212           return false;
213         }
214 
215         if (manager->node_users().find(inp) != manager->node_users().end()) {
216           auto users = manager->node_users()[inp];
217           if (!users.contains(make_pair(node, i))) {
218             return false;
219           }
220         }
221       }
222     }
223 
224     if (manager->node_users().find(node) != manager->node_users().end()) {
225       auto users = manager->node_users()[node];
226       for (auto iter = users.begin(); iter != users.end(); ++iter) {
227         auto node2 = iter->first;
228         auto key = iter->second;
229         if (!manager->all_nodes().contains(node2)) {
230           return false;
231         }
232         if (node2->cast<CNodePtr>()->input(key) != node) {
233           return false;
234         }
235       }
236     }
237   }
238 
239   return true;
240 }
241 
242 class TestManager : public UT::Common {
243  public:
244   TestManager() : getPyFun("gtest_input.ir.manager_test") {}
245 
246   void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
247 
248  public:
249   std::vector<PrimitivePtr> swaps;
250   UT::PyFuncGraphFetcher getPyFun;
251 };
252 
253 FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
254   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
255   ParameterPtr x = func_graph->add_parameter();
256   ParameterPtr y = func_graph->add_parameter();
257   std::vector<AnfNodePtr> inputs;
258   inputs.push_back(NewValueNode(prim));
259   inputs.push_back(x);
260   inputs.push_back(y);
261   CNodePtr cnode_add = func_graph->NewCNode(inputs);
262   inputs.clear();
263   inputs.push_back(NewValueNode(prim::kPrimReturn));
264   inputs.push_back(cnode_add);
265   CNodePtr cnode_return = func_graph->NewCNode(inputs);
266   func_graph->set_return(cnode_return);
267   return func_graph;
268 }
269 
270 std::vector<FuncGraphPtr> MakeNestedGraph() {
271   /*
272    *def f(x):
273    *    def g():
274    *         return x
275    *    return g
276    */
277   FuncGraphPtr f = std::make_shared<FuncGraph>();
278   FuncGraphPtr fg = std::make_shared<FuncGraph>();
279 
280   ParameterPtr x = f->add_parameter();
281 
282   std::vector<AnfNodePtr> inputs;
283   inputs.push_back(NewValueNode(fg));
284   inputs.push_back(NewValueNode(prim::kPrimReturn));
285 
286   CNodePtr cnode_f = f->NewCNode(inputs);
287   f->set_return(cnode_f);
288 
289   inputs.clear();
290   inputs.push_back(NewValueNode(prim::kPrimReturn));
291   inputs.push_back(x);
292   CNodePtr cnode_g = fg->NewCNode(inputs);
293   fg->set_return(cnode_g);
294 
295   std::vector<FuncGraphPtr> result = {f, fg};
296   return result;
297 }
298 
299 std::vector<FuncGraphPtr> MakeNestedGraph2() {
300   /* build a closure func_graph */
301   /*
302    *def foo(x, y):
303    *    def bar(x1):
304    *         return x1 + y
305    *    return bar(x)
306    */
307   FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
308   ParameterPtr x = graph_foo->add_parameter();
309   ParameterPtr y = graph_foo->add_parameter();
310 
311   std::vector<AnfNodePtr> inputs;
312 
313   // build func_graph bar
314   FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
315   ParameterPtr x1 = graph_bar->add_parameter();
316   inputs.clear();
317   inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
318   inputs.push_back(x1);
319   inputs.push_back(y);
320   CNodePtr cnode_add = graph_bar->NewCNode(inputs);
321   inputs.clear();
322   inputs.push_back(NewValueNode(prim::kPrimReturn));
323   inputs.push_back(cnode_add);
324   CNodePtr cnode_return = graph_bar->NewCNode(inputs);
325   graph_bar->set_return(cnode_return);
326 
327   // build func_graph foo
328   inputs.clear();
329   inputs.push_back(NewValueNode(graph_bar));
330   inputs.push_back(x);
331   CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
332 
333   inputs.clear();
334   inputs.push_back(NewValueNode(prim::kPrimReturn));
335   inputs.push_back(cnode_graph_bar);
336   cnode_return = graph_foo->NewCNode(inputs);
337   graph_foo->set_return(cnode_return);
338 
339   std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
340   return result;
341 }
342 
343 // Add TestManager::CheckManager function to checkout the result
344 void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
345   auto size = mng->func_graphs().size();
346 
347   ASSERT_EQ(size, mng->free_variables_total().size());
348 }
349 
350 TEST_F(TestManager, test_scalar_add_manual) {
351   auto prim_scalar_add = prim::kPrimScalarAdd;
352   FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
353   auto mng = Manage(func_graph);
354 }
355 
356 TEST_F(TestManager, test_scalar_replace) {
357   auto prim_scalar_add = prim::kPrimScalarAdd;
358 
359   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
360   ParameterPtr x = func_graph->add_parameter();
361   ParameterPtr y = func_graph->add_parameter();
362   std::vector<AnfNodePtr> inputs;
363   inputs.push_back(NewValueNode(prim_scalar_add));
364   inputs.push_back(x);
365   inputs.push_back(y);
366   CNodePtr cnode_add = func_graph->NewCNode(inputs);
367   inputs.clear();
368   inputs.push_back(NewValueNode(prim::kPrimReturn));
369   inputs.push_back(cnode_add);
370   CNodePtr cnode_return = func_graph->NewCNode(inputs);
371   func_graph->set_return(cnode_return);
372 
373   auto mng = Manage(func_graph);
374   std::cout << "start " << x->ToString() << std::endl;
375   mng->Replace(cnode_add, x);
376 }
377 
378 TEST_F(TestManager, test_nested_manual) {
379   auto graphs = MakeNestedGraph();
380   auto f = graphs[0];
381   auto g = graphs[1];
382 
383   auto mng = Manage(f);
384 
385   ASSERT_EQ(6, mng->all_nodes().size());
386   ASSERT_EQ(2, mng->func_graphs().size());
387   ASSERT_EQ(4, mng->node_users().size());
388   ASSERT_EQ(1, mng->roots().size());
389   CheckAnalysisSize(mng);
390 
391   ASSERT_EQ(2, f->nodes().size());
392   ASSERT_EQ(1, g->nodes().size());
393 
394   auto &users = mng->node_users();
395   for (auto &iter : users) {
396     ASSERT_EQ(1, iter.second.size());
397   }
398 
399   ASSERT_EQ(1, f->func_graphs_used().size());
400   ASSERT_EQ(0, g->func_graphs_used().size());
401 
402   ASSERT_EQ(0, f->free_variables().size());
403   ASSERT_EQ(1, g->free_variables().size());
404 
405   auto fv_total = mng->free_variables_total();
406   ASSERT_EQ(0, fv_total[f].size());
407   ASSERT_EQ(1, fv_total[g].size());
408 
409   ASSERT_EQ(0, f->func_graph_cnodes_index().size());
410   ASSERT_EQ(1, g->func_graph_cnodes_index().size());
411 }
412 
413 TEST_F(TestManager, test_deep_nested2_manual) {
414   // create parser
415   FuncGraphPtr func_graph = getPyFun("test_custom");
416   return;
417 
418   // parse ast to func graph
419   FuncGraphPtr gfn = BasicClone(func_graph);
420   if (gfn == nullptr) {
421     return;
422   }
423 
424   auto mng = Manage(gfn);
425 
426   ASSERT_EQ(3, mng->func_graphs().size());
427   ASSERT_EQ(1, mng->roots().size());
428   ASSERT_EQ(4, gfn->nodes().size());
429   ASSERT_EQ(20, mng->all_nodes().size());
430   ASSERT_EQ(25, mng->node_users().size());
431   CheckAnalysisSize(mng);
432 }
433 
434 TEST_F(TestManager, test_deep_nested_manual) {
435   FuncGraphPtr f = std::make_shared<FuncGraph>();
436   FuncGraphPtr fg = std::make_shared<FuncGraph>();
437   FuncGraphPtr h = std::make_shared<FuncGraph>();
438 
439   ParameterPtr x = f->add_parameter();
440   ParameterPtr y = f->add_parameter();
441   ParameterPtr z = f->add_parameter();
442 
443   std::vector<AnfNodePtr> inputs;
444   inputs.push_back(NewValueNode(fg));
445   inputs.push_back(x);
446   inputs.push_back(y);
447   CNodePtr cnode_1 = f->NewCNode(inputs);
448 
449   inputs.clear();
450   inputs.push_back(cnode_1);
451   inputs.push_back(NewValueNode(prim::kPrimReturn));
452   CNodePtr cnode_0 = f->NewCNode(inputs);
453   f->set_return(cnode_0);
454 
455   ParameterPtr x1 = fg->add_parameter();
456   ParameterPtr y1 = fg->add_parameter();
457   inputs.clear();
458   inputs.push_back(NewValueNode(h));
459   inputs.push_back(x1);
460   CNodePtr cnode_3 = fg->NewCNode(inputs);
461 
462   inputs.clear();
463   inputs.push_back(cnode_3);
464   inputs.push_back(NewValueNode(prim::kPrimReturn));
465   CNodePtr cnode_2 = fg->NewCNode(inputs);
466   fg->set_return(cnode_2);
467 
468   ParameterPtr x2 = h->add_parameter();
469 
470   inputs.clear();
471   inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
472   inputs.push_back(x2);
473   inputs.push_back(y1);
474   CNodePtr cnode_6 = h->NewCNode(inputs);
475 
476   inputs.clear();
477   inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
478   inputs.push_back(z);
479   inputs.push_back(cnode_6);
480   CNodePtr cnode_5 = h->NewCNode(inputs);
481 
482   inputs.clear();
483   inputs.push_back(cnode_5);
484   inputs.push_back(NewValueNode(prim::kPrimReturn));
485   CNodePtr cnode_4 = h->NewCNode(inputs);
486   h->set_return(cnode_4);
487 
488   auto mng = Manage(f);
489 
490   ASSERT_EQ(3, mng->func_graphs().size());
491   ASSERT_EQ(1, mng->roots().size());
492   ASSERT_EQ(20, mng->all_nodes().size());
493   CheckAnalysisSize(mng);
494 }
495 
496 TEST_F(TestManager, test_parent1_manual) {
497   FuncGraphPtr fg = std::make_shared<FuncGraph>();
498 
499   Parameter param(fg);
500   std::vector<AnfNodePtr> params;
501   CNodePtr app = std::make_shared<CNode>(params, fg);
502   fg->set_return(app);
503   fg->set_parameters(params);
504 
505   std::shared_ptr<FuncGraphManager> manager = MakeManager();
506   manager->AddFuncGraph(fg, true);
507   FuncGraphPtr p = fg->parent();
508   assert(p == nullptr);
509 }
510 
511 TEST_F(TestManager, test_parent_manual) {
512   auto prim_scalar_add = prim::kPrimScalarAdd;
513   FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
514 
515   std::shared_ptr<FuncGraphManager> manager = MakeManager();
516   manager->AddFuncGraph(fg);
517   FuncGraphPtr p = fg->parent();
518   assert(p == nullptr);
519 }
520 
521 TEST_F(TestManager, test_flat) {
522   std::vector<std::shared_ptr<Stage>> stages;
523   std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
524   std::map<std::string, int> size_list;
525   size_list["nodes"] = 2;
526 }
527 
528 TEST_F(TestManager, test_nested) {
529   std::vector<std::shared_ptr<Stage>> stages;
530   std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
531   std::map<std::string, int> size_list;
532   return;
533 }
534 
535 TEST_F(TestManager, test_calls) {
536   std::vector<std::shared_ptr<Stage>> stages;
537   std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
538                                     "fvs_direct=h:a", "fvs_total=h:a; g:h"};
539   std::map<std::string, int> size_list;
540   return;
541 }
542 
543 TEST_F(TestManager, test_unused_param) {
544   std::vector<std::shared_ptr<Stage>> stages;
545   std::vector<std::string> specs = {"nodes=X:x,y"};
546   std::map<std::string, int> size_list;
547 }
548 
549 TEST_F(TestManager, test_cannot_replace_return) {
550   FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
551   ASSERT_NE(fg, nullptr);
552 
553   auto mng = Manage(fg);
554   ASSERT_EQ(fg->manager(), mng);
555 
556   ASSERT_NE(mng, nullptr);
557   ASSERT_GT(fg->parameters().size(), 0);
558   ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
559 }
560 
561 TEST_F(TestManager, test_weak_manager) {
562   FuncGraphPtr fg = getPyFun("ir_get_fn");
563 
564   auto mng1 = MakeManager({fg}, false);
565   ASSERT_EQ(fg->manager(), nullptr);
566   auto mng2 = MakeManager({fg}, true);
567   ASSERT_EQ(fg->manager(), mng2);
568   auto mng3 = MakeManager({fg}, false);
569   ASSERT_EQ(fg->manager(), mng2);
570 }
571 
572 TEST_F(TestManager, test_drop_root) {
573   FuncGraphPtr fg = getPyFun("ir_get_fn");
574 
575   auto mng = Manage(fg);
576   const auto &fgs = mng->func_graphs();
577   ASSERT_TRUE(fgs.contains(fg));
578   FuncGraphSet s;
579   s.add(fg);
580   mng->MaybeDropFuncGraphs(s);
581   ASSERT_TRUE(fgs.contains(fg));
582 }
583 
584 TEST_F(TestManager, test_keep_roots) {
585   FuncGraphPtr fg1 = getPyFun("ir_get_fn");
586   FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
587 
588   auto mng = Manage(fg1);
589   ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
590   ASSERT_TRUE(mng->func_graphs().contains(fg1));
591 
592   mng->AddFuncGraph(fg2);
593   ASSERT_EQ(mng->func_graphs().size(), 2);
594   ASSERT_TRUE(mng->func_graphs().contains(fg2));
595 
596   mng->KeepRoots();
597   ASSERT_EQ(mng->func_graphs().size(), 1);
598   ASSERT_TRUE(mng->func_graphs().contains(fg1));
599 
600   mng->KeepRoots({fg2});
601   ASSERT_EQ(mng->func_graphs().size(), 1);
602   ASSERT_TRUE(mng->func_graphs().contains(fg2));
603 }
604 
605 TEST_F(TestManager, test_keep_roots_recursion) {
606   return;
607 
608   FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
609   ASSERT_NE(fg, nullptr);
610   auto mng = Manage(fg);
611   parse::ResolveAll(mng);
612 
613   ASSERT_NE(mng, nullptr);
614   ASSERT_EQ(mng->func_graphs().size(), 4);
615 
616   ASSERT_GT(fg->parameters().size(), 0);
617   mng->Replace(fg->output(), fg->parameters()[0]);
618   ASSERT_EQ(mng->func_graphs().size(), 3);
619 
620   mng->KeepRoots();
621   ASSERT_EQ(mng->func_graphs().size(), 1);
622 }
623 
624 TEST_F(TestManager, test_add_edge_replace) {
625   // fg(x, y, u):
626   //    x1 = load(x, u)
627   //    a = add(x1, y)
628   //    u1 = update_state(u, x1);
629   //    out = depend(a, u1)
630   //    return out
631   FuncGraphPtr fg = std::make_shared<FuncGraph>();
632   auto x = fg->add_parameter();
633   auto y = fg->add_parameter();
634   auto u = fg->add_parameter();
635   auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
636   auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
637   auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
638   auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
639   fg->set_output(out);
640 
641   // Create manager.
642   auto mgr = Manage(fg);
643   ASSERT_NE(mgr, nullptr);
644 
645   // Before AddEdge.
646   //    a = add(x1, y)
647   //    u1 = update_state(u, x1);
648   //    out = depend(a, u1)
649   auto a_users = mgr->node_users()[a];
650   ASSERT_EQ(a_users.size(), 1);
651 
652   mgr->AddEdge(u1, a);
653 
654   // After AddEdge.
655   //    a = add(x1, y)
656   //    u1 = update_state(u, x1, a);
657   //    out = depend(a, u1)
658   a_users = mgr->node_users()[a];
659   ASSERT_EQ(a_users.size(), 2);
660 
661   // Remove edge by replace update_state.
662   auto u2 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
663   mgr->Replace(u1, u2);
664 
665   // After replace update_state.
666   //    a = add(x1, y)
667   //    u2 = update_state(u, x1);
668   //    out = depend(a, u2)
669   a_users = mgr->node_users()[a];
670   ASSERT_EQ(a_users.size(), 1);
671 
672   mgr->AddEdge(u2, a);
673 
674   // After AddEdge to u2.
675   //    a = add(x1, y)
676   //    u2 = update_state(u, x1, a);
677   //    out = depend(a, u2)
678   a_users = mgr->node_users()[a];
679   ASSERT_EQ(a_users.size(), 2);
680 }
681 
682 TEST_F(TestManager, test_add_edge_replace_new) {
683   // fg(x, y, u):
684   //    x1 = load(x, u)
685   //    a = add(x1, y)
686   //    u1 = update_state(u, x1);
687   //    out = depend(a, u1)
688   //    return out
689   FuncGraphPtr fg = std::make_shared<FuncGraph>();
690   auto x = fg->add_parameter();
691   auto y = fg->add_parameter();
692   auto u = fg->add_parameter();
693   auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
694   auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
695   auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
696   auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
697   fg->set_output(out);
698 
699   // Create manager.
700   auto mgr = Manage(fg);
701   ASSERT_NE(mgr, nullptr);
702 
703   auto new_add = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
704   mgr->AddEdge(u1, new_add);
705 
706   //    x1 = load(x, u)
707   //    a = add(x1, y)
708   //    new_add = add(x1, y)
709   //    u1 = update_state(u, x1, new_add);
710   //    out = depend(a, u1)
711   //    return out
712   ASSERT_EQ(mgr->node_users()[x1].size(), 3);
713   ASSERT_EQ(mgr->node_users()[y].size(), 2);
714   ASSERT_EQ(mgr->node_users()[new_add].size(), 1);
715 
716   auto new_add1 = fg->NewCNode({NewValueNode(prim::kPrimAdd), y, y});
717   mgr->Replace(new_add, new_add1);
718 
719   //    x1 = load(x, u)
720   //    a = add(x1, y)
721   //    new_add1 = add(y, y)
722   //    u1 = update_state(u, x1, new_add1);
723   //    out = depend(a, u1)
724   //    return out
725   ASSERT_EQ(mgr->node_users()[x1].size(), 2);
726   ASSERT_EQ(mgr->node_users()[y].size(), 3);
727   ASSERT_EQ(mgr->node_users()[new_add].size(), 0);
728   ASSERT_EQ(mgr->node_users()[new_add1].size(), 1);
729 }
730 
731 TEST_F(TestManager, test_set_edge) {
732   // fg(x, y, u):
733   //    t = make_tuple(x, y)
734   //    d = depend(t, u);
735   //    get_item = tuple_get_item(d, 0)
736   //    return get_item
737   FuncGraphPtr fg = std::make_shared<FuncGraph>();
738   auto x = fg->add_parameter();
739   auto y = fg->add_parameter();
740   auto u = fg->add_parameter();
741   auto t = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), x, y});
742   auto d = fg->NewCNode({NewValueNode(prim::kPrimDepend), t, u});
743   auto get_item = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), d, NewValueNode(0)});
744   fg->set_output(get_item);
745 
746   // Create manager.
747   auto mgr = Manage(fg);
748   ASSERT_NE(mgr, nullptr);
749 
750   // Before SetEdge.
751   ASSERT_EQ(mgr->node_users()[t].size(), 1);
752   ASSERT_EQ(mgr->node_users()[d].size(), 1);
753 
754   auto depend = get_item->input(1)->cast<CNodePtr>();
755   mgr->SetEdge(get_item, 1, depend->input(1));
756 
757   // After SetEdge.
758   ASSERT_EQ(get_item->input(1), t);
759   ASSERT_EQ(depend->input(1), t);
760   ASSERT_EQ(mgr->node_users()[d].size(), 0);
761   ASSERT_EQ(mgr->node_users()[t].size(), 1);  // depend removed.
762   ASSERT_EQ(mgr->node_users()[t].front().first, get_item);
763 }
764 
765 }  // namespace mindspore
766