1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pattern_to_pattern_pass_utils.h"
18 #include "mindspore/core/ops/math_ops.h"
19 #include "include/backend/optimizer/node_pass.h"
20
21 namespace mindspore {
22 namespace opt {
23 namespace {
24 const auto kZero = 0;
25 const auto kOne = 1;
26 const auto kTwo = 2;
27 const auto kThree = 3;
28
29 const auto kA = "a";
30 const auto kB = "b";
31 const auto kC = "c";
32 const auto kD = "d";
33 const auto kE = "e";
34 const auto kAAddB = "a_add_b";
35 const auto kCAddD = "c_add_d";
36 const auto kMul = "mul";
37 const auto kAdd = "add";
38
39 class TestFastMul0 : public PatternToPatternPass {
40 // a*b + a*c -> a*(b+c)
41 public:
TestFastMul0()42 explicit TestFastMul0() : PatternToPatternPass("test_fast_mul0") {}
43 ~TestFastMul0() override = default;
44
DefineSrcPattern(SrcPattern * src_pattern)45 void DefineSrcPattern(SrcPattern *src_pattern) override {
46 (*src_pattern)
47 .AddVar("a")
48 .AddVar("b")
49 .AddVar("c")
50 .AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "a", "b"})
51 .AddCNode("ac", {std::make_shared<Primitive>(kMulOpName), "a", "c"})
52 .AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "ab", "ac"});
53 }
DefineDstPattern(DstPattern * dst_pattern)54 void DefineDstPattern(DstPattern *dst_pattern) override {
55 (*dst_pattern)
56 .AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
57 .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"});
58 }
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &) const59 bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
60 };
61
62 class TestFastMul1 : public PatternToPatternPass {
63 // a*b + c*d -> a*c
64 public:
TestFastMul1()65 explicit TestFastMul1() : PatternToPatternPass("test_fast_mul1") {}
66 ~TestFastMul1() override = default;
67
DefineSrcPattern(SrcPattern * src_pattern)68 void DefineSrcPattern(SrcPattern *src_pattern) override {
69 (*src_pattern)
70 .AddVar("a")
71 .AddVar("b")
72 .AddVar("c")
73 .AddVar("d")
74 .AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "a", "b"})
75 .AddCNode("cd", {std::make_shared<Primitive>(kMulOpName), "c", "d"})
76 .AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "ab", "cd"});
77 }
DefineDstPattern(DstPattern * dst_pattern)78 void DefineDstPattern(DstPattern *dst_pattern) override {
79 (*dst_pattern).AddCNode("ad", {std::make_shared<Primitive>(kMulOpName), "a", "d"});
80 }
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &) const81 bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
82 };
83
84 class TestFastMul2 : public PatternToPatternPass {
85 // a*b -> b*a
86 public:
TestFastMul2()87 explicit TestFastMul2() : PatternToPatternPass("test_fast_mul2") {}
88 ~TestFastMul2() override = default;
89
DefineSrcPattern(SrcPattern * src_pattern)90 void DefineSrcPattern(SrcPattern *src_pattern) override {
91 (*src_pattern).AddSeqVar("Sv").AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "Sv"});
92 }
DefineDstPattern(DstPattern * dst_pattern)93 void DefineDstPattern(DstPattern *dst_pattern) override {
94 auto ba = Unpacking("Sv");
95 auto ab = Unpacking("Sv");
96 ba[0] = ab[1];
97 ba[1] = ab[0];
98 (*dst_pattern).AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), ba});
99 }
CheckMatchedDAG(const PatternMap &,const FuncGraphPtr &,const AnfNodePtr &) const100 bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
101 };
102 } // namespace
103
104 class TestFastPatternToPatternPass : public UT::Common {
105 public:
TestFastPatternToPatternPass()106 TestFastPatternToPatternPass() : fg_(std::make_shared<FuncGraph>()){};
107
108 public:
109 FuncGraphPtr fg_;
110 };
111
112 /// Feature: Fast PatternToPattern Pass
113 /// Description: Fast PatternToPattern Pass rewrite graph
114 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul0)115 TEST_F(TestFastPatternToPatternPass, Mul0) {
116 // a*b + a*c -> a*(b+c)
117 // init
118 auto check = CheckPattern();
119 auto pass = TestFastMul0();
120
121 // build func graph
122 auto a = std::make_shared<AnfNode>(fg_);
123 auto b = std::make_shared<AnfNode>(fg_);
124 auto c = std::make_shared<AnfNode>(fg_);
125 AnfNodePtr ab =
126 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b}, fg_);
127 AnfNodePtr ac =
128 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, c}, fg_);
129 AnfNodePtr add = std::make_shared<CNode>(
130 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
131
132 fg_->set_output(add);
133 auto manager = MakeManager({fg_});
134 if (manager) {
135 manager->AddFuncGraph(fg_);
136 fg_->set_manager(manager);
137 }
138 if (!fg_->has_user_data<FuncGraphPassIndex>()) {
139 fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
140 }
141 auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
142 GenIndex(fg_, func_graph_index);
143
144 ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
145 ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1);
146 ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1);
147 ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
148 ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
149 ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2);
150
151 ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
152 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
153 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
154
155 auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
156 auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
157
158 ASSERT_TRUE(add_set.size() == 1);
159 ASSERT_TRUE(mul_set.size() == 2);
160 ASSERT_TRUE(add_set.find(add) != add_set.end());
161 ASSERT_TRUE(mul_set.find(ab) != mul_set.end());
162 ASSERT_TRUE(mul_set.find(ac) != mul_set.end());
163
164 auto new_node = pass.Run(fg_, add);
165 ASSERT_NE(new_node, nullptr);
166 (void)manager->Replace(add, new_node);
167 pass.AfterProcess(add, new_node, fg_, func_graph_index);
168
169 ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
170 ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0);
171 ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0);
172 ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
173 ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
174 ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
175 ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1);
176 ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1);
177
178 ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
179 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
180 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
181
182 auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
183 auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
184
185 ASSERT_TRUE(add_set_2.size() == 1);
186 ASSERT_TRUE(mul_set_2.size() == 1);
187 ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end());
188 ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end());
189
190 // build pattern
191 check.src_pattern_.AddVar("a")
192 .AddVar("b")
193 .AddVar("c")
194 .AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
195 .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"});
196
197 // pattern engine
198 ASSERT_TRUE(check.build_pattern_map(new_node));
199
200 // check
201 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
202 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b));
203 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c));
204 ASSERT_EQ(check.m_->Get("bc")->cast<CNodePtr>()->size(), 3);
205 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(0),
206 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
207 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(1), b));
208 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(2), c));
209 ASSERT_EQ(check.m_->Get("mul")->cast<CNodePtr>()->size(), 3);
210 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(0),
211 NewValueNode(std::make_shared<Primitive>(kMulOpName))));
212 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(1), a));
213 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(2), check.m_->Get("bc")));
214 }
215
216 /// Feature: Fast PatternToPattern Pass
217 /// Description: Fast PatternToPattern Pass rewrite graph
218 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul0NotRoot)219 TEST_F(TestFastPatternToPatternPass, Mul0NotRoot) {
220 // (a*b + a*c) + d -> a*(b+c) + d
221 // init
222 auto check = CheckPattern();
223 auto pass = TestFastMul0();
224
225 // build func graph
226 auto a = std::make_shared<AnfNode>(fg_);
227 auto b = std::make_shared<AnfNode>(fg_);
228 auto c = std::make_shared<AnfNode>(fg_);
229 auto d = std::make_shared<AnfNode>(fg_);
230 AnfNodePtr ab =
231 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b}, fg_);
232 AnfNodePtr ac =
233 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, c}, fg_);
234 AnfNodePtr add = std::make_shared<CNode>(
235 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
236 AnfNodePtr add1 = std::make_shared<CNode>(
237 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), add, d}, fg_);
238
239 fg_->set_output(add1);
240 auto manager = MakeManager({fg_});
241 if (manager) {
242 manager->AddFuncGraph(fg_);
243 fg_->set_manager(manager);
244 }
245 if (!fg_->has_user_data<FuncGraphPassIndex>()) {
246 fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
247 }
248 auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
249 GenIndex(fg_, func_graph_index);
250
251 ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
252 ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
253 ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1);
254 ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1);
255 ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
256 ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
257 ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
258 ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2);
259
260 ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
261 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
262 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
263
264 auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
265 auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
266
267 ASSERT_TRUE(add_set.size() == 2);
268 ASSERT_TRUE(mul_set.size() == 2);
269 ASSERT_TRUE(add_set.find(add1) != add_set.end());
270 ASSERT_TRUE(add_set.find(add) != add_set.end());
271 ASSERT_TRUE(mul_set.find(ab) != mul_set.end());
272 ASSERT_TRUE(mul_set.find(ac) != mul_set.end());
273
274 auto new_node = pass.Run(fg_, add);
275 ASSERT_NE(new_node, nullptr);
276 (void)manager->Replace(add, new_node);
277 pass.AfterProcess(add, new_node, fg_, func_graph_index);
278
279 ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
280 ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0);
281 ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0);
282 ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
283 ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
284 ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
285 ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
286 ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
287 ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1);
288 ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1);
289
290 ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
291 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
292 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
293
294 auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
295 auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
296
297 ASSERT_TRUE(add_set_2.size() == 2);
298 ASSERT_TRUE(mul_set_2.size() == 1);
299 ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end());
300 ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end());
301 ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end());
302
303 // build pattern
304 check.src_pattern_.AddVar("a")
305 .AddVar("b")
306 .AddVar("c")
307 .AddVar("d")
308 .AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
309 .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"})
310 .AddCNode("add1", {std::make_shared<Primitive>(kAddOpName), "mul", "d"});
311
312 // pattern engine
313 ASSERT_TRUE(check.build_pattern_map(add1));
314
315 // check
316 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
317 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b));
318 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c));
319 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d));
320
321 ASSERT_EQ(check.m_->Get("bc")->cast<CNodePtr>()->size(), 3);
322 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(0),
323 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
324 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(1), b));
325 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(2), c));
326
327 ASSERT_EQ(check.m_->Get("mul")->cast<CNodePtr>()->size(), 3);
328 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(0),
329 NewValueNode(std::make_shared<Primitive>(kMulOpName))));
330 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(1), a));
331 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(2), check.m_->Get("bc")));
332
333 ASSERT_EQ(check.m_->Get("add1")->cast<CNodePtr>()->size(), 3);
334 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(0),
335 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
336 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(1), check.m_->Get("mul")));
337 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(2), d));
338 }
339
340 /// Feature: Fast PatternToPattern Pass
341 /// Description: Fast PatternToPattern Pass rewrite graph
342 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul1)343 TEST_F(TestFastPatternToPatternPass, Mul1) {
344 // (a * (b1 + d) + (c1 * c2) * d) + e -> (a + d) + e
345 // init
346 auto check = CheckPattern();
347 auto pass = TestFastMul1();
348
349 // build func graph
350 auto a = std::make_shared<AnfNode>(fg_);
351 auto b = std::make_shared<AnfNode>(fg_);
352 auto c1 = std::make_shared<AnfNode>(fg_);
353 auto c2 = std::make_shared<AnfNode>(fg_);
354 auto d = std::make_shared<AnfNode>(fg_);
355 auto e = std::make_shared<AnfNode>(fg_);
356
357 AnfNodePtr b_add_d =
358 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), b, d}, fg_);
359 AnfNodePtr c1_mul_c2 = std::make_shared<CNode>(
360 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), c1, c2}, fg_);
361 AnfNodePtr a_mul = std::make_shared<CNode>(
362 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b_add_d}, fg_);
363 AnfNodePtr d_mul = std::make_shared<CNode>(
364 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), c1_mul_c2, d}, fg_);
365 AnfNodePtr add = std::make_shared<CNode>(
366 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), a_mul, d_mul}, fg_);
367 AnfNodePtr add1 = std::make_shared<CNode>(
368 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), add, e}, fg_);
369
370 fg_->set_output(add1);
371 auto manager = MakeManager({fg_});
372 if (manager) {
373 manager->AddFuncGraph(fg_);
374 fg_->set_manager(manager);
375 }
376 if (!fg_->has_user_data<FuncGraphPassIndex>()) {
377 fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
378 }
379 auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
380 GenIndex(fg_, func_graph_index);
381
382 ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 1);
383 ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 1);
384 ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 1);
385 ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 1);
386 ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
387 ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
388
389 ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
390 ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
391 ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 1);
392 ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 1);
393 ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 2);
394 ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1);
395
396 ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
397 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
398 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
399
400 auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
401 auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
402
403 ASSERT_TRUE(add_set.size() == 3);
404 ASSERT_TRUE(mul_set.size() == 3);
405 ASSERT_TRUE(add_set.find(add1) != add_set.end());
406 ASSERT_TRUE(add_set.find(add) != add_set.end());
407 ASSERT_TRUE(add_set.find(b_add_d) != add_set.end());
408 ASSERT_TRUE(mul_set.find(a_mul) != mul_set.end());
409 ASSERT_TRUE(mul_set.find(d_mul) != mul_set.end());
410 ASSERT_TRUE(mul_set.find(c1_mul_c2) != mul_set.end());
411
412 auto new_node = pass.Run(fg_, add);
413 ASSERT_NE(new_node, nullptr);
414 (void)manager->Replace(add, new_node);
415 pass.AfterProcess(add, new_node, fg_, func_graph_index);
416
417 ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 0);
418 ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 0);
419 ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 0);
420 ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 0);
421 ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
422 ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
423 ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("ad")) == 1);
424
425 ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
426 ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 0);
427 ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 0);
428 ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 0);
429 ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
430 ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1);
431
432 ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
433 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
434 ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
435
436 auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
437 auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
438
439 ASSERT_TRUE(add_set_2.size() == 1);
440 ASSERT_TRUE(mul_set_2.size() == 1);
441 ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end());
442 ASSERT_TRUE(mul_set_2.find(pass.m_->Get("ad")) != mul_set_2.end());
443
444 // build pattern
445 check.src_pattern_.AddVar("a")
446 .AddVar("d")
447 .AddVar("e")
448 .AddCNode("ad", {std::make_shared<Primitive>(kMulOpName), "a", "d"})
449 .AddCNode("add1", {std::make_shared<Primitive>(kAddOpName), "ad", "e"});
450
451 // pattern engine
452 ASSERT_TRUE(check.build_pattern_map(add1));
453
454 // check
455 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
456 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d));
457 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("e"), e));
458
459 ASSERT_EQ(check.m_->Get("ad")->cast<CNodePtr>()->size(), 3);
460 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(0),
461 NewValueNode(std::make_shared<Primitive>(kMulOpName))));
462 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(1), a));
463 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(2), d));
464
465 ASSERT_EQ(check.m_->Get("add1")->cast<CNodePtr>()->size(), 3);
466 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(0),
467 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
468 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(1), check.m_->Get("ad")));
469 ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(2), e));
470 }
471
472 namespace {
Check0(const FuncGraphIndexPtr & fg,const std::map<std::string,AnfNodePtr> & node_map)473 void Check0(const FuncGraphIndexPtr &fg, const std::map<std::string, AnfNodePtr> &node_map) {
474 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne);
475 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne);
476 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kOne);
477 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne);
478
479 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne);
480 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne);
481 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne);
482 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne);
483 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne);
484
485 ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo);
486 ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end());
487 ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end());
488
489 auto &add_set = fg->name_to_cnode_[kAddOpName];
490 auto &mul_set = fg->name_to_cnode_[kMulOpName];
491
492 ASSERT_TRUE(add_set.size() == kThree);
493 ASSERT_TRUE(mul_set.size() == kOne);
494 ASSERT_TRUE(add_set.find(node_map.at(kAdd)) != add_set.end());
495 ASSERT_TRUE(add_set.find(node_map.at(kAAddB)) != add_set.end());
496 ASSERT_TRUE(add_set.find(node_map.at(kCAddD)) != add_set.end());
497 ASSERT_TRUE(mul_set.find(node_map.at(kMul)) != mul_set.end());
498 }
Check1(const TestFastMul2 & pass,const FuncGraphIndexPtr & fg,const std::map<std::string,AnfNodePtr> & node_map)499 void Check1(const TestFastMul2 &pass, const FuncGraphIndexPtr &fg, const std::map<std::string, AnfNodePtr> &node_map) {
500 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne);
501 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne);
502 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kZero);
503 ASSERT_TRUE(fg->node_degree_.at(pass.m_->Get(kMul)) == kOne);
504 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne);
505
506 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne);
507 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne);
508 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne);
509 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne);
510 ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne);
511
512 ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo);
513 ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end());
514 ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end());
515
516 auto &add_set_2 = fg->name_to_cnode_[kAddOpName];
517 auto &mul_set_2 = fg->name_to_cnode_[kMulOpName];
518
519 ASSERT_TRUE(add_set_2.size() == kThree);
520 ASSERT_TRUE(mul_set_2.size() == kOne);
521 ASSERT_TRUE(add_set_2.find(node_map.at(kAAddB)) != add_set_2.end());
522 ASSERT_TRUE(add_set_2.find(node_map.at(kCAddD)) != add_set_2.end());
523 ASSERT_TRUE(mul_set_2.find(pass.m_->Get(kMul)) != mul_set_2.end());
524 }
525
Check2(const CheckPattern & check,const std::map<std::string,AnfNodePtr> & node_map)526 void Check2(const CheckPattern &check, const std::map<std::string, AnfNodePtr> &node_map) {
527 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kA), node_map.at(kA)));
528 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kB), node_map.at(kB)));
529 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kC), node_map.at(kC)));
530 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kD), node_map.at(kD)));
531 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kE), node_map.at(kE)));
532
533 ASSERT_EQ(check.m_->Get(kAAddB)->cast<CNodePtr>()->size(), kThree);
534 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kZero),
535 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
536 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kOne), node_map.at(kA)));
537 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kTwo), node_map.at(kB)));
538
539 ASSERT_EQ(check.m_->Get(kCAddD)->cast<CNodePtr>()->size(), kThree);
540 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kZero),
541 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
542 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kOne), node_map.at(kC)));
543 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kTwo), node_map.at(kD)));
544
545 ASSERT_EQ(check.m_->Get(kMul)->cast<CNodePtr>()->size(), kThree);
546 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kZero),
547 NewValueNode(std::make_shared<Primitive>(kMulOpName))));
548 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kOne), node_map.at(kCAddD)));
549 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kTwo), node_map.at(kAAddB)));
550
551 ASSERT_EQ(check.m_->Get(kAdd)->cast<CNodePtr>()->size(), kThree);
552 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kZero),
553 NewValueNode(std::make_shared<Primitive>(kAddOpName))));
554 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kOne), check.m_->Get(kMul)));
555 ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kTwo), node_map.at(kE)));
556 }
557 } // namespace
558
559 /// Feature: Fast PatternToPattern Pass
560 /// Description: Fast PatternToPattern Pass rewrite graph
561 /// Expectation: Get correct Graph
TEST_F(TestFastPatternToPatternPass,Mul2)562 TEST_F(TestFastPatternToPatternPass, Mul2) {
563 // ((a + b) * (c + d)) + e -> ((c + d) * (a + b)) + e
564 // init
565 auto check = CheckPattern();
566 auto pass = TestFastMul2();
567
568 // build func graph
569 auto a = std::make_shared<AnfNode>(fg_);
570 auto b = std::make_shared<AnfNode>(fg_);
571 auto c = std::make_shared<AnfNode>(fg_);
572 auto d = std::make_shared<AnfNode>(fg_);
573 auto e = std::make_shared<AnfNode>(fg_);
574
575 AnfNodePtr a_add_b =
576 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), a, b}, fg_);
577 AnfNodePtr c_add_d =
578 std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), c, d}, fg_);
579 AnfNodePtr mul = std::make_shared<CNode>(
580 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a_add_b, c_add_d}, fg_);
581 AnfNodePtr add = std::make_shared<CNode>(
582 std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), mul, e}, fg_);
583
584 std::map<std::string, AnfNodePtr> node_map;
585 node_map.emplace("a", a);
586 node_map.emplace("b", b);
587 node_map.emplace("c", c);
588 node_map.emplace("d", d);
589 node_map.emplace("e", e);
590 node_map.emplace("a_add_b", a_add_b);
591 node_map.emplace("c_add_d", c_add_d);
592 node_map.emplace("mul", mul);
593 node_map.emplace("add", add);
594
595 fg_->set_output(add);
596 auto manager = MakeManager({fg_});
597 if (manager) {
598 manager->AddFuncGraph(fg_);
599 fg_->set_manager(manager);
600 }
601 if (!fg_->has_user_data<FuncGraphPassIndex>()) {
602 fg_->set_user_data<FuncGraphPassIndex>(std::make_shared<FuncGraphPassIndex>());
603 }
604 auto func_graph_index = fg_->user_data<FuncGraphPassIndex>();
605 GenIndex(fg_, func_graph_index);
606
607 Check0(func_graph_index, node_map);
608 auto new_node = pass.Run(fg_, mul);
609 ASSERT_NE(new_node, nullptr);
610 (void)manager->Replace(mul, new_node);
611 pass.AfterProcess(mul, new_node, fg_, func_graph_index);
612 Check1(pass, func_graph_index, node_map);
613
614 // build pattern
615 check.src_pattern_.AddVar("a")
616 .AddVar("b")
617 .AddVar("c")
618 .AddVar("d")
619 .AddVar("e")
620 .AddCNode("a_add_b", {std::make_shared<Primitive>(kAddOpName), "a", "b"})
621 .AddCNode("c_add_d", {std::make_shared<Primitive>(kAddOpName), "c", "d"})
622 .AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "c_add_d", "a_add_b"})
623 .AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "mul", "e"});
624
625 // pattern engine
626 ASSERT_TRUE(check.build_pattern_map(add));
627 Check2(check, node_map);
628 }
629 } // namespace opt
630 } // namespace mindspore
631