• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/platform/test.h"
24 
25 namespace xla {
26 namespace {
27 
28 namespace m = match;
29 using PatternMatcherTest = HloTestBase;
30 
TEST_F(PatternMatcherTest,AddOp)31 TEST_F(PatternMatcherTest, AddOp) {
32   constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
33     ENTRY %two_plus_two_computation () -> f32[] {
34       %two = f32[] constant(2)
35       ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
36     }
37   )";
38   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
39                           ParseAndReturnVerifiedModule(kModuleStr));
40 
41   const HloInstruction* matched_inst;
42   HloInstruction* matched_operand;
43   Shape* matched_shape;
44   Layout* matched_layout;
45 
46   ASSERT_TRUE(Match(
47       hlo_module->entry_computation()->root_instruction(),
48       match::Op(&matched_inst)
49           .WithName("two_plus_two")
50           .WithOpcode(HloOpcode::kAdd)
51           .WithShape(
52               match::Shape(&matched_shape)
53                   .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
54           .WithOperand(
55               0,
56               match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
57   ASSERT_NE(matched_inst, nullptr);
58   EXPECT_EQ(matched_inst->name(), "two_plus_two");
59   EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
60 
61   EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
62                     match::Add(match::Constant(), match::Constant())));
63 
64   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
65                      match::Op().WithName("bad_name")));
66   matched_inst = nullptr;
67   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
68                      match::Multiply(&matched_inst, match::Op(), match::Op())));
69 }
70 
TEST_F(PatternMatcherTest,ScalarShape)71 TEST_F(PatternMatcherTest, ScalarShape) {
72   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
73   Shape* matched_shape;
74   EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
75   EXPECT_EQ(matched_shape, &scalar_shape);
76   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
77   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
78   EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
79   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
80   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
81   EXPECT_FALSE(Match(
82       &scalar_shape,
83       match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
84 }
85 
TEST_F(PatternMatcherTest,DenseArrayShape)86 TEST_F(PatternMatcherTest, DenseArrayShape) {
87   auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
88   Shape* matched_shape;
89   EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
90   EXPECT_EQ(matched_shape, &array_shape);
91   EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
92   EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
93   EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
94   EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
95   EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
96   EXPECT_FALSE(
97       Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
98   Layout* matched_layout;
99   EXPECT_TRUE(Match(&array_shape,
100                     match::Shape().WithLayout(
101                         match::Layout(&matched_layout).WithDenseFormat())));
102   EXPECT_EQ(matched_layout, &array_shape.layout());
103 }
104 
TEST_F(PatternMatcherTest,TupleShape)105 TEST_F(PatternMatcherTest, TupleShape) {
106   auto tuple_shape = ShapeUtil::MakeTupleShape({
107       ShapeUtil::MakeShape(F32, {1, 2, 3}),
108       ShapeUtil::MakeShape(S32, {4, 5}),
109   });
110   EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
111   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
112   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
113 
114   Shape* subshape;
115   ASSERT_TRUE(Match(
116       &tuple_shape,
117       match::Shape().WithSubshape(
118           {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
119   ASSERT_NE(subshape, nullptr);
120   EXPECT_TRUE(
121       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
122   EXPECT_TRUE(Match(&tuple_shape,
123                     match::Shape().WithSubshape(
124                         {0}, match::Shape().EqualTo(
125                                  &ShapeUtil::GetSubshape(tuple_shape, {0})))));
126   EXPECT_FALSE(Match(&tuple_shape,
127                      match::Shape().WithSubshape(
128                          {0}, match::Shape().EqualTo(
129                                   &ShapeUtil::GetSubshape(tuple_shape, {1})))));
130 
131   ASSERT_TRUE(Match(
132       &tuple_shape,
133       match::Shape().WithSubshape(
134           {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
135   ASSERT_NE(subshape, nullptr);
136   EXPECT_TRUE(
137       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
138   EXPECT_TRUE(Match(&tuple_shape,
139                     match::Shape().WithSubshape(
140                         {1}, match::Shape().EqualTo(
141                                  &ShapeUtil::GetSubshape(tuple_shape, {1})))));
142   EXPECT_FALSE(Match(&tuple_shape,
143                      match::Shape().WithSubshape(
144                          {1}, match::Shape().EqualTo(
145                                   &ShapeUtil::GetSubshape(tuple_shape, {0})))));
146 
147   EXPECT_FALSE(
148       Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
149   EXPECT_FALSE(
150       Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
151 }
152 
TEST_F(PatternMatcherTest,FusionKind)153 TEST_F(PatternMatcherTest, FusionKind) {
154   constexpr char kModuleStr[] = R"(
155     HloModule test_module
156 
157     fused_computation {
158       ROOT fp0 = f32[] parameter(0)
159     }
160 
161     ENTRY while.v11 {
162       p0 = f32[] parameter(0)
163       ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
164     })";
165   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
166                           ParseAndReturnVerifiedModule(kModuleStr));
167 
168   auto* root = hlo_module->entry_computation()->root_instruction();
169   EXPECT_TRUE(Match(
170       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
171   EXPECT_FALSE(Match(
172       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
173   EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
174                                            HloInstruction::FusionKind::kLoop)));
175 }
176 
TEST_F(PatternMatcherTest,GetTupleElement)177 TEST_F(PatternMatcherTest, GetTupleElement) {
178   constexpr char kModuleStr[] = R"(
179     HloModule test_module
180 
181     ENTRY while.v11 {
182       p0 = (f32[], f32[], f32[]) parameter(0)
183       ROOT gte = f32[] get-tuple-element(p0), index=1
184     })";
185   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
186                           ParseAndReturnVerifiedModule(kModuleStr));
187 
188   auto* root = hlo_module->entry_computation()->root_instruction();
189   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
190   EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
191   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
192   EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
193   EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
194 }
195 
TEST_F(PatternMatcherTest,AnyOf)196 TEST_F(PatternMatcherTest, AnyOf) {
197   constexpr char kModuleStr[] = R"(
198     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
199   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
200                           ParseAndReturnVerifiedModule(kModuleStr));
201   auto* root = hlo_module->entry_computation()->root_instruction();
202 
203   EXPECT_TRUE(
204       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
205                                                match::ConstantScalar(1))));
206   EXPECT_TRUE(
207       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
208                                                match::ConstantScalar(0))));
209   EXPECT_FALSE(
210       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
211                                                match::ConstantScalar(2))));
212 }
213 
TEST_F(PatternMatcherTest,ConstantScalar)214 TEST_F(PatternMatcherTest, ConstantScalar) {
215   using match::ConstantEffectiveScalar;
216   using match::ConstantScalar;
217   using match::Op;
218   using match::Tuple;
219 
220   constexpr char kModuleStr[] = R"(
221     HloModule test_module
222     ENTRY test {
223       a = s32[] constant(1)
224       b = s32[1,1] constant({{2}})
225       c = s32[1,2] constant({{2,2}})
226       d = f32[] constant(1)
227       e = f32[] constant(1.25)
228       ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
229     })";
230   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
231                           ParseAndReturnVerifiedModule(kModuleStr));
232   auto* root = hlo_module->entry_computation()->root_instruction();
233 
234   const HloInstruction* a = root->operand(0);
235   const HloInstruction* b = root->operand(1);
236   const HloInstruction* c = root->operand(2);
237   const HloInstruction* d = root->operand(3);
238   const HloInstruction* e = root->operand(4);
239   EXPECT_TRUE(Match(a, ConstantScalar()));
240   EXPECT_TRUE(Match(a, ConstantScalar(1)));
241   EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
242   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
243   EXPECT_FALSE(Match(a, ConstantScalar(2)));
244   EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
245   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
246   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
247 
248   EXPECT_FALSE(Match(b, ConstantScalar()));
249   EXPECT_FALSE(Match(b, ConstantScalar(2)));
250   EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
251   EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
252 
253   EXPECT_FALSE(Match(c, ConstantScalar()));
254   EXPECT_FALSE(Match(c, ConstantScalar(2)));
255   EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
256   EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
257 
258   EXPECT_TRUE(Match(d, ConstantScalar(1)));
259   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
260   EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
261   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
262 
263   EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
264   EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
265   EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
266   EXPECT_FALSE(Match(e, ConstantScalar(1)));
267   EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
268 
269   const HloInstruction* instr = nullptr;
270   EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
271   EXPECT_EQ(instr, a);
272 
273   instr = nullptr;
274   EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
275   EXPECT_EQ(instr, a);
276 
277   instr = nullptr;
278   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
279   EXPECT_EQ(instr, a);
280 
281   instr = nullptr;
282   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
283   EXPECT_EQ(instr, a);
284 }
285 
TEST_F(PatternMatcherTest,MultiplyAnyOrder)286 TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
287   using match::ConstantScalar;
288   using match::MultiplyAnyOrder;
289 
290   constexpr char kModuleStr[] = R"(
291     HloModule test_module
292     ENTRY test {
293       lhs = f16[] constant(42)
294       rhs = f16[] constant(52)
295       ROOT multiply = f16[] multiply(lhs, rhs)
296     })";
297   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
298                           ParseAndReturnVerifiedModule(kModuleStr));
299   auto* root = hlo_module->entry_computation()->root_instruction();
300   const HloInstruction* instr;
301 
302   EXPECT_TRUE(Match(
303       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
304   EXPECT_TRUE(Match(
305       root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
306 
307   // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
308   // e.g. IsNonConstant() on it.
309   EXPECT_TRUE(Match(
310       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
311                 .IsNonConstant()));
312   EXPECT_TRUE(
313       Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
314                       .IsNonConstant()));
315 }
316 
TEST_F(PatternMatcherTest,AnyOfShortCircuit)317 TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
318   using match::AnyOf;
319   using match::Multiply;
320   using match::Op;
321 
322   constexpr char kModuleStr[] = R"(
323     HloModule test_module
324     ENTRY test {
325       lhs = f16[] constant(42)
326       rhs = f16[] constant(52)
327       ROOT multiply = f16[] multiply(lhs, rhs)
328     })";
329   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
330                           ParseAndReturnVerifiedModule(kModuleStr));
331   auto* root = hlo_module->entry_computation()->root_instruction();
332 
333   {
334     const HloInstruction* mul = nullptr;
335     const HloInstruction* any = nullptr;
336 
337     ASSERT_TRUE(Match(
338         root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
339     EXPECT_NE(nullptr, mul);
340     EXPECT_EQ(nullptr, any);
341   }
342   {
343     const HloInstruction* mul = nullptr;
344     const HloInstruction* any = nullptr;
345 
346     ASSERT_TRUE(Match(
347         root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
348     EXPECT_NE(nullptr, any);
349     EXPECT_EQ(nullptr, mul);
350   }
351 }
352 
TEST_F(PatternMatcherTest,AllOf)353 TEST_F(PatternMatcherTest, AllOf) {
354   using match::AllOf;
355   using match::Broadcast;
356   using match::Constant;
357   using match::Op;
358 
359   constexpr char kModuleStr[] = R"(
360     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
361   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
362                           ParseAndReturnVerifiedModule(kModuleStr));
363   auto* root = hlo_module->entry_computation()->root_instruction();
364 
365   auto f16_scalar = ShapeUtil::MakeShape(F16, {});
366   auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
367   auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
368   auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
369   ASSERT_TRUE(Match(root, scalar_pattern));
370   ASSERT_TRUE(Match(root, f16_pattern));
371   ASSERT_TRUE(Match(root, f16_compatible_pattern));
372   EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
373                                                 f16_compatible_pattern)));
374   EXPECT_TRUE(
375       Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
376                                         scalar_pattern)));
377   EXPECT_FALSE(
378       Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
379   EXPECT_FALSE(Match(
380       root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
381   EXPECT_FALSE(
382       Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
383 }
384 
TEST_F(PatternMatcherTest,AllOfNoCaptureIfNotMatch)385 TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
386   using match::AllOf;
387   using match::Broadcast;
388   using match::Constant;
389   using match::Op;
390 
391   constexpr char kModuleStr[] = R"(
392     HloModule test_module
393     ENTRY test {
394       ROOT v = f16[] constant(42)
395     })";
396   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
397                           ParseAndReturnVerifiedModule(kModuleStr));
398   auto* root = hlo_module->entry_computation()->root_instruction();
399 
400   const HloInstruction* constant = nullptr;
401   ASSERT_FALSE(
402       Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
403   EXPECT_EQ(nullptr, constant);
404   ASSERT_TRUE(Match(root, Constant(&constant)));
405   EXPECT_NE(nullptr, constant);
406 }
407 
TEST_F(PatternMatcherTest,TestNoCapture)408 TEST_F(PatternMatcherTest, TestNoCapture) {
409   using match::Constant;
410 
411   constexpr char kModuleStr[] = R"(
412     HloModule test_module
413     ENTRY test {
414       ROOT v = f16[] constant(42)
415     })";
416   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
417                           ParseAndReturnVerifiedModule(kModuleStr));
418   auto* root = hlo_module->entry_computation()->root_instruction();
419 
420   const HloInstruction* constant = nullptr;
421   ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
422   EXPECT_EQ(nullptr, constant);
423 }
424 
TEST_F(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)425 TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
426   using match::Add;
427   using match::AddAnyOrder;
428   using match::AnyOf;
429   using match::Op;
430 
431   constexpr char kModuleStr[] = R"(
432     HloModule test_module
433     ENTRY test {
434       u = f16[] parameter(0)
435       v = f16[] parameter(1)
436       ROOT add = f16[] add(u, v)
437     })";
438   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
439                           ParseAndReturnVerifiedModule(kModuleStr));
440   auto* root = hlo_module->entry_computation()->root_instruction();
441 
442   const HloInstruction* addend0 = nullptr;
443   const HloInstruction* addend1 = nullptr;
444   const HloInstruction* addend2 = nullptr;
445   auto add2_pattern = Add(Op(&addend0), Op(&addend1));
446   auto add3_pattern = AnyOf<HloInstruction>(
447       AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
448 
449   ASSERT_TRUE(Match(root, add3_pattern));
450   EXPECT_NE(nullptr, addend0);
451   EXPECT_NE(nullptr, addend1);
452   EXPECT_EQ(nullptr, addend2);
453 }
454 
TEST_F(PatternMatcherTest,TestConcat)455 TEST_F(PatternMatcherTest, TestConcat) {
456   using match::Concatenate;
457   using match::ConstantScalar;
458   using match::Op;
459   using match::Reshape;
460 
461   constexpr char kModuleStr[] = R"(
462     HloModule test_module
463     ENTRY test {
464       c1 = u32[] constant(1)
465       c2 = u32[] constant(2)
466       c3 = u32[] constant(3)
467       c4 = u32[] constant(4)
468       r1 = u32[1] reshape(c1)
469       r2 = u32[1] reshape(c2)
470       r3 = u32[1] reshape(c3)
471       r4 = u32[1] reshape(c4)
472       ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
473     })";
474   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
475                           ParseAndReturnVerifiedModule(kModuleStr));
476   auto* root = hlo_module->entry_computation()->root_instruction();
477   ASSERT_TRUE(Match(
478       root,
479       Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
480                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
481   ASSERT_FALSE(Match(
482       root,
483       Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
484                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
485   ASSERT_FALSE(Match(
486       root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
487                         Reshape(ConstantScalar(3)))));
488   ASSERT_FALSE(Match(
489       root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
490                         Reshape(ConstantScalar(4)))));
491 }
492 
493 template <typename Pattern>
Description(const Pattern & pattern)494 string Description(const Pattern& pattern) {
495   std::stringstream ss;
496   pattern.DescribeTo(&ss);
497   return ss.str();
498 }
499 
500 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)501 string Explanation(Elem* elem, const Pattern& pattern) {
502   std::stringstream ss;
503   MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
504   Match(elem, pattern, options);
505   return ss.str();
506 }
507 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)508 string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
509   return Explanation(elem.get(), pattern);
510 }
511 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)512 string Explanation(const Elem& elem, const Pattern& pattern) {
513   return Explanation(&elem, pattern);
514 }
515 
516 // Helper macro for checking a pattern's description and the explanation printed
517 // when attempting to match (and presumably failing) on a given object.
518 //
519 // We use a macro rather than a function because we want good line numbers in
520 // errors.  We use this rather than writing a helper that returns a pair of
521 // (description, explanation) and doing something like
522 //
523 //   EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
524 //
525 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
526 // while EXPECT_THAT does not.  This unified diff makes the errors much easier
527 // to read.
528 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc,    \
529                                     expected_explanation)            \
530   do {                                                               \
531     EXPECT_EQ(Description(pattern), (expected_desc));                \
532     EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
533   } while (0)
534 
TEST_F(PatternMatcherTest,LayoutDescribeToAndExplain)535 TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
536   auto layout = LayoutUtil::MakeLayout({1, 2});
537   auto layout2 = LayoutUtil::MakeLayout({2, 2});
538 
539   EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
540                               "a layout", "Layout is null");
541   EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
542                               "a layout equal to {1,2}",
543                               "Layout {2,2} is not equal to expected {1,2}");
544 }
545 
TEST_F(PatternMatcherTest,CustomCallTargetMatcherDescribeAndExplain)546 TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
547   constexpr char kModuleStr[] = R"(
548     HloModule test_module
549 
550     ENTRY test {
551       ROOT out = f32[] custom-call(), custom_call_target="test_target"
552     }
553   )";
554 
555   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
556                           ParseAndReturnVerifiedModule(kModuleStr));
557 
558   auto* root = hlo_module->entry_computation()->root_instruction();
559   EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
560   EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
561 
562   EXPECT_DESC_AND_EXPLANATION(
563       root, match::Op().WithCustomCallTarget("other_target"),
564       "an HloInstruction custom call with target 'other_target'",
565       "HloInstruction is not a custom call with a target 'other_target'\nin "
566       "out = f32[] custom-call(), custom_call_target=\"test_target\"");
567 }
568 
TEST_F(PatternMatcherTest,ShapeDescribeToAndExplain)569 TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
570   auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
571   auto layout = shape.layout();
572 
573   EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
574                               "a shape", "Shape is null");
575   EXPECT_DESC_AND_EXPLANATION(
576       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
577       m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
578       "Shape not equal to f32[1,2]{0,1}\n"
579       "in f32[1,2]{1,0}");
580   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
581                               m::Shape().CompatibleTo(&shape),
582                               "a shape compatible with f32[1,2]",
583                               "Shape not compatible with f32[1,2]\n"
584                               "in f32[2,2]{1,0}");
585   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
586                               "a shape with element type F16",
587                               "Shape does not have element type F16\n"
588                               "in f32[1,2]{0,1}");
589   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
590                               "a shape that represents a scalar",
591                               "Shape is not a scalar\n"
592                               "in f32[1,2]{0,1}");
593   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
594                               "a shape that represents an array",
595                               "Shape is not an array\n"
596                               "in ()");
597   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
598                               "a shape that represents a tuple",
599                               "Shape is not a tuple\n"
600                               "in f32[1,2]{0,1}");
601   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
602                               "a shape that is an effective scalar",
603                               "Shape is not an effective scalar\n"
604                               "in f32[1,2]{0,1}");
605   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
606                               "a shape that has 42 dimensions",
607                               "Shape does not have rank 42\n"
608                               "in f32[1,2]{0,1}");
609   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
610                               "a shape that is a scalar",
611                               "Shape is not a scalar\n"
612                               "in f32[1,2]{0,1}");
613   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
614                               "a shape:\n"
615                               " * that has 1 dimension AND\n"
616                               " * that represents an array",
617                               "Shape does not have rank 1\n"
618                               "in f32[1,2]{0,1}");
619   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
620                               m::Shape().IsArray().WithRank(1),
621                               "a shape:\n"
622                               " * that represents an array AND\n"
623                               " * that has 1 dimension",
624                               "Shape is not an array\n"
625                               "in ()");
626   EXPECT_DESC_AND_EXPLANATION(
627       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
628       m::Shape().WithLayoutEqualTo(&layout),
629       "a shape with\n  a layout equal to {0,1}",
630       "Layout {1,0} is not equal to expected {0,1}\n"
631       "in f32[1,2]{1,0}");
632   EXPECT_DESC_AND_EXPLANATION(shape,
633                               m::Shape().WithSubshapeEqualTo({10}, &shape),
634                               "a shape with subshape at index {10} which is\n"
635                               "  a shape equal to f32[1,2]{0,1}",
636                               "No subshape at {10}\n"
637                               "in f32[1,2]{0,1}");
638   EXPECT_DESC_AND_EXPLANATION(
639       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
640       m::Shape().WithSubshapeEqualTo({0}, &shape),
641       "a shape with subshape at index {0} which is\n"
642       "  a shape equal to f32[1,2]{0,1}",
643       "Shape not equal to f32[1,2]{0,1}\n"
644       "in f32[2,2]{1,0}\n"
645       "in subshape at {0}\n"
646       "in (f32[2,2])");
647   EXPECT_DESC_AND_EXPLANATION(shape,
648                               m::Shape().WithSubshapeCompatibleTo({10}, &shape),
649                               "a shape with subshape at index {10} which is\n"
650                               "  a shape compatible with f32[1,2]",
651                               "No subshape at {10}\n"
652                               "in f32[1,2]{0,1}");
653   EXPECT_DESC_AND_EXPLANATION(
654       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
655       m::Shape().WithSubshapeCompatibleTo({0}, &shape),
656       "a shape with subshape at index {0} which is\n"
657       "  a shape compatible with f32[1,2]",
658       "Shape not compatible with f32[1,2]\n"
659       "in f32[2,2]{1,0}\n"
660       "in subshape at {0}\n"
661       "in (f32[2,2])");
662   EXPECT_DESC_AND_EXPLANATION(
663       ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
664       m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
665       "a shape with subshape at index {0,0} which is\n"
666       "  a shape that represents a scalar",
667       "Shape is not a scalar\n"
668       "in f32[1,2]{0,1}\n"
669       "in subshape at {0,0}\n"
670       "in ((f32[1,2]))");
671 }
672 
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)673 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
674                                         std::unique_ptr<HloInstruction> instr) {
675   instr->SetAndSanitizeName(string(name));
676   return instr;
677 }
678 
TEST_F(PatternMatcherTest,HloInstructionDescribeToAndExplain)679 TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
680   std::unique_ptr<HloInstruction> iota =
681       SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
682                                               /*iota_dimension=*/0));
683   std::unique_ptr<HloInstruction> constant =
684       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
685 
686   EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
687                               m::Op(), "an HloInstruction",
688                               "HloInstruction* is null");
689   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
690                               "an HloInstruction named \"foo\"",
691                               "HloInstruction not named \"foo\"\n"
692                               "in i = s32[42]{0} iota(), iota_dimension=0");
693   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
694                               "an HloInstruction with opcode add",
695                               "HloInstruction doesn't have opcode add\n"
696                               "in i = s32[42]{0} iota(), iota_dimension=0");
697   EXPECT_DESC_AND_EXPLANATION(
698       constant, m::Op().IsNonConstant(),
699       "an HloInstruction with any opcode other than constant",
700       "HloInstruction has opcode constant, expected anything else\n"
701       "in c = s32[] constant(0)");
702   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
703                               "an HloInstruction with 42 operands",
704                               "HloInstruction doesn't have 42 operands\n"
705                               "in i = s32[42]{0} iota(), iota_dimension=0");
706   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
707                               "an HloInstruction outputting\n"
708                               "  a shape that represents a tuple",
709                               "Shape is not a tuple\n"
710                               "in s32[42]{0}\n"
711                               "in output shape\n"
712                               "in i = s32[42]{0} iota(), iota_dimension=0");
713   EXPECT_DESC_AND_EXPLANATION(
714       iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
715       "an HloInstruction with operand 2 which is:\n"
716       "  an HloInstruction with opcode add",
717       "desired operand index 2 is out of bounds\n"
718       "in i = s32[42]{0} iota(), iota_dimension=0");
719 
720   EXPECT_DESC_AND_EXPLANATION(
721       SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
722                                                 HloOpcode::kAdd, constant.get(),
723                                                 constant.get())),
724       m::Op().WithOperand(1, m::Op().IsNonConstant()),
725       "an HloInstruction with operand 1 which is:\n"
726       "  an HloInstruction with any opcode other than constant",
727       "HloInstruction has opcode constant, expected anything else\n"
728       "in c = s32[] constant(0)\n"
729       "in operand 1\n"
730       "in a = s32[] add(s32[] c, s32[] c)");
731   EXPECT_DESC_AND_EXPLANATION(
732       iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
733       "an HloInstruction with fusion kind kLoop",
734       "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
735       "in i = s32[42]{0} iota(), iota_dimension=0");
736   EXPECT_DESC_AND_EXPLANATION(
737       iota, m::Op().WithTupleIndex(42),
738       "an HloInstruction which is a GTE with index 42",
739       "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
740       "in i = s32[42]{0} iota(), iota_dimension=0");
741   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
742                               "an HloInstruction which is a constant scalar",
743                               "HloInstruction is not a constant\n"
744                               "in i = s32[42]{0} iota(), iota_dimension=0");
745   EXPECT_DESC_AND_EXPLANATION(
746       SetName("c", HloInstruction::CreateConstant(
747                        LiteralUtil::CreateR1<int>({1, 2}))),
748       m::Op().IsConstantEffectiveScalar(),
749       "an HloInstruction which is a constant effective scalar",
750       "HloInstruction is not an effective scalar\n"
751       "in c = s32[2]{0} constant({1, 2})");
752   EXPECT_DESC_AND_EXPLANATION(
753       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
754       m::Op().IsConstantScalar(42),
755       "an HloInstruction which is a constant scalar with value 42",
756       "HloInstruction's constant value 10 did not match expected value 42\n"
757       "in c = s32[] constant(10)");
758   EXPECT_DESC_AND_EXPLANATION(
759       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
760       m::Op().IsConstantEffectiveScalar(1.25),
761       "an HloInstruction which is a constant effective scalar with value 1.25",
762       "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
763       "in c = f64[] constant(2.25)");
764   EXPECT_DESC_AND_EXPLANATION(
765       constant, m::Op().Is(iota.get()),
766       absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
767                    " (i = s32[42]{0} iota(), iota_dimension=0)"),
768       absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
769                    absl::Hex(iota.get()),
770                    " (i = s32[42]{0} iota(), iota_dimension=0)\n"
771                    "in c = s32[] constant(0)"));
772 }
773 
TEST_F(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)774 TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
775   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
776   EXPECT_DESC_AND_EXPLANATION(
777       SetName("a", HloInstruction::CreateBinary(
778                        scalar_s32, HloOpcode::kAdd,
779                        SetName("b", HloInstruction::CreateConstant(
780                                         LiteralUtil::CreateR0(0)))
781                            .get(),
782                        SetName("c", HloInstruction::CreateConstant(
783                                         LiteralUtil::CreateR0(0)))
784                            .get())),
785       m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
786       "an HloInstruction:\n"
787       " * with opcode add AND\n"
788       " * with two operands in either order:\n"
789       "    - an HloInstruction named \"b\"\n"
790       "    - an HloInstruction named \"bar\"",
791       "HloInstruction's operands (ignoring order) did not match second "
792       "matcher.  Specifically,\n"
793       " - an HloInstruction named \"bar\"\n"
794       "does not match LHS:\n"
795       " - HloInstruction not named \"bar\"\n"
796       "   in b = s32[] constant(0)\n"
797       "does not match RHS:\n"
798       " - HloInstruction not named \"bar\"\n"
799       "   in c = s32[] constant(0)\n"
800       "in a = s32[] add(s32[] b, s32[] c)");
801 
802   EXPECT_DESC_AND_EXPLANATION(
803       SetName("a",
804               HloInstruction::CreateBinary(
805                   scalar_s32, HloOpcode::kAdd,
806                   HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
807                   SetName("c", HloInstruction::CreateConstant(
808                                    LiteralUtil::CreateR0(0)))
809                       .get())),
810       m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
811       "an HloInstruction:\n"
812       " * with opcode add AND\n"
813       " * with two operands in either order:\n"
814       "    - an HloInstruction which is a constant scalar\n"
815       "    - an HloInstruction with opcode constant",
816       "HloInstruction's LHS operand did not match either of the two matchers.  "
817       "Specifically,\n"
818       " - an HloInstruction which is a constant scalar\n"
819       "does not match LHS:\n"
820       " - HloInstruction is not a constant\n"
821       "   in p = s32[] parameter(0)\n"
822       "and\n"
823       " - an HloInstruction with opcode constant\n"
824       "does not match LHS:\n"
825       " - HloInstruction doesn't have opcode constant\n"
826       "   in p = s32[] parameter(0)\n"
827       "in a = s32[] add(s32[] p, s32[] c)");
828 }
829 
TEST_F(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)830 TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
831   EXPECT_DESC_AND_EXPLANATION(
832       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
833       m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
834                                m::Op().WithName("bar")),
835       "any of:\n"
836       " - an HloInstruction named \"foo\" OR\n"
837       " - an HloInstruction named \"bar\"",
838       "None of the following matchers succeeded:\n"
839       "Matcher #1\n"
840       " - an HloInstruction named \"foo\"\n"
841       "failed with\n"
842       " - HloInstruction not named \"foo\"\n"
843       "   in c = s32[] constant(0)\n"
844       "Matcher #2\n"
845       " - an HloInstruction named \"bar\"\n"
846       "failed with\n"
847       " - HloInstruction not named \"bar\"\n"
848       "   in c = s32[] constant(0)");
849 }
850 
TEST_F(PatternMatcherTest,Parameter)851 TEST_F(PatternMatcherTest, Parameter) {
852   auto param =
853       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
854   auto non_param =
855       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
856   EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
857   EXPECT_TRUE(Match(param.get(), m::Parameter()));
858   EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
859   EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
860   EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
861 
862   EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
863                               "an HloInstruction:\n"
864                               " * with opcode parameter AND\n"
865                               " * which is parameter 1",
866                               "HloInstruction doesn't have opcode parameter\n"
867                               "in c = s32[] constant(0)");
868   EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
869                             0, ShapeUtil::MakeShape(F32, {}), "p0"),
870                         m::Parameter(1)),
871             "HloInstruction is not parameter 1\n"
872             "in p0 = f32[] parameter(0)");
873 }
874 
TEST_F(PatternMatcherTest,OneUseAndOneUser)875 TEST_F(PatternMatcherTest, OneUseAndOneUser) {
876   auto param =
877       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
878 
879   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
880   EXPECT_DESC_AND_EXPLANATION(
881       param, m::Op().WithOneUse(),
882       "an HloInstruction which has exactly one use",
883       "HloInstruction has 0 users, but expected exactly one.\n"
884       "in p0 = f32[] parameter(0)");
885 
886   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
887   EXPECT_DESC_AND_EXPLANATION(
888       param, m::Op().WithOneUser(),
889       "an HloInstruction which has exactly one user (but possibly is used "
890       "multiple times by that instruction)",
891       "HloInstruction has 0 users, but expected exactly one.\n"
892       "in p0 = f32[] parameter(0)");
893 
894   {
895     auto reshape =
896         SetName("r", HloInstruction::CreateReshape(
897                          ShapeUtil::MakeShape(F32, {1}), param.get()));
898     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
899     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
900 
901     auto reshape1 =
902         SetName("r1", HloInstruction::CreateReshape(
903                           ShapeUtil::MakeShape(F32, {1}), param.get()));
904     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
905     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
906 
907     const char* kMultipleUserExplanation =
908         "HloInstruction has 2 users, but expected exactly one.\n"
909         "All users:\n"
910         " - r = f32[1]{0} reshape(f32[] p0)\n"
911         " - r1 = f32[1]{0} reshape(f32[] p0)\n"
912         "in p0 = f32[] parameter(0)";
913     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
914               kMultipleUserExplanation);
915     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
916               kMultipleUserExplanation);
917   }
918 
919   auto add = SetName("add", HloInstruction::CreateBinary(
920                                 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
921                                 param.get(), param.get()));
922   EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
923   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
924   EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
925             "HloInstruction is used 2 times by its user, but is expected to be "
926             "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
927             "in p0 = f32[] parameter(0)");
928 }
929 
TEST_F(PatternMatcherTest,Comparison)930 TEST_F(PatternMatcherTest, Comparison) {
931   auto shape = ShapeUtil::MakeShape(F32, {1});
932   auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
933   auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
934   auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
935                                           ComparisonDirection::kEq);
936   auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
937                                           ComparisonDirection::kNe);
938   auto add =
939       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
940   auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
941                                           ComparisonDirection::kLe);
942 
943   EXPECT_TRUE(Match(eq.get(), m::Compare()));
944   EXPECT_TRUE(Match(eq.get(), m::Eq()));
945   EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
946   EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
947   EXPECT_TRUE(Match(ne.get(), m::Compare()));
948   EXPECT_TRUE(Match(ne.get(), m::Ne()));
949   EXPECT_TRUE(Match(
950       le.get(),
951       m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
952   EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
953                                     m::Add(m::Parameter(0), m::Parameter(1)))));
954 
955   EXPECT_FALSE(Match(eq.get(), m::Add()));
956   EXPECT_FALSE(Match(eq.get(), m::Ne()));
957   EXPECT_FALSE(
958       Match(le.get(),
959             m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
960   EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
961   EXPECT_DESC_AND_EXPLANATION(
962       eq, m::Ne().WithOneUser(),
963       "an HloInstruction:\n"
964       " * with opcode compare AND\n"
965       " * which has comparison direction NE AND\n"
966       " * which has exactly one user (but possibly is used "
967       "multiple times by that instruction)",
968       "HloInstruction is not comparison NE\n"
969       "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
970       "direction=EQ");
971 }
972 
973 }  // namespace
974 }  // namespace xla
975