1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18
19 #include <type_traits>
20
21 #include "absl/strings/str_replace.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/utility/utility.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31
32 namespace xla {
33
34 // A pattern matcher for HloInstructions, Shapes, and Layouts.
35 //
36 // The Match function's first argument must be HloInstruction*, Shape*, or
37 // Layout*. The second argument is a pattern that will be matched against the
38 // first argument, as described below.
39 //
40 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
41 // functions. By default, the returned patterns will match any HloInstruction,
42 // Shape, or Layout, respectively. However the match can be made more specific
43 // by using the pattern's modifier methods, for example:
44 //
45 // match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
46 // 0, match::Op().WithOpcode(HloOpcode::kConstant))
47 //
48 // This pattern will match Add instructions whose first operand is a constant.
49 //
50 // Each pattern type has the following modifiers, which are described where
51 // nontrivial.
52 //
53 // Op():
54 // - Is: is the given HloInstruction* (i.e. pointer equality)
55 // - WithName
56 // - WithOpcode
57 // - WithoutOpcode: anything other than the given opcode
58 // - WithShape: instr's shape matches the given pattern
59 // - WithShapeEqualTo: instr's shape is equal to the given Shape
60 // - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
61 // - WithNumOperands
62 // - WithOperand: operand at the given index matches the given pattern
63 // - IsConstant
64 // - IsNonConstant
65 // - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
66 // e.g. IsConstantScalar() or IsConstantScalar(42).
67 // - WithFusionKind
68 // - WithTupleIndex: get-tuple-element operations with the given tuple index
69 // - WithOneUse: Instruction is used as an operand exactly once.
70 // - WithOneUser: Instruction is used by exactly one other instruction, but
71 // is possibly used more than once as an operand (e.g. multiply(x,x)).
72 // - WithComparisonDirection: instr has the given direction
73 //
74 // Shape():
75 // - EqualTo
76 // - CompatibleTo
77 // - IsScalar/IsEffectiveScalar/IsArray/IsTuple
78 // - IsDenseArray
79 // - WithLayout: layout shape's layout matches the given pattern (e.g.
80 // Layout().WithDenseFormat())
81 // - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
82 // Layout, but not the result of Layout().foo())
83 // - WithSubshape: shape is a tuple whose subshape matches the given pattern
84 // (e.g. Shape().IsScalar()).
85 // - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
86 // (i.e. another Shape, but not the result of Shape().foo())
87 // - WithElementType: shape is an array/scalar with the given elem type
88 // - WithRank: shape is an array/scalar with the given rank
89 //
90 // Layout():
91 // - EqualTo
92 // - WithDenseFormat
93 //
94 // Op(), Shape(), and Layout() may be passed an argument of type
95 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
96 // these pointers. If the pattern is matched, the address of the matched value
97 // will be "captured" and stored at this location.
98 //
99 // For example:
100 // HloInstruction* foo = ...;
101 // HloInstruction* matched_operand;
102 // CHECK(Match(foo,
103 // match::Op().WithOperand(0, match::Op(&matched_operand))));
104 //
105 // Helpers are provided for most HLO instructions. These helpers can be called
106 // with no arguments, in which case they will match any instruction matching the
107 // opcode. They may also be called with matches for the operands and with an
108 // optional capture. (The capture must be the first argument.) Some examples of
109 // these helpers and their equivalents are provided below.
110
111 // Example nullary instruction:
112 // Parameter() == Op().WithOpcode(HloOpcode::kParameter)
113 // Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter)
114 //
115 // Example unary instruction:
116 // Abs() == Op().WithOpcode(HloOpcode::kAbs)
117 // Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs)
118 // .WithOperand(0, Op(&a)))
119 // Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs)
120 // .WithOperand(0, Op(&b))
121 //
122 // Commutative binary instructions have a special form that accepts either order
123 // of args, e.g.:
124 //
125 // AddAnyOrder(Parameter(1), Abs()) ==
126 // Op().WithOpcode(HloOpcode::kAdd)
127 // .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
128 //
129 // MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`.
130 //
131 // The following additional helpers are provided. In all cases, `&a` is
132 // optional.
133 //
134 // ConstantScalar(&a) == Op(&a).IsConstantScalar();
135 // ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v);
136 // ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar();
137 // ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v)
138 // NonConstant(&a) == Op(&a).IsNonConstant()
139 // GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index)
140 // .WithOperand(0, b);
141 // Parameter(&a, n) == Op(&a).WithParameterNum(n);
142
143 struct MatchOption {
144 // If true, actually capture matched item into the user pointer.
145 bool capture;
146
147 // An explanation for why we failed to match is streamed here, if not-null.
148 std::ostream* explain_os;
149 };
150
151 template <typename Value, typename Pattern>
152 bool Match(Value* value, const Pattern& pattern,
153 MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
154 if (option.capture) {
155 auto new_option = option;
156 new_option.capture = false;
157 if (!pattern.Match(value, new_option)) {
158 return false;
159 }
160 }
161 return pattern.Match(value, option);
162 }
163
164 namespace match {
165
166 namespace detail {
167
168 // Macro for streaming to option.explain_os if it's not null.
169 //
170 // EXPLAIN << "value of foo(): " << foo()
171 //
172 #pragma push_macro("EXPLAIN")
173 #define EXPLAIN \
174 if (option.explain_os) *option.explain_os
175
176 // kIndentInc is the additional number of spaces that we indent by when we
177 // increase the indent "by one".
178 enum {
179 kIndentInc = 2,
180 };
181
182 // Writes a newline and then `indent` spaces.
183 //
184 // We follow an unintuitive convention in this file's pretty-printers: Indents
185 // are performed by the caller, not the callee. For example, if you want to
186 // print
187 //
188 // foo:
189 // - bar
190 //
191 // you'd do:
192 //
193 // Foo::DescribeTo(std::ostream* os, int64 indent) {
194 // *os << "foo:";
195 // Indent(os, indent) // Create a newline at the *current* indent level.
196 // *os << " - ";
197 // bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3.
198 // }
199 //
200 // Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
201 //
202 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
203 // performed by Foo. This convention allows the caller to decide whether a
204 // matcher is preceded by a newline, which is important e.g. for the AllOf
205 // matcher.
206 //
207 // (Incidentally, indenting in Match's explanations is handled differently.
208 // Indents are a common case in DescribeTo [we're printing a whole tree], but
209 // they're a special case in Match [we're printing only a path through the tree
210 // that encounters a failing node]. Indents in Match only appear when we
211 // encounter a failing disjunction, so we just handle them as a special case
212 // there.)
Indent(std::ostream * os,int64_t indent)213 inline void Indent(std::ostream* os, int64_t indent) {
214 *os << "\n";
215 for (int64_t i = 0; i < indent; ++i) {
216 *os << " ";
217 }
218 }
219
220 // SFINAE template that determines whether T declares a static member
221 // kIsTrivialMatcher.
222 //
223 // Trivial matchers get special treatment. For example, when printing
224 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
225 // yields e.g.
226 // "a shape compatible with f32[1,2]"
227 // rather than
228 // "a shape AND compatible with f32[1,2]"
229 template <typename T, typename Dummy = void>
230 struct IsTrivialMatcher {
231 static constexpr bool value = false;
232 };
233 template <typename T>
234 struct IsTrivialMatcher<T,
235 typename std::enable_if<T::kIsTrivialMatcher>::type> {
236 static constexpr bool value = true;
237 };
238
239 template <typename Item, typename... Patterns>
240 class AllOfPattern {
241 public:
242 explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
243
244 bool Match(const Item* item, MatchOption option) const {
245 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
246 // This invariant is guaranteed by the top-level Match and AnyOf.
247 DCHECK(matched || !option.capture);
248 return matched;
249 }
250
251 bool Match(Item* item, MatchOption option) const {
252 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
253 // This invariant is guaranteed by the top-level Match and AnyOf.
254 DCHECK(matched || !option.capture);
255 return matched;
256 }
257
258 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
259 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
260 }
261
262 // Accessor for patterns_. Please don't use this outside of this file.
263 const std::tuple<Patterns...>& patterns() const { return patterns_; }
264
265 private:
266 template <typename ItemType, size_t index>
267 bool MatchImpl(ItemType* item, MatchOption option,
268 std::integral_constant<size_t, index>) const {
269 // We don't need to do any EXPLAINing here; it's all correctly handled by
270 // our sub-matchers (if any fail).
271 return std::get<index>(patterns_).Match(item, option) &&
272 MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
273 }
274
275 template <typename ItemType>
276 bool MatchImpl(ItemType* item, MatchOption option,
277 std::integral_constant<size_t, sizeof...(Patterns)>) const {
278 return true;
279 }
280
281 // Pretty-printing a conjunction has some special cases to make it easy to
282 // read in the simple (common) case.
283 //
284 // If sizeof...(Patterns) == 1, prints as e.g.
285 //
286 // a shape
287 //
288 // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
289 // shape") prints as
290 //
291 // a shape compatible with f32[1,2]
292 //
293 // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
294 //
295 // a shape:
296 // * compatible with f32[1,2] AND
297 // * that represents a scalar
298 //
299 // Otherwise prints as:
300 //
301 // all of:
302 // * foo AND
303 // * bar
304 //
305 template <size_t index>
306 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
307 int64_t indent) const {
308 constexpr bool first_is_trivial =
309 IsTrivialMatcher<typename std::remove_reference<decltype(std::get<0>(
310 patterns_))>::type>::value;
311 constexpr bool is_last = index == sizeof...(Patterns) - 1;
312 const auto& submatcher = std::get<index>(patterns_);
313
314 auto print_bulleted_item = [&] {
315 *os << " * ";
316 submatcher.DescribeTo(os, indent + 3);
317 if (!is_last) {
318 *os << " AND";
319 Indent(os, indent);
320 }
321 };
322
323 if (index == 0) {
324 if (first_is_trivial || is_last) {
325 submatcher.DescribeTo(os, indent + kIndentInc);
326 if (sizeof...(Patterns) > 2) {
327 *os << ":";
328 Indent(os, indent);
329 }
330 } else {
331 *os << "all of:";
332 Indent(os, indent);
333 print_bulleted_item();
334 }
335 } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
336 *os << " ";
337 submatcher.DescribeTo(os, indent);
338 } else {
339 print_bulleted_item();
340 }
341 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
342 }
343
344 void DescribeToImpl(std::ostream* os,
345 std::integral_constant<size_t, sizeof...(Patterns)>,
346 int64_t indent) const {}
347
348 std::tuple<Patterns...> patterns_;
349 };
350
351 } // namespace detail
352
353 // Returns a pattern that represents the conjunction of all input patterns. All
354 // patterns need to match in order to have the AllOf pattern match.
355 template <typename Item, typename... Patterns>
356 auto AllOf(const Patterns&... patterns) {
357 return detail::AllOfPattern<typename std::remove_const<Item>::type,
358 Patterns...>(patterns...);
359 }
360
361 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
362 //
363 // This transformation is necessary for good pretty-printing.
364 template <typename Item, typename... InnerPs, typename... OuterPs>
365 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
366 const OuterPs&... outer_ps) {
367 // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
368 auto make_all_of = [](const InnerPs&... inner_ps,
369 const OuterPs&... outer_ps) {
370 return detail::AllOfPattern<typename std::remove_const<Item>::type,
371 InnerPs..., OuterPs...>(inner_ps...,
372 outer_ps...);
373 };
374 return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
375 std::make_tuple(outer_ps...)));
376 }
377
378 namespace detail {
379
380 template <typename LayoutType, typename Impl>
381 class LayoutPattern;
382
383 // The base LayoutPattern implementation. Matches only if the layout is not
384 // nullptr.
385 class LayoutPatternBaseImpl {
386 public:
387 bool Match(const ::xla::Layout* layout, MatchOption option) const {
388 if (layout == nullptr) {
389 EXPLAIN << "Layout is null";
390 return false;
391 }
392 return true;
393 }
394
395 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
396 *os << "a layout";
397 }
398
399 static constexpr bool kIsTrivialMatcher = true;
400 };
401
402 // A LayoutPattern implementation that matches only if the layout equals a
403 // Layout proto.
404 class LayoutPatternEqualImpl {
405 public:
406 explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
407 : layout_(layout) {}
408
409 bool Match(const ::xla::Layout* layout, MatchOption option) const {
410 if (!LayoutUtil::Equal(*layout_, *layout)) {
411 EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
412 << " is not equal to expected "
413 << LayoutUtil::HumanString(*layout_);
414 return false;
415 }
416 return true;
417 }
418
419 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
420 *os << "equal to " << LayoutUtil::HumanString(*layout_);
421 }
422
423 private:
424 const ::xla::Layout* layout_;
425 };
426
427 // A LayoutPattern implementation that matches only if the layout has a given
428 // format.
429 class LayoutPatternFormatImpl {
430 public:
431 explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
432
433 bool Match(const ::xla::Layout* layout, MatchOption option) const {
434 if (layout->format() != format_) {
435 EXPLAIN << "Layout has format " << Format_Name(layout->format())
436 << " but expected " << Format_Name(format_);
437 return false;
438 }
439 return true;
440 }
441
442 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
443 *os << "with format " << Format_Name(format_);
444 }
445
446 private:
447 Format format_;
448 };
449
450 // A pattern that matches Layouts.
451 template <typename LayoutType, typename Impl>
452 class LayoutPattern {
453 private:
454 template <typename NewImpl>
455 auto AppendImpl(NewImpl new_impl) const {
456 auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
457 return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
458 matched_layout_);
459 }
460
461 public:
462 explicit constexpr LayoutPattern(const Impl& impl,
463 LayoutType** matched_layout)
464 : impl_(impl), matched_layout_(matched_layout) {}
465
466 // Returns true and captures the layout iff it matches the pattern.
467 bool Match(const ::xla::Layout* layout, MatchOption option) const {
468 if (impl_.Match(layout, option)) {
469 if (option.capture && matched_layout_) {
470 *matched_layout_ = layout;
471 }
472 return true;
473 }
474 return false;
475 }
476
477 // Returns true and captures the layout iff it matches the pattern.
478 bool Match(::xla::Layout* layout, MatchOption option) const {
479 if (impl_.Match(layout, option)) {
480 if (option.capture && matched_layout_) {
481 *matched_layout_ = layout;
482 }
483 return true;
484 }
485 return false;
486 }
487
488 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
489 impl_.DescribeTo(os, indent);
490 }
491
492 // Modifies the pattern to match only if the layout equals the given proto.
493 // The layout must outlive the returned pattern.
494 constexpr auto EqualTo(const ::xla::Layout* layout) const {
495 return AppendImpl(LayoutPatternEqualImpl(layout));
496 }
497
498 // Modifies the pattern to match only if the layout has a dense format.
499 constexpr auto WithDenseFormat() const {
500 return AppendImpl(LayoutPatternFormatImpl(DENSE));
501 }
502
503 private:
504 Impl impl_;
505 LayoutType** matched_layout_;
506 };
507
508 template <typename Item, typename... Patterns>
509 class AnyOfPattern {
510 public:
511 explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
512
513 bool Match(const Item* item, MatchOption option) const {
514 return MatchImpl(item, option);
515 }
516
517 bool Match(Item* item, MatchOption option) const {
518 return MatchImpl(item, option);
519 }
520
521 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
522 *os << "any of:";
523 Indent(os, indent);
524 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
525 }
526
527 private:
528 template <typename ItemType>
529 bool MatchImpl(ItemType* item, MatchOption option) const {
530 // If we're generating an explanation, buffer it until we know we failed.
531 absl::optional<std::stringstream> explanation;
532 MatchOption new_option = option;
533 if (option.explain_os) {
534 new_option.explain_os = &explanation.emplace();
535 }
536 bool rv = MatchRecursiveImpl(item, new_option,
537 std::integral_constant<size_t, 0>());
538 if (!rv && option.explain_os) {
539 EXPLAIN << "None of the following matchers succeeded:";
540 EXPLAIN << explanation->str();
541 }
542 return rv;
543 }
544
545 template <typename ItemType, size_t index>
546 bool MatchRecursiveImpl(ItemType* item, MatchOption option,
547 std::integral_constant<size_t, index>) const {
548 auto new_option = option;
549 new_option.capture = false;
550
551 absl::optional<std::stringstream> explanation;
552 if (option.explain_os) {
553 new_option.explain_os = &explanation.emplace();
554 }
555
556 // Try to match the sub-pattern without capturing behavior.
557 if (std::get<index>(patterns_).Match(item, new_option)) {
558 // Capture the branch.
559 if (option.capture) {
560 // TODO(timshen): Currently the behavior can be exponential. Optimize it
561 // with memoization or recording the matched sub-pattern index, if it
562 // takes too long to run.
563 //
564 // Specifically, the "memoization" approach is to create an empty
565 // container with the key (pattern, instruction), and value as whether
566 // matched or not.
567 //
568 // Alternatively, we may run the pattern matching with captures off, but
569 // instead record a "trace" somewhere, indicating how exactly the
570 // pattern matches the input. For example, the trace information for
571 // AnyOf will be a runtime number indicate which sub-pattern is matched.
572 // Then we run another pass to do captures only with the help of the
573 // trace.
574 bool matched = std::get<index>(patterns_).Match(item, option);
575 DCHECK(matched);
576 }
577 return true;
578 }
579 if (option.explain_os) {
580 EXPLAIN << "\nMatcher #" << index + 1;
581 EXPLAIN << "\n - ";
582 std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
583 EXPLAIN << "\nfailed with";
584 EXPLAIN << "\n - ";
585 EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}});
586 }
587 return MatchRecursiveImpl(item, option,
588 std::integral_constant<size_t, index + 1>());
589 }
590
591 template <typename ItemType>
592 bool MatchRecursiveImpl(
593 ItemType* item, MatchOption option,
594 std::integral_constant<size_t, sizeof...(Patterns)>) const {
595 return false;
596 }
597
598 template <size_t index>
599 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
600 int64_t indent) const {
601 *os << " - ";
602 std::get<index>(patterns_).DescribeTo(os, indent + 3);
603 if (index != sizeof...(Patterns) - 1) {
604 *os << " OR";
605 Indent(os, indent);
606 }
607 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
608 }
609
610 void DescribeToImpl(std::ostream* os,
611 std::integral_constant<size_t, sizeof...(Patterns)>,
612 int64_t indent) const {}
613
614 std::tuple<Patterns...> patterns_;
615 };
616
617 } // namespace detail
618
619 // Returns a pattern that represents the logical disjunction of the input
620 // patterns. The returned pattern matches from left to right, and stops on the
621 // first match.
622 template <typename Item, typename... Patterns>
623 auto AnyOf(const Patterns&... patterns) {
624 return detail::AnyOfPattern<typename std::remove_const<Item>::type,
625 Patterns...>(patterns...);
626 }
627
628 // Creates a layout pattern that will capture the matched layout in the
629 // argument.
630 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
631 return detail::LayoutPattern<const ::xla::Layout,
632 detail::LayoutPatternBaseImpl>(
633 detail::LayoutPatternBaseImpl(), matched_layout);
634 }
635
636 // Creates a layout pattern that will capture the matched layout in the
637 // argument.
638 inline constexpr auto Layout(::xla::Layout** matched_layout) {
639 return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
640 detail::LayoutPatternBaseImpl(), matched_layout);
641 }
642
643 namespace detail {
644
645 template <typename ShapeType, typename Impl>
646 class ShapePattern;
647
648 // The base ShapePattern implementation. Matches only if the shape is not
649 // nullptr.
650 class ShapePatternBaseImpl {
651 public:
652 bool Match(const ::xla::Shape* shape, MatchOption option) const {
653 if (shape == nullptr) {
654 EXPLAIN << "Shape is null";
655 }
656 return shape != nullptr;
657 }
658
659 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
660 *os << "a shape";
661 }
662
663 static constexpr bool kIsTrivialMatcher = true;
664 };
665
666 // A ShapePattern implementation that matches only if the shape equals a Shape
667 // proto.
668 class ShapePatternEqualImpl {
669 public:
670 explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
671 : shape_(shape) {}
672
673 bool Match(const ::xla::Shape* shape, MatchOption option) const {
674 if (!ShapeUtil::Equal(*shape_, *shape)) {
675 EXPLAIN << "Shape not equal to "
676 << ShapeUtil::HumanStringWithLayout(*shape_);
677 return false;
678 }
679 return true;
680 }
681
682 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
683 *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
684 }
685
686 private:
687 const ::xla::Shape* shape_;
688 };
689
690 // A ShapePattern implementation that matches only if the shape is compatible to
691 // a Shape proto.
692 class ShapePatternCompatibleImpl {
693 public:
694 explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
695 : shape_(shape) {}
696
697 bool Match(const ::xla::Shape* shape, MatchOption option) const {
698 if (!ShapeUtil::Compatible(*shape_, *shape)) {
699 EXPLAIN << "Shape not compatible with "
700 << ShapeUtil::HumanString(*shape_);
701 return false;
702 }
703 return true;
704 }
705
706 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
707 *os << "compatible with " << ShapeUtil::HumanString(*shape_);
708 }
709
710 private:
711 const ::xla::Shape* shape_;
712 };
713
714 // A ShapePattern implementation that matches only if the shape has a given
715 // element type.
716 class ShapePatternElementTypeImpl {
717 public:
718 explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
719 : element_type_(element_type) {}
720
721 bool Match(const ::xla::Shape* shape, MatchOption option) const {
722 if (shape->element_type() != element_type_) {
723 EXPLAIN << "Shape does not have element type "
724 << PrimitiveType_Name(element_type_);
725 return false;
726 }
727 return true;
728 }
729
730 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
731 *os << "with element type " << PrimitiveType_Name(element_type_);
732 }
733
734 private:
735 PrimitiveType element_type_;
736 };
737
738 // A ShapePattern implementation that matches only if the shape has a given
739 // list of dimensions.
740 class ShapePatternDimsImpl {
741 public:
742 explicit ShapePatternDimsImpl(absl::Span<const int64> dims)
743 : dims_(dims.begin(), dims.end()) {}
744
745 bool Match(const ::xla::Shape* shape, MatchOption option) const {
746 if (shape->dimensions() != dims_) {
747 EXPLAIN << "Shape does not have dimensions [" << absl::StrJoin(dims_, ",")
748 << "]";
749 return false;
750 }
751 return true;
752 }
753
754 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
755 *os << "with dimensions [" << absl::StrJoin(dims_, ",") << "]";
756 }
757
758 private:
759 absl::InlinedVector<int64, 8> dims_;
760 };
761
762 // A ShapePattern implementation that matches only if the shape is scalar.
763 class ShapePatternIsScalarImpl {
764 public:
765 explicit constexpr ShapePatternIsScalarImpl() {}
766
767 bool Match(const ::xla::Shape* shape, MatchOption option) const {
768 if (!ShapeUtil::IsScalar(*shape)) {
769 EXPLAIN << "Shape is not a scalar";
770 return false;
771 }
772 return true;
773 }
774
775 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
776 *os << "that represents a scalar";
777 }
778 };
779
780 // A ShapePattern implementation that matches only if the shape is an array
781 class ShapePatternIsArrayImpl {
782 public:
783 explicit constexpr ShapePatternIsArrayImpl() {}
784
785 bool Match(const ::xla::Shape* shape, MatchOption option) const {
786 if (!shape->IsArray()) {
787 EXPLAIN << "Shape is not an array";
788 return false;
789 }
790 return true;
791 }
792
793 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
794 *os << "that represents an array";
795 }
796 };
797
798 // A ShapePattern implementation that matches only if the shape is a tuple.
799 class ShapePatternIsTupleImpl {
800 public:
801 explicit constexpr ShapePatternIsTupleImpl() {}
802
803 bool Match(const ::xla::Shape* shape, MatchOption option) const {
804 if (!shape->IsTuple()) {
805 EXPLAIN << "Shape is not a tuple";
806 return false;
807 }
808 return true;
809 }
810
811 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
812 *os << "that represents a tuple";
813 }
814 };
815
816 // A ShapePattern implementation that matches only if the shape is an effective
817 // scalar.
818 class ShapePatternEffectiveScalarImpl {
819 public:
820 explicit constexpr ShapePatternEffectiveScalarImpl() {}
821
822 bool Match(const ::xla::Shape* shape, MatchOption option) const {
823 if (!ShapeUtil::IsEffectiveScalar(*shape)) {
824 EXPLAIN << "Shape is not an effective scalar";
825 return false;
826 }
827 return true;
828 }
829
830 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
831 *os << "that is an effective scalar";
832 }
833 };
834
835 // A ShapePattern implementation that matches only if the shape has a given
836 // rank.
837 class ShapePatternRankImpl {
838 public:
839 explicit constexpr ShapePatternRankImpl(int64_t rank) : rank_(rank) {}
840
841 bool Match(const ::xla::Shape* shape, MatchOption option) const {
842 if (shape->rank() != rank_) {
843 if (rank_ == 0) {
844 EXPLAIN << "Shape is not a scalar";
845 } else {
846 EXPLAIN << "Shape does not have rank " << rank_;
847 }
848 return false;
849 }
850 return true;
851 }
852
853 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
854 if (rank_ == 0) {
855 *os << "that is a scalar";
856 } else {
857 *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
858 }
859 }
860
861 private:
862 int64 rank_;
863 };
864
865 // A ShapePattern implementation that matches only if the shape has a layout
866 // that matches a given pattern.
867 template <typename LayoutType, typename LayoutImpl>
868 class ShapePatternLayoutImpl {
869 public:
870 explicit constexpr ShapePatternLayoutImpl(
871 const LayoutPattern<LayoutType, LayoutImpl>& layout)
872 : layout_(layout) {}
873
874 bool Match(const ::xla::Shape* shape, MatchOption option) const {
875 return LayoutUtil::HasLayout(*shape) &&
876 layout_.Match(&shape->layout(), option);
877 }
878
879 bool Match(::xla::Shape* shape, MatchOption option) const {
880 if (!LayoutUtil::HasLayout(*shape)) {
881 EXPLAIN << "Shape does not have a layout";
882 return false;
883 }
884 if (!layout_.Match(shape->mutable_layout(), option)) {
885 EXPLAIN << "\nin layout";
886 return false;
887 }
888 return true;
889 }
890
891 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
892 *os << "with";
893 Indent(os, indent + kIndentInc);
894 layout_.DescribeTo(os, indent + kIndentInc);
895 }
896
897 private:
898 LayoutPattern<LayoutType, LayoutImpl> layout_;
899 };
900
901 // A ShapePattern implementation that matches only if the shape has a subshape
902 // that matches a given pattern.
903 template <typename SubshapeType, typename SubshapeImpl>
904 class ShapePatternSubshapeImpl {
905 public:
906 explicit ShapePatternSubshapeImpl(
907 ShapeIndexView index,
908 const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
909 : index_(index), subshape_(subshape) {}
910
911 bool Match(const ::xla::Shape* shape, MatchOption option) const {
912 return MatchImpl(shape, option);
913 }
914
915 bool Match(::xla::Shape* shape, MatchOption option) const {
916 return MatchImpl(shape, option);
917 }
918
919 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
920 *os << "with subshape at index " << index_.ToString() << " which is";
921 Indent(os, indent + kIndentInc);
922 subshape_.DescribeTo(os, indent + kIndentInc);
923 }
924
925 private:
926 ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
927 return ShapeUtil::GetMutableSubshape(shape, index_);
928 }
929 const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
930 return &ShapeUtil::GetSubshape(*shape, index_);
931 }
932
933 template <typename ShapeType>
934 bool MatchImpl(ShapeType* shape, MatchOption option) const {
935 if (!ShapeUtil::IndexIsValid(*shape, index_)) {
936 EXPLAIN << "No subshape at " << index_.ToString();
937 return false;
938 }
939 if (!subshape_.Match(GetSubshape(shape), option)) {
940 EXPLAIN << "\nin subshape at " << index_.ToString();
941 return false;
942 }
943 return true;
944 }
945
946 ShapeIndexView index_;
947 ShapePattern<SubshapeType, SubshapeImpl> subshape_;
948 };
949
950 // A pattern that matches Shapes.
951 template <typename ShapeType, typename Impl>
952 class ShapePattern {
953 private:
954 template <typename NewImpl>
955 auto AppendImpl(NewImpl new_impl) const {
956 auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
957 return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
958 matched_shape_);
959 }
960
961 public:
962 explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
963 : impl_(impl), matched_shape_(matched_shape) {}
964
965 // Returns true and captures the shape iff it matches the pattern.
966 bool Match(const ::xla::Shape* shape, MatchOption option) const {
967 if (impl_.Match(shape, option)) {
968 if (option.capture && matched_shape_) {
969 *matched_shape_ = shape;
970 }
971 return true;
972 }
973 if (shape) {
974 EXPLAIN << "\nin "
975 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
976 : ShapeUtil::HumanString(*shape));
977 }
978 return false;
979 }
980
981 // Returns true and captures the shape iff it matches the pattern.
982 bool Match(::xla::Shape* shape, MatchOption option) const {
983 if (impl_.Match(shape, option)) {
984 if (option.capture && matched_shape_) {
985 *matched_shape_ = shape;
986 }
987 return true;
988 }
989 EXPLAIN << "\nin "
990 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
991 : ShapeUtil::HumanString(*shape));
992 return false;
993 }
994
995 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
996 return impl_.DescribeTo(os, indent);
997 }
998
999 // Modifies the pattern to match only if the shape equals the given proto.
1000 // The layout must outlive the returned pattern.
1001 constexpr auto EqualTo(const ::xla::Shape* shape) const {
1002 return AppendImpl(ShapePatternEqualImpl(shape));
1003 }
1004
1005 // Modifies the pattern to match only if the shape is compatible to the given
1006 // proto. The layout must outlive the returned pattern.
1007 constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
1008 return AppendImpl(ShapePatternCompatibleImpl(shape));
1009 }
1010
1011 // Modifies the pattern to match only if the shape has the given element type.
1012 constexpr auto WithElementType(PrimitiveType element_type) const {
1013 return AppendImpl(ShapePatternElementTypeImpl(element_type));
1014 }
1015
1016 constexpr auto WithDims(absl::Span<const int64> dims) const {
1017 return AppendImpl(ShapePatternDimsImpl(dims));
1018 }
1019
1020 // Modifies the pattern to match only if the shape is scalar.
1021 constexpr auto IsScalar() const {
1022 return AppendImpl(ShapePatternIsScalarImpl());
1023 }
1024
1025 // Modifies the pattern to match only if the shape is an array.
1026 constexpr auto IsArray() const {
1027 return AppendImpl(ShapePatternIsArrayImpl());
1028 }
1029
1030 // Modifies the pattern to match only if the shape is a tuple.
1031 constexpr auto IsTuple() const {
1032 return AppendImpl(ShapePatternIsTupleImpl());
1033 }
1034
1035 constexpr auto IsEffectiveScalar() const {
1036 return AppendImpl(ShapePatternEffectiveScalarImpl());
1037 }
1038
1039 // Modifies the pattern to match only if the shape has the given rank.
1040 constexpr auto WithRank(int64_t rank) const {
1041 return AppendImpl(ShapePatternRankImpl(rank));
1042 }
1043
1044 // Modifies the pattern to match only if the shape has a layout that matches
1045 // the given pattern.
1046 template <typename LayoutType, typename LayoutImpl>
1047 auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1048 return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1049 }
1050
1051 constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1052 return WithLayout(Layout().EqualTo(layout));
1053 }
1054
1055 constexpr auto IsDenseArray() const {
1056 return WithLayout(Layout().WithDenseFormat());
1057 }
1058
1059 // Modifies the pattern to match only if the shape has a subshape that matches
1060 // the given pattern.
1061 template <typename SubshapeType, typename SubshapeImpl>
1062 auto WithSubshape(
1063 ShapeIndexView index,
1064 const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1065 return AppendImpl(
1066 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1067 }
1068
1069 ShapePattern<ShapeType,
1070 AllOfPattern<::xla::Shape, Impl,
1071 ShapePatternSubshapeImpl<
1072 const ::xla::Shape,
1073 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1074 ShapePatternEqualImpl>>>>
1075 WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1076 return WithSubshape(index,
1077 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1078 ShapePatternBaseImpl(), nullptr)
1079 .EqualTo(shape));
1080 }
1081
1082 ShapePattern<ShapeType,
1083 AllOfPattern<::xla::Shape, Impl,
1084 ShapePatternSubshapeImpl<
1085 const ::xla::Shape,
1086 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1087 ShapePatternCompatibleImpl>>>>
1088 WithSubshapeCompatibleTo(ShapeIndexView index,
1089 const ::xla::Shape* shape) const {
1090 return WithSubshape(index,
1091 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1092 ShapePatternBaseImpl(), nullptr)
1093 .CompatibleTo(shape));
1094 }
1095
1096 private:
1097 Impl impl_;
1098 ShapeType** matched_shape_;
1099 };
1100
1101 } // namespace detail
1102
1103 // Creates a shape pattern that will capture the matched layout in the argument.
1104 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1105 return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1106 detail::ShapePatternBaseImpl(), matched_shape);
1107 }
1108
1109 // Creates a shape pattern that will capture the matched layout in the argument.
1110 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1111 return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1112 detail::ShapePatternBaseImpl(), matched_shape);
1113 }
1114
1115 namespace detail {
1116
1117 // Overloads to get a const or non-const operand out of an instruction.
1118 inline HloInstruction* HloOperand(HloInstruction* instr, int64_t idx) {
1119 return instr->mutable_operand(idx);
1120 }
1121 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1122 int64_t idx) {
1123 return instr->operand(idx);
1124 }
1125
1126 // Pretty-printer for HloInstruction. Sort of like ToShortString, but with
1127 // fewer %s and more shapes.
1128 inline string InstToString(const HloInstruction* inst) {
1129 return inst->ToString(
1130 HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1131 }
1132
1133 template <typename HloInstructionType, typename Impl>
1134 class HloInstructionPattern;
1135
1136 // The base HloInstructionPattern implementation. Matches only if the
1137 // instruction is not nullptr.
1138 class HloInstructionPatternBaseImpl {
1139 public:
1140 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1141 if (inst == nullptr) {
1142 EXPLAIN << "HloInstruction* is null";
1143 return false;
1144 }
1145 return true;
1146 }
1147
1148 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1149 *os << "an HloInstruction";
1150 }
1151
1152 static constexpr bool kIsTrivialMatcher = true;
1153 };
1154
1155 // An HloInstructionPattern implementation that matches only if the instruction
1156 // has a given name.
1157 class HloInstructionPatternNameImpl {
1158 public:
1159 explicit HloInstructionPatternNameImpl(absl::string_view name)
1160 : name_(name) {}
1161
1162 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1163 if (inst->name() != name_) {
1164 EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1165 return false;
1166 }
1167 return true;
1168 }
1169
1170 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1171 *os << "named \"" << name_ << "\"";
1172 }
1173
1174 private:
1175 absl::string_view name_;
1176 };
1177
1178 // An HloInstructionPattern implementation that matches only if the instruction
1179 // equals a particular pointer.
1180 class HloInstructionIsImpl {
1181 public:
1182 explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1183
1184 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1185 if (inst != inst_) {
1186 EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1187 << std::showbase << reinterpret_cast<uint64>(inst) << " is not "
1188 << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1189 << ")";
1190 return false;
1191 }
1192 return true;
1193 }
1194
1195 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1196 *os << "which is " << std::hex << std::nouppercase << std::showbase
1197 << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1198 << ")";
1199 }
1200
1201 private:
1202 const HloInstruction* inst_;
1203 };
1204
1205 // An HloInstructionPattern implementation that matches only if the instruction
1206 // has a given opcode.
1207 class HloInstructionPatternOpcodeImpl {
1208 public:
1209 explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1210 bool invert)
1211 : opcode_(opcode), invert_(invert) {}
1212
1213 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1214 if (invert_ && inst->opcode() == opcode_) {
1215 EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1216 << ", expected anything else";
1217 return false;
1218 }
1219 if (!invert_ && inst->opcode() != opcode_) {
1220 EXPLAIN << "HloInstruction doesn't have opcode "
1221 << HloOpcodeString(opcode_);
1222 return false;
1223 }
1224 return true;
1225 }
1226
1227 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1228 if (!invert_) {
1229 *os << "with opcode " << HloOpcodeString(opcode_);
1230 } else {
1231 *os << "with any opcode other than " << HloOpcodeString(opcode_);
1232 }
1233 }
1234
1235 private:
1236 HloOpcode opcode_;
1237 bool invert_;
1238 };
1239
1240 // An HloInstructionPattern implementation that matches only if the instruction
1241 // has one of a given list of custom call targets.
1242 class HloInstructionCustomCallTargetImpl {
1243 public:
1244 explicit HloInstructionCustomCallTargetImpl(
1245 absl::Span<const absl::string_view> custom_call_targets)
1246 : custom_call_targets_(custom_call_targets.begin(),
1247 custom_call_targets.end()) {}
1248
1249 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1250 if (inst->opcode() != HloOpcode::kCustomCall ||
1251 !absl::c_linear_search(custom_call_targets_,
1252 inst->custom_call_target())) {
1253 if (custom_call_targets_.size() == 1) {
1254 EXPLAIN << "HloInstruction is not a custom call with a target '"
1255 << custom_call_targets_.front() << "'";
1256 } else {
1257 EXPLAIN << "HloInstruction is not a custom call with a target in {"
1258 << absl::StrJoin(custom_call_targets_, ", ") << "}";
1259 }
1260 return false;
1261 }
1262 return true;
1263 }
1264
1265 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1266 if (custom_call_targets_.size() == 1) {
1267 *os << "custom call with target '" << custom_call_targets_.front() << "'";
1268 } else {
1269 *os << "custom call with target in {"
1270 << absl::StrJoin(custom_call_targets_, ", ") << "}";
1271 }
1272 }
1273
1274 private:
1275 absl::InlinedVector<std::string, 1> custom_call_targets_;
1276 };
1277
1278 // An HloInstructionPattern implementation that matches only if the instruction
1279 // has the given number of operands.
1280 class HloInstructionPatternNumOperandsImpl {
1281 public:
1282 explicit constexpr HloInstructionPatternNumOperandsImpl(int64_t num_operands)
1283 : num_operands_(num_operands) {}
1284
1285 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1286 if (inst->operand_count() != num_operands_) {
1287 EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1288 return false;
1289 }
1290 return true;
1291 }
1292
1293 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1294 *os << "with " << num_operands_ << " operand"
1295 << (num_operands_ != 1 ? "s" : "");
1296 }
1297
1298 private:
1299 int64 num_operands_;
1300 };
1301
1302 // An HloInstructionPattern implementation that matches only if the instruction
1303 // has a shape that matches a given pattern.
1304 template <typename ShapeType, typename ShapeImpl>
1305 class HloInstructionPatternShapeImpl {
1306 public:
1307 explicit constexpr HloInstructionPatternShapeImpl(
1308 const ShapePattern<ShapeType, ShapeImpl>& shape)
1309 : shape_(shape) {}
1310
1311 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1312 if (!shape_.Match(&inst->shape(), option)) {
1313 EXPLAIN << "\nin output shape";
1314 return false;
1315 }
1316 return true;
1317 }
1318
1319 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1320 if (!shape_.Match(inst->mutable_shape(), option)) {
1321 EXPLAIN << "\nin output shape";
1322 return false;
1323 }
1324 return true;
1325 }
1326
1327 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1328 *os << "outputting";
1329 Indent(os, indent + kIndentInc);
1330 shape_.DescribeTo(os, indent + kIndentInc);
1331 }
1332
1333 private:
1334 ShapePattern<ShapeType, ShapeImpl> shape_;
1335 };
1336
1337 // An HloInstructionPattern implementation that matches only if the instruction
1338 // has an operand that matches a given pattern.
1339 template <typename OperandType, typename OperandImpl>
1340 class HloInstructionPatternOperandImpl {
1341 public:
1342 explicit constexpr HloInstructionPatternOperandImpl(
1343 int64_t operand_index,
1344 const HloInstructionPattern<OperandType, OperandImpl>& operand)
1345 : operand_index_(operand_index), operand_(operand) {}
1346
1347 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1348 return MatchImpl(inst, option);
1349 }
1350
1351 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1352 return MatchImpl(inst, option);
1353 }
1354
1355 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1356 *os << "with operand " << operand_index_ << " which is:";
1357 Indent(os, indent + kIndentInc);
1358 operand_.DescribeTo(os, indent + kIndentInc);
1359 }
1360
1361 private:
1362 template <typename HloInstructionType>
1363 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1364 if (operand_index_ >= inst->operand_count()) {
1365 EXPLAIN << "desired operand index " << operand_index_
1366 << " is out of bounds";
1367 return false;
1368 }
1369 if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1370 EXPLAIN << "\nin operand " << operand_index_;
1371 return false;
1372 }
1373 return true;
1374 }
1375
1376 int64 operand_index_;
1377 HloInstructionPattern<OperandType, OperandImpl> operand_;
1378 };
1379
1380 // Matches a binary instruction whose operands come in any order.
1381 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1382 typename OperandImpl2>
1383 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1384 public:
1385 explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1386 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1387 const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1388 : op1_(op1), op2_(op2) {}
1389
1390 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1391 return MatchImpl(inst, option);
1392 }
1393
1394 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1395 return MatchImpl(inst, option);
1396 }
1397
1398 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1399 *os << "with two operands in either order:";
1400 Indent(os, indent);
1401 *os << " - ";
1402 op1_.DescribeTo(os, indent + 3);
1403 Indent(os, indent);
1404 *os << " - ";
1405 op2_.DescribeTo(os, indent + 3);
1406 }
1407
1408 private:
1409 HloInstruction* operand(HloInstruction* inst, int64_t idx) const {
1410 return inst->mutable_operand(idx);
1411 }
1412 const HloInstruction* operand(const HloInstruction* inst, int64_t idx) const {
1413 return inst->operand(idx);
1414 }
1415
1416 template <typename HloInstructionType>
1417 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1418 // We could implement this using AnyOf and AllOf matchers, but the templates
1419 // get pretty difficult to debug, since any compile error herein becomes
1420 // not-an-error via SFINAE. Also this way lets us give better messages on
1421 // failure.
1422 if (inst->operand_count() != 2) {
1423 EXPLAIN << "HloInstruction did not have two operands";
1424 return false;
1425 }
1426
1427 // If we're not generating explanations, this is pretty simple.
1428 if (!option.explain_os) {
1429 auto try_match = [&](int64_t idx1, int64_t idx2) {
1430 MatchOption new_option = option;
1431 new_option.capture = false;
1432 if (op1_.Match(operand(inst, idx1), new_option) &&
1433 op2_.Match(operand(inst, idx2), new_option)) {
1434 if (option.capture) {
1435 bool matched = op1_.Match(operand(inst, idx1), option) &&
1436 op2_.Match(operand(inst, idx2), option);
1437 DCHECK(matched);
1438 }
1439 return true;
1440 }
1441 return false;
1442 };
1443 return try_match(0, 1) || try_match(1, 0);
1444 }
1445
1446 // If we are generating explanations, we have some work to do in order to
1447 // generate a helpful error.
1448 //
1449 // First, try all four operand/matcher combinations, recording the
1450 // failure explanations separately from option.explain_os. matches[i][j]
1451 // tells us if matcher_i matches operand j.
1452 bool matches[/*matcher*/ 2][/*operand*/ 2];
1453 std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1454 for (int i = 0; i < 2; ++i) {
1455 for (int j = 0; j < 2; ++j) {
1456 MatchOption new_option = option;
1457 new_option.capture = false;
1458 new_option.explain_os = &explanations[i][j];
1459 matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1460 : op2_.Match(operand(inst, j), new_option);
1461 }
1462 }
1463
1464 // Check if the match succeeded.
1465 for (int i = 0; i < 2; ++i) {
1466 if (matches[0][i] && matches[1][(i + 1) % 2]) {
1467 // Rerun the matches with capture enabled if necessary.
1468 if (option.capture) {
1469 auto* operand1 = operand(inst, i);
1470 auto* operand2 = operand(inst, (i + 1) % 2);
1471 bool matched =
1472 op1_.Match(operand1, option) && op2_.Match(operand2, option);
1473 DCHECK(matched);
1474 }
1475 return true;
1476 }
1477 }
1478
1479 auto describe_matcher = [&](int matcher_idx) {
1480 EXPLAIN << "\n - ";
1481 if (matcher_idx == 0) {
1482 op1_.DescribeTo(option.explain_os, /*indent=*/3);
1483 } else {
1484 CHECK_EQ(matcher_idx, 1);
1485 op2_.DescribeTo(option.explain_os, /*indent=*/3);
1486 }
1487 for (int i = 0; i < 2; ++i) {
1488 if (matches[matcher_idx][/*operand*/ i]) {
1489 continue;
1490 }
1491 EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1492 EXPLAIN << " - ";
1493 EXPLAIN << absl::StrReplaceAll(
1494 explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}});
1495 }
1496 };
1497
1498 // If we failed to match, one of the following is true:
1499 // 1. op1 (op2) matches neither LHS nor RHS, or
1500 // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1501 // We print different explanations depending on which case we're in.
1502
1503 // Case 1.
1504 bool wrote_explanation = false;
1505 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1506 if (!matches[i][0] && !matches[i][1]) {
1507 EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1508 << (i == 0 ? "first" : "second") << " matcher. Specifically,";
1509 describe_matcher(i);
1510 wrote_explanation = true;
1511 }
1512 }
1513
1514 // Case 2.
1515 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1516 if (matches[/*matcher*/ 0][/*operand*/ i] &&
1517 matches[/*matcher*/ 1][/*operand*/ i]) {
1518 CHECK(!matches[0][(i + 1) % 2]);
1519 CHECK(!matches[1][(i + 1) % 2]);
1520 CHECK(!wrote_explanation);
1521 EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1522 << " operand did not match either of the two matchers. "
1523 "Specifically,";
1524 describe_matcher(0);
1525 EXPLAIN << "\nand";
1526 describe_matcher(1);
1527 wrote_explanation = true;
1528 }
1529 }
1530
1531 CHECK(wrote_explanation);
1532 return false;
1533 }
1534
1535 HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1536 HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1537 };
1538
1539 // An HloInstructionPattern implementation that matches only if the instruction
1540 // is a fusion node with a particular kind.
1541 class HloInstructionPatternFusionKindImpl {
1542 public:
1543 explicit constexpr HloInstructionPatternFusionKindImpl(
1544 ::xla::HloInstruction::FusionKind kind)
1545 : kind_(kind) {}
1546
1547 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1548 return MatchImpl(inst, option);
1549 }
1550
1551 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1552 return MatchImpl(inst, option);
1553 }
1554
1555 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1556 *os << "with fusion kind " << ToString(kind_);
1557 }
1558
1559 private:
1560 template <typename HloInstructionType>
1561 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1562 if (inst->opcode() != HloOpcode::kFusion) {
1563 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1564 << "; it's not a fusion";
1565 return false;
1566 }
1567 if (inst->fusion_kind() != kind_) {
1568 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1569 return false;
1570 }
1571 return true;
1572 }
1573
1574 ::xla::HloInstruction::FusionKind kind_;
1575 };
1576
1577 // An HloInstructionPattern implementation that matches only if the instruction
1578 // is a kGetTupleElement with a particular tuple index.
1579 class HloInstructionPatternTupleIndexImpl {
1580 public:
1581 explicit constexpr HloInstructionPatternTupleIndexImpl(int64_t tuple_index)
1582 : tuple_index_(tuple_index) {}
1583
1584 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1585 return MatchImpl(inst, option);
1586 }
1587
1588 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1589 return MatchImpl(inst, option);
1590 }
1591
1592 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1593 *os << "which is a GTE with index " << tuple_index_;
1594 }
1595
1596 private:
1597 template <typename HloInstructionType>
1598 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1599 if (inst->opcode() != HloOpcode::kGetTupleElement) {
1600 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1601 << "; it's not a GTE at all";
1602 return false;
1603 }
1604 if (inst->tuple_index() != tuple_index_) {
1605 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1606 return false;
1607 }
1608 return true;
1609 }
1610
1611 int64 tuple_index_;
1612 };
1613
1614 class HloInstructionPatternParameterNumImpl {
1615 public:
1616 explicit constexpr HloInstructionPatternParameterNumImpl(
1617 int64_t parameter_num)
1618 : parameter_num_(parameter_num) {}
1619
1620 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1621 return MatchImpl(inst, option);
1622 }
1623
1624 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1625 return MatchImpl(inst, option);
1626 }
1627
1628 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1629 *os << "which is parameter " << parameter_num_;
1630 }
1631
1632 private:
1633 template <typename HloInstructionType>
1634 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1635 if (inst->opcode() != HloOpcode::kParameter ||
1636 inst->parameter_number() != parameter_num_) {
1637 EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1638 return false;
1639 }
1640 return true;
1641 }
1642
1643 int64 parameter_num_;
1644 };
1645
1646 // Superclass that contains common code used by Op::WithOneUse() and
1647 // Op::WithOneUser().
1648 class HloInstructionPatternOneUseOrUserImpl {
1649 protected:
1650 bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1651 if (inst->user_count() != 1) {
1652 EXPLAIN << "HloInstruction has " << inst->user_count()
1653 << " users, but expected exactly one.";
1654 if (inst->user_count() > 1) {
1655 EXPLAIN << "\nAll users:";
1656 for (const HloInstruction* user : inst->users()) {
1657 EXPLAIN << "\n - " << InstToString(user);
1658 }
1659 }
1660 return false;
1661 }
1662 return true;
1663 }
1664 };
1665
1666 class HloInstructionPatternOneUseImpl
1667 : public HloInstructionPatternOneUseOrUserImpl {
1668 public:
1669 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1670 if (!MatchOneUser(inst, option)) {
1671 return false;
1672 }
1673
1674 int64_t use_count = absl::c_count_if(
1675 inst->users()[0]->operands(),
1676 [&](const HloInstruction* operand) { return operand == inst; });
1677 if (use_count != 1) {
1678 EXPLAIN << "HloInstruction is used " << use_count
1679 << " times by its user, but is expected to be used just once: "
1680 << InstToString(inst->users()[0]);
1681 return false;
1682 }
1683 return true;
1684 }
1685
1686 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1687 *os << "which has exactly one use";
1688 }
1689 };
1690
1691 class HloInstructionPatternOneUserImpl
1692 : public HloInstructionPatternOneUseOrUserImpl {
1693 public:
1694 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1695 return MatchOneUser(inst, option);
1696 }
1697
1698 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1699 *os << "which has exactly one user (but possibly is used multiple times by "
1700 "that instruction)";
1701 }
1702 };
1703
1704 class HloInstructionPatternComparisonDirectionImpl {
1705 public:
1706 explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1707 ComparisonDirection direction)
1708 : direction_(direction) {}
1709
1710 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1711 return MatchImpl(inst, option);
1712 }
1713
1714 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1715 return MatchImpl(inst, option);
1716 }
1717
1718 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1719 *os << "which has comparison direction "
1720 << ComparisonDirectionToString(direction_);
1721 }
1722
1723 private:
1724 template <typename HloInstructionType>
1725 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1726 if (inst->opcode() != HloOpcode::kCompare ||
1727 inst->comparison_direction() != direction_) {
1728 EXPLAIN << "HloInstruction is not comparison "
1729 << ComparisonDirectionToString(direction_);
1730 return false;
1731 }
1732 return true;
1733 }
1734
1735 ComparisonDirection direction_;
1736 };
1737
1738 // Matches a constant scalar or effective scalar, optionally with a given value.
1739 template <typename ScalarTy>
1740 class HloConstantScalarImpl {
1741 public:
1742 explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1743 : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1744
1745 constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1746 : val_(val), match_effective_scalar_(match_effective_scalar) {}
1747
1748 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1749 return MatchImpl(inst, option);
1750 }
1751
1752 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1753 return MatchImpl(inst, option);
1754 }
1755
1756 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1757 *os << "which is a constant "
1758 << (match_effective_scalar_ ? "effective " : "") << "scalar";
1759 if (val_.has_value()) {
1760 *os << " with value " << *val_;
1761 }
1762 }
1763
1764 private:
1765 template <typename InstTy>
1766 bool MatchImpl(InstTy* inst, MatchOption option) const {
1767 const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1768 if (!const_inst) {
1769 EXPLAIN << "HloInstruction is not a constant";
1770 return false;
1771 }
1772 if (match_effective_scalar_ &&
1773 !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1774 EXPLAIN << "HloInstruction is not an effective scalar";
1775 return false;
1776 }
1777 if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1778 EXPLAIN << "HloInstruction is not a scalar";
1779 return false;
1780 }
1781 if (!val_.has_value()) {
1782 return true;
1783 }
1784
1785 auto const_inst_scalar_or = const_inst->literal().Reshape({});
1786 if (!const_inst_scalar_or.ok()) {
1787 EXPLAIN << "could not convert matched literal to effective scalar";
1788 return false;
1789 }
1790 Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1791 if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1792 EXPLAIN << "HloInstruction's constant value "
1793 << const_inst_scalar.ToStringWithoutShape()
1794 << " did not match expected value " << *val_;
1795 return false;
1796 }
1797 return true;
1798 }
1799
1800 absl::optional<ScalarTy> val_;
1801 bool match_effective_scalar_;
1802 };
1803
1804 // A pattern that matches HloInstructions.
1805 template <typename HloInstructionType, typename Impl>
1806 class HloInstructionPattern {
1807 private:
1808 template <typename NewImpl>
1809 auto AppendImpl(NewImpl new_impl) const {
1810 auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1811 return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1812 std::move(new_allof), matched_inst_);
1813 }
1814
1815 public:
1816 explicit constexpr HloInstructionPattern(const Impl& impl,
1817 HloInstructionType** matched_inst)
1818 : impl_(impl), matched_inst_(matched_inst) {}
1819
1820 // Returns true and captures the instruction iff it matches the pattern.
1821 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1822 if (impl_.Match(inst, option)) {
1823 if (option.capture && matched_inst_) {
1824 *matched_inst_ = inst;
1825 }
1826 return true;
1827 }
1828 if (inst != nullptr) {
1829 EXPLAIN << "\nin " << InstToString(inst);
1830 }
1831 return false;
1832 }
1833
1834 // Returns true and captures the instruction iff it matches the pattern.
1835 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1836 if (impl_.Match(inst, option)) {
1837 if (option.capture && matched_inst_) {
1838 *matched_inst_ = inst;
1839 }
1840 return true;
1841 }
1842 EXPLAIN << "\nin " << InstToString(inst);
1843 return false;
1844 }
1845
1846 // Modifies the pattern to match only if the instruction has the given name.
1847 auto WithName(absl::string_view name) const {
1848 return AppendImpl(HloInstructionPatternNameImpl(name));
1849 }
1850
1851 // Modifies the pattern to match only if the instruction has the given opcode.
1852 auto WithOpcode(HloOpcode opcode) const {
1853 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1854 }
1855
1856 // Modifies the pattern to match only the custom call with a given target.
1857 auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1858 return AppendImpl(HloInstructionCustomCallTargetImpl({custom_call_target}));
1859 }
1860
1861 // Modifies the pattern to match a custom call with one of the given targets.
1862 auto WithCustomCallTarget(
1863 absl::Span<const absl::string_view> custom_call_targets) const {
1864 return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_targets));
1865 }
1866
1867 auto WithNumOperands(int64_t num_operands) const {
1868 return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1869 }
1870
1871 // Modifies the pattern to match only if the instruction does not have the
1872 // given opcode.
1873 auto WithoutOpcode(HloOpcode opcode) const {
1874 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1875 }
1876
1877 constexpr auto Is(const HloInstruction* instr) const {
1878 return AppendImpl(HloInstructionIsImpl(instr));
1879 }
1880
1881 // Modifies the pattern to match only if the instruction is a constant.
1882 constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1883
1884 constexpr auto IsConstantScalar() const {
1885 return AppendImpl(
1886 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1887 }
1888
1889 // This does not check that T has the same type as the instruction, so e.g.
1890 // IsConstantScalar(1.0) may match a constant of shape int32[].
1891 template <typename ScalarTy>
1892 constexpr auto IsConstantScalar(const ScalarTy& val) const {
1893 return AppendImpl(
1894 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1895 }
1896
1897 constexpr auto IsConstantEffectiveScalar() const {
1898 return AppendImpl(
1899 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1900 }
1901
1902 template <typename ScalarTy>
1903 constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1904 return AppendImpl(
1905 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1906 }
1907
1908 // Modifies the pattern to match only if the instruction is not a constant.
1909 constexpr auto IsNonConstant() const {
1910 return WithoutOpcode(HloOpcode::kConstant);
1911 }
1912
1913 // Modifies the pattern to match only if the instruction has a shape that
1914 // matches the given pattern.
1915 template <typename ShapeType, typename ShapeImpl>
1916 constexpr auto WithShape(
1917 const ShapePattern<ShapeType, ShapeImpl>& shape) const {
1918 return AppendImpl(
1919 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1920 }
1921
1922 // Because we only specify the shape's element type and dims, this is
1923 // effectivley checking shape-compatible-to, not shape-equal-to. Perhaps this
1924 // function should be called WithShapeCompatibleTo, but the short name is
1925 // nice, and there's no ambiguity because there's no layout in the args!
1926 constexpr auto WithShape(PrimitiveType ty, absl::Span<const int64> dims) {
1927 return WithShape(Shape().WithElementType(ty).WithDims(dims));
1928 }
1929
1930 // Make this a templated function to work around gcc 4.9.4 template infinite
1931 // recursion bug.
1932 template <typename Dummy = void>
1933 constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
1934 return WithShape(Shape().EqualTo(shape));
1935 }
1936
1937 // Make this a templated function to work around gcc 4.9.4 template infinite
1938 // recursion bug.
1939 template <typename Dummy = void>
1940 constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
1941 return WithShape(Shape().CompatibleTo(shape));
1942 }
1943
1944 // Modifies the pattern to match only if the instruction has an operand that
1945 // matches the given pattern.
1946 template <typename OperandType, typename OperandImpl>
1947 constexpr auto WithOperand(
1948 int64_t operand_index,
1949 const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
1950 return AppendImpl(
1951 HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1952 operand_index, operand));
1953 }
1954
1955 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1956 typename OperandImpl2>
1957 constexpr auto WithBinaryOperandsAnyOrder(
1958 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1959 const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
1960 return AppendImpl(
1961 HloInstructionPatternBinaryOperandsAnyOrderImpl<
1962 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1963 }
1964
1965 // Modifies the pattern to match only if the instruction is a fusion node with
1966 // the given kind.
1967 constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
1968 return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1969 }
1970
1971 // Modifies the pattern to match only if the instruction is a
1972 // get-tuple-element with the given tuple index.
1973 constexpr auto WithTupleIndex(int64_t tuple_index) const {
1974 return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1975 }
1976
1977 // Modifies the pattern to match only if the instruction is a parameter
1978 // with the given parameter number.
1979 constexpr auto WithParameterNum(int64_t parameter_num) const {
1980 return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1981 }
1982
1983 // Modifies the pattern to match if the instruction is used exactly once.
1984 // Does not match if the instruction is used twice by the same user (e.g.
1985 // multiply(x,x)).
1986 constexpr auto WithOneUse() const {
1987 return AppendImpl(HloInstructionPatternOneUseImpl());
1988 }
1989
1990 // Modifies the pattern to match if the instruction is used by exactly one
1991 // other instruction. Will match if the instruction is used twice, so long as
1992 // it's by the same user (e.g. multiply(x,x)).
1993 constexpr auto WithOneUser() const {
1994 return AppendImpl(HloInstructionPatternOneUserImpl());
1995 }
1996
1997 // Modifies the pattern to match only if the instruction has the given
1998 // comparison direction.
1999 auto WithComparisonDirection(ComparisonDirection direction) const {
2000 return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
2001 }
2002
2003 void DescribeTo(std::ostream* os, int64_t indent = 0) const {
2004 impl_.DescribeTo(os, indent);
2005 }
2006
2007 private:
2008 Impl impl_;
2009 HloInstructionType** matched_inst_;
2010 };
2011
2012 } // namespace detail
2013
2014 // Creates an instruction pattern that will capture the matched instruction in
2015 // the argument.
2016 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
2017 return detail::HloInstructionPattern<const ::xla::HloInstruction,
2018 detail::HloInstructionPatternBaseImpl>(
2019 detail::HloInstructionPatternBaseImpl(), matched_inst);
2020 }
2021
2022 // Creates an instruction pattern that will capture the matched instruction in
2023 // the argument.
2024 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
2025 return detail::HloInstructionPattern<::xla::HloInstruction,
2026 detail::HloInstructionPatternBaseImpl>(
2027 detail::HloInstructionPatternBaseImpl(), matched_inst);
2028 }
2029
2030 // Helpers for nullary instructions.
2031 #define XLA_NULLOP_PATTERN(NAME) \
2032 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2033 \
2034 template <typename HloInstructionType> \
2035 inline auto NAME(HloInstructionType** matched_inst) { \
2036 return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
2037 }
2038 XLA_NULLOP_PATTERN(Constant)
2039 XLA_NULLOP_PATTERN(Parameter)
2040 XLA_NULLOP_PATTERN(Iota)
2041 XLA_NULLOP_PATTERN(Rng)
2042 XLA_NULLOP_PATTERN(PartitionId)
2043 XLA_NULLOP_PATTERN(ReplicaId)
2044 #undef XLA_NULLOP_PATTERN
2045
2046 // Helpers for unary instructions.
2047 #define XLA_UNOP_PATTERN(NAME) \
2048 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2049 \
2050 template <typename Arg> \
2051 inline auto NAME(Arg&& arg) { \
2052 return Op() \
2053 .WithOpcode(HloOpcode::k##NAME) \
2054 .WithOperand(0, std::forward<Arg>(arg)); \
2055 } \
2056 \
2057 template <typename HloInstructionType, typename Arg> \
2058 inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \
2059 return Op(matched_inst) \
2060 .WithOpcode(HloOpcode::k##NAME) \
2061 .WithOperand(0, std::forward<Arg>(arg)); \
2062 }
2063 XLA_UNOP_PATTERN(Abs)
2064 XLA_UNOP_PATTERN(RoundNearestAfz)
2065 XLA_UNOP_PATTERN(Bitcast)
2066 XLA_UNOP_PATTERN(BitcastConvert)
2067 XLA_UNOP_PATTERN(Broadcast)
2068 XLA_UNOP_PATTERN(Ceil)
2069 XLA_UNOP_PATTERN(Convert)
2070 XLA_UNOP_PATTERN(Copy)
2071 XLA_UNOP_PATTERN(Cos)
2072 XLA_UNOP_PATTERN(AllReduce)
2073 XLA_UNOP_PATTERN(Exp)
2074 XLA_UNOP_PATTERN(Fft)
2075 XLA_UNOP_PATTERN(Floor)
2076 XLA_UNOP_PATTERN(GetTupleElement)
2077 XLA_UNOP_PATTERN(Imag)
2078 XLA_UNOP_PATTERN(Infeed)
2079 XLA_UNOP_PATTERN(IsFinite)
2080 XLA_UNOP_PATTERN(Log)
2081 XLA_UNOP_PATTERN(Not)
2082 XLA_UNOP_PATTERN(Negate)
2083 XLA_UNOP_PATTERN(Real)
2084 XLA_UNOP_PATTERN(Recv)
2085 XLA_UNOP_PATTERN(RecvDone)
2086 XLA_UNOP_PATTERN(ReducePrecision)
2087 XLA_UNOP_PATTERN(Reshape)
2088 XLA_UNOP_PATTERN(Reverse)
2089 XLA_UNOP_PATTERN(Rsqrt)
2090 XLA_UNOP_PATTERN(SendDone)
2091 XLA_UNOP_PATTERN(Sign)
2092 XLA_UNOP_PATTERN(Sin)
2093 XLA_UNOP_PATTERN(Slice)
2094 XLA_UNOP_PATTERN(Sqrt)
2095 XLA_UNOP_PATTERN(Tanh)
2096 XLA_UNOP_PATTERN(Transpose)
2097 #undef XLA_UNOP_PATTERN
2098
2099 // Helpers for binary instructions.
2100 #define XLA_BINOP_PATTERN(NAME) \
2101 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2102 \
2103 template <typename Lhs, typename Rhs> \
2104 inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
2105 return Op() \
2106 .WithOpcode(HloOpcode::k##NAME) \
2107 .WithOperand(0, std::forward<Lhs>(lhs)) \
2108 .WithOperand(1, std::forward<Rhs>(rhs)); \
2109 } \
2110 \
2111 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2112 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2113 return Op(matched_inst) \
2114 .WithOpcode(HloOpcode::k##NAME) \
2115 .WithOperand(0, std::forward<Lhs>(lhs)) \
2116 .WithOperand(1, std::forward<Rhs>(rhs)); \
2117 }
2118
2119 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
2120 XLA_BINOP_PATTERN(NAME) \
2121 \
2122 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2123 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2124 Rhs&& rhs) { \
2125 return Op(matched_inst) \
2126 .WithOpcode(HloOpcode::k##NAME) \
2127 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2128 std::forward<Rhs>(rhs)); \
2129 } \
2130 template <typename Lhs, typename Rhs> \
2131 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
2132 return NAME##AnyOrder<const HloInstruction>( \
2133 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2134 }
2135 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2136 XLA_BINOP_PATTERN(Atan2)
2137 XLA_BINOP_PATTERN(Divide)
2138 XLA_BINOP_PATTERN(Complex)
2139 XLA_BINOP_PATTERN(Compare)
2140 XLA_BINOP_PATTERN(Convolution)
2141 XLA_BINOP_PATTERN(Dot)
2142 XLA_BINOP_PATTERN(Gather)
2143 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2144 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2145 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2146 XLA_BINOP_PATTERN(Outfeed)
2147 XLA_BINOP_PATTERN(Pad)
2148 XLA_BINOP_PATTERN(Power)
2149 XLA_BINOP_PATTERN(Remainder)
2150 XLA_BINOP_PATTERN(Send)
2151 XLA_BINOP_PATTERN(Subtract)
2152 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2153 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2154 XLA_BINOP_PATTERN(ShiftLeft)
2155 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2156 XLA_BINOP_PATTERN(ShiftRightLogical)
2157 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2158 #undef XLA_BINOP_PATTERN
2159
2160 // Helpers for ternary instructions.
2161 #define XLA_TERNOP_PATTERN(NAME) \
2162 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2163 \
2164 template <typename Arg0, typename Arg1, typename Arg2> \
2165 inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \
2166 return Op() \
2167 .WithOpcode(HloOpcode::k##NAME) \
2168 .WithOperand(0, std::forward<Arg0>(arg0)) \
2169 .WithOperand(1, std::forward<Arg1>(arg1)) \
2170 .WithOperand(2, std::forward<Arg2>(arg2)); \
2171 } \
2172 \
2173 template <typename HloInstructionType, typename Arg0, typename Arg1, \
2174 typename Arg2> \
2175 inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \
2176 Arg1&& arg1, Arg2&& arg2) { \
2177 return Op(matched_inst) \
2178 .WithOpcode(HloOpcode::k##NAME) \
2179 .WithOperand(0, std::forward<Arg0>(arg0)) \
2180 .WithOperand(1, std::forward<Arg1>(arg1)) \
2181 .WithOperand(2, std::forward<Arg2>(arg2)); \
2182 }
2183 XLA_TERNOP_PATTERN(Clamp);
2184 XLA_TERNOP_PATTERN(Scatter);
2185 XLA_TERNOP_PATTERN(Select);
2186 XLA_TERNOP_PATTERN(SelectAndScatter);
2187 #undef XLA_TERNOP_PATTERN
2188
2189 namespace detail {
2190 template <typename Matcher, typename FirstArg>
2191 inline auto WithOperands(Matcher&& m, int64_t operand_num,
2192 FirstArg&& first_arg) {
2193 return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2194 }
2195
2196 template <typename Matcher, typename FirstArg, typename... Args>
2197 inline auto WithOperands(Matcher&& m, int64_t operand_num, FirstArg&& first_arg,
2198 Args&&... args) {
2199 return WithOperands(
2200 m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2201 operand_num + 1, std::forward<Args>(args)...);
2202 }
2203 } // namespace detail
2204
2205 #define XLA_VARIADIC_OP_PATTERN(NAME) \
2206 inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2207 \
2208 template <typename... Args> \
2209 inline auto NAME(Args&&... args) { \
2210 return detail::WithOperands( \
2211 Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2212 /*operand_num=*/0, std::forward<Args>(args)...); \
2213 } \
2214 \
2215 template <typename HloInstructionType, typename... Args> \
2216 inline auto NAME(HloInstructionType** matched_inst, Args&&... args) { \
2217 return detail::WithOperands(Op(matched_inst) \
2218 .WithOpcode(HloOpcode::k##NAME) \
2219 .WithNumOperands(sizeof...(Args)), \
2220 /*operand_num=*/0, \
2221 std::forward<Args>(args)...); \
2222 }
2223
2224 // We could implement all ops as "variadic" ops, but it would make the
2225 // already-bad compile errors even worse.
2226 XLA_VARIADIC_OP_PATTERN(AfterAll);
2227 XLA_VARIADIC_OP_PATTERN(Concatenate);
2228 XLA_VARIADIC_OP_PATTERN(Conditional);
2229 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2230 XLA_VARIADIC_OP_PATTERN(DynamicUpdateSlice)
2231 XLA_VARIADIC_OP_PATTERN(Fusion);
2232 XLA_VARIADIC_OP_PATTERN(Map)
2233 XLA_VARIADIC_OP_PATTERN(Reduce);
2234 XLA_VARIADIC_OP_PATTERN(ReduceWindow)
2235 XLA_VARIADIC_OP_PATTERN(Sort);
2236 XLA_VARIADIC_OP_PATTERN(Tuple);
2237
2238 // CustomCall doesn't use the XLA_VARIADIC_OP_PATTERN macro so that you can
2239 // optionally pass a string_view for the custom_call_target before the other
2240 // operands.
2241 inline auto CustomCall() { return Op().WithOpcode(HloOpcode::kCustomCall); }
2242
2243 template <typename HloInstructionType>
2244 auto CustomCall(HloInstructionType** matched_inst) {
2245 return Op(matched_inst).WithOpcode(HloOpcode::kCustomCall);
2246 }
2247
2248 template <
2249 typename Arg0, typename... Args,
2250 typename std::enable_if<
2251 !std::is_convertible<Arg0, absl::string_view>::value &&
2252 !std::is_convertible<Arg0, HloInstruction**>::value &&
2253 !std::is_convertible<Arg0, const HloInstruction**>::value>::type* =
2254 nullptr>
2255 auto CustomCall(Arg0&& arg0, Args&&... args) {
2256 return detail::WithOperands(CustomCall().WithNumOperands(sizeof...(Args) + 1),
2257 /*operand_num=*/0, std::forward<Arg0>(arg0),
2258 std::forward<Args>(args)...);
2259 }
2260 template <typename... Args>
2261 auto CustomCall(absl::string_view custom_call_target, Args&&... args) {
2262 return CustomCall(std::forward<Args>(args)...)
2263 .WithCustomCallTarget(custom_call_target);
2264 }
2265
2266 template <typename HloInstructionType, typename Arg0, typename... Args,
2267 typename std::enable_if<!std::is_convertible<
2268 Arg0, absl::string_view>::value>::type* = nullptr>
2269 auto CustomCall(HloInstructionType** matched_inst, Arg0&& arg0,
2270 Args&&... args) {
2271 return detail::WithOperands(
2272 CustomCall(matched_inst).WithNumOperands(sizeof...(Args) + 1),
2273 /*operand_num=*/0, std::forward<Arg0>(arg0), std::forward<Args>(args)...);
2274 }
2275 template <typename HloInstructionType, typename... Args>
2276 auto CustomCall(HloInstructionType** matched_inst,
2277 absl::string_view custom_call_target, Args&&... args) {
2278 return CustomCall(matched_inst, std::forward<Args>(args)...)
2279 .WithCustomCallTarget(custom_call_target);
2280 }
2281
2282 // Helpers for comparison instructions.
2283 #define XLA_COMPARE_PATTERN(NAME) \
2284 inline auto NAME() { \
2285 return Op() \
2286 .WithOpcode(HloOpcode::kCompare) \
2287 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2288 } \
2289 \
2290 template <typename Lhs, typename Rhs> \
2291 inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \
2292 return Op() \
2293 .WithOpcode(HloOpcode::kCompare) \
2294 .WithOperand(0, std::forward<Lhs>(lhs)) \
2295 .WithOperand(1, std::forward<Rhs>(rhs)) \
2296 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2297 } \
2298 \
2299 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2300 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2301 return Op(matched_inst) \
2302 .WithOpcode(HloOpcode::kCompare) \
2303 .WithOperand(0, std::forward<Lhs>(lhs)) \
2304 .WithOperand(1, std::forward<Rhs>(rhs)) \
2305 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2306 }
2307
2308 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
2309 XLA_COMPARE_PATTERN(NAME) \
2310 \
2311 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2312 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2313 Rhs&& rhs) { \
2314 return Op(matched_inst) \
2315 .WithOpcode(HloOpcode::kCompare) \
2316 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2317 std::forward<Rhs>(rhs)); \
2318 } \
2319 template <typename Lhs, typename Rhs> \
2320 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \
2321 return NAME##AnyOrder<const HloInstruction>( \
2322 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2323 }
2324
2325 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2326 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2327 XLA_COMPARE_PATTERN(Ge);
2328 XLA_COMPARE_PATTERN(Gt);
2329 XLA_COMPARE_PATTERN(Le);
2330 XLA_COMPARE_PATTERN(Lt);
2331
2332 // Helpers for matching non-constant instructions.
2333 inline auto NonConstant() { return Op().IsNonConstant(); }
2334
2335 template <typename HloInstructionType>
2336 inline auto NonConstant(HloInstructionType** matched_inst) {
2337 return Op(matched_inst).IsNonConstant();
2338 }
2339
2340 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2341 // element is selected.
2342 template <typename Arg>
2343 inline auto GetTupleElement(Arg&& arg, int64_t tuple_index) {
2344 return Op()
2345 .WithOpcode(HloOpcode::kGetTupleElement)
2346 .WithOperand(0, std::forward<Arg>(arg))
2347 .WithTupleIndex(tuple_index);
2348 }
2349
2350 template <typename HloInstructionType, typename Arg>
2351 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2352 int64_t tuple_index) {
2353 return Op(matched_inst)
2354 .WithOpcode(HloOpcode::kGetTupleElement)
2355 .WithOperand(0, std::forward<Arg>(arg))
2356 .WithTupleIndex(tuple_index);
2357 }
2358
2359 // Add overloads for Parameter which take an int64 specifying the parameter
2360 // number.
2361 inline auto Parameter(int64_t parameter_num) {
2362 return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2363 }
2364 template <typename HloInstructionType>
2365 inline auto Parameter(HloInstructionType** matched_inst,
2366 int64_t parameter_num) {
2367 return Op(matched_inst)
2368 .WithOpcode(HloOpcode::kParameter)
2369 .WithParameterNum(parameter_num);
2370 }
2371
2372 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2373
2374 template <typename HloInstructionType>
2375 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2376 return Op(matched_inst).IsConstantScalar();
2377 }
2378
2379 template <typename ScalarTy>
2380 inline auto ConstantScalar(ScalarTy val) {
2381 return Op().IsConstantScalar(val);
2382 }
2383
2384 template <typename HloInstructionType, typename ScalarTy>
2385 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2386 return Op(matched_inst).IsConstantScalar(val);
2387 }
2388
2389 inline auto ConstantEffectiveScalar() {
2390 return Op().IsConstantEffectiveScalar();
2391 }
2392
2393 template <typename HloInstructionType>
2394 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2395 return Op(matched_inst).IsConstantEffectiveScalar();
2396 }
2397
2398 template <typename ScalarTy>
2399 inline auto ConstantEffectiveScalar(ScalarTy val) {
2400 return Op().IsConstantEffectiveScalar(val);
2401 }
2402
2403 template <typename HloInstructionType, typename ScalarTy>
2404 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2405 ScalarTy val) {
2406 return Op(matched_inst).IsConstantEffectiveScalar(val);
2407 }
2408
2409 } // namespace match
2410
2411 } // namespace xla
2412
2413 #undef EXPLAIN
2414 #pragma pop_macro("EXPLAIN")
2415 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2416