• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pattern_to_pattern_pass_utils.h"
18 #include "mindspore/core/ops/math_ops.h"
19 #include "include/backend/optimizer/node_pass.h"
20 
21 namespace mindspore {
22 namespace opt {
23 namespace {
24 const auto kZero = 0;
25 const auto kOne = 1;
26 const auto kTwo = 2;
27 const auto kThree = 3;
28 
29 const auto kA = "a";
30 const auto kB = "b";
31 const auto kC = "c";
32 const auto kD = "d";
33 const auto kE = "e";
34 const auto kAAddB = "a_add_b";
35 const auto kCAddD = "c_add_d";
36 const auto kMul = "mul";
37 const auto kAdd = "add";
38 
39 class TestFastMul0 : public PatternToPatternPass {
40   // a*b + a*c -> a*(b+c)
41  public:
TestFastMul0()42   explicit TestFastMul0() : PatternToPatternPass("test_fast_mul0") {}
43   ~TestFastMul0() override = default;
44 
DefineSrcPattern(SrcPattern * src_pattern)45   void DefineSrcPattern(SrcPattern *src_pattern) override {
46     (*src_pattern)
47       .AddVar("a")
48       .AddVar("b")
49       .AddVar("c")
50       .AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "a", "b"})
51       .AddCNode("ac", {std::make_shared<Primitive>(kMulOpName), "a", "c"})
52       .AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "ab", "ac"});
53   }
DefineDstPattern(DstPattern * dst_pattern)54   void DefineDstPattern(DstPattern *dst_pattern) override {
55     (*dst_pattern)
56       .AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
57       .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"});
58   }
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &) const59   bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
60 };
61 
62 class TestFastMul1 : public PatternToPatternPass {
63   // a*b + c*d -> a*c
64  public:
TestFastMul1()65   explicit TestFastMul1() : PatternToPatternPass("test_fast_mul1") {}
66   ~TestFastMul1() override = default;
67 
DefineSrcPattern(SrcPattern * src_pattern)68   void DefineSrcPattern(SrcPattern *src_pattern) override {
69     (*src_pattern)
70       .AddVar("a")
71       .AddVar("b")
72       .AddVar("c")
73       .AddVar("d")
74       .AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "a", "b"})
75       .AddCNode("cd", {std::make_shared<Primitive>(kMulOpName), "c", "d"})
76       .AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "ab", "cd"});
77   }
DefineDstPattern(DstPattern * dst_pattern)78   void DefineDstPattern(DstPattern *dst_pattern) override {
79     (*dst_pattern).AddCNode("ad", {std::make_shared<Primitive>(kMulOpName), "a", "d"});
80   }
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &) const81   bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
82 };
83 
84 class TestFastMul2 : public PatternToPatternPass {
85   // a*b -> b*a
86  public:
TestFastMul2()87   explicit TestFastMul2() : PatternToPatternPass("test_fast_mul2") {}
88   ~TestFastMul2() override = default;
89 
DefineSrcPattern(SrcPattern * src_pattern)90   void DefineSrcPattern(SrcPattern *src_pattern) override {
91     (*src_pattern).AddSeqVar("Sv").AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "Sv"});
92   }
DefineDstPattern(DstPattern * dst_pattern)93   void DefineDstPattern(DstPattern *dst_pattern) override {
94     auto ba = Unpacking("Sv");
95     auto ab = Unpacking("Sv");
96     ba[0] = ab[1];
97     ba[1] = ab[0];
98     (*dst_pattern).AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), ba});
99   }
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &) const100   bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
101 };
102 }  // namespace
103 
104 class TestFastPatternToPatternPass : public UT::Common {
105  public:
TestFastPatternToPatternPass()106   TestFastPatternToPatternPass() : fg_(std::make_shared<FuncGraph>()){};
107 
108  public:
109   FuncGraphPtr fg_;
110 };
111 
112 /// Feature: Fast PatternToPattern Pass
113 /// Description: Fast PatternToPattern Pass rewrite graph
114 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul0)115 TEST_F(TestFastPatternToPatternPass, Mul0) {
116   // a*b + a*c -> a*(b+c)
117   // init
118   auto check = CheckPattern();
119   auto pass = TestFastMul0();
120 
121   // build func graph
122   auto a = std::make_shared<AnfNode>(fg_);
123   auto b = std::make_shared<AnfNode>(fg_);
124   auto c = std::make_shared<AnfNode>(fg_);
125   AnfNodePtr ab =
126     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b}, fg_);
127   AnfNodePtr ac =
128     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, c}, fg_);
129   AnfNodePtr add = std::make_shared<CNode>(
130     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
131 
132   fg_->set_output(add);
133   auto manager = MakeManager({fg_});
134   if (manager) {
135     manager->AddFuncGraph(fg_);
136     fg_->set_manager(manager);
137   }
138   if (!fg_->has_user_data<FuncGraphPassIndex>()) {
139     fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
140   }
141   auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
142   GenIndex(fg_, func_graph_index);
143 
144   ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
145   ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1);
146   ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1);
147   ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
148   ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
149   ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2);
150 
151   ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
152   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
153   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
154 
155   auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
156   auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
157 
158   ASSERT_TRUE(add_set.size() == 1);
159   ASSERT_TRUE(mul_set.size() == 2);
160   ASSERT_TRUE(add_set.find(add) != add_set.end());
161   ASSERT_TRUE(mul_set.find(ab) != mul_set.end());
162   ASSERT_TRUE(mul_set.find(ac) != mul_set.end());
163 
164   auto new_node = pass.Run(fg_, add);
165   ASSERT_NE(new_node, nullptr);
166   (void)manager->Replace(add, new_node);
167   pass.AfterProcess(add, new_node, fg_, func_graph_index);
168 
169   ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
170   ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0);
171   ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0);
172   ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
173   ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
174   ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
175   ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1);
176   ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1);
177 
178   ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
179   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
180   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
181 
182   auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
183   auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
184 
185   ASSERT_TRUE(add_set_2.size() == 1);
186   ASSERT_TRUE(mul_set_2.size() == 1);
187   ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end());
188   ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end());
189 
190   // build pattern
191   check.src_pattern_.AddVar("a")
192     .AddVar("b")
193     .AddVar("c")
194     .AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
195     .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"});
196 
197   // pattern engine
198   ASSERT_TRUE(check.build_pattern_map(new_node));
199 
200   // check
201   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
202   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b));
203   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c));
204   ASSERT_EQ(check.m_->Get("bc")->cast<CNodePtr>()->size(), 3);
205   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(0),
206                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
207   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(1), b));
208   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(2), c));
209   ASSERT_EQ(check.m_->Get("mul")->cast<CNodePtr>()->size(), 3);
210   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(0),
211                             NewValueNode(std::make_shared<Primitive>(kMulOpName))));
212   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(1), a));
213   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(2), check.m_->Get("bc")));
214 }
215 
216 /// Feature: Fast PatternToPattern Pass
217 /// Description: Fast PatternToPattern Pass rewrite graph
218 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul0NotRoot)219 TEST_F(TestFastPatternToPatternPass, Mul0NotRoot) {
220   // (a*b + a*c) + d -> a*(b+c) + d
221   // init
222   auto check = CheckPattern();
223   auto pass = TestFastMul0();
224 
225   // build func graph
226   auto a = std::make_shared<AnfNode>(fg_);
227   auto b = std::make_shared<AnfNode>(fg_);
228   auto c = std::make_shared<AnfNode>(fg_);
229   auto d = std::make_shared<AnfNode>(fg_);
230   AnfNodePtr ab =
231     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b}, fg_);
232   AnfNodePtr ac =
233     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, c}, fg_);
234   AnfNodePtr add = std::make_shared<CNode>(
235     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
236   AnfNodePtr add1 = std::make_shared<CNode>(
237     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), add, d}, fg_);
238 
239   fg_->set_output(add1);
240   auto manager = MakeManager({fg_});
241   if (manager) {
242     manager->AddFuncGraph(fg_);
243     fg_->set_manager(manager);
244   }
245   if (!fg_->has_user_data<FuncGraphPassIndex>()) {
246     fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
247   }
248   auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
249   GenIndex(fg_, func_graph_index);
250 
251   ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
252   ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
253   ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1);
254   ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1);
255   ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
256   ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
257   ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
258   ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2);
259 
260   ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
261   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
262   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
263 
264   auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
265   auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
266 
267   ASSERT_TRUE(add_set.size() == 2);
268   ASSERT_TRUE(mul_set.size() == 2);
269   ASSERT_TRUE(add_set.find(add1) != add_set.end());
270   ASSERT_TRUE(add_set.find(add) != add_set.end());
271   ASSERT_TRUE(mul_set.find(ab) != mul_set.end());
272   ASSERT_TRUE(mul_set.find(ac) != mul_set.end());
273 
274   auto new_node = pass.Run(fg_, add);
275   ASSERT_NE(new_node, nullptr);
276   (void)manager->Replace(add, new_node);
277   pass.AfterProcess(add, new_node, fg_, func_graph_index);
278 
279   ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
280   ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0);
281   ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0);
282   ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
283   ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
284   ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
285   ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
286   ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
287   ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1);
288   ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1);
289 
290   ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
291   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
292   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
293 
294   auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
295   auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
296 
297   ASSERT_TRUE(add_set_2.size() == 2);
298   ASSERT_TRUE(mul_set_2.size() == 1);
299   ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end());
300   ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end());
301   ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end());
302 
303   // build pattern
304   check.src_pattern_.AddVar("a")
305     .AddVar("b")
306     .AddVar("c")
307     .AddVar("d")
308     .AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
309     .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"})
310     .AddCNode("add1", {std::make_shared<Primitive>(kAddOpName), "mul", "d"});
311 
312   // pattern engine
313   ASSERT_TRUE(check.build_pattern_map(add1));
314 
315   // check
316   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
317   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b));
318   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c));
319   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d));
320 
321   ASSERT_EQ(check.m_->Get("bc")->cast<CNodePtr>()->size(), 3);
322   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(0),
323                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
324   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(1), b));
325   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(2), c));
326 
327   ASSERT_EQ(check.m_->Get("mul")->cast<CNodePtr>()->size(), 3);
328   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(0),
329                             NewValueNode(std::make_shared<Primitive>(kMulOpName))));
330   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(1), a));
331   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(2), check.m_->Get("bc")));
332 
333   ASSERT_EQ(check.m_->Get("add1")->cast<CNodePtr>()->size(), 3);
334   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(0),
335                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
336   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(1), check.m_->Get("mul")));
337   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(2), d));
338 }
339 
340 /// Feature: Fast PatternToPattern Pass
341 /// Description: Fast PatternToPattern Pass rewrite graph
342 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul1)343 TEST_F(TestFastPatternToPatternPass, Mul1) {
344   // (a * (b1 + d) + (c1 * c2) * d) + e -> (a + d) + e
345   // init
346   auto check = CheckPattern();
347   auto pass = TestFastMul1();
348 
349   // build func graph
350   auto a = std::make_shared<AnfNode>(fg_);
351   auto b = std::make_shared<AnfNode>(fg_);
352   auto c1 = std::make_shared<AnfNode>(fg_);
353   auto c2 = std::make_shared<AnfNode>(fg_);
354   auto d = std::make_shared<AnfNode>(fg_);
355   auto e = std::make_shared<AnfNode>(fg_);
356 
357   AnfNodePtr b_add_d =
358     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), b, d}, fg_);
359   AnfNodePtr c1_mul_c2 = std::make_shared<CNode>(
360     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), c1, c2}, fg_);
361   AnfNodePtr a_mul = std::make_shared<CNode>(
362     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b_add_d}, fg_);
363   AnfNodePtr d_mul = std::make_shared<CNode>(
364     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), c1_mul_c2, d}, fg_);
365   AnfNodePtr add = std::make_shared<CNode>(
366     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), a_mul, d_mul}, fg_);
367   AnfNodePtr add1 = std::make_shared<CNode>(
368     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), add, e}, fg_);
369 
370   fg_->set_output(add1);
371   auto manager = MakeManager({fg_});
372   if (manager) {
373     manager->AddFuncGraph(fg_);
374     fg_->set_manager(manager);
375   }
376   if (!fg_->has_user_data<FuncGraphPassIndex>()) {
377     fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
378   }
379   auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
380   GenIndex(fg_, func_graph_index);
381 
382   ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 1);
383   ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 1);
384   ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 1);
385   ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 1);
386   ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
387   ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
388 
389   ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
390   ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
391   ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 1);
392   ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 1);
393   ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 2);
394   ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1);
395 
396   ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
397   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
398   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
399 
400   auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
401   auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
402 
403   ASSERT_TRUE(add_set.size() == 3);
404   ASSERT_TRUE(mul_set.size() == 3);
405   ASSERT_TRUE(add_set.find(add1) != add_set.end());
406   ASSERT_TRUE(add_set.find(add) != add_set.end());
407   ASSERT_TRUE(add_set.find(b_add_d) != add_set.end());
408   ASSERT_TRUE(mul_set.find(a_mul) != mul_set.end());
409   ASSERT_TRUE(mul_set.find(d_mul) != mul_set.end());
410   ASSERT_TRUE(mul_set.find(c1_mul_c2) != mul_set.end());
411 
412   auto new_node = pass.Run(fg_, add);
413   ASSERT_NE(new_node, nullptr);
414   (void)manager->Replace(add, new_node);
415   pass.AfterProcess(add, new_node, fg_, func_graph_index);
416 
417   ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 0);
418   ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 0);
419   ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 0);
420   ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 0);
421   ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
422   ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
423   ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("ad")) == 1);
424 
425   ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
426   ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 0);
427   ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 0);
428   ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 0);
429   ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
430   ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1);
431 
432   ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
433   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
434   ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
435 
436   auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
437   auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
438 
439   ASSERT_TRUE(add_set_2.size() == 1);
440   ASSERT_TRUE(mul_set_2.size() == 1);
441   ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end());
442   ASSERT_TRUE(mul_set_2.find(pass.m_->Get("ad")) != mul_set_2.end());
443 
444   // build pattern
445   check.src_pattern_.AddVar("a")
446     .AddVar("d")
447     .AddVar("e")
448     .AddCNode("ad", {std::make_shared<Primitive>(kMulOpName), "a", "d"})
449     .AddCNode("add1", {std::make_shared<Primitive>(kAddOpName), "ad", "e"});
450 
451   // pattern engine
452   ASSERT_TRUE(check.build_pattern_map(add1));
453 
454   // check
455   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
456   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d));
457   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("e"), e));
458 
459   ASSERT_EQ(check.m_->Get("ad")->cast<CNodePtr>()->size(), 3);
460   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(0),
461                             NewValueNode(std::make_shared<Primitive>(kMulOpName))));
462   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(1), a));
463   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(2), d));
464 
465   ASSERT_EQ(check.m_->Get("add1")->cast<CNodePtr>()->size(), 3);
466   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(0),
467                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
468   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(1), check.m_->Get("ad")));
469   ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(2), e));
470 }
471 
472 namespace {
Check0(const FuncGraphIndexPtr & fg,const std::map<std::string,AnfNodePtr> & node_map)473 void Check0(const FuncGraphIndexPtr &fg, const std::map<std::string, AnfNodePtr> &node_map) {
474   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne);
475   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne);
476   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kOne);
477   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne);
478 
479   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne);
480   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne);
481   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne);
482   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne);
483   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne);
484 
485   ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo);
486   ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end());
487   ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end());
488 
489   auto &add_set = fg->name_to_cnode_[kAddOpName];
490   auto &mul_set = fg->name_to_cnode_[kMulOpName];
491 
492   ASSERT_TRUE(add_set.size() == kThree);
493   ASSERT_TRUE(mul_set.size() == kOne);
494   ASSERT_TRUE(add_set.find(node_map.at(kAdd)) != add_set.end());
495   ASSERT_TRUE(add_set.find(node_map.at(kAAddB)) != add_set.end());
496   ASSERT_TRUE(add_set.find(node_map.at(kCAddD)) != add_set.end());
497   ASSERT_TRUE(mul_set.find(node_map.at(kMul)) != mul_set.end());
498 }
Check1(const TestFastMul2 & pass,const FuncGraphIndexPtr & fg,const std::map<std::string,AnfNodePtr> & node_map)499 void Check1(const TestFastMul2 &pass, const FuncGraphIndexPtr &fg, const std::map<std::string, AnfNodePtr> &node_map) {
500   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne);
501   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne);
502   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kZero);
503   ASSERT_TRUE(fg->node_degree_.at(pass.m_->Get(kMul)) == kOne);
504   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne);
505 
506   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne);
507   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne);
508   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne);
509   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne);
510   ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne);
511 
512   ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo);
513   ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end());
514   ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end());
515 
516   auto &add_set_2 = fg->name_to_cnode_[kAddOpName];
517   auto &mul_set_2 = fg->name_to_cnode_[kMulOpName];
518 
519   ASSERT_TRUE(add_set_2.size() == kThree);
520   ASSERT_TRUE(mul_set_2.size() == kOne);
521   ASSERT_TRUE(add_set_2.find(node_map.at(kAAddB)) != add_set_2.end());
522   ASSERT_TRUE(add_set_2.find(node_map.at(kCAddD)) != add_set_2.end());
523   ASSERT_TRUE(mul_set_2.find(pass.m_->Get(kMul)) != mul_set_2.end());
524 }
525 
Check2(const CheckPattern & check,const std::map<std::string,AnfNodePtr> & node_map)526 void Check2(const CheckPattern &check, const std::map<std::string, AnfNodePtr> &node_map) {
527   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kA), node_map.at(kA)));
528   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kB), node_map.at(kB)));
529   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kC), node_map.at(kC)));
530   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kD), node_map.at(kD)));
531   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kE), node_map.at(kE)));
532 
533   ASSERT_EQ(check.m_->Get(kAAddB)->cast<CNodePtr>()->size(), kThree);
534   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kZero),
535                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
536   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kOne), node_map.at(kA)));
537   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kTwo), node_map.at(kB)));
538 
539   ASSERT_EQ(check.m_->Get(kCAddD)->cast<CNodePtr>()->size(), kThree);
540   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kZero),
541                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
542   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kOne), node_map.at(kC)));
543   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kTwo), node_map.at(kD)));
544 
545   ASSERT_EQ(check.m_->Get(kMul)->cast<CNodePtr>()->size(), kThree);
546   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kZero),
547                             NewValueNode(std::make_shared<Primitive>(kMulOpName))));
548   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kOne), node_map.at(kCAddD)));
549   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kTwo), node_map.at(kAAddB)));
550 
551   ASSERT_EQ(check.m_->Get(kAdd)->cast<CNodePtr>()->size(), kThree);
552   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kZero),
553                             NewValueNode(std::make_shared<Primitive>(kAddOpName))));
554   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kOne), check.m_->Get(kMul)));
555   ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kTwo), node_map.at(kE)));
556 }
557 }  // namespace
558 
559 /// Feature: Fast PatternToPattern Pass
560 /// Description: Fast PatternToPattern Pass rewrite graph
561 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul2)562 TEST_F(TestFastPatternToPatternPass, Mul2) {
563   // ((a + b) * (c + d)) + e -> ((c + d) * (a + b)) + e
564   // init
565   auto check = CheckPattern();
566   auto pass = TestFastMul2();
567 
568   // build func graph
569   auto a = std::make_shared<AnfNode>(fg_);
570   auto b = std::make_shared<AnfNode>(fg_);
571   auto c = std::make_shared<AnfNode>(fg_);
572   auto d = std::make_shared<AnfNode>(fg_);
573   auto e = std::make_shared<AnfNode>(fg_);
574 
575   AnfNodePtr a_add_b =
576     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), a, b}, fg_);
577   AnfNodePtr c_add_d =
578     std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), c, d}, fg_);
579   AnfNodePtr mul = std::make_shared<CNode>(
580     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a_add_b, c_add_d}, fg_);
581   AnfNodePtr add = std::make_shared<CNode>(
582     std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), mul, e}, fg_);
583 
584   std::map<std::string, AnfNodePtr> node_map;
585   node_map.emplace("a", a);
586   node_map.emplace("b", b);
587   node_map.emplace("c", c);
588   node_map.emplace("d", d);
589   node_map.emplace("e", e);
590   node_map.emplace("a_add_b", a_add_b);
591   node_map.emplace("c_add_d", c_add_d);
592   node_map.emplace("mul", mul);
593   node_map.emplace("add", add);
594 
595   fg_->set_output(add);
596   auto manager = MakeManager({fg_});
597   if (manager) {
598     manager->AddFuncGraph(fg_);
599     fg_->set_manager(manager);
600   }
601   if (!fg_->has_user_data<FuncGraphPassIndex>()) {
602     fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
603   }
604   auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
605   GenIndex(fg_, func_graph_index);
606 
607   Check0(func_graph_index, node_map);
608   auto new_node = pass.Run(fg_, mul);
609   ASSERT_NE(new_node, nullptr);
610   (void)manager->Replace(mul, new_node);
611   pass.AfterProcess(mul, new_node, fg_, func_graph_index);
612   Check1(pass, func_graph_index, node_map);
613 
614   // build pattern
615   check.src_pattern_.AddVar("a")
616     .AddVar("b")
617     .AddVar("c")
618     .AddVar("d")
619     .AddVar("e")
620     .AddCNode("a_add_b", {std::make_shared<Primitive>(kAddOpName), "a", "b"})
621     .AddCNode("c_add_d", {std::make_shared<Primitive>(kAddOpName), "c", "d"})
622     .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "c_add_d", "a_add_b"})
623     .AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "mul", "e"});
624 
625   // pattern engine
626   ASSERT_TRUE(check.build_pattern_map(add));
627   Check2(check, node_map);
628 }
629 }  // namespace opt
630 }  // namespace mindspore
631