• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 
19 #include "common/common_test.h"
20 #include "common/py_func_graph_fetcher.h"
21 
22 #include "ir/anf.h"
23 #include "ir/visitor.h"
24 #include "ir/func_graph_cloner.h"
25 #include "frontend/optimizer/opt.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 #include "frontend/optimizer/irpass.h"
28 #include "frontend/optimizer/irpass/arithmetic_simplify.h"
29 
30 #include "debug/draw.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/cse.h"
33 
34 namespace mindspore {
35 namespace opt {
36 class TestOptOpt : public UT::Common {
37  public:
TestOptOpt()38   TestOptOpt() : getPyFun("gtest_input.optimizer.opt_test", true) {}
39 
40   class IdempotentEliminater : public AnfVisitor {
41    public:
operator ()(const OptimizerPtr &,const AnfNodePtr & node)42     AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
43       x_ = nullptr;
44       AnfVisitor::Match(P, {irpass::IsCNode})(node);
45       if (x_ == nullptr || node->func_graph() == nullptr) {
46         return nullptr;
47       }
48 
49       return node->func_graph()->NewCNode({NewValueNode(P), x_});
50     };
51 
Visit(const CNodePtr & cnode)52     void Visit(const CNodePtr &cnode) override {
53       if (IsPrimitiveCNode(cnode, P) && cnode->inputs().size() == 2) {
54         x_ = cnode->input(1);
55       }
56     }
57 
58    private:
59     AnfNodePtr x_{nullptr};
60   };
61 
62   class QctToP : public AnfVisitor {
63    public:
operator ()(const OptimizerPtr &,const AnfNodePtr & node)64     AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
65       v_ = nullptr;
66       AnfVisitor::Match(Q, {irpass::IsVNode})(node);
67       if (v_ == nullptr || node->func_graph() == nullptr) {
68         return nullptr;
69       }
70 
71       return node->func_graph()->NewCNode({NewValueNode(P), v_});
72     };
73 
Visit(const ValueNodePtr & vnode)74     void Visit(const ValueNodePtr &vnode) override { v_ = vnode; }
75 
76    private:
77     AnfNodePtr v_{nullptr};
78   };
79 
SetUp()80   void SetUp() {
81     elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
82     elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
83     idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
84     Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
85   }
86 
CheckTransform(FuncGraphPtr gbefore,FuncGraphPtr gafter,const SubstitutionList & transform)87   bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
88     equiv_node.clear();
89     equiv_graph.clear();
90 
91     FuncGraphPtr gbefore_clone = BasicClone(gbefore);
92     OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
93     transform(gbefore_clone, optimizer);
94 
95     return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
96   }
97 
CheckOpt(FuncGraphPtr before,FuncGraphPtr after,std::vector<SubstitutionPtr> opts={})98   bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) {
99     SubstitutionList eq(opts);
100     return CheckTransform(before, after, eq);
101   }
102 
103  public:
104   UT::PyFuncGraphFetcher getPyFun;
105 
106   FuncGraphPairMapEquiv equiv_graph;
107   NodeMapEquiv equiv_node;
108 
109   static const PrimitivePtr P;
110   static const PrimitivePtr Q;
111   static const PrimitivePtr R;
112 
113   SubstitutionPtr elim_Z;
114   SubstitutionPtr elim_R;
115   SubstitutionPtr idempotent_P;
116   SubstitutionPtr Qct_to_P;
117 };
118 
119 const PrimitivePtr TestOptOpt::P = std::make_shared<Primitive>("P");
120 const PrimitivePtr TestOptOpt::Q = std::make_shared<Primitive>("Q");
121 const PrimitivePtr TestOptOpt::R = std::make_shared<Primitive>("R");
122 
TEST_F(TestOptOpt,TestCheckOptIsClone)123 TEST_F(TestOptOpt, TestCheckOptIsClone) {
124   FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
125 
126   ASSERT_TRUE(nullptr != before);
127   ASSERT_TRUE(CheckOpt(before, before));
128   ASSERT_FALSE(CheckOpt(before, before, std::vector<SubstitutionPtr>({elim_Z})));
129 }
130 
TEST_F(TestOptOpt,Elim)131 TEST_F(TestOptOpt, Elim) {
132   FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_1");
133   FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
134 
135   ASSERT_TRUE(nullptr != before);
136   ASSERT_TRUE(nullptr != after);
137   ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
138 }
139 
TEST_F(TestOptOpt,ElimTwo)140 TEST_F(TestOptOpt, ElimTwo) {
141   FuncGraphPtr before = getPyFun.CallAndParseRet("test_add_zero", "before_2");
142   FuncGraphPtr after = getPyFun.CallAndParseRet("test_add_zero", "after");
143 
144   ASSERT_TRUE(nullptr != before);
145   ASSERT_TRUE(nullptr != after);
146   ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_Z})));
147 }
148 
TEST_F(TestOptOpt,ElimR)149 TEST_F(TestOptOpt, ElimR) {
150   FuncGraphPtr before = getPyFun.CallAndParseRet("test_elimR", "before_1");
151   FuncGraphPtr after = getPyFun.CallAndParseRet("test_elimR", "after");
152 
153   ASSERT_TRUE(nullptr != before);
154   ASSERT_TRUE(nullptr != after);
155   ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({elim_R})));
156 }
157 
TEST_F(TestOptOpt,idempotent)158 TEST_F(TestOptOpt, idempotent) {
159   FuncGraphPtr before_2 = getPyFun.CallAndParseRet("test_idempotent", "before_2");
160   FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_idempotent", "before_1");
161   FuncGraphPtr after = getPyFun.CallAndParseRet("test_idempotent", "after");
162 
163   ASSERT_TRUE(nullptr != before_2);
164   ASSERT_TRUE(nullptr != before_1);
165   ASSERT_TRUE(nullptr != after);
166 
167   ASSERT_TRUE(CheckOpt(before_1, after, std::vector<SubstitutionPtr>({idempotent_P})));
168   ASSERT_TRUE(CheckOpt(before_2, after, std::vector<SubstitutionPtr>({idempotent_P})));
169 }
170 
TEST_F(TestOptOpt,ConstantVariable)171 TEST_F(TestOptOpt, ConstantVariable) {
172   FuncGraphPtr before = getPyFun.CallAndParseRet("test_constant_variable", "before_1");
173   FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_variable", "after");
174 
175   ASSERT_TRUE(nullptr != before);
176   ASSERT_TRUE(nullptr != after);
177   ASSERT_TRUE(CheckOpt(before, after, std::vector<SubstitutionPtr>({Qct_to_P})));
178 }
179 
TEST_F(TestOptOpt,CSE)180 TEST_F(TestOptOpt, CSE) {
181   // test a simple cse testcase test_f1
182   FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_cse", "test_f1");
183 
184   ASSERT_TRUE(nullptr != test_graph1);
185 
186   // add func_graph the GraphManager
187   FuncGraphManagerPtr manager1 = Manage(test_graph1);
188 
189   ASSERT_EQ(manager1->all_nodes().size(), 9);
190 
191   auto cse = std::make_shared<CSE>();
192   ASSERT_TRUE(cse != nullptr);
193   bool is_changed = cse->Cse(test_graph1, manager1);
194 
195   ASSERT_TRUE(is_changed);
196   ASSERT_EQ(manager1->all_nodes().size(), 8);
197 
198   // test a more complicated case test_f2
199   FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_cse", "test_f2");
200 
201   ASSERT_TRUE(nullptr != test_graph2);
202 
203   FuncGraphManagerPtr manager2 = Manage(test_graph2);
204   ASSERT_EQ(manager2->all_nodes().size(), 16);
205   is_changed = cse->Cse(test_graph2, manager2);
206   ASSERT_TRUE(is_changed);
207   ASSERT_EQ(manager2->all_nodes().size(), 12);
208 }
209 
210 }  // namespace opt
211 }  // namespace mindspore
212