• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/platform/test.h"
24 
25 namespace xla {
26 namespace {
27 
28 namespace m = match;
29 using PatternMatcherTest = HloTestBase;
30 
TEST_F(PatternMatcherTest,AddOp)31 TEST_F(PatternMatcherTest, AddOp) {
32   constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
33     ENTRY %two_plus_two_computation () -> f32[] {
34       %two = f32[] constant(2)
35       ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
36     }
37   )";
38   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
39                           ParseAndReturnVerifiedModule(kModuleStr));
40 
41   const HloInstruction* matched_inst;
42   HloInstruction* matched_operand;
43   Shape* matched_shape;
44   Layout* matched_layout;
45 
46   ASSERT_TRUE(Match(
47       hlo_module->entry_computation()->root_instruction(),
48       match::Op(&matched_inst)
49           .WithName("two_plus_two")
50           .WithOpcode(HloOpcode::kAdd)
51           .WithShape(
52               match::Shape(&matched_shape)
53                   .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
54           .WithOperand(
55               0,
56               match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
57   ASSERT_NE(matched_inst, nullptr);
58   EXPECT_EQ(matched_inst->name(), "two_plus_two");
59   EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
60 
61   EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
62                     match::Add(match::Constant(), match::Constant())));
63 
64   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
65                      match::Op().WithName("bad_name")));
66   matched_inst = nullptr;
67   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
68                      match::Multiply(&matched_inst, match::Op(), match::Op())));
69 }
70 
TEST_F(PatternMatcherTest,ScalarShape)71 TEST_F(PatternMatcherTest, ScalarShape) {
72   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
73   Shape* matched_shape;
74   EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
75   EXPECT_EQ(matched_shape, &scalar_shape);
76   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
77   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
78   EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
79   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
80   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
81   EXPECT_FALSE(Match(
82       &scalar_shape,
83       match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
84 }
85 
TEST_F(PatternMatcherTest,DenseArrayShape)86 TEST_F(PatternMatcherTest, DenseArrayShape) {
87   auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
88   Shape* matched_shape;
89   EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
90   EXPECT_EQ(matched_shape, &array_shape);
91   EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
92   EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
93   EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
94   EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
95   EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
96   EXPECT_FALSE(
97       Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
98   Layout* matched_layout;
99   EXPECT_TRUE(Match(&array_shape,
100                     match::Shape().WithLayout(
101                         match::Layout(&matched_layout).WithDenseFormat())));
102   EXPECT_EQ(matched_layout, &array_shape.layout());
103 }
104 
TEST_F(PatternMatcherTest,TupleShape)105 TEST_F(PatternMatcherTest, TupleShape) {
106   auto tuple_shape = ShapeUtil::MakeTupleShape({
107       ShapeUtil::MakeShape(F32, {1, 2, 3}),
108       ShapeUtil::MakeShape(S32, {4, 5}),
109   });
110   EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
111   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
112   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
113 
114   Shape* subshape;
115   ASSERT_TRUE(Match(
116       &tuple_shape,
117       match::Shape().WithSubshape(
118           {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
119   ASSERT_NE(subshape, nullptr);
120   EXPECT_TRUE(
121       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
122   EXPECT_TRUE(Match(&tuple_shape,
123                     match::Shape().WithSubshape(
124                         {0}, match::Shape().EqualTo(
125                                  &ShapeUtil::GetSubshape(tuple_shape, {0})))));
126   EXPECT_FALSE(Match(&tuple_shape,
127                      match::Shape().WithSubshape(
128                          {0}, match::Shape().EqualTo(
129                                   &ShapeUtil::GetSubshape(tuple_shape, {1})))));
130 
131   ASSERT_TRUE(Match(
132       &tuple_shape,
133       match::Shape().WithSubshape(
134           {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
135   ASSERT_NE(subshape, nullptr);
136   EXPECT_TRUE(
137       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
138   EXPECT_TRUE(Match(&tuple_shape,
139                     match::Shape().WithSubshape(
140                         {1}, match::Shape().EqualTo(
141                                  &ShapeUtil::GetSubshape(tuple_shape, {1})))));
142   EXPECT_FALSE(Match(&tuple_shape,
143                      match::Shape().WithSubshape(
144                          {1}, match::Shape().EqualTo(
145                                   &ShapeUtil::GetSubshape(tuple_shape, {0})))));
146 
147   EXPECT_FALSE(
148       Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
149   EXPECT_FALSE(
150       Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
151 }
152 
TEST_F(PatternMatcherTest,FusionKind)153 TEST_F(PatternMatcherTest, FusionKind) {
154   constexpr char kModuleStr[] = R"(
155     HloModule test_module
156 
157     fused_computation {
158       ROOT fp0 = f32[] parameter(0)
159     }
160 
161     ENTRY while.v11 {
162       p0 = f32[] parameter(0)
163       ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
164     })";
165   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
166                           ParseAndReturnVerifiedModule(kModuleStr));
167 
168   auto* root = hlo_module->entry_computation()->root_instruction();
169   EXPECT_TRUE(Match(
170       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
171   EXPECT_FALSE(Match(
172       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
173   EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
174                                            HloInstruction::FusionKind::kLoop)));
175 }
176 
TEST_F(PatternMatcherTest,GetTupleElement)177 TEST_F(PatternMatcherTest, GetTupleElement) {
178   constexpr char kModuleStr[] = R"(
179     HloModule test_module
180 
181     ENTRY while.v11 {
182       p0 = (f32[], f32[], f32[]) parameter(0)
183       ROOT gte = f32[] get-tuple-element(p0), index=1
184     })";
185   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
186                           ParseAndReturnVerifiedModule(kModuleStr));
187 
188   auto* root = hlo_module->entry_computation()->root_instruction();
189   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
190   EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
191   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
192   EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
193   EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
194 }
195 
TEST_F(PatternMatcherTest,AnyOf)196 TEST_F(PatternMatcherTest, AnyOf) {
197   constexpr char kModuleStr[] = R"(
198     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
199   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
200                           ParseAndReturnVerifiedModule(kModuleStr));
201   auto* root = hlo_module->entry_computation()->root_instruction();
202 
203   EXPECT_TRUE(
204       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
205                                                match::ConstantScalar(1))));
206   EXPECT_TRUE(
207       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
208                                                match::ConstantScalar(0))));
209   EXPECT_FALSE(
210       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
211                                                match::ConstantScalar(2))));
212 }
213 
TEST_F(PatternMatcherTest,ConstantScalar)214 TEST_F(PatternMatcherTest, ConstantScalar) {
215   using match::ConstantEffectiveScalar;
216   using match::ConstantScalar;
217   using match::Op;
218   using match::Tuple;
219 
220   constexpr char kModuleStr[] = R"(
221     HloModule test_module
222     ENTRY test {
223       a = s32[] constant(1)
224       b = s32[1,1] constant({{2}})
225       c = s32[1,2] constant({{2,2}})
226       d = f32[] constant(1)
227       e = f32[] constant(1.25)
228       ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
229     })";
230   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
231                           ParseAndReturnVerifiedModule(kModuleStr));
232   auto* root = hlo_module->entry_computation()->root_instruction();
233 
234   const HloInstruction* a = root->operand(0);
235   const HloInstruction* b = root->operand(1);
236   const HloInstruction* c = root->operand(2);
237   const HloInstruction* d = root->operand(3);
238   const HloInstruction* e = root->operand(4);
239   EXPECT_TRUE(Match(a, ConstantScalar()));
240   EXPECT_TRUE(Match(a, ConstantScalar(1)));
241   EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
242   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
243   EXPECT_FALSE(Match(a, ConstantScalar(2)));
244   EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
245   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
246   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
247 
248   EXPECT_FALSE(Match(b, ConstantScalar()));
249   EXPECT_FALSE(Match(b, ConstantScalar(2)));
250   EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
251   EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
252 
253   EXPECT_FALSE(Match(c, ConstantScalar()));
254   EXPECT_FALSE(Match(c, ConstantScalar(2)));
255   EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
256   EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
257 
258   EXPECT_TRUE(Match(d, ConstantScalar(1)));
259   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
260   EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
261   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
262 
263   EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
264   EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
265   EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
266   EXPECT_FALSE(Match(e, ConstantScalar(1)));
267   EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
268 
269   const HloInstruction* instr = nullptr;
270   EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
271   EXPECT_EQ(instr, a);
272 
273   instr = nullptr;
274   EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
275   EXPECT_EQ(instr, a);
276 
277   instr = nullptr;
278   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
279   EXPECT_EQ(instr, a);
280 
281   instr = nullptr;
282   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
283   EXPECT_EQ(instr, a);
284 }
285 
TEST_F(PatternMatcherTest,MultiplyAnyOrder)286 TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
287   using match::ConstantScalar;
288   using match::MultiplyAnyOrder;
289 
290   constexpr char kModuleStr[] = R"(
291     HloModule test_module
292     ENTRY test {
293       lhs = f16[] constant(42)
294       rhs = f16[] constant(52)
295       ROOT multiply = f16[] multiply(lhs, rhs)
296     })";
297   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
298                           ParseAndReturnVerifiedModule(kModuleStr));
299   auto* root = hlo_module->entry_computation()->root_instruction();
300   const HloInstruction* instr;
301 
302   EXPECT_TRUE(Match(
303       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
304   EXPECT_TRUE(Match(
305       root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
306 
307   // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
308   // e.g. IsNonConstant() on it.
309   EXPECT_TRUE(Match(
310       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
311                 .IsNonConstant()));
312   EXPECT_TRUE(
313       Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
314                       .IsNonConstant()));
315 }
316 
TEST_F(PatternMatcherTest,AnyOfShortCircuit)317 TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
318   using match::AnyOf;
319   using match::Multiply;
320   using match::Op;
321 
322   constexpr char kModuleStr[] = R"(
323     HloModule test_module
324     ENTRY test {
325       lhs = f16[] constant(42)
326       rhs = f16[] constant(52)
327       ROOT multiply = f16[] multiply(lhs, rhs)
328     })";
329   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
330                           ParseAndReturnVerifiedModule(kModuleStr));
331   auto* root = hlo_module->entry_computation()->root_instruction();
332 
333   {
334     const HloInstruction* mul = nullptr;
335     const HloInstruction* any = nullptr;
336 
337     ASSERT_TRUE(Match(
338         root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
339     EXPECT_NE(nullptr, mul);
340     EXPECT_EQ(nullptr, any);
341   }
342   {
343     const HloInstruction* mul = nullptr;
344     const HloInstruction* any = nullptr;
345 
346     ASSERT_TRUE(Match(
347         root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
348     EXPECT_NE(nullptr, any);
349     EXPECT_EQ(nullptr, mul);
350   }
351 }
352 
TEST_F(PatternMatcherTest,AllOf)353 TEST_F(PatternMatcherTest, AllOf) {
354   using match::AllOf;
355   using match::Broadcast;
356   using match::Constant;
357   using match::Op;
358 
359   constexpr char kModuleStr[] = R"(
360     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
361   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
362                           ParseAndReturnVerifiedModule(kModuleStr));
363   auto* root = hlo_module->entry_computation()->root_instruction();
364 
365   auto f16_scalar = ShapeUtil::MakeShape(F16, {});
366   auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
367   auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
368   auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
369   ASSERT_TRUE(Match(root, scalar_pattern));
370   ASSERT_TRUE(Match(root, f16_pattern));
371   ASSERT_TRUE(Match(root, f16_compatible_pattern));
372   EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
373                                                 f16_compatible_pattern)));
374   EXPECT_TRUE(
375       Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
376                                         scalar_pattern)));
377   EXPECT_FALSE(
378       Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
379   EXPECT_FALSE(Match(
380       root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
381   EXPECT_FALSE(
382       Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
383 }
384 
TEST_F(PatternMatcherTest,AllOfNoCaptureIfNotMatch)385 TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
386   using match::AllOf;
387   using match::Broadcast;
388   using match::Constant;
389   using match::Op;
390 
391   constexpr char kModuleStr[] = R"(
392     HloModule test_module
393     ENTRY test {
394       ROOT v = f16[] constant(42)
395     })";
396   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
397                           ParseAndReturnVerifiedModule(kModuleStr));
398   auto* root = hlo_module->entry_computation()->root_instruction();
399 
400   const HloInstruction* constant = nullptr;
401   ASSERT_FALSE(
402       Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
403   EXPECT_EQ(nullptr, constant);
404   ASSERT_TRUE(Match(root, Constant(&constant)));
405   EXPECT_NE(nullptr, constant);
406 }
407 
TEST_F(PatternMatcherTest,TestNoCapture)408 TEST_F(PatternMatcherTest, TestNoCapture) {
409   using match::Constant;
410 
411   constexpr char kModuleStr[] = R"(
412     HloModule test_module
413     ENTRY test {
414       ROOT v = f16[] constant(42)
415     })";
416   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
417                           ParseAndReturnVerifiedModule(kModuleStr));
418   auto* root = hlo_module->entry_computation()->root_instruction();
419 
420   const HloInstruction* constant = nullptr;
421   ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
422   EXPECT_EQ(nullptr, constant);
423 }
424 
TEST_F(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)425 TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
426   using match::Add;
427   using match::AddAnyOrder;
428   using match::AnyOf;
429   using match::Op;
430 
431   constexpr char kModuleStr[] = R"(
432     HloModule test_module
433     ENTRY test {
434       u = f16[] parameter(0)
435       v = f16[] parameter(1)
436       ROOT add = f16[] add(u, v)
437     })";
438   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
439                           ParseAndReturnVerifiedModule(kModuleStr));
440   auto* root = hlo_module->entry_computation()->root_instruction();
441 
442   const HloInstruction* addend0 = nullptr;
443   const HloInstruction* addend1 = nullptr;
444   const HloInstruction* addend2 = nullptr;
445   auto add2_pattern = Add(Op(&addend0), Op(&addend1));
446   auto add3_pattern = AnyOf<HloInstruction>(
447       AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
448 
449   ASSERT_TRUE(Match(root, add3_pattern));
450   EXPECT_NE(nullptr, addend0);
451   EXPECT_NE(nullptr, addend1);
452   EXPECT_EQ(nullptr, addend2);
453 }
454 
TEST_F(PatternMatcherTest,TestConcat)455 TEST_F(PatternMatcherTest, TestConcat) {
456   using match::Concatenate;
457   using match::ConstantScalar;
458   using match::Op;
459   using match::Reshape;
460 
461   constexpr char kModuleStr[] = R"(
462     HloModule test_module
463     ENTRY test {
464       c1 = u32[] constant(1)
465       c2 = u32[] constant(2)
466       c3 = u32[] constant(3)
467       c4 = u32[] constant(4)
468       r1 = u32[1] reshape(c1)
469       r2 = u32[1] reshape(c2)
470       r3 = u32[1] reshape(c3)
471       r4 = u32[1] reshape(c4)
472       ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
473     })";
474   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
475                           ParseAndReturnVerifiedModule(kModuleStr));
476   auto* root = hlo_module->entry_computation()->root_instruction();
477   ASSERT_TRUE(Match(
478       root,
479       Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
480                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
481   ASSERT_FALSE(Match(
482       root,
483       Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
484                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
485   ASSERT_FALSE(Match(
486       root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
487                         Reshape(ConstantScalar(3)))));
488   ASSERT_FALSE(Match(
489       root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
490                         Reshape(ConstantScalar(4)))));
491 }
492 
493 template <typename Pattern>
Description(const Pattern & pattern)494 string Description(const Pattern& pattern) {
495   std::stringstream ss;
496   pattern.DescribeTo(&ss);
497   return ss.str();
498 }
499 
500 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)501 string Explanation(Elem* elem, const Pattern& pattern) {
502   std::stringstream ss;
503   MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
504   Match(elem, pattern, options);
505   return ss.str();
506 }
507 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)508 string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
509   return Explanation(elem.get(), pattern);
510 }
511 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)512 string Explanation(const Elem& elem, const Pattern& pattern) {
513   return Explanation(&elem, pattern);
514 }
515 
516 // Helper macro for checking a pattern's description and the explanation printed
517 // when attempting to match (and presumably failing) on a given object.
518 //
519 // We use a macro rather than a function because we want good line numbers in
520 // errors.  We use this rather than writing a helper that returns a pair of
521 // (description, explanation) and doing something like
522 //
523 //   EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
524 //
525 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
526 // while EXPECT_THAT does not.  This unified diff makes the errors much easier
527 // to read.
528 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc,    \
529                                     expected_explanation)            \
530   do {                                                               \
531     EXPECT_EQ(Description(pattern), (expected_desc));                \
532     EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
533   } while (0)
534 
TEST_F(PatternMatcherTest,LayoutDescribeToAndExplain)535 TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
536   auto layout = LayoutUtil::MakeLayout({1, 2});
537   auto layout2 = LayoutUtil::MakeLayout({2, 2});
538 
539   EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
540                               "a layout", "Layout is null");
541   EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
542                               "a layout equal to {1,2}",
543                               "Layout {2,2} is not equal to expected {1,2}");
544 }
545 
TEST_F(PatternMatcherTest,CustomCallTargetMatcherDescribeAndExplain)546 TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
547   constexpr char kModuleStr[] = R"(
548     HloModule test_module
549 
550     ENTRY test {
551       ROOT out = f32[] custom-call(), custom_call_target="test_target"
552     }
553   )";
554 
555   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
556                           ParseAndReturnVerifiedModule(kModuleStr));
557 
558   auto* root = hlo_module->entry_computation()->root_instruction();
559   EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
560   EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
561 
562   EXPECT_DESC_AND_EXPLANATION(
563       root, match::Op().WithCustomCallTarget("other_target"),
564       "an HloInstruction custom call with target 'other_target'",
565       "HloInstruction is not a custom call with a target 'other_target'\nin "
566       "out = f32[] custom-call(), custom_call_target=\"test_target\"");
567 }
568 
TEST_F(PatternMatcherTest,ShapeDescribeToAndExplain)569 TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
570   auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
571   auto layout = shape.layout();
572 
573   EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
574                               "a shape", "Shape is null");
575   EXPECT_DESC_AND_EXPLANATION(
576       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
577       m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
578       "Shape not equal to f32[1,2]{0,1}\n"
579       "in f32[1,2]{1,0}");
580   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
581                               m::Shape().CompatibleTo(&shape),
582                               "a shape compatible with f32[1,2]",
583                               "Shape not compatible with f32[1,2]\n"
584                               "in f32[2,2]{1,0}");
585   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
586                               "a shape with element type F16",
587                               "Shape does not have element type F16\n"
588                               "in f32[1,2]{0,1}");
589   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
590                               "a shape that represents a scalar",
591                               "Shape is not a scalar\n"
592                               "in f32[1,2]{0,1}");
593   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
594                               "a shape that represents an array",
595                               "Shape is not an array\n"
596                               "in ()");
597   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
598                               "a shape that represents a tuple",
599                               "Shape is not a tuple\n"
600                               "in f32[1,2]{0,1}");
601   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
602                               "a shape that is an effective scalar",
603                               "Shape is not an effective scalar\n"
604                               "in f32[1,2]{0,1}");
605   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
606                               "a shape that has 42 dimensions",
607                               "Shape does not have rank 42\n"
608                               "in f32[1,2]{0,1}");
609   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
610                               "a shape that is a scalar",
611                               "Shape is not a scalar\n"
612                               "in f32[1,2]{0,1}");
613   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
614                               "a shape:\n"
615                               " * that has 1 dimension AND\n"
616                               " * that represents an array",
617                               "Shape does not have rank 1\n"
618                               "in f32[1,2]{0,1}");
619   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
620                               m::Shape().IsArray().WithRank(1),
621                               "a shape:\n"
622                               " * that represents an array AND\n"
623                               " * that has 1 dimension",
624                               "Shape is not an array\n"
625                               "in ()");
626   EXPECT_DESC_AND_EXPLANATION(
627       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
628       m::Shape().WithLayoutEqualTo(&layout),
629       "a shape with\n  a layout equal to {0,1}",
630       "Layout {1,0} is not equal to expected {0,1}\n"
631       "in f32[1,2]{1,0}");
632   EXPECT_DESC_AND_EXPLANATION(shape,
633                               m::Shape().WithSubshapeEqualTo({10}, &shape),
634                               "a shape with subshape at index {10} which is\n"
635                               "  a shape equal to f32[1,2]{0,1}",
636                               "No subshape at {10}\n"
637                               "in f32[1,2]{0,1}");
638   EXPECT_DESC_AND_EXPLANATION(
639       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
640       m::Shape().WithSubshapeEqualTo({0}, &shape),
641       "a shape with subshape at index {0} which is\n"
642       "  a shape equal to f32[1,2]{0,1}",
643       "Shape not equal to f32[1,2]{0,1}\n"
644       "in f32[2,2]{1,0}\n"
645       "in subshape at {0}\n"
646       "in (f32[2,2])");
647   EXPECT_DESC_AND_EXPLANATION(shape,
648                               m::Shape().WithSubshapeCompatibleTo({10}, &shape),
649                               "a shape with subshape at index {10} which is\n"
650                               "  a shape compatible with f32[1,2]",
651                               "No subshape at {10}\n"
652                               "in f32[1,2]{0,1}");
653   EXPECT_DESC_AND_EXPLANATION(
654       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
655       m::Shape().WithSubshapeCompatibleTo({0}, &shape),
656       "a shape with subshape at index {0} which is\n"
657       "  a shape compatible with f32[1,2]",
658       "Shape not compatible with f32[1,2]\n"
659       "in f32[2,2]{1,0}\n"
660       "in subshape at {0}\n"
661       "in (f32[2,2])");
662   EXPECT_DESC_AND_EXPLANATION(
663       ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
664       m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
665       "a shape with subshape at index {0,0} which is\n"
666       "  a shape that represents a scalar",
667       "Shape is not a scalar\n"
668       "in f32[1,2]{0,1}\n"
669       "in subshape at {0,0}\n"
670       "in ((f32[1,2]))");
671 }
672 
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)673 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
674                                         std::unique_ptr<HloInstruction> instr) {
675   instr->SetAndSanitizeName(string(name));
676   return instr;
677 }
678 
TEST_F(PatternMatcherTest,HloInstructionDescribeToAndExplain)679 TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
680   std::unique_ptr<HloInstruction> iota =
681       SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
682                                               /*iota_dimension=*/0));
683   std::unique_ptr<HloInstruction> constant =
684       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
685 
686   EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
687                               m::Op(), "an HloInstruction",
688                               "HloInstruction* is null");
689   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
690                               "an HloInstruction named \"foo\"",
691                               "HloInstruction not named \"foo\"\n"
692                               "in i = s32[42]{0} iota(), iota_dimension=0");
693   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
694                               "an HloInstruction with opcode add",
695                               "HloInstruction doesn't have opcode add\n"
696                               "in i = s32[42]{0} iota(), iota_dimension=0");
697   EXPECT_DESC_AND_EXPLANATION(
698       constant, m::Op().IsNonConstant(),
699       "an HloInstruction with any opcode other than constant",
700       "HloInstruction has opcode constant, expected anything else\n"
701       "in c = s32[] constant(0)");
702   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
703                               "an HloInstruction with 42 operands",
704                               "HloInstruction doesn't have 42 operands\n"
705                               "in i = s32[42]{0} iota(), iota_dimension=0");
706   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
707                               "an HloInstruction outputting\n"
708                               "  a shape that represents a tuple",
709                               "Shape is not a tuple\n"
710                               "in s32[42]{0}\n"
711                               "in output shape\n"
712                               "in i = s32[42]{0} iota(), iota_dimension=0");
713   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(F32, {42}),
714                               "an HloInstruction outputting\n"
715                               "  a shape:\n"
716                               "   * with element type F32 AND\n"
717                               "   * with dimensions [42]",
718                               "Shape does not have element type F32\n"
719                               "in s32[42]{0}\n"
720                               "in output shape\n"
721                               "in i = s32[42]{0} iota(), iota_dimension=0");
722   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(S32, {128}),
723                               "an HloInstruction outputting\n"
724                               "  a shape:\n"
725                               "   * with element type S32 AND\n"
726                               "   * with dimensions [128]",
727                               "Shape does not have dimensions [128]\n"
728                               "in s32[42]{0}\n"
729                               "in output shape\n"
730                               "in i = s32[42]{0} iota(), iota_dimension=0");
731   EXPECT_DESC_AND_EXPLANATION(
732       iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
733       "an HloInstruction with operand 2 which is:\n"
734       "  an HloInstruction with opcode add",
735       "desired operand index 2 is out of bounds\n"
736       "in i = s32[42]{0} iota(), iota_dimension=0");
737 
738   EXPECT_DESC_AND_EXPLANATION(
739       SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
740                                                 HloOpcode::kAdd, constant.get(),
741                                                 constant.get())),
742       m::Op().WithOperand(1, m::Op().IsNonConstant()),
743       "an HloInstruction with operand 1 which is:\n"
744       "  an HloInstruction with any opcode other than constant",
745       "HloInstruction has opcode constant, expected anything else\n"
746       "in c = s32[] constant(0)\n"
747       "in operand 1\n"
748       "in a = s32[] add(s32[] c, s32[] c)");
749   EXPECT_DESC_AND_EXPLANATION(
750       iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
751       "an HloInstruction with fusion kind kLoop",
752       "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
753       "in i = s32[42]{0} iota(), iota_dimension=0");
754   EXPECT_DESC_AND_EXPLANATION(
755       iota, m::Op().WithTupleIndex(42),
756       "an HloInstruction which is a GTE with index 42",
757       "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
758       "in i = s32[42]{0} iota(), iota_dimension=0");
759   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
760                               "an HloInstruction which is a constant scalar",
761                               "HloInstruction is not a constant\n"
762                               "in i = s32[42]{0} iota(), iota_dimension=0");
763   EXPECT_DESC_AND_EXPLANATION(
764       SetName("c", HloInstruction::CreateConstant(
765                        LiteralUtil::CreateR1<int>({1, 2}))),
766       m::Op().IsConstantEffectiveScalar(),
767       "an HloInstruction which is a constant effective scalar",
768       "HloInstruction is not an effective scalar\n"
769       "in c = s32[2]{0} constant({1, 2})");
770   EXPECT_DESC_AND_EXPLANATION(
771       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
772       m::Op().IsConstantScalar(42),
773       "an HloInstruction which is a constant scalar with value 42",
774       "HloInstruction's constant value 10 did not match expected value 42\n"
775       "in c = s32[] constant(10)");
776   EXPECT_DESC_AND_EXPLANATION(
777       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
778       m::Op().IsConstantEffectiveScalar(1.25),
779       "an HloInstruction which is a constant effective scalar with value 1.25",
780       "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
781       "in c = f64[] constant(2.25)");
782   EXPECT_DESC_AND_EXPLANATION(
783       constant, m::Op().Is(iota.get()),
784       absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
785                    " (i = s32[42]{0} iota(), iota_dimension=0)"),
786       absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
787                    absl::Hex(iota.get()),
788                    " (i = s32[42]{0} iota(), iota_dimension=0)\n"
789                    "in c = s32[] constant(0)"));
790 }
791 
TEST_F(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)792 TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
793   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
794   EXPECT_DESC_AND_EXPLANATION(
795       SetName("a", HloInstruction::CreateBinary(
796                        scalar_s32, HloOpcode::kAdd,
797                        SetName("b", HloInstruction::CreateConstant(
798                                         LiteralUtil::CreateR0(0)))
799                            .get(),
800                        SetName("c", HloInstruction::CreateConstant(
801                                         LiteralUtil::CreateR0(0)))
802                            .get())),
803       m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
804       "an HloInstruction:\n"
805       " * with opcode add AND\n"
806       " * with two operands in either order:\n"
807       "    - an HloInstruction named \"b\"\n"
808       "    - an HloInstruction named \"bar\"",
809       "HloInstruction's operands (ignoring order) did not match second "
810       "matcher.  Specifically,\n"
811       " - an HloInstruction named \"bar\"\n"
812       "does not match LHS:\n"
813       " - HloInstruction not named \"bar\"\n"
814       "   in b = s32[] constant(0)\n"
815       "does not match RHS:\n"
816       " - HloInstruction not named \"bar\"\n"
817       "   in c = s32[] constant(0)\n"
818       "in a = s32[] add(s32[] b, s32[] c)");
819 
820   EXPECT_DESC_AND_EXPLANATION(
821       SetName("a",
822               HloInstruction::CreateBinary(
823                   scalar_s32, HloOpcode::kAdd,
824                   HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
825                   SetName("c", HloInstruction::CreateConstant(
826                                    LiteralUtil::CreateR0(0)))
827                       .get())),
828       m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
829       "an HloInstruction:\n"
830       " * with opcode add AND\n"
831       " * with two operands in either order:\n"
832       "    - an HloInstruction which is a constant scalar\n"
833       "    - an HloInstruction with opcode constant",
834       "HloInstruction's LHS operand did not match either of the two matchers.  "
835       "Specifically,\n"
836       " - an HloInstruction which is a constant scalar\n"
837       "does not match LHS:\n"
838       " - HloInstruction is not a constant\n"
839       "   in p = s32[] parameter(0)\n"
840       "and\n"
841       " - an HloInstruction with opcode constant\n"
842       "does not match LHS:\n"
843       " - HloInstruction doesn't have opcode constant\n"
844       "   in p = s32[] parameter(0)\n"
845       "in a = s32[] add(s32[] p, s32[] c)");
846 }
847 
TEST_F(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)848 TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
849   EXPECT_DESC_AND_EXPLANATION(
850       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
851       m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
852                                m::Op().WithName("bar")),
853       "any of:\n"
854       " - an HloInstruction named \"foo\" OR\n"
855       " - an HloInstruction named \"bar\"",
856       "None of the following matchers succeeded:\n"
857       "Matcher #1\n"
858       " - an HloInstruction named \"foo\"\n"
859       "failed with\n"
860       " - HloInstruction not named \"foo\"\n"
861       "   in c = s32[] constant(0)\n"
862       "Matcher #2\n"
863       " - an HloInstruction named \"bar\"\n"
864       "failed with\n"
865       " - HloInstruction not named \"bar\"\n"
866       "   in c = s32[] constant(0)");
867 }
868 
TEST_F(PatternMatcherTest,Parameter)869 TEST_F(PatternMatcherTest, Parameter) {
870   auto param =
871       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
872   auto non_param =
873       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
874   EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
875   EXPECT_TRUE(Match(param.get(), m::Parameter()));
876   EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
877   EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
878   EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
879 
880   EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
881                               "an HloInstruction:\n"
882                               " * with opcode parameter AND\n"
883                               " * which is parameter 1",
884                               "HloInstruction doesn't have opcode parameter\n"
885                               "in c = s32[] constant(0)");
886   EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
887                             0, ShapeUtil::MakeShape(F32, {}), "p0"),
888                         m::Parameter(1)),
889             "HloInstruction is not parameter 1\n"
890             "in p0 = f32[] parameter(0)");
891 }
892 
TEST_F(PatternMatcherTest,OneUseAndOneUser)893 TEST_F(PatternMatcherTest, OneUseAndOneUser) {
894   auto param =
895       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
896 
897   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
898   EXPECT_DESC_AND_EXPLANATION(
899       param, m::Op().WithOneUse(),
900       "an HloInstruction which has exactly one use",
901       "HloInstruction has 0 users, but expected exactly one.\n"
902       "in p0 = f32[] parameter(0)");
903 
904   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
905   EXPECT_DESC_AND_EXPLANATION(
906       param, m::Op().WithOneUser(),
907       "an HloInstruction which has exactly one user (but possibly is used "
908       "multiple times by that instruction)",
909       "HloInstruction has 0 users, but expected exactly one.\n"
910       "in p0 = f32[] parameter(0)");
911 
912   {
913     auto reshape =
914         SetName("r", HloInstruction::CreateReshape(
915                          ShapeUtil::MakeShape(F32, {1}), param.get()));
916     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
917     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
918 
919     auto reshape1 =
920         SetName("r1", HloInstruction::CreateReshape(
921                           ShapeUtil::MakeShape(F32, {1}), param.get()));
922     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
923     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
924 
925     const char* kMultipleUserExplanation =
926         "HloInstruction has 2 users, but expected exactly one.\n"
927         "All users:\n"
928         " - r = f32[1]{0} reshape(f32[] p0)\n"
929         " - r1 = f32[1]{0} reshape(f32[] p0)\n"
930         "in p0 = f32[] parameter(0)";
931     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
932               kMultipleUserExplanation);
933     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
934               kMultipleUserExplanation);
935   }
936 
937   auto add = SetName("add", HloInstruction::CreateBinary(
938                                 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
939                                 param.get(), param.get()));
940   EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
941   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
942   EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
943             "HloInstruction is used 2 times by its user, but is expected to be "
944             "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
945             "in p0 = f32[] parameter(0)");
946 }
947 
TEST_F(PatternMatcherTest,Comparison)948 TEST_F(PatternMatcherTest, Comparison) {
949   auto shape = ShapeUtil::MakeShape(F32, {1});
950   auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
951   auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
952   auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
953                                           ComparisonDirection::kEq);
954   auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
955                                           ComparisonDirection::kNe);
956   auto add =
957       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
958   auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
959                                           ComparisonDirection::kLe);
960 
961   EXPECT_TRUE(Match(eq.get(), m::Compare()));
962   EXPECT_TRUE(Match(eq.get(), m::Eq()));
963   EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
964   EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
965   EXPECT_TRUE(Match(ne.get(), m::Compare()));
966   EXPECT_TRUE(Match(ne.get(), m::Ne()));
967   EXPECT_TRUE(Match(
968       le.get(),
969       m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
970   EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
971                                     m::Add(m::Parameter(0), m::Parameter(1)))));
972 
973   EXPECT_FALSE(Match(eq.get(), m::Add()));
974   EXPECT_FALSE(Match(eq.get(), m::Ne()));
975   EXPECT_FALSE(
976       Match(le.get(),
977             m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
978   EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
979   EXPECT_DESC_AND_EXPLANATION(
980       eq, m::Ne().WithOneUser(),
981       "an HloInstruction:\n"
982       " * with opcode compare AND\n"
983       " * which has comparison direction NE AND\n"
984       " * which has exactly one user (but possibly is used "
985       "multiple times by that instruction)",
986       "HloInstruction is not comparison NE\n"
987       "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
988       "direction=EQ");
989 }
990 
TEST_F(PatternMatcherTest,CustomCallMatchers)991 TEST_F(PatternMatcherTest, CustomCallMatchers) {
992   constexpr char kModuleStr[] = R"(
993     HloModule test_module
994 
995     ENTRY test {
996       p0 = f32[] parameter(0)
997       p1 = f32[] parameter(1)
998       ROOT out = f32[] custom-call(p0, p1), custom_call_target="test_target"
999     }
1000   )";
1001   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1002                           ParseAndReturnVerifiedModule(kModuleStr));
1003   auto* root = hlo_module->entry_computation()->root_instruction();
1004 
1005   EXPECT_TRUE(Match(root, m::CustomCall()));
1006   EXPECT_TRUE(Match(root, m::CustomCall("test_target")));
1007   EXPECT_TRUE(Match(
1008       root, m::CustomCall("test_target", m::Parameter(0), m::Parameter(1))));
1009 
1010   HloInstruction* instr;
1011   EXPECT_TRUE(Match(root, m::CustomCall(&instr)));
1012   EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target")));
1013   EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target", m::Parameter(0),
1014                                         m::Parameter(1))));
1015 
1016   const HloInstruction* const_instr;
1017   EXPECT_TRUE(Match(root, m::CustomCall(&const_instr)));
1018   EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target")));
1019   EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target",
1020                                         m::Parameter(0), m::Parameter(1))));
1021 
1022   EXPECT_FALSE(Match(root, m::CustomCall("other_target")));
1023   EXPECT_FALSE(Match(
1024       root, m::CustomCall("test_target", m::Parameter(1), m::Parameter(0))));
1025 }
1026 
1027 }  // namespace
1028 }  // namespace xla
1029