1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace xla {
26 namespace {
27
28 namespace m = match;
29 using PatternMatcherTest = HloTestBase;
30
TEST_F(PatternMatcherTest,AddOp)31 TEST_F(PatternMatcherTest, AddOp) {
32 constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
33 ENTRY %two_plus_two_computation () -> f32[] {
34 %two = f32[] constant(2)
35 ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
36 }
37 )";
38 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
39 ParseAndReturnVerifiedModule(kModuleStr));
40
41 const HloInstruction* matched_inst;
42 HloInstruction* matched_operand;
43 Shape* matched_shape;
44 Layout* matched_layout;
45
46 ASSERT_TRUE(Match(
47 hlo_module->entry_computation()->root_instruction(),
48 match::Op(&matched_inst)
49 .WithName("two_plus_two")
50 .WithOpcode(HloOpcode::kAdd)
51 .WithShape(
52 match::Shape(&matched_shape)
53 .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
54 .WithOperand(
55 0,
56 match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
57 ASSERT_NE(matched_inst, nullptr);
58 EXPECT_EQ(matched_inst->name(), "two_plus_two");
59 EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
60
61 EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
62 match::Add(match::Constant(), match::Constant())));
63
64 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
65 match::Op().WithName("bad_name")));
66 matched_inst = nullptr;
67 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
68 match::Multiply(&matched_inst, match::Op(), match::Op())));
69 }
70
TEST_F(PatternMatcherTest,ScalarShape)71 TEST_F(PatternMatcherTest, ScalarShape) {
72 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
73 Shape* matched_shape;
74 EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
75 EXPECT_EQ(matched_shape, &scalar_shape);
76 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
77 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
78 EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
79 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
80 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
81 EXPECT_FALSE(Match(
82 &scalar_shape,
83 match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
84 }
85
TEST_F(PatternMatcherTest,DenseArrayShape)86 TEST_F(PatternMatcherTest, DenseArrayShape) {
87 auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
88 Shape* matched_shape;
89 EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
90 EXPECT_EQ(matched_shape, &array_shape);
91 EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
92 EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
93 EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
94 EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
95 EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
96 EXPECT_FALSE(
97 Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
98 Layout* matched_layout;
99 EXPECT_TRUE(Match(&array_shape,
100 match::Shape().WithLayout(
101 match::Layout(&matched_layout).WithDenseFormat())));
102 EXPECT_EQ(matched_layout, &array_shape.layout());
103 }
104
TEST_F(PatternMatcherTest,TupleShape)105 TEST_F(PatternMatcherTest, TupleShape) {
106 auto tuple_shape = ShapeUtil::MakeTupleShape({
107 ShapeUtil::MakeShape(F32, {1, 2, 3}),
108 ShapeUtil::MakeShape(S32, {4, 5}),
109 });
110 EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
111 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
112 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
113
114 Shape* subshape;
115 ASSERT_TRUE(Match(
116 &tuple_shape,
117 match::Shape().WithSubshape(
118 {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
119 ASSERT_NE(subshape, nullptr);
120 EXPECT_TRUE(
121 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
122 EXPECT_TRUE(Match(&tuple_shape,
123 match::Shape().WithSubshape(
124 {0}, match::Shape().EqualTo(
125 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
126 EXPECT_FALSE(Match(&tuple_shape,
127 match::Shape().WithSubshape(
128 {0}, match::Shape().EqualTo(
129 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
130
131 ASSERT_TRUE(Match(
132 &tuple_shape,
133 match::Shape().WithSubshape(
134 {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
135 ASSERT_NE(subshape, nullptr);
136 EXPECT_TRUE(
137 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
138 EXPECT_TRUE(Match(&tuple_shape,
139 match::Shape().WithSubshape(
140 {1}, match::Shape().EqualTo(
141 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
142 EXPECT_FALSE(Match(&tuple_shape,
143 match::Shape().WithSubshape(
144 {1}, match::Shape().EqualTo(
145 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
146
147 EXPECT_FALSE(
148 Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
149 EXPECT_FALSE(
150 Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
151 }
152
TEST_F(PatternMatcherTest,FusionKind)153 TEST_F(PatternMatcherTest, FusionKind) {
154 constexpr char kModuleStr[] = R"(
155 HloModule test_module
156
157 fused_computation {
158 ROOT fp0 = f32[] parameter(0)
159 }
160
161 ENTRY while.v11 {
162 p0 = f32[] parameter(0)
163 ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
164 })";
165 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
166 ParseAndReturnVerifiedModule(kModuleStr));
167
168 auto* root = hlo_module->entry_computation()->root_instruction();
169 EXPECT_TRUE(Match(
170 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
171 EXPECT_FALSE(Match(
172 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
173 EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
174 HloInstruction::FusionKind::kLoop)));
175 }
176
TEST_F(PatternMatcherTest,GetTupleElement)177 TEST_F(PatternMatcherTest, GetTupleElement) {
178 constexpr char kModuleStr[] = R"(
179 HloModule test_module
180
181 ENTRY while.v11 {
182 p0 = (f32[], f32[], f32[]) parameter(0)
183 ROOT gte = f32[] get-tuple-element(p0), index=1
184 })";
185 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
186 ParseAndReturnVerifiedModule(kModuleStr));
187
188 auto* root = hlo_module->entry_computation()->root_instruction();
189 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
190 EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
191 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
192 EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
193 EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
194 }
195
TEST_F(PatternMatcherTest,AnyOf)196 TEST_F(PatternMatcherTest, AnyOf) {
197 constexpr char kModuleStr[] = R"(
198 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
199 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
200 ParseAndReturnVerifiedModule(kModuleStr));
201 auto* root = hlo_module->entry_computation()->root_instruction();
202
203 EXPECT_TRUE(
204 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
205 match::ConstantScalar(1))));
206 EXPECT_TRUE(
207 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
208 match::ConstantScalar(0))));
209 EXPECT_FALSE(
210 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
211 match::ConstantScalar(2))));
212 }
213
TEST_F(PatternMatcherTest,ConstantScalar)214 TEST_F(PatternMatcherTest, ConstantScalar) {
215 using match::ConstantEffectiveScalar;
216 using match::ConstantScalar;
217 using match::Op;
218 using match::Tuple;
219
220 constexpr char kModuleStr[] = R"(
221 HloModule test_module
222 ENTRY test {
223 a = s32[] constant(1)
224 b = s32[1,1] constant({{2}})
225 c = s32[1,2] constant({{2,2}})
226 d = f32[] constant(1)
227 e = f32[] constant(1.25)
228 ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
229 })";
230 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
231 ParseAndReturnVerifiedModule(kModuleStr));
232 auto* root = hlo_module->entry_computation()->root_instruction();
233
234 const HloInstruction* a = root->operand(0);
235 const HloInstruction* b = root->operand(1);
236 const HloInstruction* c = root->operand(2);
237 const HloInstruction* d = root->operand(3);
238 const HloInstruction* e = root->operand(4);
239 EXPECT_TRUE(Match(a, ConstantScalar()));
240 EXPECT_TRUE(Match(a, ConstantScalar(1)));
241 EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
242 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
243 EXPECT_FALSE(Match(a, ConstantScalar(2)));
244 EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
245 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
246 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
247
248 EXPECT_FALSE(Match(b, ConstantScalar()));
249 EXPECT_FALSE(Match(b, ConstantScalar(2)));
250 EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
251 EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
252
253 EXPECT_FALSE(Match(c, ConstantScalar()));
254 EXPECT_FALSE(Match(c, ConstantScalar(2)));
255 EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
256 EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
257
258 EXPECT_TRUE(Match(d, ConstantScalar(1)));
259 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
260 EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
261 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
262
263 EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
264 EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
265 EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
266 EXPECT_FALSE(Match(e, ConstantScalar(1)));
267 EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
268
269 const HloInstruction* instr = nullptr;
270 EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
271 EXPECT_EQ(instr, a);
272
273 instr = nullptr;
274 EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
275 EXPECT_EQ(instr, a);
276
277 instr = nullptr;
278 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
279 EXPECT_EQ(instr, a);
280
281 instr = nullptr;
282 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
283 EXPECT_EQ(instr, a);
284 }
285
TEST_F(PatternMatcherTest,MultiplyAnyOrder)286 TEST_F(PatternMatcherTest, MultiplyAnyOrder) {
287 using match::ConstantScalar;
288 using match::MultiplyAnyOrder;
289
290 constexpr char kModuleStr[] = R"(
291 HloModule test_module
292 ENTRY test {
293 lhs = f16[] constant(42)
294 rhs = f16[] constant(52)
295 ROOT multiply = f16[] multiply(lhs, rhs)
296 })";
297 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
298 ParseAndReturnVerifiedModule(kModuleStr));
299 auto* root = hlo_module->entry_computation()->root_instruction();
300 const HloInstruction* instr;
301
302 EXPECT_TRUE(Match(
303 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
304 EXPECT_TRUE(Match(
305 root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
306
307 // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
308 // e.g. IsNonConstant() on it.
309 EXPECT_TRUE(Match(
310 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
311 .IsNonConstant()));
312 EXPECT_TRUE(
313 Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
314 .IsNonConstant()));
315 }
316
TEST_F(PatternMatcherTest,AnyOfShortCircuit)317 TEST_F(PatternMatcherTest, AnyOfShortCircuit) {
318 using match::AnyOf;
319 using match::Multiply;
320 using match::Op;
321
322 constexpr char kModuleStr[] = R"(
323 HloModule test_module
324 ENTRY test {
325 lhs = f16[] constant(42)
326 rhs = f16[] constant(52)
327 ROOT multiply = f16[] multiply(lhs, rhs)
328 })";
329 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
330 ParseAndReturnVerifiedModule(kModuleStr));
331 auto* root = hlo_module->entry_computation()->root_instruction();
332
333 {
334 const HloInstruction* mul = nullptr;
335 const HloInstruction* any = nullptr;
336
337 ASSERT_TRUE(Match(
338 root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
339 EXPECT_NE(nullptr, mul);
340 EXPECT_EQ(nullptr, any);
341 }
342 {
343 const HloInstruction* mul = nullptr;
344 const HloInstruction* any = nullptr;
345
346 ASSERT_TRUE(Match(
347 root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
348 EXPECT_NE(nullptr, any);
349 EXPECT_EQ(nullptr, mul);
350 }
351 }
352
TEST_F(PatternMatcherTest,AllOf)353 TEST_F(PatternMatcherTest, AllOf) {
354 using match::AllOf;
355 using match::Broadcast;
356 using match::Constant;
357 using match::Op;
358
359 constexpr char kModuleStr[] = R"(
360 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
361 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
362 ParseAndReturnVerifiedModule(kModuleStr));
363 auto* root = hlo_module->entry_computation()->root_instruction();
364
365 auto f16_scalar = ShapeUtil::MakeShape(F16, {});
366 auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
367 auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
368 auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
369 ASSERT_TRUE(Match(root, scalar_pattern));
370 ASSERT_TRUE(Match(root, f16_pattern));
371 ASSERT_TRUE(Match(root, f16_compatible_pattern));
372 EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
373 f16_compatible_pattern)));
374 EXPECT_TRUE(
375 Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
376 scalar_pattern)));
377 EXPECT_FALSE(
378 Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
379 EXPECT_FALSE(Match(
380 root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
381 EXPECT_FALSE(
382 Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
383 }
384
TEST_F(PatternMatcherTest,AllOfNoCaptureIfNotMatch)385 TEST_F(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
386 using match::AllOf;
387 using match::Broadcast;
388 using match::Constant;
389 using match::Op;
390
391 constexpr char kModuleStr[] = R"(
392 HloModule test_module
393 ENTRY test {
394 ROOT v = f16[] constant(42)
395 })";
396 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
397 ParseAndReturnVerifiedModule(kModuleStr));
398 auto* root = hlo_module->entry_computation()->root_instruction();
399
400 const HloInstruction* constant = nullptr;
401 ASSERT_FALSE(
402 Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
403 EXPECT_EQ(nullptr, constant);
404 ASSERT_TRUE(Match(root, Constant(&constant)));
405 EXPECT_NE(nullptr, constant);
406 }
407
TEST_F(PatternMatcherTest,TestNoCapture)408 TEST_F(PatternMatcherTest, TestNoCapture) {
409 using match::Constant;
410
411 constexpr char kModuleStr[] = R"(
412 HloModule test_module
413 ENTRY test {
414 ROOT v = f16[] constant(42)
415 })";
416 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
417 ParseAndReturnVerifiedModule(kModuleStr));
418 auto* root = hlo_module->entry_computation()->root_instruction();
419
420 const HloInstruction* constant = nullptr;
421 ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
422 EXPECT_EQ(nullptr, constant);
423 }
424
TEST_F(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)425 TEST_F(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
426 using match::Add;
427 using match::AddAnyOrder;
428 using match::AnyOf;
429 using match::Op;
430
431 constexpr char kModuleStr[] = R"(
432 HloModule test_module
433 ENTRY test {
434 u = f16[] parameter(0)
435 v = f16[] parameter(1)
436 ROOT add = f16[] add(u, v)
437 })";
438 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
439 ParseAndReturnVerifiedModule(kModuleStr));
440 auto* root = hlo_module->entry_computation()->root_instruction();
441
442 const HloInstruction* addend0 = nullptr;
443 const HloInstruction* addend1 = nullptr;
444 const HloInstruction* addend2 = nullptr;
445 auto add2_pattern = Add(Op(&addend0), Op(&addend1));
446 auto add3_pattern = AnyOf<HloInstruction>(
447 AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
448
449 ASSERT_TRUE(Match(root, add3_pattern));
450 EXPECT_NE(nullptr, addend0);
451 EXPECT_NE(nullptr, addend1);
452 EXPECT_EQ(nullptr, addend2);
453 }
454
TEST_F(PatternMatcherTest,TestConcat)455 TEST_F(PatternMatcherTest, TestConcat) {
456 using match::Concatenate;
457 using match::ConstantScalar;
458 using match::Op;
459 using match::Reshape;
460
461 constexpr char kModuleStr[] = R"(
462 HloModule test_module
463 ENTRY test {
464 c1 = u32[] constant(1)
465 c2 = u32[] constant(2)
466 c3 = u32[] constant(3)
467 c4 = u32[] constant(4)
468 r1 = u32[1] reshape(c1)
469 r2 = u32[1] reshape(c2)
470 r3 = u32[1] reshape(c3)
471 r4 = u32[1] reshape(c4)
472 ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
473 })";
474 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
475 ParseAndReturnVerifiedModule(kModuleStr));
476 auto* root = hlo_module->entry_computation()->root_instruction();
477 ASSERT_TRUE(Match(
478 root,
479 Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
480 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
481 ASSERT_FALSE(Match(
482 root,
483 Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
484 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
485 ASSERT_FALSE(Match(
486 root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
487 Reshape(ConstantScalar(3)))));
488 ASSERT_FALSE(Match(
489 root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
490 Reshape(ConstantScalar(4)))));
491 }
492
493 template <typename Pattern>
Description(const Pattern & pattern)494 string Description(const Pattern& pattern) {
495 std::stringstream ss;
496 pattern.DescribeTo(&ss);
497 return ss.str();
498 }
499
500 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)501 string Explanation(Elem* elem, const Pattern& pattern) {
502 std::stringstream ss;
503 MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
504 Match(elem, pattern, options);
505 return ss.str();
506 }
507 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)508 string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
509 return Explanation(elem.get(), pattern);
510 }
511 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)512 string Explanation(const Elem& elem, const Pattern& pattern) {
513 return Explanation(&elem, pattern);
514 }
515
516 // Helper macro for checking a pattern's description and the explanation printed
517 // when attempting to match (and presumably failing) on a given object.
518 //
519 // We use a macro rather than a function because we want good line numbers in
520 // errors. We use this rather than writing a helper that returns a pair of
521 // (description, explanation) and doing something like
522 //
523 // EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
524 //
525 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
526 // while EXPECT_THAT does not. This unified diff makes the errors much easier
527 // to read.
528 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
529 expected_explanation) \
530 do { \
531 EXPECT_EQ(Description(pattern), (expected_desc)); \
532 EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
533 } while (0)
534
TEST_F(PatternMatcherTest,LayoutDescribeToAndExplain)535 TEST_F(PatternMatcherTest, LayoutDescribeToAndExplain) {
536 auto layout = LayoutUtil::MakeLayout({1, 2});
537 auto layout2 = LayoutUtil::MakeLayout({2, 2});
538
539 EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
540 "a layout", "Layout is null");
541 EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
542 "a layout equal to {1,2}",
543 "Layout {2,2} is not equal to expected {1,2}");
544 }
545
TEST_F(PatternMatcherTest,CustomCallTargetMatcherDescribeAndExplain)546 TEST_F(PatternMatcherTest, CustomCallTargetMatcherDescribeAndExplain) {
547 constexpr char kModuleStr[] = R"(
548 HloModule test_module
549
550 ENTRY test {
551 ROOT out = f32[] custom-call(), custom_call_target="test_target"
552 }
553 )";
554
555 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
556 ParseAndReturnVerifiedModule(kModuleStr));
557
558 auto* root = hlo_module->entry_computation()->root_instruction();
559 EXPECT_TRUE(Match(root, match::Op().WithCustomCallTarget("test_target")));
560 EXPECT_FALSE(Match(root, match::Op().WithCustomCallTarget("other_target")));
561
562 EXPECT_DESC_AND_EXPLANATION(
563 root, match::Op().WithCustomCallTarget("other_target"),
564 "an HloInstruction custom call with target 'other_target'",
565 "HloInstruction is not a custom call with a target 'other_target'\nin "
566 "out = f32[] custom-call(), custom_call_target=\"test_target\"");
567 }
568
TEST_F(PatternMatcherTest,ShapeDescribeToAndExplain)569 TEST_F(PatternMatcherTest, ShapeDescribeToAndExplain) {
570 auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
571 auto layout = shape.layout();
572
573 EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
574 "a shape", "Shape is null");
575 EXPECT_DESC_AND_EXPLANATION(
576 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
577 m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
578 "Shape not equal to f32[1,2]{0,1}\n"
579 "in f32[1,2]{1,0}");
580 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
581 m::Shape().CompatibleTo(&shape),
582 "a shape compatible with f32[1,2]",
583 "Shape not compatible with f32[1,2]\n"
584 "in f32[2,2]{1,0}");
585 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
586 "a shape with element type F16",
587 "Shape does not have element type F16\n"
588 "in f32[1,2]{0,1}");
589 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
590 "a shape that represents a scalar",
591 "Shape is not a scalar\n"
592 "in f32[1,2]{0,1}");
593 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
594 "a shape that represents an array",
595 "Shape is not an array\n"
596 "in ()");
597 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
598 "a shape that represents a tuple",
599 "Shape is not a tuple\n"
600 "in f32[1,2]{0,1}");
601 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
602 "a shape that is an effective scalar",
603 "Shape is not an effective scalar\n"
604 "in f32[1,2]{0,1}");
605 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
606 "a shape that has 42 dimensions",
607 "Shape does not have rank 42\n"
608 "in f32[1,2]{0,1}");
609 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
610 "a shape that is a scalar",
611 "Shape is not a scalar\n"
612 "in f32[1,2]{0,1}");
613 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
614 "a shape:\n"
615 " * that has 1 dimension AND\n"
616 " * that represents an array",
617 "Shape does not have rank 1\n"
618 "in f32[1,2]{0,1}");
619 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
620 m::Shape().IsArray().WithRank(1),
621 "a shape:\n"
622 " * that represents an array AND\n"
623 " * that has 1 dimension",
624 "Shape is not an array\n"
625 "in ()");
626 EXPECT_DESC_AND_EXPLANATION(
627 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
628 m::Shape().WithLayoutEqualTo(&layout),
629 "a shape with\n a layout equal to {0,1}",
630 "Layout {1,0} is not equal to expected {0,1}\n"
631 "in f32[1,2]{1,0}");
632 EXPECT_DESC_AND_EXPLANATION(shape,
633 m::Shape().WithSubshapeEqualTo({10}, &shape),
634 "a shape with subshape at index {10} which is\n"
635 " a shape equal to f32[1,2]{0,1}",
636 "No subshape at {10}\n"
637 "in f32[1,2]{0,1}");
638 EXPECT_DESC_AND_EXPLANATION(
639 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
640 m::Shape().WithSubshapeEqualTo({0}, &shape),
641 "a shape with subshape at index {0} which is\n"
642 " a shape equal to f32[1,2]{0,1}",
643 "Shape not equal to f32[1,2]{0,1}\n"
644 "in f32[2,2]{1,0}\n"
645 "in subshape at {0}\n"
646 "in (f32[2,2])");
647 EXPECT_DESC_AND_EXPLANATION(shape,
648 m::Shape().WithSubshapeCompatibleTo({10}, &shape),
649 "a shape with subshape at index {10} which is\n"
650 " a shape compatible with f32[1,2]",
651 "No subshape at {10}\n"
652 "in f32[1,2]{0,1}");
653 EXPECT_DESC_AND_EXPLANATION(
654 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
655 m::Shape().WithSubshapeCompatibleTo({0}, &shape),
656 "a shape with subshape at index {0} which is\n"
657 " a shape compatible with f32[1,2]",
658 "Shape not compatible with f32[1,2]\n"
659 "in f32[2,2]{1,0}\n"
660 "in subshape at {0}\n"
661 "in (f32[2,2])");
662 EXPECT_DESC_AND_EXPLANATION(
663 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
664 m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
665 "a shape with subshape at index {0,0} which is\n"
666 " a shape that represents a scalar",
667 "Shape is not a scalar\n"
668 "in f32[1,2]{0,1}\n"
669 "in subshape at {0,0}\n"
670 "in ((f32[1,2]))");
671 }
672
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)673 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
674 std::unique_ptr<HloInstruction> instr) {
675 instr->SetAndSanitizeName(string(name));
676 return instr;
677 }
678
TEST_F(PatternMatcherTest,HloInstructionDescribeToAndExplain)679 TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
680 std::unique_ptr<HloInstruction> iota =
681 SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
682 /*iota_dimension=*/0));
683 std::unique_ptr<HloInstruction> constant =
684 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
685
686 EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
687 m::Op(), "an HloInstruction",
688 "HloInstruction* is null");
689 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
690 "an HloInstruction named \"foo\"",
691 "HloInstruction not named \"foo\"\n"
692 "in i = s32[42]{0} iota(), iota_dimension=0");
693 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
694 "an HloInstruction with opcode add",
695 "HloInstruction doesn't have opcode add\n"
696 "in i = s32[42]{0} iota(), iota_dimension=0");
697 EXPECT_DESC_AND_EXPLANATION(
698 constant, m::Op().IsNonConstant(),
699 "an HloInstruction with any opcode other than constant",
700 "HloInstruction has opcode constant, expected anything else\n"
701 "in c = s32[] constant(0)");
702 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
703 "an HloInstruction with 42 operands",
704 "HloInstruction doesn't have 42 operands\n"
705 "in i = s32[42]{0} iota(), iota_dimension=0");
706 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
707 "an HloInstruction outputting\n"
708 " a shape that represents a tuple",
709 "Shape is not a tuple\n"
710 "in s32[42]{0}\n"
711 "in output shape\n"
712 "in i = s32[42]{0} iota(), iota_dimension=0");
713 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(F32, {42}),
714 "an HloInstruction outputting\n"
715 " a shape:\n"
716 " * with element type F32 AND\n"
717 " * with dimensions [42]",
718 "Shape does not have element type F32\n"
719 "in s32[42]{0}\n"
720 "in output shape\n"
721 "in i = s32[42]{0} iota(), iota_dimension=0");
722 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(S32, {128}),
723 "an HloInstruction outputting\n"
724 " a shape:\n"
725 " * with element type S32 AND\n"
726 " * with dimensions [128]",
727 "Shape does not have dimensions [128]\n"
728 "in s32[42]{0}\n"
729 "in output shape\n"
730 "in i = s32[42]{0} iota(), iota_dimension=0");
731 EXPECT_DESC_AND_EXPLANATION(
732 iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
733 "an HloInstruction with operand 2 which is:\n"
734 " an HloInstruction with opcode add",
735 "desired operand index 2 is out of bounds\n"
736 "in i = s32[42]{0} iota(), iota_dimension=0");
737
738 EXPECT_DESC_AND_EXPLANATION(
739 SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
740 HloOpcode::kAdd, constant.get(),
741 constant.get())),
742 m::Op().WithOperand(1, m::Op().IsNonConstant()),
743 "an HloInstruction with operand 1 which is:\n"
744 " an HloInstruction with any opcode other than constant",
745 "HloInstruction has opcode constant, expected anything else\n"
746 "in c = s32[] constant(0)\n"
747 "in operand 1\n"
748 "in a = s32[] add(s32[] c, s32[] c)");
749 EXPECT_DESC_AND_EXPLANATION(
750 iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
751 "an HloInstruction with fusion kind kLoop",
752 "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
753 "in i = s32[42]{0} iota(), iota_dimension=0");
754 EXPECT_DESC_AND_EXPLANATION(
755 iota, m::Op().WithTupleIndex(42),
756 "an HloInstruction which is a GTE with index 42",
757 "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
758 "in i = s32[42]{0} iota(), iota_dimension=0");
759 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
760 "an HloInstruction which is a constant scalar",
761 "HloInstruction is not a constant\n"
762 "in i = s32[42]{0} iota(), iota_dimension=0");
763 EXPECT_DESC_AND_EXPLANATION(
764 SetName("c", HloInstruction::CreateConstant(
765 LiteralUtil::CreateR1<int>({1, 2}))),
766 m::Op().IsConstantEffectiveScalar(),
767 "an HloInstruction which is a constant effective scalar",
768 "HloInstruction is not an effective scalar\n"
769 "in c = s32[2]{0} constant({1, 2})");
770 EXPECT_DESC_AND_EXPLANATION(
771 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
772 m::Op().IsConstantScalar(42),
773 "an HloInstruction which is a constant scalar with value 42",
774 "HloInstruction's constant value 10 did not match expected value 42\n"
775 "in c = s32[] constant(10)");
776 EXPECT_DESC_AND_EXPLANATION(
777 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
778 m::Op().IsConstantEffectiveScalar(1.25),
779 "an HloInstruction which is a constant effective scalar with value 1.25",
780 "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
781 "in c = f64[] constant(2.25)");
782 EXPECT_DESC_AND_EXPLANATION(
783 constant, m::Op().Is(iota.get()),
784 absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
785 " (i = s32[42]{0} iota(), iota_dimension=0)"),
786 absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
787 absl::Hex(iota.get()),
788 " (i = s32[42]{0} iota(), iota_dimension=0)\n"
789 "in c = s32[] constant(0)"));
790 }
791
TEST_F(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)792 TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
793 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
794 EXPECT_DESC_AND_EXPLANATION(
795 SetName("a", HloInstruction::CreateBinary(
796 scalar_s32, HloOpcode::kAdd,
797 SetName("b", HloInstruction::CreateConstant(
798 LiteralUtil::CreateR0(0)))
799 .get(),
800 SetName("c", HloInstruction::CreateConstant(
801 LiteralUtil::CreateR0(0)))
802 .get())),
803 m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
804 "an HloInstruction:\n"
805 " * with opcode add AND\n"
806 " * with two operands in either order:\n"
807 " - an HloInstruction named \"b\"\n"
808 " - an HloInstruction named \"bar\"",
809 "HloInstruction's operands (ignoring order) did not match second "
810 "matcher. Specifically,\n"
811 " - an HloInstruction named \"bar\"\n"
812 "does not match LHS:\n"
813 " - HloInstruction not named \"bar\"\n"
814 " in b = s32[] constant(0)\n"
815 "does not match RHS:\n"
816 " - HloInstruction not named \"bar\"\n"
817 " in c = s32[] constant(0)\n"
818 "in a = s32[] add(s32[] b, s32[] c)");
819
820 EXPECT_DESC_AND_EXPLANATION(
821 SetName("a",
822 HloInstruction::CreateBinary(
823 scalar_s32, HloOpcode::kAdd,
824 HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
825 SetName("c", HloInstruction::CreateConstant(
826 LiteralUtil::CreateR0(0)))
827 .get())),
828 m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
829 "an HloInstruction:\n"
830 " * with opcode add AND\n"
831 " * with two operands in either order:\n"
832 " - an HloInstruction which is a constant scalar\n"
833 " - an HloInstruction with opcode constant",
834 "HloInstruction's LHS operand did not match either of the two matchers. "
835 "Specifically,\n"
836 " - an HloInstruction which is a constant scalar\n"
837 "does not match LHS:\n"
838 " - HloInstruction is not a constant\n"
839 " in p = s32[] parameter(0)\n"
840 "and\n"
841 " - an HloInstruction with opcode constant\n"
842 "does not match LHS:\n"
843 " - HloInstruction doesn't have opcode constant\n"
844 " in p = s32[] parameter(0)\n"
845 "in a = s32[] add(s32[] p, s32[] c)");
846 }
847
TEST_F(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)848 TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
849 EXPECT_DESC_AND_EXPLANATION(
850 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
851 m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
852 m::Op().WithName("bar")),
853 "any of:\n"
854 " - an HloInstruction named \"foo\" OR\n"
855 " - an HloInstruction named \"bar\"",
856 "None of the following matchers succeeded:\n"
857 "Matcher #1\n"
858 " - an HloInstruction named \"foo\"\n"
859 "failed with\n"
860 " - HloInstruction not named \"foo\"\n"
861 " in c = s32[] constant(0)\n"
862 "Matcher #2\n"
863 " - an HloInstruction named \"bar\"\n"
864 "failed with\n"
865 " - HloInstruction not named \"bar\"\n"
866 " in c = s32[] constant(0)");
867 }
868
TEST_F(PatternMatcherTest,Parameter)869 TEST_F(PatternMatcherTest, Parameter) {
870 auto param =
871 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
872 auto non_param =
873 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
874 EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
875 EXPECT_TRUE(Match(param.get(), m::Parameter()));
876 EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
877 EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
878 EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
879
880 EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
881 "an HloInstruction:\n"
882 " * with opcode parameter AND\n"
883 " * which is parameter 1",
884 "HloInstruction doesn't have opcode parameter\n"
885 "in c = s32[] constant(0)");
886 EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
887 0, ShapeUtil::MakeShape(F32, {}), "p0"),
888 m::Parameter(1)),
889 "HloInstruction is not parameter 1\n"
890 "in p0 = f32[] parameter(0)");
891 }
892
TEST_F(PatternMatcherTest,OneUseAndOneUser)893 TEST_F(PatternMatcherTest, OneUseAndOneUser) {
894 auto param =
895 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
896
897 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
898 EXPECT_DESC_AND_EXPLANATION(
899 param, m::Op().WithOneUse(),
900 "an HloInstruction which has exactly one use",
901 "HloInstruction has 0 users, but expected exactly one.\n"
902 "in p0 = f32[] parameter(0)");
903
904 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
905 EXPECT_DESC_AND_EXPLANATION(
906 param, m::Op().WithOneUser(),
907 "an HloInstruction which has exactly one user (but possibly is used "
908 "multiple times by that instruction)",
909 "HloInstruction has 0 users, but expected exactly one.\n"
910 "in p0 = f32[] parameter(0)");
911
912 {
913 auto reshape =
914 SetName("r", HloInstruction::CreateReshape(
915 ShapeUtil::MakeShape(F32, {1}), param.get()));
916 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
917 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
918
919 auto reshape1 =
920 SetName("r1", HloInstruction::CreateReshape(
921 ShapeUtil::MakeShape(F32, {1}), param.get()));
922 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
923 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
924
925 const char* kMultipleUserExplanation =
926 "HloInstruction has 2 users, but expected exactly one.\n"
927 "All users:\n"
928 " - r = f32[1]{0} reshape(f32[] p0)\n"
929 " - r1 = f32[1]{0} reshape(f32[] p0)\n"
930 "in p0 = f32[] parameter(0)";
931 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
932 kMultipleUserExplanation);
933 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
934 kMultipleUserExplanation);
935 }
936
937 auto add = SetName("add", HloInstruction::CreateBinary(
938 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
939 param.get(), param.get()));
940 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
941 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
942 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
943 "HloInstruction is used 2 times by its user, but is expected to be "
944 "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
945 "in p0 = f32[] parameter(0)");
946 }
947
TEST_F(PatternMatcherTest,Comparison)948 TEST_F(PatternMatcherTest, Comparison) {
949 auto shape = ShapeUtil::MakeShape(F32, {1});
950 auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
951 auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
952 auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
953 ComparisonDirection::kEq);
954 auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
955 ComparisonDirection::kNe);
956 auto add =
957 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
958 auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
959 ComparisonDirection::kLe);
960
961 EXPECT_TRUE(Match(eq.get(), m::Compare()));
962 EXPECT_TRUE(Match(eq.get(), m::Eq()));
963 EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
964 EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
965 EXPECT_TRUE(Match(ne.get(), m::Compare()));
966 EXPECT_TRUE(Match(ne.get(), m::Ne()));
967 EXPECT_TRUE(Match(
968 le.get(),
969 m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
970 EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
971 m::Add(m::Parameter(0), m::Parameter(1)))));
972
973 EXPECT_FALSE(Match(eq.get(), m::Add()));
974 EXPECT_FALSE(Match(eq.get(), m::Ne()));
975 EXPECT_FALSE(
976 Match(le.get(),
977 m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
978 EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
979 EXPECT_DESC_AND_EXPLANATION(
980 eq, m::Ne().WithOneUser(),
981 "an HloInstruction:\n"
982 " * with opcode compare AND\n"
983 " * which has comparison direction NE AND\n"
984 " * which has exactly one user (but possibly is used "
985 "multiple times by that instruction)",
986 "HloInstruction is not comparison NE\n"
987 "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
988 "direction=EQ");
989 }
990
TEST_F(PatternMatcherTest,CustomCallMatchers)991 TEST_F(PatternMatcherTest, CustomCallMatchers) {
992 constexpr char kModuleStr[] = R"(
993 HloModule test_module
994
995 ENTRY test {
996 p0 = f32[] parameter(0)
997 p1 = f32[] parameter(1)
998 ROOT out = f32[] custom-call(p0, p1), custom_call_target="test_target"
999 }
1000 )";
1001 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1002 ParseAndReturnVerifiedModule(kModuleStr));
1003 auto* root = hlo_module->entry_computation()->root_instruction();
1004
1005 EXPECT_TRUE(Match(root, m::CustomCall()));
1006 EXPECT_TRUE(Match(root, m::CustomCall("test_target")));
1007 EXPECT_TRUE(Match(
1008 root, m::CustomCall("test_target", m::Parameter(0), m::Parameter(1))));
1009
1010 HloInstruction* instr;
1011 EXPECT_TRUE(Match(root, m::CustomCall(&instr)));
1012 EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target")));
1013 EXPECT_TRUE(Match(root, m::CustomCall(&instr, "test_target", m::Parameter(0),
1014 m::Parameter(1))));
1015
1016 const HloInstruction* const_instr;
1017 EXPECT_TRUE(Match(root, m::CustomCall(&const_instr)));
1018 EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target")));
1019 EXPECT_TRUE(Match(root, m::CustomCall(&const_instr, "test_target",
1020 m::Parameter(0), m::Parameter(1))));
1021
1022 EXPECT_FALSE(Match(root, m::CustomCall("other_target")));
1023 EXPECT_FALSE(Match(
1024 root, m::CustomCall("test_target", m::Parameter(1), m::Parameter(0))));
1025 }
1026
1027 } // namespace
1028 } // namespace xla
1029