1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18
19 #include "absl/strings/str_replace.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/utility/utility.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29
30 namespace xla {
31
32 // A pattern matcher for HloInstructions, Shapes, and Layouts.
33 //
34 // The Match function's first argument must be HloInstruction*, Shape*, or
35 // Layout*. The second argument is a pattern that will be matched against the
36 // first argument, as described below.
37 //
38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
39 // functions. By default, the returned patterns will match any HloInstruction,
40 // Shape, or Layout, respectively. However the match can be made more specific
41 // by using the pattern's modifier methods, for example:
42 //
43 // match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
44 // 0, match::Op().WithOpcode(HloOpcode::kConstant))
45 //
46 // This pattern will match Add instructions whose first operand is a constant.
47 //
48 // Each pattern type has the following modifiers, which are described where
49 // nontrivial.
50 //
51 // Op():
52 // - Is: is the given HloInstruction* (i.e. pointer equality)
53 // - WithName
54 // - WithOpcode
55 // - WithoutOpcode: anything other than the given opcode
56 // - WithShape: instr's shape matches the given pattern
57 // - WithShapeEqualTo: instr's shape is equal to the given Shape
58 // - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
59 // - WithNumOperands
60 // - WithOperand: operand at the given index matches the given pattern
61 // - IsConstant
62 // - IsNonConstant
63 // - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
64 // e.g. IsConstantScalar() or IsConstantScalar(42).
65 // - WithFusionKind
66 // - WithTupleIndex: get-tuple-element operations with the given tuple index
67 // - WithOneUse: Instruction is used as an operand exactly once.
68 // - WithOneUser: Instruction is used by exactly one other instruction, but
69 // is possibly used more than once as an operand (e.g. multiply(x,x)).
70 // - WithComparisonDirection: instr has the given direction
71 //
72 // Shape():
73 // - EqualTo
74 // - CompatibleTo
75 // - IsScalar/IsEffectiveScalar/IsArray/IsTuple
76 // - IsDenseArray
77 // - WithLayout: layout shape's layout matches the given pattern (e.g.
78 // Layout().WithDenseFormat())
79 // - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
80 // Layout, but not the result of Layout().foo())
81 // - WithSubshape: shape is a tuple whose subshape matches the given pattern
82 // (e.g. Shape().IsScalar()).
83 // - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
84 // (i.e. another Shape, but not the result of Shape().foo())
85 // - WithElementType: shape is an array/scalar with the given elem type
86 // - WithRank: shape is an array/scalar with the given rank
87 //
88 // Layout():
89 // - EqualTo
90 // - WithDenseFormat
91 //
92 // Op(), Shape(), and Layout() may be passed an argument of type
93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
94 // these pointers. If the pattern is matched, the address of the matched value
95 // will be "captured" and stored at this location.
96 //
97 // For example:
98 // HloInstruction* foo = ...;
99 // HloInstruction* matched_operand;
100 // CHECK(Match(foo,
101 // match::Op().WithOperand(0, match::Op(&matched_operand))));
102 //
103 // Helpers are provided for most HLO instructions. These helpers can be called
104 // with no arguments, in which case they will match any instruction matching the
105 // opcode. They may also be called with matches for the operands and with an
106 // optional capture. (The capture must be the first argument.) Some examples of
107 // these helpers and their equivalents are provided below.
108
109 // Example nullary instruction:
110 // Parameter() == Op().WithOpcode(HloOpcode::kParameter)
111 // Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter)
112 //
113 // Example unary instruction:
114 // Abs() == Op().WithOpcode(HloOpcode::kAbs)
115 // Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs)
116 // .WithOperand(0, Op(&a)))
117 // Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs)
118 // .WithOperand(0, Op(&b))
119 //
120 // Commutative binary instructions have a special form that accepts either order
121 // of args, e.g.:
122 //
123 // AddAnyOrder(Parameter(1), Abs()) ==
124 // Op().WithOpcode(HloOpcode::kAdd)
125 // .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
126 //
127 // MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`.
128 //
129 // The following additional helpers are provided. In all cases, `&a` is
130 // optional.
131 //
132 // ConstantScalar(&a) == Op(&a).IsConstantScalar();
133 // ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v);
134 // ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar();
135 // ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v)
136 // NonConstant(&a) == Op(&a).IsNonConstant()
137 // GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index)
138 // .WithOperand(0, b);
139 // Parameter(&a, n) == Op(&a).WithParameterNum(n);
140
141 struct MatchOption {
142 // If true, actually capture matched item into the user pointer.
143 bool capture;
144
145 // An explanation for why we failed to match is streamed here, if not-null.
146 std::ostream* explain_os;
147 };
148
149 template <typename Value, typename Pattern>
150 bool Match(Value* value, const Pattern& pattern,
151 MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
152 if (option.capture) {
153 auto new_option = option;
154 new_option.capture = false;
155 if (!pattern.Match(value, new_option)) {
156 return false;
157 }
158 }
159 return pattern.Match(value, option);
160 }
161
162 namespace match {
163
164 namespace detail {
165
166 // Macro for streaming to option.explain_os if it's not null.
167 //
168 // EXPLAIN << "value of foo(): " << foo()
169 //
170 #pragma push_macro("EXPLAIN")
171 #define EXPLAIN \
172 if (option.explain_os) *option.explain_os
173
174 // kIndentInc is the additional number of spaces that we indent by when we
175 // increase the indent "by one".
176 enum {
177 kIndentInc = 2,
178 };
179
180 // Writes a newline and then `indent` spaces.
181 //
182 // We follow an unintuitive convention in this file's pretty-printers: Indents
183 // are performed by the caller, not the callee. For example, if you want to
184 // print
185 //
186 // foo:
187 // - bar
188 //
189 // you'd do:
190 //
191 // Foo::DescribeTo(std::ostream* os, int64 indent) {
192 // *os << "foo:";
193 // Indent(os, indent) // Create a newline at the *current* indent level.
194 // *os << " - ";
195 // bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3.
196 // }
197 //
198 // Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
199 //
200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
201 // performed by Foo. This convention allows the caller to decide whether a
202 // matcher is preceded by a newline, which is important e.g. for the AllOf
203 // matcher.
204 //
205 // (Incidentally, indenting in Match's explanations is handled differently.
206 // Indents are a common case in DescribeTo [we're printing a whole tree], but
207 // they're a special case in Match [we're printing only a path through the tree
208 // that encounters a failing node]. Indents in Match only appear when we
209 // encounter a failing disjunction, so we just handle them as a special case
210 // there.)
Indent(std::ostream * os,int64 indent)211 inline void Indent(std::ostream* os, int64 indent) {
212 *os << "\n";
213 for (int64 i = 0; i < indent; ++i) {
214 *os << " ";
215 }
216 }
217
218 // SFINAE template that determines whether T declares a static member
219 // kIsTrivialMatcher.
220 //
221 // Trivial matchers get special treatment. For example, when printing
222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
223 // yields e.g.
224 // "a shape compatible with f32[1,2]"
225 // rather than
226 // "a shape AND compatible with f32[1,2]"
227 template <typename T, typename Dummy = void>
228 struct IsTrivialMatcher {
229 static constexpr bool value = false;
230 };
231 template <typename T>
232 struct IsTrivialMatcher<T,
233 typename std::enable_if<T::kIsTrivialMatcher>::type> {
234 static constexpr bool value = true;
235 };
236
237 template <typename Item, typename... Patterns>
238 class AllOfPattern {
239 public:
240 explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
241
242 bool Match(const Item* item, MatchOption option) const {
243 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
244 // This invariant is guaranteed by the top-level Match and AnyOf.
245 DCHECK(matched || !option.capture);
246 return matched;
247 }
248
249 bool Match(Item* item, MatchOption option) const {
250 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
251 // This invariant is guaranteed by the top-level Match and AnyOf.
252 DCHECK(matched || !option.capture);
253 return matched;
254 }
255
256 void DescribeTo(std::ostream* os, int64 indent = 0) const {
257 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
258 }
259
260 // Accessor for patterns_. Please don't use this outside of this file.
261 const std::tuple<Patterns...>& patterns() const { return patterns_; }
262
263 private:
264 template <typename ItemType, size_t index>
265 bool MatchImpl(ItemType* item, MatchOption option,
266 std::integral_constant<size_t, index>) const {
267 // We don't need to do any EXPLAINing here; it's all correctly handled by
268 // our sub-matchers (if any fail).
269 return std::get<index>(patterns_).Match(item, option) &&
270 MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
271 }
272
273 template <typename ItemType>
274 bool MatchImpl(ItemType* item, MatchOption option,
275 std::integral_constant<size_t, sizeof...(Patterns)>) const {
276 return true;
277 }
278
279 // Pretty-printing a conjunction has some special cases to make it easy to
280 // read in the simple (common) case.
281 //
282 // If sizeof...(Patterns) == 1, prints as e.g.
283 //
284 // a shape
285 //
286 // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
287 // shape") prints as
288 //
289 // a shape compatible with f32[1,2]
290 //
291 // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
292 //
293 // a shape:
294 // * compatible with f32[1,2] AND
295 // * that represents a scalar
296 //
297 // Otherwise prints as:
298 //
299 // all of:
300 // * foo AND
301 // * bar
302 //
303 template <size_t index>
304 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
305 int64 indent) const {
306 constexpr bool first_is_trivial =
307 IsTrivialMatcher<typename std::remove_reference<decltype(
308 std::get<0>(patterns_))>::type>::value;
309 constexpr bool is_last = index == sizeof...(Patterns) - 1;
310 const auto& submatcher = std::get<index>(patterns_);
311
312 auto print_bulleted_item = [&] {
313 *os << " * ";
314 submatcher.DescribeTo(os, indent + 3);
315 if (!is_last) {
316 *os << " AND";
317 Indent(os, indent);
318 }
319 };
320
321 if (index == 0) {
322 if (first_is_trivial || is_last) {
323 submatcher.DescribeTo(os, indent + kIndentInc);
324 if (sizeof...(Patterns) > 2) {
325 *os << ":";
326 Indent(os, indent);
327 }
328 } else {
329 *os << "all of:";
330 Indent(os, indent);
331 print_bulleted_item();
332 }
333 } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
334 *os << " ";
335 submatcher.DescribeTo(os, indent);
336 } else {
337 print_bulleted_item();
338 }
339 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
340 }
341
342 void DescribeToImpl(std::ostream* os,
343 std::integral_constant<size_t, sizeof...(Patterns)>,
344 int64 indent) const {}
345
346 std::tuple<Patterns...> patterns_;
347 };
348
349 } // namespace detail
350
351 // Returns a pattern that represents the conjunction of all input patterns. All
352 // patterns need to match in order to have the AllOf pattern match.
353 template <typename Item, typename... Patterns>
354 auto AllOf(const Patterns&... patterns) {
355 return detail::AllOfPattern<typename std::remove_const<Item>::type,
356 Patterns...>(patterns...);
357 }
358
359 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
360 //
361 // This transformation is necessary for good pretty-printing.
362 template <typename Item, typename... InnerPs, typename... OuterPs>
363 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
364 const OuterPs&... outer_ps) {
365 // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
366 auto make_all_of = [](const InnerPs&... inner_ps,
367 const OuterPs&... outer_ps) {
368 return detail::AllOfPattern<typename std::remove_const<Item>::type,
369 InnerPs..., OuterPs...>(inner_ps...,
370 outer_ps...);
371 };
372 return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
373 std::make_tuple(outer_ps...)));
374 }
375
376 namespace detail {
377
378 template <typename LayoutType, typename Impl>
379 class LayoutPattern;
380
381 // The base LayoutPattern implementation. Matches only if the layout is not
382 // nullptr.
383 class LayoutPatternBaseImpl {
384 public:
385 bool Match(const ::xla::Layout* layout, MatchOption option) const {
386 if (layout == nullptr) {
387 EXPLAIN << "Layout is null";
388 return false;
389 }
390 return true;
391 }
392
393 void DescribeTo(std::ostream* os, int64 indent = 0) const {
394 *os << "a layout";
395 }
396
397 static constexpr bool kIsTrivialMatcher = true;
398 };
399
400 // A LayoutPattern implementation that matches only if the layout equals a
401 // Layout proto.
402 class LayoutPatternEqualImpl {
403 public:
404 explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
405 : layout_(layout) {}
406
407 bool Match(const ::xla::Layout* layout, MatchOption option) const {
408 if (!LayoutUtil::Equal(*layout_, *layout)) {
409 EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
410 << " is not equal to expected "
411 << LayoutUtil::HumanString(*layout_);
412 return false;
413 }
414 return true;
415 }
416
417 void DescribeTo(std::ostream* os, int64 indent = 0) const {
418 *os << "equal to " << LayoutUtil::HumanString(*layout_);
419 }
420
421 private:
422 const ::xla::Layout* layout_;
423 };
424
425 // A LayoutPattern implementation that matches only if the layout has a given
426 // format.
427 class LayoutPatternFormatImpl {
428 public:
429 explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
430
431 bool Match(const ::xla::Layout* layout, MatchOption option) const {
432 if (layout->format() != format_) {
433 EXPLAIN << "Layout has format " << Format_Name(layout->format())
434 << " but expected " << Format_Name(format_);
435 return false;
436 }
437 return true;
438 }
439
440 void DescribeTo(std::ostream* os, int64 indent = 0) const {
441 *os << "with format " << Format_Name(format_);
442 }
443
444 private:
445 Format format_;
446 };
447
448 // A pattern that matches Layouts.
449 template <typename LayoutType, typename Impl>
450 class LayoutPattern {
451 private:
452 template <typename NewImpl>
453 auto AppendImpl(NewImpl new_impl) const {
454 auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
455 return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
456 matched_layout_);
457 }
458
459 public:
460 explicit constexpr LayoutPattern(const Impl& impl,
461 LayoutType** matched_layout)
462 : impl_(impl), matched_layout_(matched_layout) {}
463
464 // Returns true and captures the layout iff it matches the pattern.
465 bool Match(const ::xla::Layout* layout, MatchOption option) const {
466 if (impl_.Match(layout, option)) {
467 if (option.capture && matched_layout_) {
468 *matched_layout_ = layout;
469 }
470 return true;
471 }
472 return false;
473 }
474
475 // Returns true and captures the layout iff it matches the pattern.
476 bool Match(::xla::Layout* layout, MatchOption option) const {
477 if (impl_.Match(layout, option)) {
478 if (option.capture && matched_layout_) {
479 *matched_layout_ = layout;
480 }
481 return true;
482 }
483 return false;
484 }
485
486 void DescribeTo(std::ostream* os, int64 indent = 0) const {
487 impl_.DescribeTo(os, indent);
488 }
489
490 // Modifies the pattern to match only if the layout equals the given proto.
491 // The layout must outlive the returned pattern.
492 constexpr auto EqualTo(const ::xla::Layout* layout) const {
493 return AppendImpl(LayoutPatternEqualImpl(layout));
494 }
495
496 // Modifies the pattern to match only if the layout has a dense format.
497 constexpr auto WithDenseFormat() const {
498 return AppendImpl(LayoutPatternFormatImpl(DENSE));
499 }
500
501 private:
502 Impl impl_;
503 LayoutType** matched_layout_;
504 };
505
506 template <typename Item, typename... Patterns>
507 class AnyOfPattern {
508 public:
509 explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
510
511 bool Match(const Item* item, MatchOption option) const {
512 return MatchImpl(item, option);
513 }
514
515 bool Match(Item* item, MatchOption option) const {
516 return MatchImpl(item, option);
517 }
518
519 void DescribeTo(std::ostream* os, int64 indent = 0) const {
520 *os << "any of:";
521 Indent(os, indent);
522 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
523 }
524
525 private:
526 template <typename ItemType>
527 bool MatchImpl(ItemType* item, MatchOption option) const {
528 // If we're generating an explanation, buffer it until we know we failed.
529 absl::optional<std::stringstream> explanation;
530 MatchOption new_option = option;
531 if (option.explain_os) {
532 new_option.explain_os = &explanation.emplace();
533 }
534 bool rv = MatchRecursiveImpl(item, new_option,
535 std::integral_constant<size_t, 0>());
536 if (!rv && option.explain_os) {
537 EXPLAIN << "None of the following matchers succeeded:";
538 EXPLAIN << explanation->str();
539 }
540 return rv;
541 }
542
543 template <typename ItemType, size_t index>
544 bool MatchRecursiveImpl(ItemType* item, MatchOption option,
545 std::integral_constant<size_t, index>) const {
546 auto new_option = option;
547 new_option.capture = false;
548
549 absl::optional<std::stringstream> explanation;
550 if (option.explain_os) {
551 new_option.explain_os = &explanation.emplace();
552 }
553
554 // Try to match the sub-pattern without capturing behavior.
555 if (std::get<index>(patterns_).Match(item, new_option)) {
556 // Capture the branch.
557 if (option.capture) {
558 // TODO(timshen): Currently the behavior can be exponential. Optimize it
559 // with memoization or recording the matched sub-pattern index, if it
560 // takes too long to run.
561 //
562 // Specifically, the "memoization" approach is to create an empty
563 // container with the key (pattern, instruction), and value as whether
564 // matched or not.
565 //
566 // Alternatively, we may run the pattern matching with captures off, but
567 // instead record a "trace" somewhere, indicating how exactly the
568 // pattern matches the input. For example, the trace information for
569 // AnyOf will be a runtime number indicate which sub-pattern is matched.
570 // Then we run another pass to do captures only with the help of the
571 // trace.
572 bool matched = std::get<index>(patterns_).Match(item, option);
573 DCHECK(matched);
574 }
575 return true;
576 }
577 if (option.explain_os) {
578 EXPLAIN << "\nMatcher #" << index + 1;
579 EXPLAIN << "\n - ";
580 std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
581 EXPLAIN << "\nfailed with";
582 EXPLAIN << "\n - ";
583 EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}});
584 }
585 return MatchRecursiveImpl(item, option,
586 std::integral_constant<size_t, index + 1>());
587 }
588
589 template <typename ItemType>
590 bool MatchRecursiveImpl(
591 ItemType* item, MatchOption option,
592 std::integral_constant<size_t, sizeof...(Patterns)>) const {
593 return false;
594 }
595
596 template <size_t index>
597 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
598 int64 indent) const {
599 *os << " - ";
600 std::get<index>(patterns_).DescribeTo(os, indent + 3);
601 if (index != sizeof...(Patterns) - 1) {
602 *os << " OR";
603 Indent(os, indent);
604 }
605 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
606 }
607
608 void DescribeToImpl(std::ostream* os,
609 std::integral_constant<size_t, sizeof...(Patterns)>,
610 int64 indent) const {}
611
612 std::tuple<Patterns...> patterns_;
613 };
614
615 } // namespace detail
616
617 // Returns a pattern that represents the logical disjunction of the input
618 // patterns. The returned pattern matches from left to right, and stops on the
619 // first match.
620 template <typename Item, typename... Patterns>
621 auto AnyOf(const Patterns&... patterns) {
622 return detail::AnyOfPattern<typename std::remove_const<Item>::type,
623 Patterns...>(patterns...);
624 }
625
626 // Creates a layout pattern that will capture the matched layout in the
627 // argument.
628 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
629 return detail::LayoutPattern<const ::xla::Layout,
630 detail::LayoutPatternBaseImpl>(
631 detail::LayoutPatternBaseImpl(), matched_layout);
632 }
633
634 // Creates a layout pattern that will capture the matched layout in the
635 // argument.
636 inline constexpr auto Layout(::xla::Layout** matched_layout) {
637 return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
638 detail::LayoutPatternBaseImpl(), matched_layout);
639 }
640
641 namespace detail {
642
643 template <typename ShapeType, typename Impl>
644 class ShapePattern;
645
646 // The base ShapePattern implementation. Matches only if the shape is not
647 // nullptr.
648 class ShapePatternBaseImpl {
649 public:
650 bool Match(const ::xla::Shape* shape, MatchOption option) const {
651 if (shape == nullptr) {
652 EXPLAIN << "Shape is null";
653 }
654 return shape != nullptr;
655 }
656
657 void DescribeTo(std::ostream* os, int64 indent = 0) const {
658 *os << "a shape";
659 }
660
661 static constexpr bool kIsTrivialMatcher = true;
662 };
663
664 // A ShapePattern implementation that matches only if the shape equals a Shape
665 // proto.
666 class ShapePatternEqualImpl {
667 public:
668 explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
669 : shape_(shape) {}
670
671 bool Match(const ::xla::Shape* shape, MatchOption option) const {
672 if (!ShapeUtil::Equal(*shape_, *shape)) {
673 EXPLAIN << "Shape not equal to "
674 << ShapeUtil::HumanStringWithLayout(*shape_);
675 return false;
676 }
677 return true;
678 }
679
680 void DescribeTo(std::ostream* os, int64 indent = 0) const {
681 *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
682 }
683
684 private:
685 const ::xla::Shape* shape_;
686 };
687
688 // A ShapePattern implementation that matches only if the shape is compatible to
689 // a Shape proto.
690 class ShapePatternCompatibleImpl {
691 public:
692 explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
693 : shape_(shape) {}
694
695 bool Match(const ::xla::Shape* shape, MatchOption option) const {
696 if (!ShapeUtil::Compatible(*shape_, *shape)) {
697 EXPLAIN << "Shape not compatible with "
698 << ShapeUtil::HumanString(*shape_);
699 return false;
700 }
701 return true;
702 }
703
704 void DescribeTo(std::ostream* os, int64 indent = 0) const {
705 *os << "compatible with " << ShapeUtil::HumanString(*shape_);
706 }
707
708 private:
709 const ::xla::Shape* shape_;
710 };
711
712 // A ShapePattern implementation that matches only if the shape has a given
713 // element type.
714 class ShapePatternElementTypeImpl {
715 public:
716 explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
717 : element_type_(element_type) {}
718
719 bool Match(const ::xla::Shape* shape, MatchOption option) const {
720 if (shape->element_type() != element_type_) {
721 EXPLAIN << "Shape does not have element type "
722 << PrimitiveType_Name(element_type_);
723 return false;
724 }
725 return true;
726 }
727
728 void DescribeTo(std::ostream* os, int64 indent = 0) const {
729 *os << "with element type " << PrimitiveType_Name(element_type_);
730 }
731
732 private:
733 PrimitiveType element_type_;
734 };
735
736 // A ShapePattern implementation that matches only if the shape is scalar.
737 class ShapePatternIsScalarImpl {
738 public:
739 explicit constexpr ShapePatternIsScalarImpl() {}
740
741 bool Match(const ::xla::Shape* shape, MatchOption option) const {
742 if (!ShapeUtil::IsScalar(*shape)) {
743 EXPLAIN << "Shape is not a scalar";
744 return false;
745 }
746 return true;
747 }
748
749 void DescribeTo(std::ostream* os, int64 indent = 0) const {
750 *os << "that represents a scalar";
751 }
752 };
753
754 // A ShapePattern implementation that matches only if the shape is an array
755 class ShapePatternIsArrayImpl {
756 public:
757 explicit constexpr ShapePatternIsArrayImpl() {}
758
759 bool Match(const ::xla::Shape* shape, MatchOption option) const {
760 if (!shape->IsArray()) {
761 EXPLAIN << "Shape is not an array";
762 return false;
763 }
764 return true;
765 }
766
767 void DescribeTo(std::ostream* os, int64 indent = 0) const {
768 *os << "that represents an array";
769 }
770 };
771
772 // A ShapePattern implementation that matches only if the shape is a tuple.
773 class ShapePatternIsTupleImpl {
774 public:
775 explicit constexpr ShapePatternIsTupleImpl() {}
776
777 bool Match(const ::xla::Shape* shape, MatchOption option) const {
778 if (!shape->IsTuple()) {
779 EXPLAIN << "Shape is not a tuple";
780 return false;
781 }
782 return true;
783 }
784
785 void DescribeTo(std::ostream* os, int64 indent = 0) const {
786 *os << "that represents a tuple";
787 }
788 };
789
790 // A ShapePattern implementation that matches only if the shape is an effective
791 // scalar.
792 class ShapePatternEffectiveScalarImpl {
793 public:
794 explicit constexpr ShapePatternEffectiveScalarImpl() {}
795
796 bool Match(const ::xla::Shape* shape, MatchOption option) const {
797 if (!ShapeUtil::IsEffectiveScalar(*shape)) {
798 EXPLAIN << "Shape is not an effective scalar";
799 return false;
800 }
801 return true;
802 }
803
804 void DescribeTo(std::ostream* os, int64 indent = 0) const {
805 *os << "that is an effective scalar";
806 }
807 };
808
809 // A ShapePattern implementation that matches only if the shape has a given
810 // rank.
811 class ShapePatternRankImpl {
812 public:
813 explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
814
815 bool Match(const ::xla::Shape* shape, MatchOption option) const {
816 if (shape->rank() != rank_) {
817 if (rank_ == 0) {
818 EXPLAIN << "Shape is not a scalar";
819 } else {
820 EXPLAIN << "Shape does not have rank " << rank_;
821 }
822 return false;
823 }
824 return true;
825 }
826
827 void DescribeTo(std::ostream* os, int64 indent = 0) const {
828 if (rank_ == 0) {
829 *os << "that is a scalar";
830 } else {
831 *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
832 }
833 }
834
835 private:
836 int64 rank_;
837 };
838
839 // A ShapePattern implementation that matches only if the shape has a layout
840 // that matches a given pattern.
841 template <typename LayoutType, typename LayoutImpl>
842 class ShapePatternLayoutImpl {
843 public:
844 explicit constexpr ShapePatternLayoutImpl(
845 const LayoutPattern<LayoutType, LayoutImpl>& layout)
846 : layout_(layout) {}
847
848 bool Match(const ::xla::Shape* shape, MatchOption option) const {
849 return LayoutUtil::HasLayout(*shape) &&
850 layout_.Match(&shape->layout(), option);
851 }
852
853 bool Match(::xla::Shape* shape, MatchOption option) const {
854 if (!LayoutUtil::HasLayout(*shape)) {
855 EXPLAIN << "Shape does not have a layout";
856 return false;
857 }
858 if (!layout_.Match(shape->mutable_layout(), option)) {
859 EXPLAIN << "\nin layout";
860 return false;
861 }
862 return true;
863 }
864
865 void DescribeTo(std::ostream* os, int64 indent = 0) const {
866 *os << "with";
867 Indent(os, indent + kIndentInc);
868 layout_.DescribeTo(os, indent + kIndentInc);
869 }
870
871 private:
872 LayoutPattern<LayoutType, LayoutImpl> layout_;
873 };
874
875 // A ShapePattern implementation that matches only if the shape has a subshape
876 // that matches a given pattern.
877 template <typename SubshapeType, typename SubshapeImpl>
878 class ShapePatternSubshapeImpl {
879 public:
880 explicit ShapePatternSubshapeImpl(
881 ShapeIndexView index,
882 const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
883 : index_(index), subshape_(subshape) {}
884
885 bool Match(const ::xla::Shape* shape, MatchOption option) const {
886 return MatchImpl(shape, option);
887 }
888
889 bool Match(::xla::Shape* shape, MatchOption option) const {
890 return MatchImpl(shape, option);
891 }
892
893 void DescribeTo(std::ostream* os, int64 indent = 0) const {
894 *os << "with subshape at index " << index_.ToString() << " which is";
895 Indent(os, indent + kIndentInc);
896 subshape_.DescribeTo(os, indent + kIndentInc);
897 }
898
899 private:
900 ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
901 return ShapeUtil::GetMutableSubshape(shape, index_);
902 }
903 const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
904 return &ShapeUtil::GetSubshape(*shape, index_);
905 }
906
907 template <typename ShapeType>
908 bool MatchImpl(ShapeType* shape, MatchOption option) const {
909 if (!ShapeUtil::IndexIsValid(*shape, index_)) {
910 EXPLAIN << "No subshape at " << index_.ToString();
911 return false;
912 }
913 if (!subshape_.Match(GetSubshape(shape), option)) {
914 EXPLAIN << "\nin subshape at " << index_.ToString();
915 return false;
916 }
917 return true;
918 }
919
920 ShapeIndexView index_;
921 ShapePattern<SubshapeType, SubshapeImpl> subshape_;
922 };
923
924 // A pattern that matches Shapes.
925 template <typename ShapeType, typename Impl>
926 class ShapePattern {
927 private:
928 template <typename NewImpl>
929 auto AppendImpl(NewImpl new_impl) const {
930 auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
931 return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
932 matched_shape_);
933 }
934
935 public:
936 explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
937 : impl_(impl), matched_shape_(matched_shape) {}
938
939 // Returns true and captures the shape iff it matches the pattern.
940 bool Match(const ::xla::Shape* shape, MatchOption option) const {
941 if (impl_.Match(shape, option)) {
942 if (option.capture && matched_shape_) {
943 *matched_shape_ = shape;
944 }
945 return true;
946 }
947 if (shape) {
948 EXPLAIN << "\nin "
949 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
950 : ShapeUtil::HumanString(*shape));
951 }
952 return false;
953 }
954
955 // Returns true and captures the shape iff it matches the pattern.
956 bool Match(::xla::Shape* shape, MatchOption option) const {
957 if (impl_.Match(shape, option)) {
958 if (option.capture && matched_shape_) {
959 *matched_shape_ = shape;
960 }
961 return true;
962 }
963 EXPLAIN << "\nin "
964 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
965 : ShapeUtil::HumanString(*shape));
966 return false;
967 }
968
969 void DescribeTo(std::ostream* os, int64 indent = 0) const {
970 return impl_.DescribeTo(os, indent);
971 }
972
973 // Modifies the pattern to match only if the shape equals the given proto.
974 // The layout must outlive the returned pattern.
975 constexpr auto EqualTo(const ::xla::Shape* shape) const {
976 return AppendImpl(ShapePatternEqualImpl(shape));
977 }
978
979 // Modifies the pattern to match only if the shape is compatible to the given
980 // proto. The layout must outlive the returned pattern.
981 constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
982 return AppendImpl(ShapePatternCompatibleImpl(shape));
983 }
984
985 // Modifies the pattern to match only if the shape has the given element type.
986 constexpr auto WithElementType(PrimitiveType element_type) const {
987 return AppendImpl(ShapePatternElementTypeImpl(element_type));
988 }
989
990 // Modifies the pattern to match only if the shape is scalar.
991 constexpr auto IsScalar() const {
992 return AppendImpl(ShapePatternIsScalarImpl());
993 }
994
995 // Modifies the pattern to match only if the shape is an array.
996 constexpr auto IsArray() const {
997 return AppendImpl(ShapePatternIsArrayImpl());
998 }
999
1000 // Modifies the pattern to match only if the shape is a tuple.
1001 constexpr auto IsTuple() const {
1002 return AppendImpl(ShapePatternIsTupleImpl());
1003 }
1004
1005 constexpr auto IsEffectiveScalar() const {
1006 return AppendImpl(ShapePatternEffectiveScalarImpl());
1007 }
1008
1009 // Modifies the pattern to match only if the shape has the given rank.
1010 constexpr auto WithRank(int64 rank) const {
1011 return AppendImpl(ShapePatternRankImpl(rank));
1012 }
1013
1014 // Modifies the pattern to match only if the shape has a layout that matches
1015 // the given pattern.
1016 template <typename LayoutType, typename LayoutImpl>
1017 auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1018 return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1019 }
1020
1021 constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1022 return WithLayout(Layout().EqualTo(layout));
1023 }
1024
1025 constexpr auto IsDenseArray() const {
1026 return WithLayout(Layout().WithDenseFormat());
1027 }
1028
1029 // Modifies the pattern to match only if the shape has a subshape that matches
1030 // the given pattern.
1031 template <typename SubshapeType, typename SubshapeImpl>
1032 auto WithSubshape(
1033 ShapeIndexView index,
1034 const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1035 return AppendImpl(
1036 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1037 }
1038
1039 ShapePattern<ShapeType,
1040 AllOfPattern<::xla::Shape, Impl,
1041 ShapePatternSubshapeImpl<
1042 const ::xla::Shape,
1043 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1044 ShapePatternEqualImpl>>>>
1045 WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1046 return WithSubshape(index,
1047 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1048 ShapePatternBaseImpl(), nullptr)
1049 .EqualTo(shape));
1050 }
1051
1052 ShapePattern<ShapeType,
1053 AllOfPattern<::xla::Shape, Impl,
1054 ShapePatternSubshapeImpl<
1055 const ::xla::Shape,
1056 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1057 ShapePatternCompatibleImpl>>>>
1058 WithSubshapeCompatibleTo(ShapeIndexView index,
1059 const ::xla::Shape* shape) const {
1060 return WithSubshape(index,
1061 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1062 ShapePatternBaseImpl(), nullptr)
1063 .CompatibleTo(shape));
1064 }
1065
1066 private:
1067 Impl impl_;
1068 ShapeType** matched_shape_;
1069 };
1070
1071 } // namespace detail
1072
1073 // Creates a shape pattern that will capture the matched layout in the argument.
1074 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1075 return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1076 detail::ShapePatternBaseImpl(), matched_shape);
1077 }
1078
1079 // Creates a shape pattern that will capture the matched layout in the argument.
1080 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1081 return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1082 detail::ShapePatternBaseImpl(), matched_shape);
1083 }
1084
1085 namespace detail {
1086
1087 // Overloads to get a const or non-const operand out of an instruction.
1088 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) {
1089 return instr->mutable_operand(idx);
1090 }
1091 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1092 int64 idx) {
1093 return instr->operand(idx);
1094 }
1095
1096 // Pretty-printer for HloInstruction. Sort of like ToShortString, but with
1097 // fewer %s and more shapes.
1098 inline string InstToString(const HloInstruction* inst) {
1099 return inst->ToString(
1100 HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1101 }
1102
1103 template <typename HloInstructionType, typename Impl>
1104 class HloInstructionPattern;
1105
1106 // The base HloInstructionPattern implementation. Matches only if the
1107 // instruction is not nullptr.
1108 class HloInstructionPatternBaseImpl {
1109 public:
1110 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1111 if (inst == nullptr) {
1112 EXPLAIN << "HloInstruction* is null";
1113 return false;
1114 }
1115 return true;
1116 }
1117
1118 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1119 *os << "an HloInstruction";
1120 }
1121
1122 static constexpr bool kIsTrivialMatcher = true;
1123 };
1124
1125 // An HloInstructionPattern implementation that matches only if the instruction
1126 // has a given name.
1127 class HloInstructionPatternNameImpl {
1128 public:
1129 explicit HloInstructionPatternNameImpl(absl::string_view name)
1130 : name_(name) {}
1131
1132 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1133 if (inst->name() != name_) {
1134 EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1135 return false;
1136 }
1137 return true;
1138 }
1139
1140 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1141 *os << "named \"" << name_ << "\"";
1142 }
1143
1144 private:
1145 absl::string_view name_;
1146 };
1147
1148 // An HloInstructionPattern implementation that matches only if the instruction
1149 // equals a particular pointer.
1150 class HloInstructionIsImpl {
1151 public:
1152 explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1153
1154 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1155 if (inst != inst_) {
1156 EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1157 << std::showbase << reinterpret_cast<uint64>(inst) << " is not "
1158 << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1159 << ")";
1160 return false;
1161 }
1162 return true;
1163 }
1164
1165 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1166 *os << "which is " << std::hex << std::nouppercase << std::showbase
1167 << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1168 << ")";
1169 }
1170
1171 private:
1172 const HloInstruction* inst_;
1173 };
1174
1175 // An HloInstructionPattern implementation that matches only if the instruction
1176 // has a given opcode.
1177 class HloInstructionPatternOpcodeImpl {
1178 public:
1179 explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1180 bool invert)
1181 : opcode_(opcode), invert_(invert) {}
1182
1183 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1184 if (invert_ && inst->opcode() == opcode_) {
1185 EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1186 << ", expected anything else";
1187 return false;
1188 }
1189 if (!invert_ && inst->opcode() != opcode_) {
1190 EXPLAIN << "HloInstruction doesn't have opcode "
1191 << HloOpcodeString(opcode_);
1192 return false;
1193 }
1194 return true;
1195 }
1196
1197 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1198 if (!invert_) {
1199 *os << "with opcode " << HloOpcodeString(opcode_);
1200 } else {
1201 *os << "with any opcode other than " << HloOpcodeString(opcode_);
1202 }
1203 }
1204
1205 private:
1206 HloOpcode opcode_;
1207 bool invert_;
1208 };
1209
1210 // An HloInstructionPattern implementation that matches only if the instruction
1211 // has a given custom call target.
1212 class HloInstructionCustomCallTargetImpl {
1213 public:
1214 explicit HloInstructionCustomCallTargetImpl(
1215 absl::string_view custom_call_target)
1216 : custom_call_target_(custom_call_target) {}
1217
1218 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1219 if (inst->opcode() != HloOpcode::kCustomCall ||
1220 inst->custom_call_target() != custom_call_target_) {
1221 EXPLAIN << "HloInstruction is not a custom call with a target '"
1222 << custom_call_target_ << "'";
1223 return false;
1224 }
1225 return true;
1226 }
1227
1228 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1229 *os << "custom call with target '" << custom_call_target_ << "'";
1230 }
1231
1232 private:
1233 std::string custom_call_target_;
1234 };
1235
1236 // An HloInstructionPattern implementation that matches only if the instruction
1237 // has the given number of operands.
1238 class HloInstructionPatternNumOperandsImpl {
1239 public:
1240 explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands)
1241 : num_operands_(num_operands) {}
1242
1243 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1244 if (inst->operand_count() != num_operands_) {
1245 EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1246 return false;
1247 }
1248 return true;
1249 }
1250
1251 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1252 *os << "with " << num_operands_ << " operand"
1253 << (num_operands_ != 1 ? "s" : "");
1254 }
1255
1256 private:
1257 int64 num_operands_;
1258 };
1259
1260 // An HloInstructionPattern implementation that matches only if the instruction
1261 // has a shape that matches a given pattern.
1262 template <typename ShapeType, typename ShapeImpl>
1263 class HloInstructionPatternShapeImpl {
1264 public:
1265 explicit constexpr HloInstructionPatternShapeImpl(
1266 const ShapePattern<ShapeType, ShapeImpl>& shape)
1267 : shape_(shape) {}
1268
1269 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1270 if (!shape_.Match(&inst->shape(), option)) {
1271 EXPLAIN << "\nin output shape";
1272 return false;
1273 }
1274 return true;
1275 }
1276
1277 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1278 if (!shape_.Match(inst->mutable_shape(), option)) {
1279 EXPLAIN << "\nin output shape";
1280 return false;
1281 }
1282 return true;
1283 }
1284
1285 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1286 *os << "outputting";
1287 Indent(os, indent + kIndentInc);
1288 shape_.DescribeTo(os, indent + kIndentInc);
1289 }
1290
1291 private:
1292 ShapePattern<ShapeType, ShapeImpl> shape_;
1293 };
1294
1295 // An HloInstructionPattern implementation that matches only if the instruction
1296 // has an operand that matches a given pattern.
1297 template <typename OperandType, typename OperandImpl>
1298 class HloInstructionPatternOperandImpl {
1299 public:
1300 explicit constexpr HloInstructionPatternOperandImpl(
1301 int64 operand_index,
1302 const HloInstructionPattern<OperandType, OperandImpl>& operand)
1303 : operand_index_(operand_index), operand_(operand) {}
1304
1305 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1306 return MatchImpl(inst, option);
1307 }
1308
1309 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1310 return MatchImpl(inst, option);
1311 }
1312
1313 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1314 *os << "with operand " << operand_index_ << " which is:";
1315 Indent(os, indent + kIndentInc);
1316 operand_.DescribeTo(os, indent + kIndentInc);
1317 }
1318
1319 private:
1320 template <typename HloInstructionType>
1321 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1322 if (operand_index_ >= inst->operand_count()) {
1323 EXPLAIN << "desired operand index " << operand_index_
1324 << " is out of bounds";
1325 return false;
1326 }
1327 if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1328 EXPLAIN << "\nin operand " << operand_index_;
1329 return false;
1330 }
1331 return true;
1332 }
1333
1334 int64 operand_index_;
1335 HloInstructionPattern<OperandType, OperandImpl> operand_;
1336 };
1337
1338 // Matches a binary instruction whose operands come in any order.
1339 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1340 typename OperandImpl2>
1341 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1342 public:
1343 explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1344 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1345 const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1346 : op1_(op1), op2_(op2) {}
1347
1348 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1349 return MatchImpl(inst, option);
1350 }
1351
1352 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1353 return MatchImpl(inst, option);
1354 }
1355
1356 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1357 *os << "with two operands in either order:";
1358 Indent(os, indent);
1359 *os << " - ";
1360 op1_.DescribeTo(os, indent + 3);
1361 Indent(os, indent);
1362 *os << " - ";
1363 op2_.DescribeTo(os, indent + 3);
1364 }
1365
1366 private:
1367 HloInstruction* operand(HloInstruction* inst, int64 idx) const {
1368 return inst->mutable_operand(idx);
1369 }
1370 const HloInstruction* operand(const HloInstruction* inst, int64 idx) const {
1371 return inst->operand(idx);
1372 }
1373
1374 template <typename HloInstructionType>
1375 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1376 // We could implement this using AnyOf and AllOf matchers, but the templates
1377 // get pretty difficult to debug, since any compile error herein becomes
1378 // not-an-error via SFINAE. Also this way lets us give better messages on
1379 // failure.
1380 if (inst->operand_count() != 2) {
1381 EXPLAIN << "HloInstruction did not have two operands";
1382 return false;
1383 }
1384
1385 // If we're not generating explanations, this is pretty simple.
1386 if (!option.explain_os) {
1387 auto try_match = [&](int64 idx1, int64 idx2) {
1388 MatchOption new_option = option;
1389 new_option.capture = false;
1390 if (op1_.Match(operand(inst, idx1), new_option) &&
1391 op2_.Match(operand(inst, idx2), new_option)) {
1392 if (option.capture) {
1393 bool matched = op1_.Match(operand(inst, idx1), option) &&
1394 op2_.Match(operand(inst, idx2), option);
1395 DCHECK(matched);
1396 }
1397 return true;
1398 }
1399 return false;
1400 };
1401 return try_match(0, 1) || try_match(1, 0);
1402 }
1403
1404 // If we are generating explanations, we have some work to do in order to
1405 // generate a helpful error.
1406 //
1407 // First, try all four operand/matcher combinations, recording the
1408 // failure explanations separately from option.explain_os. matches[i][j]
1409 // tells us if matcher_i matches operand j.
1410 bool matches[/*matcher*/ 2][/*operand*/ 2];
1411 std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1412 for (int i = 0; i < 2; ++i) {
1413 for (int j = 0; j < 2; ++j) {
1414 MatchOption new_option = option;
1415 new_option.capture = false;
1416 new_option.explain_os = &explanations[i][j];
1417 matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1418 : op2_.Match(operand(inst, j), new_option);
1419 }
1420 }
1421
1422 // Check if the match succeeded.
1423 for (int i = 0; i < 2; ++i) {
1424 if (matches[0][i] && matches[1][(i + 1) % 2]) {
1425 // Rerun the matches with capture enabled if necessary.
1426 if (option.capture) {
1427 auto* operand1 = operand(inst, i);
1428 auto* operand2 = operand(inst, (i + 1) % 2);
1429 bool matched =
1430 op1_.Match(operand1, option) && op2_.Match(operand2, option);
1431 DCHECK(matched);
1432 }
1433 return true;
1434 }
1435 }
1436
1437 auto describe_matcher = [&](int matcher_idx) {
1438 EXPLAIN << "\n - ";
1439 if (matcher_idx == 0) {
1440 op1_.DescribeTo(option.explain_os, /*indent=*/3);
1441 } else {
1442 CHECK_EQ(matcher_idx, 1);
1443 op2_.DescribeTo(option.explain_os, /*indent=*/3);
1444 }
1445 for (int i = 0; i < 2; ++i) {
1446 if (matches[matcher_idx][/*operand*/ i]) {
1447 continue;
1448 }
1449 EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1450 EXPLAIN << " - ";
1451 EXPLAIN << absl::StrReplaceAll(
1452 explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}});
1453 }
1454 };
1455
1456 // If we failed to match, one of the following is true:
1457 // 1. op1 (op2) matches neither LHS nor RHS, or
1458 // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1459 // We print different explanations depending on which case we're in.
1460
1461 // Case 1.
1462 bool wrote_explanation = false;
1463 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1464 if (!matches[i][0] && !matches[i][1]) {
1465 EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1466 << (i == 0 ? "first" : "second") << " matcher. Specifically,";
1467 describe_matcher(i);
1468 wrote_explanation = true;
1469 }
1470 }
1471
1472 // Case 2.
1473 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1474 if (matches[/*matcher*/ 0][/*operand*/ i] &&
1475 matches[/*matcher*/ 1][/*operand*/ i]) {
1476 CHECK(!matches[0][(i + 1) % 2]);
1477 CHECK(!matches[1][(i + 1) % 2]);
1478 CHECK(!wrote_explanation);
1479 EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1480 << " operand did not match either of the two matchers. "
1481 "Specifically,";
1482 describe_matcher(0);
1483 EXPLAIN << "\nand";
1484 describe_matcher(1);
1485 wrote_explanation = true;
1486 }
1487 }
1488
1489 CHECK(wrote_explanation);
1490 return false;
1491 }
1492
1493 HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1494 HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1495 };
1496
1497 // An HloInstructionPattern implementation that matches only if the instruction
1498 // is a fusion node with a particular kind.
1499 class HloInstructionPatternFusionKindImpl {
1500 public:
1501 explicit constexpr HloInstructionPatternFusionKindImpl(
1502 ::xla::HloInstruction::FusionKind kind)
1503 : kind_(kind) {}
1504
1505 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1506 return MatchImpl(inst, option);
1507 }
1508
1509 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1510 return MatchImpl(inst, option);
1511 }
1512
1513 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1514 *os << "with fusion kind " << ToString(kind_);
1515 }
1516
1517 private:
1518 template <typename HloInstructionType>
1519 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1520 if (inst->opcode() != HloOpcode::kFusion) {
1521 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1522 << "; it's not a fusion";
1523 return false;
1524 }
1525 if (inst->fusion_kind() != kind_) {
1526 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1527 return false;
1528 }
1529 return true;
1530 }
1531
1532 ::xla::HloInstruction::FusionKind kind_;
1533 };
1534
1535 // An HloInstructionPattern implementation that matches only if the instruction
1536 // is a kGetTupleElement with a particular tuple index.
1537 class HloInstructionPatternTupleIndexImpl {
1538 public:
1539 explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
1540 : tuple_index_(tuple_index) {}
1541
1542 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1543 return MatchImpl(inst, option);
1544 }
1545
1546 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1547 return MatchImpl(inst, option);
1548 }
1549
1550 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1551 *os << "which is a GTE with index " << tuple_index_;
1552 }
1553
1554 private:
1555 template <typename HloInstructionType>
1556 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1557 if (inst->opcode() != HloOpcode::kGetTupleElement) {
1558 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1559 << "; it's not a GTE at all";
1560 return false;
1561 }
1562 if (inst->tuple_index() != tuple_index_) {
1563 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1564 return false;
1565 }
1566 return true;
1567 }
1568
1569 int64 tuple_index_;
1570 };
1571
1572 class HloInstructionPatternParameterNumImpl {
1573 public:
1574 explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num)
1575 : parameter_num_(parameter_num) {}
1576
1577 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1578 return MatchImpl(inst, option);
1579 }
1580
1581 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1582 return MatchImpl(inst, option);
1583 }
1584
1585 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1586 *os << "which is parameter " << parameter_num_;
1587 }
1588
1589 private:
1590 template <typename HloInstructionType>
1591 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1592 if (inst->opcode() != HloOpcode::kParameter ||
1593 inst->parameter_number() != parameter_num_) {
1594 EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1595 return false;
1596 }
1597 return true;
1598 }
1599
1600 int64 parameter_num_;
1601 };
1602
1603 // Superclass that contains common code used by Op::WithOneUse() and
1604 // Op::WithOneUser().
1605 class HloInstructionPatternOneUseOrUserImpl {
1606 protected:
1607 bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1608 if (inst->user_count() != 1) {
1609 EXPLAIN << "HloInstruction has " << inst->user_count()
1610 << " users, but expected exactly one.";
1611 if (inst->user_count() > 1) {
1612 EXPLAIN << "\nAll users:";
1613 for (const HloInstruction* user : inst->users()) {
1614 EXPLAIN << "\n - " << InstToString(user);
1615 }
1616 }
1617 return false;
1618 }
1619 return true;
1620 }
1621 };
1622
1623 class HloInstructionPatternOneUseImpl
1624 : public HloInstructionPatternOneUseOrUserImpl {
1625 public:
1626 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1627 if (!MatchOneUser(inst, option)) {
1628 return false;
1629 }
1630
1631 int64 use_count = absl::c_count_if(
1632 inst->users()[0]->operands(),
1633 [&](const HloInstruction* operand) { return operand == inst; });
1634 if (use_count != 1) {
1635 EXPLAIN << "HloInstruction is used " << use_count
1636 << " times by its user, but is expected to be used just once: "
1637 << InstToString(inst->users()[0]);
1638 return false;
1639 }
1640 return true;
1641 }
1642
1643 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1644 *os << "which has exactly one use";
1645 }
1646 };
1647
1648 class HloInstructionPatternOneUserImpl
1649 : public HloInstructionPatternOneUseOrUserImpl {
1650 public:
1651 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1652 return MatchOneUser(inst, option);
1653 }
1654
1655 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1656 *os << "which has exactly one user (but possibly is used multiple times by "
1657 "that instruction)";
1658 }
1659 };
1660
1661 class HloInstructionPatternComparisonDirectionImpl {
1662 public:
1663 explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1664 ComparisonDirection direction)
1665 : direction_(direction) {}
1666
1667 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1668 return MatchImpl(inst, option);
1669 }
1670
1671 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1672 return MatchImpl(inst, option);
1673 }
1674
1675 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1676 *os << "which has comparison direction "
1677 << ComparisonDirectionToString(direction_);
1678 }
1679
1680 private:
1681 template <typename HloInstructionType>
1682 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1683 if (inst->opcode() != HloOpcode::kCompare ||
1684 inst->comparison_direction() != direction_) {
1685 EXPLAIN << "HloInstruction is not comparison "
1686 << ComparisonDirectionToString(direction_);
1687 return false;
1688 }
1689 return true;
1690 }
1691
1692 ComparisonDirection direction_;
1693 };
1694
1695 // Matches a constant scalar or effective scalar, optionally with a given value.
1696 template <typename ScalarTy>
1697 class HloConstantScalarImpl {
1698 public:
1699 explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1700 : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1701
1702 constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1703 : val_(val), match_effective_scalar_(match_effective_scalar) {}
1704
1705 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1706 return MatchImpl(inst, option);
1707 }
1708
1709 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1710 return MatchImpl(inst, option);
1711 }
1712
1713 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1714 *os << "which is a constant "
1715 << (match_effective_scalar_ ? "effective " : "") << "scalar";
1716 if (val_.has_value()) {
1717 *os << " with value " << *val_;
1718 }
1719 }
1720
1721 private:
1722 template <typename InstTy>
1723 bool MatchImpl(InstTy* inst, MatchOption option) const {
1724 const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1725 if (!const_inst) {
1726 EXPLAIN << "HloInstruction is not a constant";
1727 return false;
1728 }
1729 if (match_effective_scalar_ &&
1730 !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1731 EXPLAIN << "HloInstruction is not an effective scalar";
1732 return false;
1733 }
1734 if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1735 EXPLAIN << "HloInstruction is not a scalar";
1736 return false;
1737 }
1738 if (!val_.has_value()) {
1739 return true;
1740 }
1741
1742 auto const_inst_scalar_or = const_inst->literal().Reshape({});
1743 if (!const_inst_scalar_or.ok()) {
1744 EXPLAIN << "could not convert matched literal to effective scalar";
1745 return false;
1746 }
1747 Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1748 if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1749 EXPLAIN << "HloInstruction's constant value "
1750 << const_inst_scalar.ToStringWithoutShape()
1751 << " did not match expected value " << *val_;
1752 return false;
1753 }
1754 return true;
1755 }
1756
1757 absl::optional<ScalarTy> val_;
1758 bool match_effective_scalar_;
1759 };
1760
1761 // A pattern that matches HloInstructions.
1762 template <typename HloInstructionType, typename Impl>
1763 class HloInstructionPattern {
1764 private:
1765 template <typename NewImpl>
1766 auto AppendImpl(NewImpl new_impl) const {
1767 auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1768 return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1769 std::move(new_allof), matched_inst_);
1770 }
1771
1772 public:
1773 explicit constexpr HloInstructionPattern(const Impl& impl,
1774 HloInstructionType** matched_inst)
1775 : impl_(impl), matched_inst_(matched_inst) {}
1776
1777 // Returns true and captures the instruction iff it matches the pattern.
1778 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1779 if (impl_.Match(inst, option)) {
1780 if (option.capture && matched_inst_) {
1781 *matched_inst_ = inst;
1782 }
1783 return true;
1784 }
1785 if (inst != nullptr) {
1786 EXPLAIN << "\nin " << InstToString(inst);
1787 }
1788 return false;
1789 }
1790
1791 // Returns true and captures the instruction iff it matches the pattern.
1792 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1793 if (impl_.Match(inst, option)) {
1794 if (option.capture && matched_inst_) {
1795 *matched_inst_ = inst;
1796 }
1797 return true;
1798 }
1799 EXPLAIN << "\nin " << InstToString(inst);
1800 return false;
1801 }
1802
1803 // Modifies the pattern to match only if the instruction has the given name.
1804 auto WithName(absl::string_view name) const {
1805 return AppendImpl(HloInstructionPatternNameImpl(name));
1806 }
1807
1808 // Modifies the pattern to match only if the instruction has the given opcode.
1809 auto WithOpcode(HloOpcode opcode) const {
1810 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1811 }
1812
1813 // Modifies the pattern to match only the custom call with a given target.
1814 auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1815 return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
1816 }
1817
1818 auto WithNumOperands(int64 num_operands) const {
1819 return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1820 }
1821
1822 // Modifies the pattern to match only if the instruction does not have the
1823 // given opcode.
1824 auto WithoutOpcode(HloOpcode opcode) const {
1825 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1826 }
1827
1828 constexpr auto Is(const HloInstruction* instr) const {
1829 return AppendImpl(HloInstructionIsImpl(instr));
1830 }
1831
1832 // Modifies the pattern to match only if the instruction is a constant.
1833 constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1834
1835 constexpr auto IsConstantScalar() const {
1836 return AppendImpl(
1837 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1838 }
1839
1840 // This does not check that T has the same type as the instruction, so e.g.
1841 // IsConstantScalar(1.0) may match a constant of shape int32[].
1842 template <typename ScalarTy>
1843 constexpr auto IsConstantScalar(const ScalarTy& val) const {
1844 return AppendImpl(
1845 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1846 }
1847
1848 constexpr auto IsConstantEffectiveScalar() const {
1849 return AppendImpl(
1850 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1851 }
1852
1853 template <typename ScalarTy>
1854 constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1855 return AppendImpl(
1856 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1857 }
1858
1859 // Modifies the pattern to match only if the instruction is not a constant.
1860 constexpr auto IsNonConstant() const {
1861 return WithoutOpcode(HloOpcode::kConstant);
1862 }
1863
1864 // Modifies the pattern to match only if the instruction has a shape that
1865 // matches the given pattern.
1866 template <typename ShapeType, typename ShapeImpl>
1867 constexpr auto WithShape(
1868 const ShapePattern<ShapeType, ShapeImpl>& shape) const {
1869 return AppendImpl(
1870 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1871 }
1872
1873 // Make this a templated function to work around gcc 4.9.4 template infinite
1874 // recursion bug.
1875 template <typename Dummy = void>
1876 constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
1877 return WithShape(Shape().EqualTo(shape));
1878 }
1879
1880 // Make this a templated function to work around gcc 4.9.4 template infinite
1881 // recursion bug.
1882 template <typename Dummy = void>
1883 constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
1884 return WithShape(Shape().CompatibleTo(shape));
1885 }
1886
1887 // Modifies the pattern to match only if the instruction has an operand that
1888 // matches the given pattern.
1889 template <typename OperandType, typename OperandImpl>
1890 constexpr auto WithOperand(
1891 int64 operand_index,
1892 const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
1893 return AppendImpl(
1894 HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1895 operand_index, operand));
1896 }
1897
1898 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1899 typename OperandImpl2>
1900 constexpr auto WithBinaryOperandsAnyOrder(
1901 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1902 const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
1903 return AppendImpl(
1904 HloInstructionPatternBinaryOperandsAnyOrderImpl<
1905 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1906 }
1907
1908 // Modifies the pattern to match only if the instruction is a fusion node with
1909 // the given kind.
1910 constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
1911 return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1912 }
1913
1914 // Modifies the pattern to match only if the instruction is a
1915 // get-tuple-element with the given tuple index.
1916 constexpr auto WithTupleIndex(int64 tuple_index) const {
1917 return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1918 }
1919
1920 // Modifies the pattern to match only if the instruction is a parameter
1921 // with the given parameter number.
1922 constexpr auto WithParameterNum(int64 parameter_num) const {
1923 return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1924 }
1925
1926 // Modifies the pattern to match if the instruction is used exactly once.
1927 // Does not match if the instruction is used twice by the same user (e.g.
1928 // multiply(x,x)).
1929 constexpr auto WithOneUse() const {
1930 return AppendImpl(HloInstructionPatternOneUseImpl());
1931 }
1932
1933 // Modifies the pattern to match if the instruction is used by exactly one
1934 // other instruction. Will match if the instruction is used twice, so long as
1935 // it's by the same user (e.g. multiply(x,x)).
1936 constexpr auto WithOneUser() const {
1937 return AppendImpl(HloInstructionPatternOneUserImpl());
1938 }
1939
1940 // Modifies the pattern to match only if the instruction has the given
1941 // comparison direction.
1942 auto WithComparisonDirection(ComparisonDirection direction) const {
1943 return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
1944 }
1945
1946 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1947 impl_.DescribeTo(os, indent);
1948 }
1949
1950 private:
1951 Impl impl_;
1952 HloInstructionType** matched_inst_;
1953 };
1954
1955 } // namespace detail
1956
1957 // Creates an instruction pattern that will capture the matched instruction in
1958 // the argument.
1959 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
1960 return detail::HloInstructionPattern<const ::xla::HloInstruction,
1961 detail::HloInstructionPatternBaseImpl>(
1962 detail::HloInstructionPatternBaseImpl(), matched_inst);
1963 }
1964
1965 // Creates an instruction pattern that will capture the matched instruction in
1966 // the argument.
1967 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
1968 return detail::HloInstructionPattern<::xla::HloInstruction,
1969 detail::HloInstructionPatternBaseImpl>(
1970 detail::HloInstructionPatternBaseImpl(), matched_inst);
1971 }
1972
1973 // Helpers for nullary instructions.
1974 #define XLA_NULLOP_PATTERN(NAME) \
1975 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
1976 \
1977 template <typename HloInstructionType> \
1978 inline auto NAME(HloInstructionType** matched_inst) { \
1979 return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
1980 }
1981 XLA_NULLOP_PATTERN(Constant)
1982 XLA_NULLOP_PATTERN(Parameter)
1983 XLA_NULLOP_PATTERN(Iota)
1984 XLA_NULLOP_PATTERN(Rng)
1985 #undef XLA_NULLOP_PATTERN
1986
1987 // Helpers for unary instructions.
1988 #define XLA_UNOP_PATTERN(NAME) \
1989 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
1990 \
1991 template <typename Arg> \
1992 inline auto NAME(Arg&& arg) { \
1993 return Op() \
1994 .WithOpcode(HloOpcode::k##NAME) \
1995 .WithOperand(0, std::forward<Arg>(arg)); \
1996 } \
1997 \
1998 template <typename HloInstructionType, typename Arg> \
1999 inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \
2000 return Op(matched_inst) \
2001 .WithOpcode(HloOpcode::k##NAME) \
2002 .WithOperand(0, std::forward<Arg>(arg)); \
2003 }
2004 XLA_UNOP_PATTERN(Abs)
2005 XLA_UNOP_PATTERN(RoundNearestAfz)
2006 XLA_UNOP_PATTERN(Bitcast)
2007 XLA_UNOP_PATTERN(BitcastConvert)
2008 XLA_UNOP_PATTERN(Broadcast)
2009 XLA_UNOP_PATTERN(Ceil)
2010 XLA_UNOP_PATTERN(Convert)
2011 XLA_UNOP_PATTERN(Copy)
2012 XLA_UNOP_PATTERN(Cos)
2013 XLA_UNOP_PATTERN(AllReduce)
2014 XLA_UNOP_PATTERN(Exp)
2015 XLA_UNOP_PATTERN(Fft)
2016 XLA_UNOP_PATTERN(Floor)
2017 XLA_UNOP_PATTERN(GetTupleElement)
2018 XLA_UNOP_PATTERN(Imag)
2019 XLA_UNOP_PATTERN(Infeed)
2020 XLA_UNOP_PATTERN(IsFinite)
2021 XLA_UNOP_PATTERN(Log)
2022 XLA_UNOP_PATTERN(Not)
2023 XLA_UNOP_PATTERN(Negate)
2024 XLA_UNOP_PATTERN(Real)
2025 XLA_UNOP_PATTERN(Recv)
2026 XLA_UNOP_PATTERN(RecvDone)
2027 XLA_UNOP_PATTERN(ReducePrecision)
2028 XLA_UNOP_PATTERN(Reshape)
2029 XLA_UNOP_PATTERN(Reverse)
2030 XLA_UNOP_PATTERN(Rsqrt)
2031 XLA_UNOP_PATTERN(SendDone)
2032 XLA_UNOP_PATTERN(Sign)
2033 XLA_UNOP_PATTERN(Sin)
2034 XLA_UNOP_PATTERN(Slice)
2035 XLA_UNOP_PATTERN(Sqrt)
2036 XLA_UNOP_PATTERN(Tanh)
2037 XLA_UNOP_PATTERN(Transpose)
2038 #undef XLA_UNOP_PATTERN
2039
2040 // Helpers for binary instructions.
2041 #define XLA_BINOP_PATTERN(NAME) \
2042 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2043 \
2044 template <typename Lhs, typename Rhs> \
2045 inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
2046 return Op() \
2047 .WithOpcode(HloOpcode::k##NAME) \
2048 .WithOperand(0, std::forward<Lhs>(lhs)) \
2049 .WithOperand(1, std::forward<Rhs>(rhs)); \
2050 } \
2051 \
2052 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2053 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2054 return Op(matched_inst) \
2055 .WithOpcode(HloOpcode::k##NAME) \
2056 .WithOperand(0, std::forward<Lhs>(lhs)) \
2057 .WithOperand(1, std::forward<Rhs>(rhs)); \
2058 }
2059
2060 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
2061 XLA_BINOP_PATTERN(NAME) \
2062 \
2063 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2064 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2065 Rhs&& rhs) { \
2066 return Op(matched_inst) \
2067 .WithOpcode(HloOpcode::k##NAME) \
2068 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2069 std::forward<Rhs>(rhs)); \
2070 } \
2071 template <typename Lhs, typename Rhs> \
2072 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
2073 return NAME##AnyOrder<const HloInstruction>( \
2074 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2075 }
2076 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2077 XLA_BINOP_PATTERN(Atan2)
2078 XLA_BINOP_PATTERN(Divide)
2079 XLA_BINOP_PATTERN(Complex)
2080 XLA_BINOP_PATTERN(Compare)
2081 XLA_BINOP_PATTERN(Convolution)
2082 XLA_BINOP_PATTERN(Dot)
2083 XLA_BINOP_PATTERN(Gather)
2084 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2085 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2086 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2087 XLA_BINOP_PATTERN(Outfeed)
2088 XLA_BINOP_PATTERN(Pad)
2089 XLA_BINOP_PATTERN(Power)
2090 XLA_BINOP_PATTERN(ReduceWindow)
2091 XLA_BINOP_PATTERN(Remainder)
2092 XLA_BINOP_PATTERN(Send)
2093 XLA_BINOP_PATTERN(Subtract)
2094 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2095 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2096 XLA_BINOP_PATTERN(ShiftLeft)
2097 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2098 XLA_BINOP_PATTERN(ShiftRightLogical)
2099 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2100 #undef XLA_BINOP_PATTERN
2101
2102 // Helpers for ternary instructions.
2103 #define XLA_TERNOP_PATTERN(NAME) \
2104 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2105 \
2106 template <typename Arg0, typename Arg1, typename Arg2> \
2107 inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \
2108 return Op() \
2109 .WithOpcode(HloOpcode::k##NAME) \
2110 .WithOperand(0, std::forward<Arg0>(arg0)) \
2111 .WithOperand(1, std::forward<Arg1>(arg1)) \
2112 .WithOperand(2, std::forward<Arg2>(arg2)); \
2113 } \
2114 \
2115 template <typename HloInstructionType, typename Arg0, typename Arg1, \
2116 typename Arg2> \
2117 inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \
2118 Arg1&& arg1, Arg2&& arg2) { \
2119 return Op(matched_inst) \
2120 .WithOpcode(HloOpcode::k##NAME) \
2121 .WithOperand(0, std::forward<Arg0>(arg0)) \
2122 .WithOperand(1, std::forward<Arg1>(arg1)) \
2123 .WithOperand(2, std::forward<Arg2>(arg2)); \
2124 }
2125 XLA_TERNOP_PATTERN(Clamp);
2126 XLA_TERNOP_PATTERN(Scatter);
2127 XLA_TERNOP_PATTERN(Select);
2128 XLA_TERNOP_PATTERN(SelectAndScatter);
2129 #undef XLA_TERNOP_PATTERN
2130
2131 namespace detail {
2132 template <typename Matcher, typename FirstArg>
2133 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) {
2134 return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2135 }
2136
2137 template <typename Matcher, typename FirstArg, typename... Args>
2138 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
2139 Args&&... args) {
2140 return WithOperands(
2141 m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2142 operand_num + 1, std::forward<Args>(args)...);
2143 }
2144 } // namespace detail
2145
2146 #define XLA_VARIADIC_OP_PATTERN(NAME) \
2147 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2148 \
2149 template <typename... Args> \
2150 inline auto NAME(Args&&... args) { \
2151 return detail::WithOperands( \
2152 Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2153 /*operand_num=*/0, std::forward<Args>(args)...); \
2154 } \
2155 \
2156 template <typename HloInstructionType, typename... Args> \
2157 inline auto NAME(HloInstructionType** matched_inst, Args&&... args) { \
2158 return detail::WithOperands(Op(matched_inst) \
2159 .WithOpcode(HloOpcode::k##NAME) \
2160 .WithNumOperands(sizeof...(Args)), \
2161 /*operand_num=*/0, \
2162 std::forward<Args>(args)...); \
2163 }
2164
2165 // We could implement all ops as "variadic" ops, but it would make the
2166 // already-bad compile errors even worse.
2167 XLA_VARIADIC_OP_PATTERN(AfterAll);
2168 XLA_VARIADIC_OP_PATTERN(Concatenate);
2169 XLA_VARIADIC_OP_PATTERN(Conditional);
2170 XLA_VARIADIC_OP_PATTERN(CustomCall);
2171 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2172 XLA_VARIADIC_OP_PATTERN(Fusion);
2173 XLA_VARIADIC_OP_PATTERN(Map)
2174 XLA_VARIADIC_OP_PATTERN(Reduce);
2175 XLA_VARIADIC_OP_PATTERN(Sort);
2176 XLA_VARIADIC_OP_PATTERN(Tuple);
2177
2178 // Helpers for comparison instructions.
2179 #define XLA_COMPARE_PATTERN(NAME) \
2180 inline auto NAME() { \
2181 return Op() \
2182 .WithOpcode(HloOpcode::kCompare) \
2183 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2184 } \
2185 \
2186 template <typename Lhs, typename Rhs> \
2187 inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
2188 return Op() \
2189 .WithOpcode(HloOpcode::kCompare) \
2190 .WithOperand(0, std::forward<Lhs>(lhs)) \
2191 .WithOperand(1, std::forward<Rhs>(rhs)) \
2192 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2193 } \
2194 \
2195 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2196 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2197 return Op(matched_inst) \
2198 .WithOpcode(HloOpcode::kCompare) \
2199 .WithOperand(0, std::forward<Lhs>(lhs)) \
2200 .WithOperand(1, std::forward<Rhs>(rhs)) \
2201 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2202 }
2203
2204 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
2205 XLA_COMPARE_PATTERN(NAME) \
2206 \
2207 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2208 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2209 Rhs&& rhs) { \
2210 return Op(matched_inst) \
2211 .WithOpcode(HloOpcode::kCompare) \
2212 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2213 std::forward<Rhs>(rhs)); \
2214 } \
2215 template <typename Lhs, typename Rhs> \
2216 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
2217 return NAME##AnyOrder<const HloInstruction>( \
2218 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2219 }
2220
2221 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2222 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2223 XLA_COMPARE_PATTERN(Ge);
2224 XLA_COMPARE_PATTERN(Gt);
2225 XLA_COMPARE_PATTERN(Le);
2226 XLA_COMPARE_PATTERN(Lt);
2227
2228 // Helpers for matching non-constant instructions.
2229 inline auto NonConstant() { return Op().IsNonConstant(); }
2230
2231 template <typename HloInstructionType>
2232 inline auto NonConstant(HloInstructionType** matched_inst) {
2233 return Op(matched_inst).IsNonConstant();
2234 }
2235
2236 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2237 // element is selected.
2238 template <typename Arg>
2239 inline auto GetTupleElement(Arg&& arg, int64 tuple_index) {
2240 return Op()
2241 .WithOpcode(HloOpcode::kGetTupleElement)
2242 .WithOperand(0, std::forward<Arg>(arg))
2243 .WithTupleIndex(tuple_index);
2244 }
2245
2246 template <typename HloInstructionType, typename Arg>
2247 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2248 int64 tuple_index) {
2249 return Op(matched_inst)
2250 .WithOpcode(HloOpcode::kGetTupleElement)
2251 .WithOperand(0, std::forward<Arg>(arg))
2252 .WithTupleIndex(tuple_index);
2253 }
2254
2255 // Add overloads for Parameter which take an int64 specifying the parameter
2256 // number.
2257 inline auto Parameter(int64 parameter_num) {
2258 return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2259 }
2260 template <typename HloInstructionType>
2261 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) {
2262 return Op(matched_inst)
2263 .WithOpcode(HloOpcode::kParameter)
2264 .WithParameterNum(parameter_num);
2265 }
2266
2267 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2268
2269 template <typename HloInstructionType>
2270 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2271 return Op(matched_inst).IsConstantScalar();
2272 }
2273
2274 template <typename ScalarTy>
2275 inline auto ConstantScalar(ScalarTy val) {
2276 return Op().IsConstantScalar(val);
2277 }
2278
2279 template <typename HloInstructionType, typename ScalarTy>
2280 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2281 return Op(matched_inst).IsConstantScalar(val);
2282 }
2283
2284 inline auto ConstantEffectiveScalar() {
2285 return Op().IsConstantEffectiveScalar();
2286 }
2287
2288 template <typename HloInstructionType>
2289 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2290 return Op(matched_inst).IsConstantEffectiveScalar();
2291 }
2292
2293 template <typename ScalarTy>
2294 inline auto ConstantEffectiveScalar(ScalarTy val) {
2295 return Op().IsConstantEffectiveScalar(val);
2296 }
2297
2298 template <typename HloInstructionType, typename ScalarTy>
2299 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2300 ScalarTy val) {
2301 return Op(matched_inst).IsConstantEffectiveScalar(val);
2302 }
2303
2304 } // namespace match
2305
2306 } // namespace xla
2307
2308 #undef EXPLAIN
2309 #pragma pop_macro("EXPLAIN")
2310 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2311