1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common_test.h"
17 #include "common/py_func_graph_fetcher.h"
18 #include "ir/dtype.h"
19 #include "ir/manager.h"
20 #include "ir/func_graph_cloner.h"
21 #include "pipeline/jit/parse/parse.h"
22 #include "frontend/operator/ops.h"
23 #include "utils/log_adapter.h"
24 #include "debug/draw.h"
25 #include "utils/label.h"
26
27 namespace mindspore {
28
29 namespace {
SplitString(std::string str,std::string pattern)30 std::vector<std::string> SplitString(std::string str, std::string pattern) {
31 std::string::size_type pos;
32 std::vector<std::string> result;
33 str += pattern;
34 std::string::size_type size = str.size();
35
36 for (std::string::size_type i = 0; i < size; ++i) {
37 pos = str.find(pattern, i);
38 if (pos < size) {
39 std::string s = str.substr(i, pos - i);
40 result.push_back(s);
41 i = pos + pattern.size() - 1;
42 }
43 }
44
45 return result;
46 }
47 } // namespace
48 using std::dynamic_pointer_cast;
49
50 using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>;
51 using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
52
53 class NestingSpecs;
54
55 class Stage {
56 public:
Stage(std::vector<std::string> specs)57 explicit Stage(std::vector<std::string> specs) {
58 for (auto arg : specs) {
59 auto spec = SplitString(arg, "=");
60 if (spec.size() <= 1) {
61 continue;
62 }
63 std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]);
64 specs_[ToFullString(spec[0])] = nesting;
65 }
66 }
67
~Stage()68 ~Stage() {}
69
subs()70 std::map<std::string, std::string> &subs() { return subs_; }
71
set_subs(const std::map<std::string,std::string> & subs)72 void set_subs(const std::map<std::string, std::string> &subs) { subs_ = subs; }
73
74 private:
ToFullString(std::string s)75 std::string ToFullString(std::string s) {
76 if (s.find("fv") != std::string::npos) {
77 s = s.replace(s.find("fv"), 2, "free_variable");
78 }
79
80 if (s.find("deps") != std::string::npos) {
81 s = s.replace(s.find("deps"), 4, "dependencies");
82 }
83
84 return s;
85 }
86
87 std::map<std::string, std::shared_ptr<NestingSpecs>> specs_;
88 std::map<std::string, std::string> subs_;
89 };
90
91 class NestingSpecs {
92 public:
NestingSpecs(Stage * stage,std::string specs)93 NestingSpecs(Stage *stage, std::string specs) : stage_(stage) { ParseSpecs(specs); }
94
~NestingSpecs()95 ~NestingSpecs() {}
96
Name(Any node)97 std::string Name(Any node) {
98 std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info());
99 if (stage_->subs().find(name) != stage_->subs().end()) {
100 return stage_->subs()[name];
101 }
102
103 return name;
104 }
105
Check(std::shared_ptr<DepComputer> results)106 void Check(std::shared_ptr<DepComputer> results) {
107 if (expected_.empty() && expected_recursive_.empty()) {
108 return;
109 }
110
111 auto parent = dynamic_pointer_cast<ParentComputer>(results);
112 if (parent != nullptr) {
113 CheckParent(parent);
114 return;
115 }
116
117 auto recursive = dynamic_pointer_cast<RecursiveComputer>(results);
118 if (recursive != nullptr) {
119 CheckRecursive(recursive);
120 return;
121 }
122 }
123
124 private:
ParseSpecs(std::string specs)125 void ParseSpecs(std::string specs) {
126 if (specs.empty()) {
127 return;
128 }
129
130 std::vector<std::string> str_list = SplitString(specs, ";");
131 for (auto spec : str_list) {
132 spec.erase(0, spec.find_first_not_of(" "));
133 spec.erase(spec.find_last_not_of(" ") + 1);
134 if (spec.empty()) {
135 continue;
136 }
137 if (spec.find("->") != std::string::npos) {
138 auto substr = SplitString(spec, "->");
139 ASSERT_GT(substr.size(), 1);
140 auto key = substr[0];
141 auto value = substr[1];
142 if (!value.empty()) {
143 expected_[key] = {value};
144 }
145 } else if (spec.find(":") != std::string::npos) {
146 auto substr = SplitString(spec, ":");
147 ASSERT_GT(substr.size(), 1);
148 auto key = substr[0];
149 auto values = SplitString(substr[1], ",");
150 std::set<std::string> values_set(values.begin(), values.end());
151 if (!values_set.empty()) {
152 expected_[key] = values_set;
153 }
154 } else {
155 expected_recursive_[spec] = true;
156 }
157 }
158 }
159
CheckParent(std::shared_ptr<ParentComputer> results)160 void CheckParent(std::shared_ptr<ParentComputer> results) {
161 std::map<std::string, std::set<std::string>> clean_results;
162 for (auto &iter : results->parent_analysis()) {
163 auto key = iter.first;
164 auto value = iter.second;
165 if (key == nullptr) {
166 continue;
167 }
168 std::string k = Name(key);
169
170 std::set<std::string> v;
171 if (value != nullptr && !Name(value).empty()) {
172 v.insert(Name(value));
173 }
174
175 if (!v.empty()) {
176 clean_results[k] = v;
177 }
178 }
179
180 ASSERT_EQ(clean_results, expected_);
181 }
182
CheckRecursive(std::shared_ptr<RecursiveComputer> results)183 void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
184 std::map<std::string, bool> clean_results;
185 for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
186 auto key = iter->first;
187 auto value = iter->second;
188 if (key == nullptr) {
189 continue;
190 }
191 std::string k = Name(key);
192
193 clean_results[k] = value;
194 }
195
196 ASSERT_EQ(clean_results, expected_recursive_);
197 }
198
199 private:
200 Stage *stage_;
201 std::map<std::string, std::set<std::string>> expected_;
202 std::map<std::string, bool> expected_recursive_;
203 };
204
CheckUsers(std::shared_ptr<FuncGraphManager> manager)205 bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) {
206 for (auto node : manager->all_nodes()) {
207 if (node->isa<CNode>()) {
208 auto &inputs = node->cast<CNodePtr>()->inputs();
209 for (size_t i = 0; i < inputs.size(); ++i) {
210 auto inp = inputs[i];
211 if (!manager->all_nodes().contains(inp)) {
212 return false;
213 }
214
215 if (manager->node_users().find(inp) != manager->node_users().end()) {
216 auto users = manager->node_users()[inp];
217 if (!users.contains(make_pair(node, i))) {
218 return false;
219 }
220 }
221 }
222 }
223
224 if (manager->node_users().find(node) != manager->node_users().end()) {
225 auto users = manager->node_users()[node];
226 for (auto iter = users.begin(); iter != users.end(); ++iter) {
227 auto node2 = iter->first;
228 auto key = iter->second;
229 if (!manager->all_nodes().contains(node2)) {
230 return false;
231 }
232 if (node2->cast<CNodePtr>()->input(key) != node) {
233 return false;
234 }
235 }
236 }
237 }
238
239 return true;
240 }
241
242 class TestManager : public UT::Common {
243 public:
TestManager()244 TestManager() : getPyFun("gtest_input.ir.manager_test") {}
245
246 void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng);
247
248 public:
249 std::vector<PrimitivePtr> swaps;
250 UT::PyFuncGraphFetcher getPyFun;
251 };
252
MakeFuncGraph(PrimitivePtr prim)253 FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) {
254 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
255 ParameterPtr x = func_graph->add_parameter();
256 ParameterPtr y = func_graph->add_parameter();
257 std::vector<AnfNodePtr> inputs;
258 inputs.push_back(NewValueNode(prim));
259 inputs.push_back(x);
260 inputs.push_back(y);
261 CNodePtr cnode_add = func_graph->NewCNode(inputs);
262 inputs.clear();
263 inputs.push_back(NewValueNode(prim::kPrimReturn));
264 inputs.push_back(cnode_add);
265 CNodePtr cnode_return = func_graph->NewCNode(inputs);
266 func_graph->set_return(cnode_return);
267 return func_graph;
268 }
269
MakeNestedGraph()270 std::vector<FuncGraphPtr> MakeNestedGraph() {
271 /*
272 *def f(x):
273 * def g():
274 * return x
275 * return g
276 */
277 FuncGraphPtr f = std::make_shared<FuncGraph>();
278 FuncGraphPtr fg = std::make_shared<FuncGraph>();
279
280 ParameterPtr x = f->add_parameter();
281
282 std::vector<AnfNodePtr> inputs;
283 inputs.push_back(NewValueNode(fg));
284 inputs.push_back(NewValueNode(prim::kPrimReturn));
285
286 CNodePtr cnode_f = f->NewCNode(inputs);
287 f->set_return(cnode_f);
288
289 inputs.clear();
290 inputs.push_back(NewValueNode(prim::kPrimReturn));
291 inputs.push_back(x);
292 CNodePtr cnode_g = fg->NewCNode(inputs);
293 fg->set_return(cnode_g);
294
295 std::vector<FuncGraphPtr> result = {f, fg};
296 return result;
297 }
298
MakeNestedGraph2()299 std::vector<FuncGraphPtr> MakeNestedGraph2() {
300 /* build a closure func_graph */
301 /*
302 *def foo(x, y):
303 * def bar(x1):
304 * return x1 + y
305 * return bar(x)
306 */
307 FuncGraphPtr graph_foo = std::make_shared<FuncGraph>();
308 ParameterPtr x = graph_foo->add_parameter();
309 ParameterPtr y = graph_foo->add_parameter();
310
311 std::vector<AnfNodePtr> inputs;
312
313 // build func_graph bar
314 FuncGraphPtr graph_bar = std::make_shared<FuncGraph>();
315 ParameterPtr x1 = graph_bar->add_parameter();
316 inputs.clear();
317 inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
318 inputs.push_back(x1);
319 inputs.push_back(y);
320 CNodePtr cnode_add = graph_bar->NewCNode(inputs);
321 inputs.clear();
322 inputs.push_back(NewValueNode(prim::kPrimReturn));
323 inputs.push_back(cnode_add);
324 CNodePtr cnode_return = graph_bar->NewCNode(inputs);
325 graph_bar->set_return(cnode_return);
326
327 // build func_graph foo
328 inputs.clear();
329 inputs.push_back(NewValueNode(graph_bar));
330 inputs.push_back(x);
331 CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs);
332
333 inputs.clear();
334 inputs.push_back(NewValueNode(prim::kPrimReturn));
335 inputs.push_back(cnode_graph_bar);
336 cnode_return = graph_foo->NewCNode(inputs);
337 graph_foo->set_return(cnode_return);
338
339 std::vector<FuncGraphPtr> result = {graph_foo, graph_bar};
340 return result;
341 }
342
343 // Add TestManager::CheckManager function to checkout the result
CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng)344 void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
345 auto size = mng->func_graphs().size();
346
347 ASSERT_EQ(size, mng->free_variables_total().size());
348 }
349
TEST_F(TestManager,test_scalar_add_manual)350 TEST_F(TestManager, test_scalar_add_manual) {
351 auto prim_scalar_add = prim::kPrimScalarAdd;
352 FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
353 auto mng = Manage(func_graph);
354 }
355
TEST_F(TestManager,test_scalar_replace)356 TEST_F(TestManager, test_scalar_replace) {
357 auto prim_scalar_add = prim::kPrimScalarAdd;
358
359 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
360 ParameterPtr x = func_graph->add_parameter();
361 ParameterPtr y = func_graph->add_parameter();
362 std::vector<AnfNodePtr> inputs;
363 inputs.push_back(NewValueNode(prim_scalar_add));
364 inputs.push_back(x);
365 inputs.push_back(y);
366 CNodePtr cnode_add = func_graph->NewCNode(inputs);
367 inputs.clear();
368 inputs.push_back(NewValueNode(prim::kPrimReturn));
369 inputs.push_back(cnode_add);
370 CNodePtr cnode_return = func_graph->NewCNode(inputs);
371 func_graph->set_return(cnode_return);
372
373 auto mng = Manage(func_graph);
374 std::cout << "start " << x->ToString() << std::endl;
375 mng->Replace(cnode_add, x);
376 }
377
TEST_F(TestManager,test_nested_manual)378 TEST_F(TestManager, test_nested_manual) {
379 auto graphs = MakeNestedGraph();
380 auto f = graphs[0];
381 auto g = graphs[1];
382
383 auto mng = Manage(f);
384
385 ASSERT_EQ(6, mng->all_nodes().size());
386 ASSERT_EQ(2, mng->func_graphs().size());
387 ASSERT_EQ(4, mng->node_users().size());
388 ASSERT_EQ(1, mng->roots().size());
389 CheckAnalysisSize(mng);
390
391 ASSERT_EQ(2, f->nodes().size());
392 ASSERT_EQ(1, g->nodes().size());
393
394 auto &users = mng->node_users();
395 for (auto &iter : users) {
396 ASSERT_EQ(1, iter.second.size());
397 }
398
399 ASSERT_EQ(1, f->func_graphs_used().size());
400 ASSERT_EQ(0, g->func_graphs_used().size());
401
402 ASSERT_EQ(0, f->free_variables().size());
403 ASSERT_EQ(1, g->free_variables().size());
404
405 auto fv_total = mng->free_variables_total();
406 ASSERT_EQ(0, fv_total[f].size());
407 ASSERT_EQ(1, fv_total[g].size());
408
409 ASSERT_EQ(0, f->func_graph_cnodes_index().size());
410 ASSERT_EQ(1, g->func_graph_cnodes_index().size());
411 }
412
TEST_F(TestManager,test_deep_nested2_manual)413 TEST_F(TestManager, test_deep_nested2_manual) {
414 // create parser
415 FuncGraphPtr func_graph = getPyFun("test_custom");
416 return;
417
418 // parse ast to func graph
419 FuncGraphPtr gfn = BasicClone(func_graph);
420 if (gfn == nullptr) {
421 return;
422 }
423
424 auto mng = Manage(gfn);
425
426 ASSERT_EQ(3, mng->func_graphs().size());
427 ASSERT_EQ(1, mng->roots().size());
428 ASSERT_EQ(4, gfn->nodes().size());
429 ASSERT_EQ(20, mng->all_nodes().size());
430 ASSERT_EQ(25, mng->node_users().size());
431 CheckAnalysisSize(mng);
432 }
433
TEST_F(TestManager,test_deep_nested_manual)434 TEST_F(TestManager, test_deep_nested_manual) {
435 FuncGraphPtr f = std::make_shared<FuncGraph>();
436 FuncGraphPtr fg = std::make_shared<FuncGraph>();
437 FuncGraphPtr h = std::make_shared<FuncGraph>();
438
439 ParameterPtr x = f->add_parameter();
440 ParameterPtr y = f->add_parameter();
441 ParameterPtr z = f->add_parameter();
442
443 std::vector<AnfNodePtr> inputs;
444 inputs.push_back(NewValueNode(fg));
445 inputs.push_back(x);
446 inputs.push_back(y);
447 CNodePtr cnode_1 = f->NewCNode(inputs);
448
449 inputs.clear();
450 inputs.push_back(cnode_1);
451 inputs.push_back(NewValueNode(prim::kPrimReturn));
452 CNodePtr cnode_0 = f->NewCNode(inputs);
453 f->set_return(cnode_0);
454
455 ParameterPtr x1 = fg->add_parameter();
456 ParameterPtr y1 = fg->add_parameter();
457 inputs.clear();
458 inputs.push_back(NewValueNode(h));
459 inputs.push_back(x1);
460 CNodePtr cnode_3 = fg->NewCNode(inputs);
461
462 inputs.clear();
463 inputs.push_back(cnode_3);
464 inputs.push_back(NewValueNode(prim::kPrimReturn));
465 CNodePtr cnode_2 = fg->NewCNode(inputs);
466 fg->set_return(cnode_2);
467
468 ParameterPtr x2 = h->add_parameter();
469
470 inputs.clear();
471 inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
472 inputs.push_back(x2);
473 inputs.push_back(y1);
474 CNodePtr cnode_6 = h->NewCNode(inputs);
475
476 inputs.clear();
477 inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
478 inputs.push_back(z);
479 inputs.push_back(cnode_6);
480 CNodePtr cnode_5 = h->NewCNode(inputs);
481
482 inputs.clear();
483 inputs.push_back(cnode_5);
484 inputs.push_back(NewValueNode(prim::kPrimReturn));
485 CNodePtr cnode_4 = h->NewCNode(inputs);
486 h->set_return(cnode_4);
487
488 auto mng = Manage(f);
489
490 ASSERT_EQ(3, mng->func_graphs().size());
491 ASSERT_EQ(1, mng->roots().size());
492 ASSERT_EQ(20, mng->all_nodes().size());
493 CheckAnalysisSize(mng);
494 }
495
TEST_F(TestManager,test_parent1_manual)496 TEST_F(TestManager, test_parent1_manual) {
497 FuncGraphPtr fg = std::make_shared<FuncGraph>();
498
499 Parameter param(fg);
500 std::vector<AnfNodePtr> params;
501 CNodePtr app = std::make_shared<CNode>(params, fg);
502 fg->set_return(app);
503 fg->set_parameters(params);
504
505 std::shared_ptr<FuncGraphManager> manager = MakeManager();
506 manager->AddFuncGraph(fg, true);
507 FuncGraphPtr p = fg->parent();
508 assert(p == nullptr);
509 }
510
TEST_F(TestManager,test_parent_manual)511 TEST_F(TestManager, test_parent_manual) {
512 auto prim_scalar_add = prim::kPrimScalarAdd;
513 FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add);
514
515 std::shared_ptr<FuncGraphManager> manager = MakeManager();
516 manager->AddFuncGraph(fg);
517 FuncGraphPtr p = fg->parent();
518 assert(p == nullptr);
519 }
520
TEST_F(TestManager,test_flat)521 TEST_F(TestManager, test_flat) {
522 std::vector<std::shared_ptr<Stage>> stages;
523 std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="};
524 std::map<std::string, int> size_list;
525 size_list["nodes"] = 2;
526 }
527
TEST_F(TestManager,test_nested)528 TEST_F(TestManager, test_nested) {
529 std::vector<std::shared_ptr<Stage>> stages;
530 std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"};
531 std::map<std::string, int> size_list;
532 return;
533 }
534
TEST_F(TestManager,test_calls)535 TEST_F(TestManager, test_calls) {
536 std::vector<std::shared_ptr<Stage>> stages;
537 std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h",
538 "fvs_direct=h:a", "fvs_total=h:a; g:h"};
539 std::map<std::string, int> size_list;
540 return;
541 }
542
TEST_F(TestManager,test_unused_param)543 TEST_F(TestManager, test_unused_param) {
544 std::vector<std::shared_ptr<Stage>> stages;
545 std::vector<std::string> specs = {"nodes=X:x,y"};
546 std::map<std::string, int> size_list;
547 }
548
TEST_F(TestManager,test_cannot_replace_return)549 TEST_F(TestManager, test_cannot_replace_return) {
550 FuncGraphPtr fg = getPyFun("test_cannot_replace_return");
551 ASSERT_NE(fg, nullptr);
552
553 auto mng = Manage(fg);
554 ASSERT_EQ(fg->manager(), mng);
555
556 ASSERT_NE(mng, nullptr);
557 ASSERT_GT(fg->parameters().size(), 0);
558 ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0]));
559 }
560
TEST_F(TestManager,test_weak_manager)561 TEST_F(TestManager, test_weak_manager) {
562 FuncGraphPtr fg = getPyFun("ir_get_fn");
563
564 auto mng1 = MakeManager({fg}, false);
565 ASSERT_EQ(fg->manager(), nullptr);
566 auto mng2 = MakeManager({fg}, true);
567 ASSERT_EQ(fg->manager(), mng2);
568 auto mng3 = MakeManager({fg}, false);
569 ASSERT_EQ(fg->manager(), mng2);
570 }
571
TEST_F(TestManager,test_drop_root)572 TEST_F(TestManager, test_drop_root) {
573 FuncGraphPtr fg = getPyFun("ir_get_fn");
574
575 auto mng = Manage(fg);
576 const auto &fgs = mng->func_graphs();
577 ASSERT_TRUE(fgs.contains(fg));
578 FuncGraphSet s;
579 s.add(fg);
580 mng->MaybeDropFuncGraphs(s);
581 ASSERT_TRUE(fgs.contains(fg));
582 }
583
TEST_F(TestManager,test_keep_roots)584 TEST_F(TestManager, test_keep_roots) {
585 FuncGraphPtr fg1 = getPyFun("ir_get_fn");
586 FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return");
587
588 auto mng = Manage(fg1);
589 ASSERT_EQ(mng->func_graphs().size(), (size_t)1);
590 ASSERT_TRUE(mng->func_graphs().contains(fg1));
591
592 mng->AddFuncGraph(fg2);
593 ASSERT_EQ(mng->func_graphs().size(), 2);
594 ASSERT_TRUE(mng->func_graphs().contains(fg2));
595
596 mng->KeepRoots();
597 ASSERT_EQ(mng->func_graphs().size(), 1);
598 ASSERT_TRUE(mng->func_graphs().contains(fg1));
599
600 mng->KeepRoots({fg2});
601 ASSERT_EQ(mng->func_graphs().size(), 1);
602 ASSERT_TRUE(mng->func_graphs().contains(fg2));
603 }
604
TEST_F(TestManager,test_keep_roots_recursion)605 TEST_F(TestManager, test_keep_roots_recursion) {
606 return;
607
608 FuncGraphPtr fg = getPyFun("test_keep_roots_recursion");
609 ASSERT_NE(fg, nullptr);
610 auto mng = Manage(fg);
611 parse::ResolveAll(mng);
612
613 ASSERT_NE(mng, nullptr);
614 ASSERT_EQ(mng->func_graphs().size(), 4);
615
616 ASSERT_GT(fg->parameters().size(), 0);
617 mng->Replace(fg->output(), fg->parameters()[0]);
618 ASSERT_EQ(mng->func_graphs().size(), 3);
619
620 mng->KeepRoots();
621 ASSERT_EQ(mng->func_graphs().size(), 1);
622 }
623
TEST_F(TestManager,test_add_edge_replace)624 TEST_F(TestManager, test_add_edge_replace) {
625 // fg(x, y, u):
626 // x1 = load(x, u)
627 // a = add(x1, y)
628 // u1 = update_state(u, x1);
629 // out = depend(a, u1)
630 // return out
631 FuncGraphPtr fg = std::make_shared<FuncGraph>();
632 auto x = fg->add_parameter();
633 auto y = fg->add_parameter();
634 auto u = fg->add_parameter();
635 auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
636 auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
637 auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
638 auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
639 fg->set_output(out);
640
641 // Create manager.
642 auto mgr = Manage(fg);
643 ASSERT_NE(mgr, nullptr);
644
645 // Before AddEdge.
646 // a = add(x1, y)
647 // u1 = update_state(u, x1);
648 // out = depend(a, u1)
649 auto a_users = mgr->node_users()[a];
650 ASSERT_EQ(a_users.size(), 1);
651
652 mgr->AddEdge(u1, a);
653
654 // After AddEdge.
655 // a = add(x1, y)
656 // u1 = update_state(u, x1, a);
657 // out = depend(a, u1)
658 a_users = mgr->node_users()[a];
659 ASSERT_EQ(a_users.size(), 2);
660
661 // Remove edge by replace update_state.
662 auto u2 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
663 mgr->Replace(u1, u2);
664
665 // After replace update_state.
666 // a = add(x1, y)
667 // u2 = update_state(u, x1);
668 // out = depend(a, u2)
669 a_users = mgr->node_users()[a];
670 ASSERT_EQ(a_users.size(), 1);
671
672 mgr->AddEdge(u2, a);
673
674 // After AddEdge to u2.
675 // a = add(x1, y)
676 // u2 = update_state(u, x1, a);
677 // out = depend(a, u2)
678 a_users = mgr->node_users()[a];
679 ASSERT_EQ(a_users.size(), 2);
680 }
681
TEST_F(TestManager,test_add_edge_replace_new)682 TEST_F(TestManager, test_add_edge_replace_new) {
683 // fg(x, y, u):
684 // x1 = load(x, u)
685 // a = add(x1, y)
686 // u1 = update_state(u, x1);
687 // out = depend(a, u1)
688 // return out
689 FuncGraphPtr fg = std::make_shared<FuncGraph>();
690 auto x = fg->add_parameter();
691 auto y = fg->add_parameter();
692 auto u = fg->add_parameter();
693 auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u});
694 auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
695 auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1});
696 auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1});
697 fg->set_output(out);
698
699 // Create manager.
700 auto mgr = Manage(fg);
701 ASSERT_NE(mgr, nullptr);
702
703 auto new_add = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y});
704 mgr->AddEdge(u1, new_add);
705
706 // x1 = load(x, u)
707 // a = add(x1, y)
708 // new_add = add(x1, y)
709 // u1 = update_state(u, x1, new_add);
710 // out = depend(a, u1)
711 // return out
712 ASSERT_EQ(mgr->node_users()[x1].size(), 3);
713 ASSERT_EQ(mgr->node_users()[y].size(), 2);
714 ASSERT_EQ(mgr->node_users()[new_add].size(), 1);
715
716 auto new_add1 = fg->NewCNode({NewValueNode(prim::kPrimAdd), y, y});
717 mgr->Replace(new_add, new_add1);
718
719 // x1 = load(x, u)
720 // a = add(x1, y)
721 // new_add1 = add(y, y)
722 // u1 = update_state(u, x1, new_add1);
723 // out = depend(a, u1)
724 // return out
725 ASSERT_EQ(mgr->node_users()[x1].size(), 2);
726 ASSERT_EQ(mgr->node_users()[y].size(), 3);
727 ASSERT_EQ(mgr->node_users()[new_add].size(), 0);
728 ASSERT_EQ(mgr->node_users()[new_add1].size(), 1);
729 }
730
TEST_F(TestManager,test_set_edge)731 TEST_F(TestManager, test_set_edge) {
732 // fg(x, y, u):
733 // t = make_tuple(x, y)
734 // d = depend(t, u);
735 // get_item = tuple_get_item(d, 0)
736 // return get_item
737 FuncGraphPtr fg = std::make_shared<FuncGraph>();
738 auto x = fg->add_parameter();
739 auto y = fg->add_parameter();
740 auto u = fg->add_parameter();
741 auto t = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), x, y});
742 auto d = fg->NewCNode({NewValueNode(prim::kPrimDepend), t, u});
743 auto get_item = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), d, NewValueNode(0)});
744 fg->set_output(get_item);
745
746 // Create manager.
747 auto mgr = Manage(fg);
748 ASSERT_NE(mgr, nullptr);
749
750 // Before SetEdge.
751 ASSERT_EQ(mgr->node_users()[t].size(), 1);
752 ASSERT_EQ(mgr->node_users()[d].size(), 1);
753
754 auto depend = get_item->input(1)->cast<CNodePtr>();
755 mgr->SetEdge(get_item, 1, depend->input(1));
756
757 // After SetEdge.
758 ASSERT_EQ(get_item->input(1), t);
759 ASSERT_EQ(depend->input(1), t);
760 ASSERT_EQ(mgr->node_users()[d].size(), 0);
761 ASSERT_EQ(mgr->node_users()[t].size(), 1); // depend removed.
762 ASSERT_EQ(mgr->node_users()[t].front().first, get_item);
763 }
764
765 } // namespace mindspore
766