1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <algorithm> 17 18 #include "common/common_test.h" 19 #include "common/py_func_graph_fetcher.h" 20 21 #include "ir/manager.h" 22 #include "utils/log_adapter.h" 23 #include "ir/func_graph_cloner.h" 24 #include "pipeline/jit/parse/parse.h" 25 #include "ir/graph_utils.h" 26 #include "debug/draw.h" 27 #include "base/core_ops.h" 28 29 namespace mindspore { 30 class TestCloner : public UT::Common { 31 public: 32 TestCloner() : getPyFun("gtest_input.ir.clone_test", true) { 33 one = NewValueNode(static_cast<int64_t>(1)); 34 two = NewValueNode(static_cast<int64_t>(2)); 35 three = NewValueNode(static_cast<int64_t>(3)); 36 } 37 38 FuncGraphPtr GraphForInline() { return nullptr; } 39 void SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, const std::vector<AnfNodePtr>& params, 40 FuncGraphPtr target); 41 42 public: 43 UT::PyFuncGraphFetcher getPyFun; 44 45 ValueNodePtr one; 46 ValueNodePtr two; 47 ValueNodePtr three; 48 }; 49 50 void TestCloner::SuccessfulInlining(const std::shared_ptr<Cloner> cl, FuncGraphPtr orig, 51 const std::vector<AnfNodePtr>& params, FuncGraphPtr target) { 52 auto g = (*cl)[orig]; 53 ASSERT_TRUE(g != target); 54 ASSERT_TRUE(g == orig); 55 56 auto new_root = (*cl)[orig->output()]; 57 ASSERT_TRUE(new_root != orig->output()); 58 59 AnfNodeSet orig_nodes = AnfNodeSet(DeepLinkedGraphSearch(orig->output())); 60 AnfNodeSet new_nodes = AnfNodeSet(DeepLinkedGraphSearch(new_root)); 61 62 for (auto& p : params) { 63 ASSERT_TRUE(new_nodes.contains(p)); 64 } 65 66 for (auto& node : orig_nodes) { 67 if (node->func_graph() == orig) { 68 ASSERT_TRUE((*cl)[node]); 69 } 70 } 71 ASSERT_TRUE(target->output() == three); 72 } 73 74 TEST_F(TestCloner, test_clone_simple) { 75 std::string py_code = "test_clone_simple"; 76 77 FuncGraphPtr g = getPyFun.CallAndParseRet(py_code); 78 ASSERT_TRUE(g != nullptr); 79 80 std::vector<FuncGraphPtr> gs = {g}; 81 Cloner cl(gs, true); 82 auto g2 = cl[g]; 83 84 AnfNodeSet d1 = AnfNodeSet(DeepScopedGraphSearch(g->get_return())); 85 AnfNodeSet d2 = AnfNodeSet(DeepScopedGraphSearch(g2->get_return())); 86 87 auto common = d1 & d2; 88 ASSERT_EQ((size_t)0, common.size()); 89 90 Cloner cl2(gs); 91 auto g3 = cl2[g]; 92 93 std::vector<Primitive> results = {Primitive(prim::kScalarAdd), Primitive(prim::kScalarMul), Primitive("Return")}; 94 AnfNodeSet d3 = AnfNodeSet(DeepScopedGraphSearch(g3->get_return())); 95 common = d1 & d3; 96 for (auto& x : common) { 97 ASSERT_TRUE(x->isa<ValueNode>()); 98 ASSERT_TRUE(find(results.begin(), results.end(), *x->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()) != 99 results.end()); 100 } 101 } 102 103 TEST_F(TestCloner, test_clone_closure) { 104 std::string py_code = "test_clone_closure"; 105 106 // parse ast to graph 107 FuncGraphPtr parsed_f = getPyFun(py_code); 108 109 FuncGraphIndex idx(parsed_f); 110 auto g = idx.GetFirstFuncGraph("j"); 111 112 std::vector<FuncGraphPtr> gs = {g}; 113 Cloner cl(gs, true); 114 115 auto g_clone = cl[g]; 116 FuncGraphIndex idx2(g_clone, DeepLinkedGraphSearch); 117 118 std::string name_list = "xy"; 119 for (auto name : name_list) { 120 ASSERT_EQ(idx.GetFirstNode(std::string(1, name)), idx2.GetFirstNode(std::string(1, name))); 121 } 122 123 ASSERT_FALSE(idx.GetFirstNode("z") == idx2.GetFirstNode("z")); 124 ASSERT_FALSE(idx.GetFirstFuncGraph("j") == idx2.GetFirstFuncGraph("j")); 125 } 126 127 TEST_F(TestCloner, test_clone_lifting) { 128 std::string py_code = "test_clone_closure"; 129 130 // parse ast to graph 131 FuncGraphPtr parsed_f = getPyFun(py_code); 132 133 auto g_lifting = LiftingClone(parsed_f); 134 135 FuncGraphIndex idx(g_lifting); 136 auto g = idx.GetFirstFuncGraph("j"); 137 138 auto params = g_lifting->parameters(); 139 auto child_params = g->parameters(); 140 ASSERT_TRUE(params.size() + 1 == child_params.size()); 141 } 142 143 TEST_F(TestCloner, test_clone_scoping) { 144 std::string py_code = "test_clone_scoping"; 145 146 // parse ast to graph 147 FuncGraphPtr g = getPyFun.CallAndParseRet(py_code); 148 149 std::vector<FuncGraphPtr> gs = {g}; 150 Cloner cl(gs, true); 151 152 auto g2 = cl[g]; 153 154 FuncGraphIndex idx1(g); 155 FuncGraphIndex idx2(g2); 156 157 std::string name_list = "fgi"; 158 for (auto name : name_list) { 159 auto result1 = idx1.GetFirstFuncGraph(std::string(1, name)); 160 auto result2 = idx2.GetFirstFuncGraph(std::string(1, name)); 161 ASSERT_FALSE(result1 == result2); 162 } 163 164 name_list = "h"; 165 for (auto name : name_list) { 166 ASSERT_TRUE(idx1.GetFirstFuncGraph(std::string(1, name)) == idx2.GetFirstFuncGraph(std::string(1, name))); 167 } 168 } 169 170 TEST_F(TestCloner, test_clone_total) { 171 std::string py_code = "test_clone_total"; 172 173 // parse ast to graph 174 getPyFun.SetDoResolve(); 175 FuncGraphPtr g = getPyFun.CallAndParseRet(py_code); 176 if (g == nullptr) { 177 return; 178 } 179 180 FuncGraphIndex idx0(g); 181 182 std::vector<FuncGraphPtr> gs = {g}; 183 Cloner cl1(gs, true, true, true); 184 auto g2 = cl1[g]; 185 FuncGraphIndex idx1(g2); 186 187 ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total_sub") == idx1.GetFirstFuncGraph("clone_total_sub")); 188 ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx1.GetFirstFuncGraph("clone_total")); 189 190 Cloner cl2(gs, true); 191 FuncGraphIndex idx2(cl2[g]); 192 193 ASSERT_FALSE(idx0.GetFirstFuncGraph("clone_total") == idx2.GetFirstFuncGraph("clone_total")); 194 ASSERT_TRUE(idx0.GetFirstFuncGraph("clone_total_sub") == idx2.GetFirstFuncGraph("clone_total_sub")); 195 } 196 197 } // namespace mindspore 198