1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace xla {
26 namespace {
27
28 namespace m = match;
29 using PatternMatcherTest = HloTestBase;
30
TEST_F(PatternMatcherTest,AddOp)31 TEST_F(PatternMatcherTest, AddOp) {
32 constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
33 ENTRY %two_plus_two_computation () -> f32[] {
34 %two = f32[] constant(2)
35 ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
36 }
37 )";
38 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
39 ParseAndReturnVerifiedModule(kModuleStr));
40
41 const HloInstruction* matched_inst;
42 HloInstruction* matched_operand;
43 Shape* matched_shape;
44 Layout* matched_layout;
45
46 ASSERT_TRUE(Match(
47 hlo_module->entry_computation()->root_instruction(),
48 match::Op(&matched_inst)
49 .WithName("two_plus_two")
50 .WithOpcode(HloOpcode::kAdd)
51 .WithShape(
52 match::Shape(&matched_shape)
53 .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
54 .WithOperand(
55 0,
56 match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
57 ASSERT_NE(matched_inst, nullptr);
58 EXPECT_EQ(matched_inst->name(), "two_plus_two");
59 EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
60
61 EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
62 match::Add(match::Constant(), match::Constant())));
63
64 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
65 match::Op().WithName("bad_name")));
66 matched_inst = nullptr;
67 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
68 match::Multiply(&matched_inst, match::Op(), match::Op())));
69 }
70
TEST_F(PatternMatcherTest,ScalarShape)71 TEST_F(PatternMatcherTest, ScalarShape) {
72 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
73 Shape* matched_shape;
74 EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
75 EXPECT_EQ(matched_shape, &scalar_shape);
76 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
77 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
78 EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
79 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
80 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
81 EXPECT_FALSE(Match(
82 &scalar_shape,
83 match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
84 }
85
TEST_F(PatternMatcherTest,DenseArrayShape)86 TEST_F(PatternMatcherTest, DenseArrayShape) {
87 auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
88 Shape* matched_shape;
89 EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
90 EXPECT_EQ(matched_shape, &array_shape);
91 EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
92 EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
93 EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
94 EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
95 EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
96 EXPECT_FALSE(
97 Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
98 Layout* matched_layout;
99 EXPECT_TRUE(Match(&array_shape,
100 match::Shape().WithLayout(
101 match::Layout(&matched_layout).WithDenseFormat())));
102 EXPECT_EQ(matched_layout, &array_shape.layout());
103 }
104
TEST_F(PatternMatcherTest,TupleShape)105 TEST_F(PatternMatcherTest, TupleShape) {
106 auto tuple_shape = ShapeUtil::MakeTupleShape({
107 ShapeUtil::MakeShape(F32, {1, 2, 3}),
108 ShapeUtil::MakeShape(S32, {4, 5}),
109 });
110 EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
111 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
112 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
113
114 Shape* subshape;
115 ASSERT_TRUE(Match(
116 &tuple_shape,
117 match::Shape().WithSubshape(
118 {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
119 ASSERT_NE(subshape, nullptr);
120 EXPECT_TRUE(
121 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
122 EXPECT_TRUE(Match(&tuple_shape,
123 match::Shape().WithSubshape(
124 {0}, match::Shape().EqualTo(
125 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
126 EXPECT_FALSE(Match(&tuple_shape,
127 match::Shape().WithSubshape(
128 {0}, match::Shape().EqualTo(
129 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
130
131 ASSERT_TRUE(Match(
132 &tuple_shape,
133 match::Shape().WithSubshape(
134 {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
135 ASSERT_NE(subshape, nullptr);
136 EXPECT_TRUE(
137 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
138 EXPECT_TRUE(Match(&tuple_shape,
139 match::Shape().WithSubshape(
140 {1}, match::Shape().EqualTo(
141 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
142 EXPECT_FALSE(Match(&tuple_shape,
143 match::Shape().WithSubshape(
144 {1}, match::Shape().EqualTo(
145 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
146
147 EXPECT_FALSE(
148 Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
149 EXPECT_FALSE(
150 Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
151 }
152
TEST_F(PatternMatcherTest,FusionKind)153 TEST_F(PatternMatcherTest, FusionKind) {
154 constexpr char kModuleStr[] = R"(
155 HloModule test_module
156
157 fused_computation {
158 ROOT fp0 = f32[] parameter(0)
159 }
160
161 ENTRY while.v11 {
162 p0 = f32[] parameter(0)
163 ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
164 })";
165 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
166 ParseAndReturnVerifiedModule(kModuleStr));
167
168 auto* root = hlo_module->entry_computation()->root_instruction();
169 EXPECT_TRUE(Match(
170 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
171 EXPECT_FALSE(Match(
172 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
173 EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
174 HloInstruction::FusionKind::kLoop)));
175 }
176
TEST_F(PatternMatcherTest,GetTupleElement)177 TEST_F(PatternMatcherTest, GetTupleElement) {
178 constexpr char kModuleStr[] = R"(
179 HloModule test_module
180
181 ENTRY while.v11 {
182 p0 = (f32[], f32[], f32[]) parameter(0)
183 ROOT gte = f32[] get-tuple-element(p0), index=1
184 })";
185 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
186 ParseAndReturnVerifiedModule(kModuleStr));
187
188 auto* root = hlo_module->entry_computation()->root_instruction();
189 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
190 EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
191 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
192 EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
193 EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
194 }
195
TEST_F(PatternMatcherTest,AnyOf)196 TEST_F(PatternMatcherTest, AnyOf) {
197 constexpr char kModuleStr[] = R"(
198 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
199 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
200 ParseAndReturnVerifiedModule(kModuleStr));
201 auto* root = hlo_module->entry_computation()->root_instruction();
202
203 EXPECT_TRUE(
204 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
205 match::ConstantScalar(1))));
206 EXPECT_TRUE(
207 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
208 match::ConstantScalar(0))));
209 EXPECT_FALSE(
210 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
211 match::ConstantScalar(2))));
212 }
213
TEST_F(PatternMatcherTest,ConstantScalar)214 TEST_F(PatternMatcherTest, ConstantScalar) {
215 using match::ConstantEffectiveScalar;
216 using match::ConstantScalar;
217 using match::Op;
218 using match::Tuple;
219
220 constexpr char kModuleStr[] = R"(
221 HloModule test_module
222 ENTRY test {
223 a = s32[] constant(1)
224 b = s32[1,1] constant({{2}})
225 c = s32[1,2] constant({{2,2}})
226 d = f32[] constant(1)
227 e = f32[] constant(1.25)
228 ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
229 })";
230 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
231 ParseAndReturnVerifiedModule(kModuleStr));
232 auto* root = hlo_module->entry_computation()->root_instruction();
233
234 const HloInstruction* a = root->operand(0);
235 const HloInstruction* b = root->operand(1);
236 const HloInstruction* c = root->operand(2);
237 const HloInstruction* d = root->operand(3);
238 const HloInstruction* e = root->operand(4);
239 EXPECT_TRUE(Match(a, ConstantScalar()));
240 EXPECT_TRUE(Match(a, ConstantScalar(1)));
241 EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
242 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
243 EXPECT_FALSE(Match(a, ConstantScalar(2)));
244 EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
245 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
246 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
247
248 EXPECT_FALSE(Match(b, ConstantScalar()));
249 EXPECT_FALSE(Match(b, ConstantScalar(2)));
250 EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
251 EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
252
253 EXPECT_FALSE(Match(c, ConstantScalar()));
254 EXPECT_FALSE(Match(c, ConstantScalar(2)));
255 EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
256 EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
257
258 EXPECT_TRUE(Match(d, ConstantScalar(1)));
259 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
260 EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
261 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
262
263 EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
264 EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
265 EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
266 EXPECT_FALSE(Match(e, ConstantScalar(1)));
267 EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
268
269 const HloInstruction* instr = nullptr;
270 EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
271 EXPECT_EQ(instr, a);
272
273 instr = nullptr;
274 EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
275 EXPECT_EQ(instr, a);
276
277 instr = nullptr;
278 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
279 EXPECT_EQ(instr, a);
280
281 instr = nullptr;
282 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
283 EXPECT_EQ(instr, a);
284 }
285
TEST_F(PatternMatcherTest,MultiplyAnyOrder)286 TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
287 using match::ConstantScalar;
288 using match::MultiplyAnyOrder;
289
290 constexpr char kModuleStr[] = R"(
291 HloModule test_module
292 ENTRY test {
293 lhs = f16[] constant(42)
294 rhs = f16[] constant(52)
295 ROOT multiply = f16[] multiply(lhs, rhs)
296 })";
297 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
298 ParseAndReturnVerifiedModule(kModuleStr));
299 auto* root = hlo_module->entry_computation()->root_instruction();
300 const HloInstruction* instr;
301
302 EXPECT_TRUE(Match(
303 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
304 EXPECT_TRUE(Match(
305 root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
306
307 // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
308 // e.g. IsNonConstant() on it.
309 EXPECT_TRUE(Match(
310 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
311 .IsNonConstant()));
312 EXPECT_TRUE(
313 Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
314 .IsNonConstant()));
315 }
316
TEST_F(PatternMatcherTest,AnyOfShortCircuit)317 TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
318 using match::AnyOf;
319 using match::Multiply;
320 using match::Op;
321
322 constexpr char kModuleStr[] = R"(
323 HloModule test_module
324 ENTRY test {
325 lhs = f16[] constant(42)
326 rhs = f16[] constant(52)
327 ROOT multiply = f16[] multiply(lhs, rhs)
328 })";
329 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
330 ParseAndReturnVerifiedModule(kModuleStr));
331 auto* root = hlo_module->entry_computation()->root_instruction();
332
333 {
334 const HloInstruction* mul = nullptr;
335 const HloInstruction* any = nullptr;
336
337 ASSERT_TRUE(Match(
338 root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
339 EXPECT_NE(nullptr, mul);
340 EXPECT_EQ(nullptr, any);
341 }
342 {
343 const HloInstruction* mul = nullptr;
344 const HloInstruction* any = nullptr;
345
346 ASSERT_TRUE(Match(
347 root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
348 EXPECT_NE(nullptr, any);
349 EXPECT_EQ(nullptr, mul);
350 }
351 }
352
TEST_F(PatternMatcherTest,AllOf)353 TEST_F(PatternMatcherTest, AllOf) {
354 using match::AllOf;
355 using match::Broadcast;
356 using match::Constant;
357 using match::Op;
358
359 constexpr char kModuleStr[] = R"(
360 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
361 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
362 ParseAndReturnVerifiedModule(kModuleStr));
363 auto* root = hlo_module->entry_computation()->root_instruction();
364
365 auto f16_scalar = ShapeUtil::MakeShape(F16, {});
366 auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
367 auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
368 auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
369 ASSERT_TRUE(Match(root, scalar_pattern));
370 ASSERT_TRUE(Match(root, f16_pattern));
371 ASSERT_TRUE(Match(root, f16_compatible_pattern));
372 EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
373 f16_compatible_pattern)));
374 EXPECT_TRUE(
375 Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
376 scalar_pattern)));
377 EXPECT_FALSE(
378 Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
379 EXPECT_FALSE(Match(
380 root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
381 EXPECT_FALSE(
382 Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
383 }
384
TEST_F(PatternMatcherTest,AllOfNoCaptureIfNotMatch)385 TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
386 using match::AllOf;
387 using match::Broadcast;
388 using match::Constant;
389 using match::Op;
390
391 constexpr char kModuleStr[] = R"(
392 HloModule test_module
393 ENTRY test {
394 ROOT v = f16[] constant(42)
395 })";
396 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
397 ParseAndReturnVerifiedModule(kModuleStr));
398 auto* root = hlo_module->entry_computation()->root_instruction();
399
400 const HloInstruction* constant = nullptr;
401 ASSERT_FALSE(
402 Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
403 EXPECT_EQ(nullptr, constant);
404 ASSERT_TRUE(Match(root, Constant(&constant)));
405 EXPECT_NE(nullptr, constant);
406 }
407
TEST_F(PatternMatcherTest,TestNoCapture)408 TEST_F(PatternMatcherTest, TestNoCapture) {
409 using match::Constant;
410
411 constexpr char kModuleStr[] = R"(
412 HloModule test_module
413 ENTRY test {
414 ROOT v = f16[] constant(42)
415 })";
416 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
417 ParseAndReturnVerifiedModule(kModuleStr));
418 auto* root = hlo_module->entry_computation()->root_instruction();
419
420 const HloInstruction* constant = nullptr;
421 ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
422 EXPECT_EQ(nullptr, constant);
423 }
424
TEST_F(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)425 TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
426 using match::Add;
427 using match::AddAnyOrder;
428 using match::AnyOf;
429 using match::Op;
430
431 constexpr char kModuleStr[] = R"(
432 HloModule test_module
433 ENTRY test {
434 u = f16[] parameter(0)
435 v = f16[] parameter(1)
436 ROOT add = f16[] add(u, v)
437 })";
438 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
439 ParseAndReturnVerifiedModule(kModuleStr));
440 auto* root = hlo_module->entry_computation()->root_instruction();
441
442 const HloInstruction* addend0 = nullptr;
443 const HloInstruction* addend1 = nullptr;
444 const HloInstruction* addend2 = nullptr;
445 auto add2_pattern = Add(Op(&addend0), Op(&addend1));
446 auto add3_pattern = AnyOf<HloInstruction>(
447 AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
448
449 ASSERT_TRUE(Match(root, add3_pattern));
450 EXPECT_NE(nullptr, addend0);
451 EXPECT_NE(nullptr, addend1);
452 EXPECT_EQ(nullptr, addend2);
453 }
454
TEST_F(PatternMatcherTest,TestConcat)455 TEST_F(PatternMatcherTest, TestConcat) {
456 using match::Concatenate;
457 using match::ConstantScalar;
458 using match::Op;
459 using match::Reshape;
460
461 constexpr char kModuleStr[] = R"(
462 HloModule test_module
463 ENTRY test {
464 c1 = u32[] constant(1)
465 c2 = u32[] constant(2)
466 c3 = u32[] constant(3)
467 c4 = u32[] constant(4)
468 r1 = u32[1] reshape(c1)
469 r2 = u32[1] reshape(c2)
470 r3 = u32[1] reshape(c3)
471 r4 = u32[1] reshape(c4)
472 ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
473 })";
474 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
475 ParseAndReturnVerifiedModule(kModuleStr));
476 auto* root = hlo_module->entry_computation()->root_instruction();
477 ASSERT_TRUE(Match(
478 root,
479 Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
480 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
481 ASSERT_FALSE(Match(
482 root,
483 Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
484 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
485 ASSERT_FALSE(Match(
486 root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
487 Reshape(ConstantScalar(3)))));
488 ASSERT_FALSE(Match(
489 root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
490 Reshape(ConstantScalar(4)))));
491 }
492
493 template <typename Pattern>
Description(const Pattern & pattern)494 string Description(const Pattern& pattern) {
495 std::stringstream ss;
496 pattern.DescribeTo(&ss);
497 return ss.str();
498 }
499
500 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)501 string Explanation(Elem* elem, const Pattern& pattern) {
502 std::stringstream ss;
503 MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
504 Match(elem, pattern, options);
505 return ss.str();
506 }
507 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)508 string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
509 return Explanation(elem.get(), pattern);
510 }
511 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)512 string Explanation(const Elem& elem, const Pattern& pattern) {
513 return Explanation(&elem, pattern);
514 }
515
516 // Helper macro for checking a pattern's description and the explanation printed
517 // when attempting to match (and presumably failing) on a given object.
518 //
519 // We use a macro rather than a function because we want good line numbers in
520 // errors. We use this rather than writing a helper that returns a pair of
521 // (description, explanation) and doing something like
522 //
523 // EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
524 //
525 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
526 // while EXPECT_THAT does not. This unified diff makes the errors much easier
527 // to read.
528 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
529 expected_explanation) \
530 do { \
531 EXPECT_EQ(Description(pattern), (expected_desc)); \
532 EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
533 } while (0)
534
TEST_F(PatternMatcherTest,LayoutDescribeToAndExplain)535 TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
536 auto layout = LayoutUtil::MakeLayout({1, 2});
537 auto layout2 = LayoutUtil::MakeLayout({2, 2});
538
539 EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
540 "a layout", "Layout is null");
541 EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
542 "a layout equal to {1,2}",
543 "Layout {2,2} is not equal to expected {1,2}");
544 }
545
TEST_F(PatternMatcherTest,CustomCallTargetMatcherDescribeAndExplain)546 TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
547 constexpr char kModuleStr[] = R"(
548 HloModule test_module
549
550 ENTRY test {
551 ROOT out = f32[] custom-call(), custom_call_target="test_target"
552 }
553 )";
554
555 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
556 ParseAndReturnVerifiedModule(kModuleStr));
557
558 auto* root = hlo_module->entry_computation()->root_instruction();
559 EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
560 EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
561
562 EXPECT_DESC_AND_EXPLANATION(
563 root, match::Op().WithCustomCallTarget("other_target"),
564 "an HloInstruction custom call with target 'other_target'",
565 "HloInstruction is not a custom call with a target 'other_target'\nin "
566 "out = f32[] custom-call(), custom_call_target=\"test_target\"");
567 }
568
TEST_F(PatternMatcherTest,ShapeDescribeToAndExplain)569 TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
570 auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
571 auto layout = shape.layout();
572
573 EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
574 "a shape", "Shape is null");
575 EXPECT_DESC_AND_EXPLANATION(
576 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
577 m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
578 "Shape not equal to f32[1,2]{0,1}\n"
579 "in f32[1,2]{1,0}");
580 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
581 m::Shape().CompatibleTo(&shape),
582 "a shape compatible with f32[1,2]",
583 "Shape not compatible with f32[1,2]\n"
584 "in f32[2,2]{1,0}");
585 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
586 "a shape with element type F16",
587 "Shape does not have element type F16\n"
588 "in f32[1,2]{0,1}");
589 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
590 "a shape that represents a scalar",
591 "Shape is not a scalar\n"
592 "in f32[1,2]{0,1}");
593 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
594 "a shape that represents an array",
595 "Shape is not an array\n"
596 "in ()");
597 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
598 "a shape that represents a tuple",
599 "Shape is not a tuple\n"
600 "in f32[1,2]{0,1}");
601 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
602 "a shape that is an effective scalar",
603 "Shape is not an effective scalar\n"
604 "in f32[1,2]{0,1}");
605 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
606 "a shape that has 42 dimensions",
607 "Shape does not have rank 42\n"
608 "in f32[1,2]{0,1}");
609 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
610 "a shape that is a scalar",
611 "Shape is not a scalar\n"
612 "in f32[1,2]{0,1}");
613 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
614 "a shape:\n"
615 " * that has 1 dimension AND\n"
616 " * that represents an array",
617 "Shape does not have rank 1\n"
618 "in f32[1,2]{0,1}");
619 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
620 m::Shape().IsArray().WithRank(1),
621 "a shape:\n"
622 " * that represents an array AND\n"
623 " * that has 1 dimension",
624 "Shape is not an array\n"
625 "in ()");
626 EXPECT_DESC_AND_EXPLANATION(
627 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
628 m::Shape().WithLayoutEqualTo(&layout),
629 "a shape with\n a layout equal to {0,1}",
630 "Layout {1,0} is not equal to expected {0,1}\n"
631 "in f32[1,2]{1,0}");
632 EXPECT_DESC_AND_EXPLANATION(shape,
633 m::Shape().WithSubshapeEqualTo({10}, &shape),
634 "a shape with subshape at index {10} which is\n"
635 " a shape equal to f32[1,2]{0,1}",
636 "No subshape at {10}\n"
637 "in f32[1,2]{0,1}");
638 EXPECT_DESC_AND_EXPLANATION(
639 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
640 m::Shape().WithSubshapeEqualTo({0}, &shape),
641 "a shape with subshape at index {0} which is\n"
642 " a shape equal to f32[1,2]{0,1}",
643 "Shape not equal to f32[1,2]{0,1}\n"
644 "in f32[2,2]{1,0}\n"
645 "in subshape at {0}\n"
646 "in (f32[2,2])");
647 EXPECT_DESC_AND_EXPLANATION(shape,
648 m::Shape().WithSubshapeCompatibleTo({10}, &shape),
649 "a shape with subshape at index {10} which is\n"
650 " a shape compatible with f32[1,2]",
651 "No subshape at {10}\n"
652 "in f32[1,2]{0,1}");
653 EXPECT_DESC_AND_EXPLANATION(
654 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
655 m::Shape().WithSubshapeCompatibleTo({0}, &shape),
656 "a shape with subshape at index {0} which is\n"
657 " a shape compatible with f32[1,2]",
658 "Shape not compatible with f32[1,2]\n"
659 "in f32[2,2]{1,0}\n"
660 "in subshape at {0}\n"
661 "in (f32[2,2])");
662 EXPECT_DESC_AND_EXPLANATION(
663 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
664 m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
665 "a shape with subshape at index {0,0} which is\n"
666 " a shape that represents a scalar",
667 "Shape is not a scalar\n"
668 "in f32[1,2]{0,1}\n"
669 "in subshape at {0,0}\n"
670 "in ((f32[1,2]))");
671 }
672
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)673 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
674 std::unique_ptr<HloInstruction> instr) {
675 instr->SetAndSanitizeName(string(name));
676 return instr;
677 }
678
TEST_F(PatternMatcherTest,HloInstructionDescribeToAndExplain)679 TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
680 std::unique_ptr<HloInstruction> iota =
681 SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
682 /*iota_dimension=*/0));
683 std::unique_ptr<HloInstruction> constant =
684 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
685
686 EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
687 m::Op(), "an HloInstruction",
688 "HloInstruction* is null");
689 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
690 "an HloInstruction named \"foo\"",
691 "HloInstruction not named \"foo\"\n"
692 "in i = s32[42]{0} iota(), iota_dimension=0");
693 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
694 "an HloInstruction with opcode add",
695 "HloInstruction doesn't have opcode add\n"
696 "in i = s32[42]{0} iota(), iota_dimension=0");
697 EXPECT_DESC_AND_EXPLANATION(
698 constant, m::Op().IsNonConstant(),
699 "an HloInstruction with any opcode other than constant",
700 "HloInstruction has opcode constant, expected anything else\n"
701 "in c = s32[] constant(0)");
702 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
703 "an HloInstruction with 42 operands",
704 "HloInstruction doesn't have 42 operands\n"
705 "in i = s32[42]{0} iota(), iota_dimension=0");
706 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
707 "an HloInstruction outputting\n"
708 " a shape that represents a tuple",
709 "Shape is not a tuple\n"
710 "in s32[42]{0}\n"
711 "in output shape\n"
712 "in i = s32[42]{0} iota(), iota_dimension=0");
713 EXPECT_DESC_AND_EXPLANATION(
714 iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
715 "an HloInstruction with operand 2 which is:\n"
716 " an HloInstruction with opcode add",
717 "desired operand index 2 is out of bounds\n"
718 "in i = s32[42]{0} iota(), iota_dimension=0");
719
720 EXPECT_DESC_AND_EXPLANATION(
721 SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
722 HloOpcode::kAdd, constant.get(),
723 constant.get())),
724 m::Op().WithOperand(1, m::Op().IsNonConstant()),
725 "an HloInstruction with operand 1 which is:\n"
726 " an HloInstruction with any opcode other than constant",
727 "HloInstruction has opcode constant, expected anything else\n"
728 "in c = s32[] constant(0)\n"
729 "in operand 1\n"
730 "in a = s32[] add(s32[] c, s32[] c)");
731 EXPECT_DESC_AND_EXPLANATION(
732 iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
733 "an HloInstruction with fusion kind kLoop",
734 "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
735 "in i = s32[42]{0} iota(), iota_dimension=0");
736 EXPECT_DESC_AND_EXPLANATION(
737 iota, m::Op().WithTupleIndex(42),
738 "an HloInstruction which is a GTE with index 42",
739 "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
740 "in i = s32[42]{0} iota(), iota_dimension=0");
741 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
742 "an HloInstruction which is a constant scalar",
743 "HloInstruction is not a constant\n"
744 "in i = s32[42]{0} iota(), iota_dimension=0");
745 EXPECT_DESC_AND_EXPLANATION(
746 SetName("c", HloInstruction::CreateConstant(
747 LiteralUtil::CreateR1<int>({1, 2}))),
748 m::Op().IsConstantEffectiveScalar(),
749 "an HloInstruction which is a constant effective scalar",
750 "HloInstruction is not an effective scalar\n"
751 "in c = s32[2]{0} constant({1, 2})");
752 EXPECT_DESC_AND_EXPLANATION(
753 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
754 m::Op().IsConstantScalar(42),
755 "an HloInstruction which is a constant scalar with value 42",
756 "HloInstruction's constant value 10 did not match expected value 42\n"
757 "in c = s32[] constant(10)");
758 EXPECT_DESC_AND_EXPLANATION(
759 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
760 m::Op().IsConstantEffectiveScalar(1.25),
761 "an HloInstruction which is a constant effective scalar with value 1.25",
762 "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
763 "in c = f64[] constant(2.25)");
764 EXPECT_DESC_AND_EXPLANATION(
765 constant, m::Op().Is(iota.get()),
766 absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
767 " (i = s32[42]{0} iota(), iota_dimension=0)"),
768 absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
769 absl::Hex(iota.get()),
770 " (i = s32[42]{0} iota(), iota_dimension=0)\n"
771 "in c = s32[] constant(0)"));
772 }
773
TEST_F(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)774 TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
775 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
776 EXPECT_DESC_AND_EXPLANATION(
777 SetName("a", HloInstruction::CreateBinary(
778 scalar_s32, HloOpcode::kAdd,
779 SetName("b", HloInstruction::CreateConstant(
780 LiteralUtil::CreateR0(0)))
781 .get(),
782 SetName("c", HloInstruction::CreateConstant(
783 LiteralUtil::CreateR0(0)))
784 .get())),
785 m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
786 "an HloInstruction:\n"
787 " * with opcode add AND\n"
788 " * with two operands in either order:\n"
789 " - an HloInstruction named \"b\"\n"
790 " - an HloInstruction named \"bar\"",
791 "HloInstruction's operands (ignoring order) did not match second "
792 "matcher. Specifically,\n"
793 " - an HloInstruction named \"bar\"\n"
794 "does not match LHS:\n"
795 " - HloInstruction not named \"bar\"\n"
796 " in b = s32[] constant(0)\n"
797 "does not match RHS:\n"
798 " - HloInstruction not named \"bar\"\n"
799 " in c = s32[] constant(0)\n"
800 "in a = s32[] add(s32[] b, s32[] c)");
801
802 EXPECT_DESC_AND_EXPLANATION(
803 SetName("a",
804 HloInstruction::CreateBinary(
805 scalar_s32, HloOpcode::kAdd,
806 HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
807 SetName("c", HloInstruction::CreateConstant(
808 LiteralUtil::CreateR0(0)))
809 .get())),
810 m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
811 "an HloInstruction:\n"
812 " * with opcode add AND\n"
813 " * with two operands in either order:\n"
814 " - an HloInstruction which is a constant scalar\n"
815 " - an HloInstruction with opcode constant",
816 "HloInstruction's LHS operand did not match either of the two matchers. "
817 "Specifically,\n"
818 " - an HloInstruction which is a constant scalar\n"
819 "does not match LHS:\n"
820 " - HloInstruction is not a constant\n"
821 " in p = s32[] parameter(0)\n"
822 "and\n"
823 " - an HloInstruction with opcode constant\n"
824 "does not match LHS:\n"
825 " - HloInstruction doesn't have opcode constant\n"
826 " in p = s32[] parameter(0)\n"
827 "in a = s32[] add(s32[] p, s32[] c)");
828 }
829
TEST_F(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)830 TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
831 EXPECT_DESC_AND_EXPLANATION(
832 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
833 m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
834 m::Op().WithName("bar")),
835 "any of:\n"
836 " - an HloInstruction named \"foo\" OR\n"
837 " - an HloInstruction named \"bar\"",
838 "None of the following matchers succeeded:\n"
839 "Matcher #1\n"
840 " - an HloInstruction named \"foo\"\n"
841 "failed with\n"
842 " - HloInstruction not named \"foo\"\n"
843 " in c = s32[] constant(0)\n"
844 "Matcher #2\n"
845 " - an HloInstruction named \"bar\"\n"
846 "failed with\n"
847 " - HloInstruction not named \"bar\"\n"
848 " in c = s32[] constant(0)");
849 }
850
TEST_F(PatternMatcherTest,Parameter)851 TEST_F(PatternMatcherTest, Parameter) {
852 auto param =
853 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
854 auto non_param =
855 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
856 EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
857 EXPECT_TRUE(Match(param.get(), m::Parameter()));
858 EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
859 EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
860 EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
861
862 EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
863 "an HloInstruction:\n"
864 " * with opcode parameter AND\n"
865 " * which is parameter 1",
866 "HloInstruction doesn't have opcode parameter\n"
867 "in c = s32[] constant(0)");
868 EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
869 0, ShapeUtil::MakeShape(F32, {}), "p0"),
870 m::Parameter(1)),
871 "HloInstruction is not parameter 1\n"
872 "in p0 = f32[] parameter(0)");
873 }
874
TEST_F(PatternMatcherTest,OneUseAndOneUser)875 TEST_F(PatternMatcherTest, OneUseAndOneUser) {
876 auto param =
877 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
878
879 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
880 EXPECT_DESC_AND_EXPLANATION(
881 param, m::Op().WithOneUse(),
882 "an HloInstruction which has exactly one use",
883 "HloInstruction has 0 users, but expected exactly one.\n"
884 "in p0 = f32[] parameter(0)");
885
886 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
887 EXPECT_DESC_AND_EXPLANATION(
888 param, m::Op().WithOneUser(),
889 "an HloInstruction which has exactly one user (but possibly is used "
890 "multiple times by that instruction)",
891 "HloInstruction has 0 users, but expected exactly one.\n"
892 "in p0 = f32[] parameter(0)");
893
894 {
895 auto reshape =
896 SetName("r", HloInstruction::CreateReshape(
897 ShapeUtil::MakeShape(F32, {1}), param.get()));
898 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
899 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
900
901 auto reshape1 =
902 SetName("r1", HloInstruction::CreateReshape(
903 ShapeUtil::MakeShape(F32, {1}), param.get()));
904 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
905 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
906
907 const char* kMultipleUserExplanation =
908 "HloInstruction has 2 users, but expected exactly one.\n"
909 "All users:\n"
910 " - r = f32[1]{0} reshape(f32[] p0)\n"
911 " - r1 = f32[1]{0} reshape(f32[] p0)\n"
912 "in p0 = f32[] parameter(0)";
913 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
914 kMultipleUserExplanation);
915 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
916 kMultipleUserExplanation);
917 }
918
919 auto add = SetName("add", HloInstruction::CreateBinary(
920 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
921 param.get(), param.get()));
922 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
923 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
924 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
925 "HloInstruction is used 2 times by its user, but is expected to be "
926 "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
927 "in p0 = f32[] parameter(0)");
928 }
929
TEST_F(PatternMatcherTest,Comparison)930 TEST_F(PatternMatcherTest, Comparison) {
931 auto shape = ShapeUtil::MakeShape(F32, {1});
932 auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
933 auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
934 auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
935 ComparisonDirection::kEq);
936 auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
937 ComparisonDirection::kNe);
938 auto add =
939 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
940 auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
941 ComparisonDirection::kLe);
942
943 EXPECT_TRUE(Match(eq.get(), m::Compare()));
944 EXPECT_TRUE(Match(eq.get(), m::Eq()));
945 EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
946 EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
947 EXPECT_TRUE(Match(ne.get(), m::Compare()));
948 EXPECT_TRUE(Match(ne.get(), m::Ne()));
949 EXPECT_TRUE(Match(
950 le.get(),
951 m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
952 EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
953 m::Add(m::Parameter(0), m::Parameter(1)))));
954
955 EXPECT_FALSE(Match(eq.get(), m::Add()));
956 EXPECT_FALSE(Match(eq.get(), m::Ne()));
957 EXPECT_FALSE(
958 Match(le.get(),
959 m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
960 EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
961 EXPECT_DESC_AND_EXPLANATION(
962 eq, m::Ne().WithOneUser(),
963 "an HloInstruction:\n"
964 " * with opcode compare AND\n"
965 " * which has comparison direction NE AND\n"
966 " * which has exactly one user (but possibly is used "
967 "multiple times by that instruction)",
968 "HloInstruction is not comparison NE\n"
969 "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
970 "direction=EQ");
971 }
972
973 } // namespace
974 } // namespace xla
975