1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <memory>
18
19 #include "common/common_test.h"
20 #include "common/py_func_graph_fetcher.h"
21
22 #include "ir/anf.h"
23 #include "ir/visitor.h"
24 #include "ir/func_graph_cloner.h"
25 #include "frontend/optimizer/opt.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 #include "frontend/optimizer/irpass.h"
28 #include "frontend/optimizer/irpass/arithmetic_simplify.h"
29
30 #include "debug/draw.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/cse.h"
33
34 namespace mindspore {
35 namespace opt {
36 class TestOptOpt : public UT::Common {
37 public:
TestOptOpt()38 TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {}
39
40 class IdempotentEliminater : public AnfVisitor {
41 public:
operator ()(const OptimizerPtr &,const AnfNodePtr & node)42 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
43 x_ = nullptr;
44 AnfVisitor::Match(P, {irpass::IsCNode})(node);
45 if (x_ == nullptr || node->func_graph() == nullptr) {
46 return nullptr;
47 }
48
49 return node->func_graph()->NewCNode({NewValueNode(P), x_});
50 };
51
Visit(const CNodePtr & cnode)52 void Visit(const CNodePtr &cnode) override {
53 if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) {
54 x_ = cnode->input(1);
55 }
56 }
57
58 private:
59 AnfNodePtr x_{nullptr};
60 };
61
62 class QctToP : public AnfVisitor {
63 public:
operator ()(const OptimizerPtr &,const AnfNodePtr & node)64 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
65 v_ = nullptr;
66 AnfVisitor::Match(Q, {irpass::IsVNode})(node);
67 if (v_ == nullptr || node->func_graph() == nullptr) {
68 return nullptr;
69 }
70
71 return node->func_graph()->NewCNode({NewValueNode(P), v_});
72 };
73
Visit(const ValueNodePtr & vnode)74 void Visit(const ValueNodePtr &vnode) override { v_ = vnode; }
75
76 private:
77 AnfNodePtr v_{nullptr};
78 };
79
SetUp()80 void SetUp() {
81 elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
82 elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
83 idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
84 Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
85 }
86
CheckTransform(FuncGraphPtr gbefore,FuncGraphPtr gafter,const SubstitutionList & transform)87 bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
88 equiv_node.clear();
89 equiv_graph.clear();
90
91 FuncGraphPtr gbefore_clone = BasicClone(gbefore);
92 OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
93 transform(gbefore_clone, optimizer);
94
95 return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
96 }
97
CheckOpt(FuncGraphPtr before,FuncGraphPtr after,std::vector<SubstitutionPtr> opts={})98 bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) {
99 SubstitutionList eq(opts);
100 return CheckTransform(before, after, eq);
101 }
102
103 public:
104 UT::PyFuncGraphFetcher getPyFun;
105
106 FuncGraphPairMapEquiv equiv_graph;
107 NodeMapEquiv equiv_node;
108
109 static const PrimitivePtr P;
110 static const PrimitivePtr Q;
111 static const PrimitivePtr R;
112
113 SubstitutionPtr elim_Z;
114 SubstitutionPtr elim_R;
115 SubstitutionPtr idempotent_P;
116 SubstitutionPtr Qct_to_P;
117 };
118
119 const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
120 const PrimitivePtr TestOptOpt::Q = std::make_shared<Primitive>("Q");
121 const PrimitivePtr TestOptOpt::R = std::make_shared<Primitive>("R");
122
TEST_F(TestOptOpt,TestCheckOptIsClone)123 TEST_F(TestOptOpt, TestCheckOptIsClone) {
124 FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
125
126 ASSERT_TRUE(nullptr != before);
127 ASSERT_TRUE(CheckOpt(before, before));
128 ASSERT_FALSE(CheckOpt(before, before, std::vector<SubstitutionPtr>({elim_Z})));
129 }
130
TEST_F(TestOptOpt,Elim)131 TEST_F(TestOptOpt, Elim) {
132 FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
133 FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
134
135 ASSERT_TRUE(nullptr != before);
136 ASSERT_TRUE(nullptr != after);
137 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
138 }
139
TEST_F(TestOptOpt,ElimTwo)140 TEST_F(TestOptOpt, ElimTwo) {
141 FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2");
142 FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
143
144 ASSERT_TRUE(nullptr != before);
145 ASSERT_TRUE(nullptr != after);
146 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
147 }
148
TEST_F(TestOptOpt,ElimR)149 TEST_F(TestOptOpt, ElimR) {
150 FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1");
151 FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after");
152
153 ASSERT_TRUE(nullptr != before);
154 ASSERT_TRUE(nullptr != after);
155 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_R})));
156 }
157
TEST_F(TestOptOpt,idempotent)158 TEST_F(TestOptOpt, idempotent) {
159 FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2");
160 FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1");
161 FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after");
162
163 ASSERT_TRUE(nullptr != before_2);
164 ASSERT_TRUE(nullptr != before_1);
165 ASSERT_TRUE(nullptr != after);
166
167 ASSERT_TRUE(CheckOpt(before_1, after, std::vector<SubstitutionPtr>({idempotent_P})));
168 ASSERT_TRUE(CheckOpt(before_2, after, std::vector<SubstitutionPtr>({idempotent_P})));
169 }
170
TEST_F(TestOptOpt,ConstantVariable)171 TEST_F(TestOptOpt, ConstantVariable) {
172 FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1");
173 FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after");
174
175 ASSERT_TRUE(nullptr != before);
176 ASSERT_TRUE(nullptr != after);
177 ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({Qct_to_P})));
178 }
179
TEST_F(TestOptOpt,CSE)180 TEST_F(TestOptOpt, CSE) {
181 // test a simple cse testcase test_f1
182 FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1");
183
184 ASSERT_TRUE(nullptr != test_graph1);
185
186 // add func_graph the GraphManager
187 FuncGraphManagerPtr manager1 = Manage(test_graph1);
188
189 ASSERT_EQ(manager1->all_nodes().size(), 9);
190
191 auto cse = std::make_shared<CSE>();
192 ASSERT_TRUE(cse != nullptr);
193 bool is_changed = cse->Cse(test_graph1, manager1);
194
195 ASSERT_TRUE(is_changed);
196 ASSERT_EQ(manager1->all_nodes().size(), 8);
197
198 // test a more complicated case test_f2
199 FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2");
200
201 ASSERT_TRUE(nullptr != test_graph2);
202
203 FuncGraphManagerPtr manager2 = Manage(test_graph2);
204 ASSERT_EQ(manager2->all_nodes().size(), 16);
205 is_changed = cse->Cse(test_graph2, manager2);
206 ASSERT_TRUE(is_changed);
207 ASSERT_EQ(manager2->all_nodes().size(), 12);
208 }
209
210 } // namespace opt
211 } // namespace mindspore
212