1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common_test.h" 17 #include "common/py_func_graph_fetcher.h" 18 #include "ir/dtype.h" 19 #include "ir/manager.h" 20 #include "ir/func_graph_cloner.h" 21 #include "pipeline/jit/parse/parse.h" 22 #include "frontend/operator/ops.h" 23 #include "utils/log_adapter.h" 24 #include "debug/draw.h" 25 #include "utils/label.h" 26 27 namespace mindspore { 28 29 namespace { 30 std::vector<std::string> SplitString(std::string str, std::string pattern) { 31 std::string::size_type pos; 32 std::vector<std::string> result; 33 str += pattern; 34 std::string::size_type size = str.size(); 35 36 for (std::string::size_type i = 0; i < size; ++i) { 37 pos = str.find(pattern, i); 38 if (pos < size) { 39 std::string s = str.substr(i, pos - i); 40 result.push_back(s); 41 i = pos + pattern.size() - 1; 42 } 43 } 44 45 return result; 46 } 47 } // namespace 48 using std::dynamic_pointer_cast; 49 50 using TodoList = std::vector<std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>>; 51 using TodoListItem = std::vector<std::pair<std::set<std::pair<AnfNodePtr, int>>, AnfNodePtr>>; 52 53 class NestingSpecs; 54 55 class Stage { 56 public: 57 explicit Stage(std::vector<std::string> specs) { 58 for (auto arg : specs) { 59 auto spec = SplitString(arg, "="); 60 if (spec.size() <= 1) { 61 continue; 62 } 63 std::shared_ptr<NestingSpecs> nesting = std::make_shared<NestingSpecs>(this, spec[1]); 64 specs_[ToFullString(spec[0])] = nesting; 65 } 66 } 67 68 ~Stage() {} 69 70 std::map<std::string, std::string> &subs() { return subs_; } 71 72 void set_subs(const std::map<std::string, std::string> &subs) { subs_ = subs; } 73 74 private: 75 std::string ToFullString(std::string s) { 76 if (s.find("fv") != std::string::npos) { 77 s = s.replace(s.find("fv"), 2, "free_variable"); 78 } 79 80 if (s.find("deps") != std::string::npos) { 81 s = s.replace(s.find("deps"), 4, "dependencies"); 82 } 83 84 return s; 85 } 86 87 std::map<std::string, std::shared_ptr<NestingSpecs>> specs_; 88 std::map<std::string, std::string> subs_; 89 }; 90 91 class NestingSpecs { 92 public: 93 NestingSpecs(Stage *stage, std::string specs) : stage_(stage) { ParseSpecs(specs); } 94 95 ~NestingSpecs() {} 96 97 std::string Name(Any node) { 98 std::string name = label_manage::Label(node.cast<AnfNodePtr>()->debug_info()); 99 if (stage_->subs().find(name) != stage_->subs().end()) { 100 return stage_->subs()[name]; 101 } 102 103 return name; 104 } 105 106 void Check(std::shared_ptr<DepComputer> results) { 107 if (expected_.empty() && expected_recursive_.empty()) { 108 return; 109 } 110 111 auto parent = dynamic_pointer_cast<ParentComputer>(results); 112 if (parent != nullptr) { 113 CheckParent(parent); 114 return; 115 } 116 117 auto recursive = dynamic_pointer_cast<RecursiveComputer>(results); 118 if (recursive != nullptr) { 119 CheckRecursive(recursive); 120 return; 121 } 122 } 123 124 private: 125 void ParseSpecs(std::string specs) { 126 if (specs.empty()) { 127 return; 128 } 129 130 std::vector<std::string> str_list = SplitString(specs, ";"); 131 for (auto spec : str_list) { 132 spec.erase(0, spec.find_first_not_of(" ")); 133 spec.erase(spec.find_last_not_of(" ") + 1); 134 if (spec.empty()) { 135 continue; 136 } 137 if (spec.find("->") != std::string::npos) { 138 auto substr = SplitString(spec, "->"); 139 ASSERT_GT(substr.size(), 1); 140 auto key = substr[0]; 141 auto value = substr[1]; 142 if (!value.empty()) { 143 expected_[key] = {value}; 144 } 145 } else if (spec.find(":") != std::string::npos) { 146 auto substr = SplitString(spec, ":"); 147 ASSERT_GT(substr.size(), 1); 148 auto key = substr[0]; 149 auto values = SplitString(substr[1], ","); 150 std::set<std::string> values_set(values.begin(), values.end()); 151 if (!values_set.empty()) { 152 expected_[key] = values_set; 153 } 154 } else { 155 expected_recursive_[spec] = true; 156 } 157 } 158 } 159 160 void CheckParent(std::shared_ptr<ParentComputer> results) { 161 std::map<std::string, std::set<std::string>> clean_results; 162 for (auto &iter : results->parent_analysis()) { 163 auto key = iter.first; 164 auto value = iter.second; 165 if (key == nullptr) { 166 continue; 167 } 168 std::string k = Name(key); 169 170 std::set<std::string> v; 171 if (value != nullptr && !Name(value).empty()) { 172 v.insert(Name(value)); 173 } 174 175 if (!v.empty()) { 176 clean_results[k] = v; 177 } 178 } 179 180 ASSERT_EQ(clean_results, expected_); 181 } 182 183 void CheckRecursive(std::shared_ptr<RecursiveComputer> results) { 184 std::map<std::string, bool> clean_results; 185 for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) { 186 auto key = iter->first; 187 auto value = iter->second; 188 if (key == nullptr) { 189 continue; 190 } 191 std::string k = Name(key); 192 193 clean_results[k] = value; 194 } 195 196 ASSERT_EQ(clean_results, expected_recursive_); 197 } 198 199 private: 200 Stage *stage_; 201 std::map<std::string, std::set<std::string>> expected_; 202 std::map<std::string, bool> expected_recursive_; 203 }; 204 205 bool CheckUsers(std::shared_ptr<FuncGraphManager> manager) { 206 for (auto node : manager->all_nodes()) { 207 if (node->isa<CNode>()) { 208 auto &inputs = node->cast<CNodePtr>()->inputs(); 209 for (size_t i = 0; i < inputs.size(); ++i) { 210 auto inp = inputs[i]; 211 if (!manager->all_nodes().contains(inp)) { 212 return false; 213 } 214 215 if (manager->node_users().find(inp) != manager->node_users().end()) { 216 auto users = manager->node_users()[inp]; 217 if (!users.contains(make_pair(node, i))) { 218 return false; 219 } 220 } 221 } 222 } 223 224 if (manager->node_users().find(node) != manager->node_users().end()) { 225 auto users = manager->node_users()[node]; 226 for (auto iter = users.begin(); iter != users.end(); ++iter) { 227 auto node2 = iter->first; 228 auto key = iter->second; 229 if (!manager->all_nodes().contains(node2)) { 230 return false; 231 } 232 if (node2->cast<CNodePtr>()->input(key) != node) { 233 return false; 234 } 235 } 236 } 237 } 238 239 return true; 240 } 241 242 class TestManager : public UT::Common { 243 public: 244 TestManager() : getPyFun("gtest_input.ir.manager_test") {} 245 246 void CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng); 247 248 public: 249 std::vector<PrimitivePtr> swaps; 250 UT::PyFuncGraphFetcher getPyFun; 251 }; 252 253 FuncGraphPtr MakeFuncGraph(PrimitivePtr prim) { 254 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 255 ParameterPtr x = func_graph->add_parameter(); 256 ParameterPtr y = func_graph->add_parameter(); 257 std::vector<AnfNodePtr> inputs; 258 inputs.push_back(NewValueNode(prim)); 259 inputs.push_back(x); 260 inputs.push_back(y); 261 CNodePtr cnode_add = func_graph->NewCNode(inputs); 262 inputs.clear(); 263 inputs.push_back(NewValueNode(prim::kPrimReturn)); 264 inputs.push_back(cnode_add); 265 CNodePtr cnode_return = func_graph->NewCNode(inputs); 266 func_graph->set_return(cnode_return); 267 return func_graph; 268 } 269 270 std::vector<FuncGraphPtr> MakeNestedGraph() { 271 /* 272 *def f(x): 273 * def g(): 274 * return x 275 * return g 276 */ 277 FuncGraphPtr f = std::make_shared<FuncGraph>(); 278 FuncGraphPtr fg = std::make_shared<FuncGraph>(); 279 280 ParameterPtr x = f->add_parameter(); 281 282 std::vector<AnfNodePtr> inputs; 283 inputs.push_back(NewValueNode(fg)); 284 inputs.push_back(NewValueNode(prim::kPrimReturn)); 285 286 CNodePtr cnode_f = f->NewCNode(inputs); 287 f->set_return(cnode_f); 288 289 inputs.clear(); 290 inputs.push_back(NewValueNode(prim::kPrimReturn)); 291 inputs.push_back(x); 292 CNodePtr cnode_g = fg->NewCNode(inputs); 293 fg->set_return(cnode_g); 294 295 std::vector<FuncGraphPtr> result = {f, fg}; 296 return result; 297 } 298 299 std::vector<FuncGraphPtr> MakeNestedGraph2() { 300 /* build a closure func_graph */ 301 /* 302 *def foo(x, y): 303 * def bar(x1): 304 * return x1 + y 305 * return bar(x) 306 */ 307 FuncGraphPtr graph_foo = std::make_shared<FuncGraph>(); 308 ParameterPtr x = graph_foo->add_parameter(); 309 ParameterPtr y = graph_foo->add_parameter(); 310 311 std::vector<AnfNodePtr> inputs; 312 313 // build func_graph bar 314 FuncGraphPtr graph_bar = std::make_shared<FuncGraph>(); 315 ParameterPtr x1 = graph_bar->add_parameter(); 316 inputs.clear(); 317 inputs.push_back(NewValueNode(prim::kPrimScalarAdd)); 318 inputs.push_back(x1); 319 inputs.push_back(y); 320 CNodePtr cnode_add = graph_bar->NewCNode(inputs); 321 inputs.clear(); 322 inputs.push_back(NewValueNode(prim::kPrimReturn)); 323 inputs.push_back(cnode_add); 324 CNodePtr cnode_return = graph_bar->NewCNode(inputs); 325 graph_bar->set_return(cnode_return); 326 327 // build func_graph foo 328 inputs.clear(); 329 inputs.push_back(NewValueNode(graph_bar)); 330 inputs.push_back(x); 331 CNodePtr cnode_graph_bar = graph_foo->NewCNode(inputs); 332 333 inputs.clear(); 334 inputs.push_back(NewValueNode(prim::kPrimReturn)); 335 inputs.push_back(cnode_graph_bar); 336 cnode_return = graph_foo->NewCNode(inputs); 337 graph_foo->set_return(cnode_return); 338 339 std::vector<FuncGraphPtr> result = {graph_foo, graph_bar}; 340 return result; 341 } 342 343 // Add TestManager::CheckManager function to checkout the result 344 void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { 345 auto size = mng->func_graphs().size(); 346 347 ASSERT_EQ(size, mng->free_variables_total().size()); 348 } 349 350 TEST_F(TestManager, test_scalar_add_manual) { 351 auto prim_scalar_add = prim::kPrimScalarAdd; 352 FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); 353 auto mng = Manage(func_graph); 354 } 355 356 TEST_F(TestManager, test_scalar_replace) { 357 auto prim_scalar_add = prim::kPrimScalarAdd; 358 359 FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); 360 ParameterPtr x = func_graph->add_parameter(); 361 ParameterPtr y = func_graph->add_parameter(); 362 std::vector<AnfNodePtr> inputs; 363 inputs.push_back(NewValueNode(prim_scalar_add)); 364 inputs.push_back(x); 365 inputs.push_back(y); 366 CNodePtr cnode_add = func_graph->NewCNode(inputs); 367 inputs.clear(); 368 inputs.push_back(NewValueNode(prim::kPrimReturn)); 369 inputs.push_back(cnode_add); 370 CNodePtr cnode_return = func_graph->NewCNode(inputs); 371 func_graph->set_return(cnode_return); 372 373 auto mng = Manage(func_graph); 374 std::cout << "start " << x->ToString() << std::endl; 375 mng->Replace(cnode_add, x); 376 } 377 378 TEST_F(TestManager, test_nested_manual) { 379 auto graphs = MakeNestedGraph(); 380 auto f = graphs[0]; 381 auto g = graphs[1]; 382 383 auto mng = Manage(f); 384 385 ASSERT_EQ(6, mng->all_nodes().size()); 386 ASSERT_EQ(2, mng->func_graphs().size()); 387 ASSERT_EQ(4, mng->node_users().size()); 388 ASSERT_EQ(1, mng->roots().size()); 389 CheckAnalysisSize(mng); 390 391 ASSERT_EQ(2, f->nodes().size()); 392 ASSERT_EQ(1, g->nodes().size()); 393 394 auto &users = mng->node_users(); 395 for (auto &iter : users) { 396 ASSERT_EQ(1, iter.second.size()); 397 } 398 399 ASSERT_EQ(1, f->func_graphs_used().size()); 400 ASSERT_EQ(0, g->func_graphs_used().size()); 401 402 ASSERT_EQ(0, f->free_variables().size()); 403 ASSERT_EQ(1, g->free_variables().size()); 404 405 auto fv_total = mng->free_variables_total(); 406 ASSERT_EQ(0, fv_total[f].size()); 407 ASSERT_EQ(1, fv_total[g].size()); 408 409 ASSERT_EQ(0, f->func_graph_cnodes_index().size()); 410 ASSERT_EQ(1, g->func_graph_cnodes_index().size()); 411 } 412 413 TEST_F(TestManager, test_deep_nested2_manual) { 414 // create parser 415 FuncGraphPtr func_graph = getPyFun("test_custom"); 416 return; 417 418 // parse ast to func graph 419 FuncGraphPtr gfn = BasicClone(func_graph); 420 if (gfn == nullptr) { 421 return; 422 } 423 424 auto mng = Manage(gfn); 425 426 ASSERT_EQ(3, mng->func_graphs().size()); 427 ASSERT_EQ(1, mng->roots().size()); 428 ASSERT_EQ(4, gfn->nodes().size()); 429 ASSERT_EQ(20, mng->all_nodes().size()); 430 ASSERT_EQ(25, mng->node_users().size()); 431 CheckAnalysisSize(mng); 432 } 433 434 TEST_F(TestManager, test_deep_nested_manual) { 435 FuncGraphPtr f = std::make_shared<FuncGraph>(); 436 FuncGraphPtr fg = std::make_shared<FuncGraph>(); 437 FuncGraphPtr h = std::make_shared<FuncGraph>(); 438 439 ParameterPtr x = f->add_parameter(); 440 ParameterPtr y = f->add_parameter(); 441 ParameterPtr z = f->add_parameter(); 442 443 std::vector<AnfNodePtr> inputs; 444 inputs.push_back(NewValueNode(fg)); 445 inputs.push_back(x); 446 inputs.push_back(y); 447 CNodePtr cnode_1 = f->NewCNode(inputs); 448 449 inputs.clear(); 450 inputs.push_back(cnode_1); 451 inputs.push_back(NewValueNode(prim::kPrimReturn)); 452 CNodePtr cnode_0 = f->NewCNode(inputs); 453 f->set_return(cnode_0); 454 455 ParameterPtr x1 = fg->add_parameter(); 456 ParameterPtr y1 = fg->add_parameter(); 457 inputs.clear(); 458 inputs.push_back(NewValueNode(h)); 459 inputs.push_back(x1); 460 CNodePtr cnode_3 = fg->NewCNode(inputs); 461 462 inputs.clear(); 463 inputs.push_back(cnode_3); 464 inputs.push_back(NewValueNode(prim::kPrimReturn)); 465 CNodePtr cnode_2 = fg->NewCNode(inputs); 466 fg->set_return(cnode_2); 467 468 ParameterPtr x2 = h->add_parameter(); 469 470 inputs.clear(); 471 inputs.push_back(NewValueNode(prim::kPrimScalarAdd)); 472 inputs.push_back(x2); 473 inputs.push_back(y1); 474 CNodePtr cnode_6 = h->NewCNode(inputs); 475 476 inputs.clear(); 477 inputs.push_back(NewValueNode(prim::kPrimScalarAdd)); 478 inputs.push_back(z); 479 inputs.push_back(cnode_6); 480 CNodePtr cnode_5 = h->NewCNode(inputs); 481 482 inputs.clear(); 483 inputs.push_back(cnode_5); 484 inputs.push_back(NewValueNode(prim::kPrimReturn)); 485 CNodePtr cnode_4 = h->NewCNode(inputs); 486 h->set_return(cnode_4); 487 488 auto mng = Manage(f); 489 490 ASSERT_EQ(3, mng->func_graphs().size()); 491 ASSERT_EQ(1, mng->roots().size()); 492 ASSERT_EQ(20, mng->all_nodes().size()); 493 CheckAnalysisSize(mng); 494 } 495 496 TEST_F(TestManager, test_parent1_manual) { 497 FuncGraphPtr fg = std::make_shared<FuncGraph>(); 498 499 Parameter param(fg); 500 std::vector<AnfNodePtr> params; 501 CNodePtr app = std::make_shared<CNode>(params, fg); 502 fg->set_return(app); 503 fg->set_parameters(params); 504 505 std::shared_ptr<FuncGraphManager> manager = MakeManager(); 506 manager->AddFuncGraph(fg, true); 507 FuncGraphPtr p = fg->parent(); 508 assert(p == nullptr); 509 } 510 511 TEST_F(TestManager, test_parent_manual) { 512 auto prim_scalar_add = prim::kPrimScalarAdd; 513 FuncGraphPtr fg = MakeFuncGraph(prim_scalar_add); 514 515 std::shared_ptr<FuncGraphManager> manager = MakeManager(); 516 manager->AddFuncGraph(fg); 517 FuncGraphPtr p = fg->parent(); 518 assert(p == nullptr); 519 } 520 521 TEST_F(TestManager, test_flat) { 522 std::vector<std::shared_ptr<Stage>> stages; 523 std::vector<std::string> specs = {"nodes=X:x", "parents=", "fvs_direct="}; 524 std::map<std::string, int> size_list; 525 size_list["nodes"] = 2; 526 } 527 528 TEST_F(TestManager, test_nested) { 529 std::vector<std::shared_ptr<Stage>> stages; 530 std::vector<std::string> specs = {"nodes=X:x", "parent=g->X", "fvs_direct=g:x"}; 531 std::map<std::string, int> size_list; 532 return; 533 } 534 535 TEST_F(TestManager, test_calls) { 536 std::vector<std::shared_ptr<Stage>> stages; 537 std::vector<std::string> specs = {"parents=g->X; h->X", "children=X:g,h", "scopes=X:X,g,h; g:g; h:h", 538 "fvs_direct=h:a", "fvs_total=h:a; g:h"}; 539 std::map<std::string, int> size_list; 540 return; 541 } 542 543 TEST_F(TestManager, test_unused_param) { 544 std::vector<std::shared_ptr<Stage>> stages; 545 std::vector<std::string> specs = {"nodes=X:x,y"}; 546 std::map<std::string, int> size_list; 547 } 548 549 TEST_F(TestManager, test_cannot_replace_return) { 550 FuncGraphPtr fg = getPyFun("test_cannot_replace_return"); 551 ASSERT_NE(fg, nullptr); 552 553 auto mng = Manage(fg); 554 ASSERT_EQ(fg->manager(), mng); 555 556 ASSERT_NE(mng, nullptr); 557 ASSERT_GT(fg->parameters().size(), 0); 558 ASSERT_FALSE(mng->Replace(fg->get_return(), fg->parameters()[0])); 559 } 560 561 TEST_F(TestManager, test_weak_manager) { 562 FuncGraphPtr fg = getPyFun("ir_get_fn"); 563 564 auto mng1 = MakeManager({fg}, false); 565 ASSERT_EQ(fg->manager(), nullptr); 566 auto mng2 = MakeManager({fg}, true); 567 ASSERT_EQ(fg->manager(), mng2); 568 auto mng3 = MakeManager({fg}, false); 569 ASSERT_EQ(fg->manager(), mng2); 570 } 571 572 TEST_F(TestManager, test_drop_root) { 573 FuncGraphPtr fg = getPyFun("ir_get_fn"); 574 575 auto mng = Manage(fg); 576 const auto &fgs = mng->func_graphs(); 577 ASSERT_TRUE(fgs.contains(fg)); 578 FuncGraphSet s; 579 s.add(fg); 580 mng->MaybeDropFuncGraphs(s); 581 ASSERT_TRUE(fgs.contains(fg)); 582 } 583 584 TEST_F(TestManager, test_keep_roots) { 585 FuncGraphPtr fg1 = getPyFun("ir_get_fn"); 586 FuncGraphPtr fg2 = getPyFun("test_cannot_replace_return"); 587 588 auto mng = Manage(fg1); 589 ASSERT_EQ(mng->func_graphs().size(), (size_t)1); 590 ASSERT_TRUE(mng->func_graphs().contains(fg1)); 591 592 mng->AddFuncGraph(fg2); 593 ASSERT_EQ(mng->func_graphs().size(), 2); 594 ASSERT_TRUE(mng->func_graphs().contains(fg2)); 595 596 mng->KeepRoots(); 597 ASSERT_EQ(mng->func_graphs().size(), 1); 598 ASSERT_TRUE(mng->func_graphs().contains(fg1)); 599 600 mng->KeepRoots({fg2}); 601 ASSERT_EQ(mng->func_graphs().size(), 1); 602 ASSERT_TRUE(mng->func_graphs().contains(fg2)); 603 } 604 605 TEST_F(TestManager, test_keep_roots_recursion) { 606 return; 607 608 FuncGraphPtr fg = getPyFun("test_keep_roots_recursion"); 609 ASSERT_NE(fg, nullptr); 610 auto mng = Manage(fg); 611 parse::ResolveAll(mng); 612 613 ASSERT_NE(mng, nullptr); 614 ASSERT_EQ(mng->func_graphs().size(), 4); 615 616 ASSERT_GT(fg->parameters().size(), 0); 617 mng->Replace(fg->output(), fg->parameters()[0]); 618 ASSERT_EQ(mng->func_graphs().size(), 3); 619 620 mng->KeepRoots(); 621 ASSERT_EQ(mng->func_graphs().size(), 1); 622 } 623 624 TEST_F(TestManager, test_add_edge_replace) { 625 // fg(x, y, u): 626 // x1 = load(x, u) 627 // a = add(x1, y) 628 // u1 = update_state(u, x1); 629 // out = depend(a, u1) 630 // return out 631 FuncGraphPtr fg = std::make_shared<FuncGraph>(); 632 auto x = fg->add_parameter(); 633 auto y = fg->add_parameter(); 634 auto u = fg->add_parameter(); 635 auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u}); 636 auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y}); 637 auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1}); 638 auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1}); 639 fg->set_output(out); 640 641 // Create manager. 642 auto mgr = Manage(fg); 643 ASSERT_NE(mgr, nullptr); 644 645 // Before AddEdge. 646 // a = add(x1, y) 647 // u1 = update_state(u, x1); 648 // out = depend(a, u1) 649 auto a_users = mgr->node_users()[a]; 650 ASSERT_EQ(a_users.size(), 1); 651 652 mgr->AddEdge(u1, a); 653 654 // After AddEdge. 655 // a = add(x1, y) 656 // u1 = update_state(u, x1, a); 657 // out = depend(a, u1) 658 a_users = mgr->node_users()[a]; 659 ASSERT_EQ(a_users.size(), 2); 660 661 // Remove edge by replace update_state. 662 auto u2 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1}); 663 mgr->Replace(u1, u2); 664 665 // After replace update_state. 666 // a = add(x1, y) 667 // u2 = update_state(u, x1); 668 // out = depend(a, u2) 669 a_users = mgr->node_users()[a]; 670 ASSERT_EQ(a_users.size(), 1); 671 672 mgr->AddEdge(u2, a); 673 674 // After AddEdge to u2. 675 // a = add(x1, y) 676 // u2 = update_state(u, x1, a); 677 // out = depend(a, u2) 678 a_users = mgr->node_users()[a]; 679 ASSERT_EQ(a_users.size(), 2); 680 } 681 682 TEST_F(TestManager, test_add_edge_replace_new) { 683 // fg(x, y, u): 684 // x1 = load(x, u) 685 // a = add(x1, y) 686 // u1 = update_state(u, x1); 687 // out = depend(a, u1) 688 // return out 689 FuncGraphPtr fg = std::make_shared<FuncGraph>(); 690 auto x = fg->add_parameter(); 691 auto y = fg->add_parameter(); 692 auto u = fg->add_parameter(); 693 auto x1 = fg->NewCNode({NewValueNode(prim::kPrimLoad), x, u}); 694 auto a = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y}); 695 auto u1 = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u, x1}); 696 auto out = fg->NewCNode({NewValueNode(prim::kPrimDepend), a, u1}); 697 fg->set_output(out); 698 699 // Create manager. 700 auto mgr = Manage(fg); 701 ASSERT_NE(mgr, nullptr); 702 703 auto new_add = fg->NewCNode({NewValueNode(prim::kPrimAdd), x1, y}); 704 mgr->AddEdge(u1, new_add); 705 706 // x1 = load(x, u) 707 // a = add(x1, y) 708 // new_add = add(x1, y) 709 // u1 = update_state(u, x1, new_add); 710 // out = depend(a, u1) 711 // return out 712 ASSERT_EQ(mgr->node_users()[x1].size(), 3); 713 ASSERT_EQ(mgr->node_users()[y].size(), 2); 714 ASSERT_EQ(mgr->node_users()[new_add].size(), 1); 715 716 auto new_add1 = fg->NewCNode({NewValueNode(prim::kPrimAdd), y, y}); 717 mgr->Replace(new_add, new_add1); 718 719 // x1 = load(x, u) 720 // a = add(x1, y) 721 // new_add1 = add(y, y) 722 // u1 = update_state(u, x1, new_add1); 723 // out = depend(a, u1) 724 // return out 725 ASSERT_EQ(mgr->node_users()[x1].size(), 2); 726 ASSERT_EQ(mgr->node_users()[y].size(), 3); 727 ASSERT_EQ(mgr->node_users()[new_add].size(), 0); 728 ASSERT_EQ(mgr->node_users()[new_add1].size(), 1); 729 } 730 731 TEST_F(TestManager, test_set_edge) { 732 // fg(x, y, u): 733 // t = make_tuple(x, y) 734 // d = depend(t, u); 735 // get_item = tuple_get_item(d, 0) 736 // return get_item 737 FuncGraphPtr fg = std::make_shared<FuncGraph>(); 738 auto x = fg->add_parameter(); 739 auto y = fg->add_parameter(); 740 auto u = fg->add_parameter(); 741 auto t = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), x, y}); 742 auto d = fg->NewCNode({NewValueNode(prim::kPrimDepend), t, u}); 743 auto get_item = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), d, NewValueNode(0)}); 744 fg->set_output(get_item); 745 746 // Create manager. 747 auto mgr = Manage(fg); 748 ASSERT_NE(mgr, nullptr); 749 750 // Before SetEdge. 751 ASSERT_EQ(mgr->node_users()[t].size(), 1); 752 ASSERT_EQ(mgr->node_users()[d].size(), 1); 753 754 auto depend = get_item->input(1)->cast<CNodePtr>(); 755 mgr->SetEdge(get_item, 1, depend->input(1)); 756 757 // After SetEdge. 758 ASSERT_EQ(get_item->input(1), t); 759 ASSERT_EQ(depend->input(1), t); 760 ASSERT_EQ(mgr->node_users()[d].size(), 0); 761 ASSERT_EQ(mgr->node_users()[t].size(), 1); // depend removed. 762 ASSERT_EQ(mgr->node_users()[t].front().first, get_item); 763 } 764 765 } // namespace mindspore 766