1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 19 #include "common/common_test.h" 20 #include "common/py_func_graph_fetcher.h" 21 22 #include "ir/anf.h" 23 #include "ir/visitor.h" 24 #include "ir/func_graph_cloner.h" 25 #include "frontend/optimizer/opt.h" 26 #include "frontend/optimizer/anf_visitor.h" 27 #include "frontend/optimizer/irpass.h" 28 #include "frontend/optimizer/irpass/arithmetic_simplify.h" 29 30 #include "debug/draw.h" 31 #include "frontend/operator/ops.h" 32 #include "frontend/optimizer/cse.h" 33 34 namespace mindspore { 35 namespace opt { 36 class TestOptOpt : public UT::Common { 37 public: 38 TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {} 39 40 class IdempotentEliminater : public AnfVisitor { 41 public: 42 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 43 x_ = nullptr; 44 AnfVisitor::Match(P, {irpass::IsCNode})(node); 45 if (x_ == nullptr || node->func_graph() == nullptr) { 46 return nullptr; 47 } 48 49 return node->func_graph()->NewCNode({NewValueNode(P), x_}); 50 }; 51 52 void Visit(const CNodePtr &cnode) override { 53 if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) { 54 x_ = cnode->input(1); 55 } 56 } 57 58 private: 59 AnfNodePtr x_{nullptr}; 60 }; 61 62 class QctToP : public AnfVisitor { 63 public: 64 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 65 v_ = nullptr; 66 AnfVisitor::Match(Q, {irpass::IsVNode})(node); 67 if (v_ == nullptr || node->func_graph() == nullptr) { 68 return nullptr; 69 } 70 71 return node->func_graph()->NewCNode({NewValueNode(P), v_}); 72 }; 73 74 void Visit(const ValueNodePtr &vnode) override { v_ = vnode; } 75 76 private: 77 AnfNodePtr v_{nullptr}; 78 }; 79 80 void SetUp() { 81 elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd); 82 elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); 83 idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); 84 Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); 85 } 86 87 bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { 88 equiv_node.clear(); 89 equiv_graph.clear(); 90 91 FuncGraphPtr gbefore_clone = BasicClone(gbefore); 92 OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>()); 93 transform(gbefore_clone, optimizer); 94 95 return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node); 96 } 97 98 bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) { 99 SubstitutionList eq(opts); 100 return CheckTransform(before, after, eq); 101 } 102 103 public: 104 UT::PyFuncGraphFetcher getPyFun; 105 106 FuncGraphPairMapEquiv equiv_graph; 107 NodeMapEquiv equiv_node; 108 109 static const PrimitivePtr P; 110 static const PrimitivePtr Q; 111 static const PrimitivePtr R; 112 113 SubstitutionPtr elim_Z; 114 SubstitutionPtr elim_R; 115 SubstitutionPtr idempotent_P; 116 SubstitutionPtr Qct_to_P; 117 }; 118 119 const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P"); 120 const PrimitivePtr TestOptOpt::Q = std::make_shared<Primitive>("Q"); 121 const PrimitivePtr TestOptOpt::R = std::make_shared<Primitive>("R"); 122 123 TEST_F(TestOptOpt, TestCheckOptIsClone) { 124 FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1"); 125 126 ASSERT_TRUE(nullptr != before); 127 ASSERT_TRUE(CheckOpt(before, before)); 128 ASSERT_FALSE(CheckOpt(before, before, std::vector<SubstitutionPtr>({elim_Z}))); 129 } 130 131 TEST_F(TestOptOpt, Elim) { 132 FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1"); 133 FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after"); 134 135 ASSERT_TRUE(nullptr != before); 136 ASSERT_TRUE(nullptr != after); 137 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z}))); 138 } 139 140 TEST_F(TestOptOpt, ElimTwo) { 141 FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2"); 142 FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after"); 143 144 ASSERT_TRUE(nullptr != before); 145 ASSERT_TRUE(nullptr != after); 146 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z}))); 147 } 148 149 TEST_F(TestOptOpt, ElimR) { 150 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1"); 151 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after"); 152 153 ASSERT_TRUE(nullptr != before); 154 ASSERT_TRUE(nullptr != after); 155 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_R}))); 156 } 157 158 TEST_F(TestOptOpt, idempotent) { 159 FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2"); 160 FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1"); 161 FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after"); 162 163 ASSERT_TRUE(nullptr != before_2); 164 ASSERT_TRUE(nullptr != before_1); 165 ASSERT_TRUE(nullptr != after); 166 167 ASSERT_TRUE(CheckOpt(before_1, after, std::vector<SubstitutionPtr>({idempotent_P}))); 168 ASSERT_TRUE(CheckOpt(before_2, after, std::vector<SubstitutionPtr>({idempotent_P}))); 169 } 170 171 TEST_F(TestOptOpt, ConstantVariable) { 172 FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1"); 173 FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after"); 174 175 ASSERT_TRUE(nullptr != before); 176 ASSERT_TRUE(nullptr != after); 177 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({Qct_to_P}))); 178 } 179 180 TEST_F(TestOptOpt, CSE) { 181 // test a simple cse testcase test_f1 182 FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1"); 183 184 ASSERT_TRUE(nullptr != test_graph1); 185 186 // add func_graph the GraphManager 187 FuncGraphManagerPtr manager1 = Manage(test_graph1); 188 189 ASSERT_EQ(manager1->all_nodes().size(), 9); 190 191 auto cse = std::make_shared<CSE>(); 192 ASSERT_TRUE(cse != nullptr); 193 bool is_changed = cse->Cse(test_graph1, manager1); 194 195 ASSERT_TRUE(is_changed); 196 ASSERT_EQ(manager1->all_nodes().size(), 8); 197 198 // test a more complicated case test_f2 199 FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2"); 200 201 ASSERT_TRUE(nullptr != test_graph2); 202 203 FuncGraphManagerPtr manager2 = Manage(test_graph2); 204 ASSERT_EQ(manager2->all_nodes().size(), 16); 205 is_changed = cse->Cse(test_graph2, manager2); 206 ASSERT_TRUE(is_changed); 207 ASSERT_EQ(manager2->all_nodes().size(), 12); 208 } 209 210 } // namespace opt 211 } // namespace mindspore 212