• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7     http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
19 #include "absl/strings/str_replace.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/utility/utility.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
30 namespace xla {
32 // A pattern matcher for HloInstructions, Shapes, and Layouts.
33 //
34 // The Match function's first argument must be HloInstruction*, Shape*, or
35 // Layout*. The second argument is a pattern that will be matched against the
36 // first argument, as described below.
37 //
38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
39 // functions. By default, the returned patterns will match any HloInstruction,
40 // Shape, or Layout, respectively. However the match can be made more specific
41 // by using the pattern's modifier methods, for example:
42 //
43 //   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
44 //     0, match::Op().WithOpcode(HloOpcode::kConstant))
45 //
46 // This pattern will match Add instructions whose first operand is a constant.
47 //
48 // Each pattern type has the following modifiers, which are described where
49 // nontrivial.
50 //
51 //   Op():
52 //     - Is: is the given HloInstruction* (i.e. pointer equality)
53 //     - WithName
54 //     - WithOpcode
55 //     - WithoutOpcode: anything other than the given opcode
56 //     - WithShape: instr's shape matches the given pattern
57 //     - WithShapeEqualTo: instr's shape is equal to the given Shape
58 //     - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
59 //     - WithNumOperands
60 //     - WithOperand: operand at the given index matches the given pattern
61 //     - IsConstant
62 //     - IsNonConstant
63 //     - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
64 //       e.g. IsConstantScalar() or IsConstantScalar(42).
65 //     - WithFusionKind
66 //     - WithTupleIndex: get-tuple-element operations with the given tuple index
67 //     - WithOneUse: Instruction is used as an operand exactly once.
68 //     - WithOneUser: Instruction is used by exactly one other instruction, but
69 //       is possibly used more than once as an operand (e.g. multiply(x,x)).
70 //     - WithComparisonDirection: instr has the given direction
71 //
72 //   Shape():
73 //     - EqualTo
74 //     - CompatibleTo
75 //     - IsScalar/IsEffectiveScalar/IsArray/IsTuple
76 //     - IsDenseArray
77 //     - WithLayout: layout shape's layout matches the given pattern (e.g.
78 //       Layout().WithDenseFormat())
79 //     - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
80 //       Layout, but not the result of Layout().foo())
81 //     - WithSubshape: shape is a tuple whose subshape matches the given pattern
82 //       (e.g. Shape().IsScalar()).
83 //     - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
84 //       (i.e. another Shape, but not the result of Shape().foo())
85 //     - WithElementType: shape is an array/scalar with the given elem type
86 //     - WithRank: shape is an array/scalar with the given rank
87 //
88 //  Layout():
89 //     - EqualTo
90 //     - WithDenseFormat
91 //
92 // Op(), Shape(), and Layout() may be passed an argument of type
93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
94 // these pointers. If the pattern is matched, the address of the matched value
95 // will be "captured" and stored at this location.
96 //
97 // For example:
98 //   HloInstruction* foo = ...;
99 //   HloInstruction* matched_operand;
100 //   CHECK(Match(foo,
101 //               match::Op().WithOperand(0, match::Op(&matched_operand))));
102 //
103 // Helpers are provided for most HLO instructions. These helpers can be called
104 // with no arguments, in which case they will match any instruction matching the
105 // opcode. They may also be called with matches for the operands and with an
106 // optional capture. (The capture must be the first argument.) Some examples of
107 // these helpers and their equivalents are provided below.
109 // Example nullary instruction:
110 //   Parameter()                    == Op().WithOpcode(HloOpcode::kParameter)
111 //   Parameter(&a)                  == Op(&a).WithOpcode(HloOpcode::kParameter)
112 //
113 // Example unary instruction:
114 //   Abs()                          == Op().WithOpcode(HloOpcode::kAbs)
115 //   Abs(Op(&a))                    == Op().WithOpcode(HloOpcode::kAbs)
116 //                                         .WithOperand(0, Op(&a)))
117 //   Abs(&a, Op(&b))                == Op(&a).WithOpcode(HloOpcode::kAbs)
118 //                                           .WithOperand(0, Op(&b))
119 //
120 // Commutative binary instructions have a special form that accepts either order
121 // of args, e.g.:
122 //
123 //   AddAnyOrder(Parameter(1), Abs()) ==
124 //     Op().WithOpcode(HloOpcode::kAdd)
125 //         .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
126 //
127 //   MultiplyAnyOrder(&a, Parameter(), Abs())  // Captures the mul in `a`.
128 //
129 // The following additional helpers are provided.  In all cases, `&a` is
130 // optional.
131 //
132 //   ConstantScalar(&a)               == Op(&a).IsConstantScalar();
133 //   ConstantScalar(&a, v)            == Op(&a).IsConstantScalar(v);
134 //   ConstantEffectiveScalar(&a)      == Op(&a).IsConstantEffectiveScalar();
135 //   ConstantEffectiveScalar(&a, v)   == Op(&a).IsConstantEffectiveScalar(&a, v)
136 //   NonConstant(&a)                  == Op(&a).IsNonConstant()
137 //   GetTupleElement(&a, b, index)    == Op(&a).WithTupleIndex(index)
138 //                                             .WithOperand(0, b);
139 //   Parameter(&a, n)                 == Op(&a).WithParameterNum(n);
141 struct MatchOption {
142   // If true, actually capture matched item into the user pointer.
143   bool capture;
145   // An explanation for why we failed to match is streamed here, if not-null.
146   std::ostream* explain_os;
147 };
149 template <typename Value, typename Pattern>
150 bool Match(Value* value, const Pattern& pattern,
151            MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
152   if (option.capture) {
153     auto new_option = option;
154     new_option.capture = false;
155     if (!pattern.Match(value, new_option)) {
156       return false;
157     }
158   }
159   return pattern.Match(value, option);
160 }
162 namespace match {
164 namespace detail {
166 // Macro for streaming to option.explain_os if it's not null.
167 //
168 //   EXPLAIN << "value of foo(): " << foo()
169 //
170 #pragma push_macro("EXPLAIN")
171 #define EXPLAIN \
172   if (option.explain_os) *option.explain_os
174 // kIndentInc is the additional number of spaces that we indent by when we
175 // increase the indent "by one".
176 enum {
177   kIndentInc = 2,
178 };
180 // Writes a newline and then `indent` spaces.
181 //
182 // We follow an unintuitive convention in this file's pretty-printers: Indents
183 // are performed by the caller, not the callee.  For example, if you want to
184 // print
185 //
186 //   foo:
187 //    - bar
188 //
189 // you'd do:
190 //
191 //  Foo::DescribeTo(std::ostream* os, int64 indent) {
192 //    *os << "foo:";
193 //    Indent(os, indent)  // Create a newline at the *current* indent level.
194 //    *os << " - ";
195 //    bar.DescribeTo(os, indent + 3);  // + 3 because strlen(" * ") == 3.
196 //  }
197 //
198 //  Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
199 //
200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
201 // performed by Foo.  This convention allows the caller to decide whether a
202 // matcher is preceded by a newline, which is important e.g. for the AllOf
203 // matcher.
204 //
205 // (Incidentally, indenting in Match's explanations is handled differently.
206 // Indents are a common case in DescribeTo [we're printing a whole tree], but
207 // they're a special case in Match [we're printing only a path through the tree
208 // that encounters a failing node]. Indents in Match only appear when we
209 // encounter a failing disjunction, so we just handle them as a special case
210 // there.)
Indent(std::ostream * os,int64 indent)211 inline void Indent(std::ostream* os, int64 indent) {
212   *os << "\n";
213   for (int64 i = 0; i < indent; ++i) {
214     *os << " ";
215   }
216 }
218 // SFINAE template that determines whether T declares a static member
219 // kIsTrivialMatcher.
220 //
221 // Trivial matchers get special treatment.  For example, when printing
222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
223 // yields e.g.
224 //    "a shape compatible with f32[1,2]"
225 // rather than
226 //    "a shape AND compatible with f32[1,2]"
227 template <typename T, typename Dummy = void>
228 struct IsTrivialMatcher {
229   static constexpr bool value = false;
230 };
231 template <typename T>
232 struct IsTrivialMatcher<T,
233                         typename std::enable_if<T::kIsTrivialMatcher>::type> {
234   static constexpr bool value = true;
235 };
237 template <typename Item, typename... Patterns>
238 class AllOfPattern {
239  public:
240   explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
242   bool Match(const Item* item, MatchOption option) const {
243     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
244     // This invariant is guaranteed by the top-level Match and AnyOf.
245     DCHECK(matched || !option.capture);
246     return matched;
247   }
249   bool Match(Item* item, MatchOption option) const {
250     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
251     // This invariant is guaranteed by the top-level Match and AnyOf.
252     DCHECK(matched || !option.capture);
253     return matched;
254   }
256   void DescribeTo(std::ostream* os, int64 indent = 0) const {
257     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
258   }
260   // Accessor for patterns_.  Please don't use this outside of this file.
261   const std::tuple<Patterns...>& patterns() const { return patterns_; }
263  private:
264   template <typename ItemType, size_t index>
265   bool MatchImpl(ItemType* item, MatchOption option,
266                  std::integral_constant<size_t, index>) const {
267     // We don't need to do any EXPLAINing here; it's all correctly handled by
268     // our sub-matchers (if any fail).
269     return std::get<index>(patterns_).Match(item, option) &&
270            MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
271   }
273   template <typename ItemType>
274   bool MatchImpl(ItemType* item, MatchOption option,
275                  std::integral_constant<size_t, sizeof...(Patterns)>) const {
276     return true;
277   }
279   // Pretty-printing a conjunction has some special cases to make it easy to
280   // read in the simple (common) case.
281   //
282   // If sizeof...(Patterns) == 1, prints as e.g.
283   //
284   //   a shape
285   //
286   // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
287   // shape") prints as
288   //
289   //   a shape compatible with f32[1,2]
290   //
291   // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
292   //
293   //   a shape:
294   //    * compatible with f32[1,2] AND
295   //    * that represents a scalar
296   //
297   // Otherwise prints as:
298   //
299   //   all of:
300   //    * foo AND
301   //    * bar
302   //
303   template <size_t index>
304   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
305                       int64 indent) const {
306     constexpr bool first_is_trivial =
307         IsTrivialMatcher<typename std::remove_reference<decltype(
308             std::get<0>(patterns_))>::type>::value;
309     constexpr bool is_last = index == sizeof...(Patterns) - 1;
310     const auto& submatcher = std::get<index>(patterns_);
312     auto print_bulleted_item = [&] {
313       *os << " * ";
314       submatcher.DescribeTo(os, indent + 3);
315       if (!is_last) {
316         *os << " AND";
317         Indent(os, indent);
318       }
319     };
321     if (index == 0) {
322       if (first_is_trivial || is_last) {
323         submatcher.DescribeTo(os, indent + kIndentInc);
324         if (sizeof...(Patterns) > 2) {
325           *os << ":";
326           Indent(os, indent);
327         }
328       } else {
329         *os << "all of:";
330         Indent(os, indent);
331         print_bulleted_item();
332       }
333     } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
334       *os << " ";
335       submatcher.DescribeTo(os, indent);
336     } else {
337       print_bulleted_item();
338     }
339     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
340   }
342   void DescribeToImpl(std::ostream* os,
343                       std::integral_constant<size_t, sizeof...(Patterns)>,
344                       int64 indent) const {}
346   std::tuple<Patterns...> patterns_;
347 };
349 }  // namespace detail
351 // Returns a pattern that represents the conjunction of all input patterns. All
352 // patterns need to match in order to have the AllOf pattern match.
353 template <typename Item, typename... Patterns>
354 auto AllOf(const Patterns&... patterns) {
355   return detail::AllOfPattern<typename std::remove_const<Item>::type,
356                               Patterns...>(patterns...);
357 }
359 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
360 //
361 // This transformation is necessary for good pretty-printing.
362 template <typename Item, typename... InnerPs, typename... OuterPs>
363 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
364            const OuterPs&... outer_ps) {
365   // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
366   auto make_all_of = [](const InnerPs&... inner_ps,
367                         const OuterPs&... outer_ps) {
368     return detail::AllOfPattern<typename std::remove_const<Item>::type,
369                                 InnerPs..., OuterPs...>(inner_ps...,
370                                                         outer_ps...);
371   };
372   return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
373                                                  std::make_tuple(outer_ps...)));
374 }
376 namespace detail {
378 template <typename LayoutType, typename Impl>
379 class LayoutPattern;
381 // The base LayoutPattern implementation. Matches only if the layout is not
382 // nullptr.
383 class LayoutPatternBaseImpl {
384  public:
385   bool Match(const ::xla::Layout* layout, MatchOption option) const {
386     if (layout == nullptr) {
387       EXPLAIN << "Layout is null";
388       return false;
389     }
390     return true;
391   }
393   void DescribeTo(std::ostream* os, int64 indent = 0) const {
394     *os << "a layout";
395   }
397   static constexpr bool kIsTrivialMatcher = true;
398 };
400 // A LayoutPattern implementation that matches only if the layout equals a
401 // Layout proto.
402 class LayoutPatternEqualImpl {
403  public:
404   explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
405       : layout_(layout) {}
407   bool Match(const ::xla::Layout* layout, MatchOption option) const {
408     if (!LayoutUtil::Equal(*layout_, *layout)) {
409       EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
410               << " is not equal to expected "
411               << LayoutUtil::HumanString(*layout_);
412       return false;
413     }
414     return true;
415   }
417   void DescribeTo(std::ostream* os, int64 indent = 0) const {
418     *os << "equal to " << LayoutUtil::HumanString(*layout_);
419   }
421  private:
422   const ::xla::Layout* layout_;
423 };
425 // A LayoutPattern implementation that matches only if the layout has a given
426 // format.
427 class LayoutPatternFormatImpl {
428  public:
429   explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
431   bool Match(const ::xla::Layout* layout, MatchOption option) const {
432     if (layout->format() != format_) {
433       EXPLAIN << "Layout has format " << Format_Name(layout->format())
434               << " but expected " << Format_Name(format_);
435       return false;
436     }
437     return true;
438   }
440   void DescribeTo(std::ostream* os, int64 indent = 0) const {
441     *os << "with format " << Format_Name(format_);
442   }
444  private:
445   Format format_;
446 };
448 // A pattern that matches Layouts.
449 template <typename LayoutType, typename Impl>
450 class LayoutPattern {
451  private:
452   template <typename NewImpl>
453   auto AppendImpl(NewImpl new_impl) const {
454     auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
455     return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
456                                                           matched_layout_);
457   }
459  public:
460   explicit constexpr LayoutPattern(const Impl& impl,
461                                    LayoutType** matched_layout)
462       : impl_(impl), matched_layout_(matched_layout) {}
464   // Returns true and captures the layout iff it matches the pattern.
465   bool Match(const ::xla::Layout* layout, MatchOption option) const {
466     if (impl_.Match(layout, option)) {
467       if (option.capture && matched_layout_) {
468         *matched_layout_ = layout;
469       }
470       return true;
471     }
472     return false;
473   }
475   // Returns true and captures the layout iff it matches the pattern.
476   bool Match(::xla::Layout* layout, MatchOption option) const {
477     if (impl_.Match(layout, option)) {
478       if (option.capture && matched_layout_) {
479         *matched_layout_ = layout;
480       }
481       return true;
482     }
483     return false;
484   }
486   void DescribeTo(std::ostream* os, int64 indent = 0) const {
487     impl_.DescribeTo(os, indent);
488   }
490   // Modifies the pattern to match only if the layout equals the given proto.
491   // The layout must outlive the returned pattern.
492   constexpr auto EqualTo(const ::xla::Layout* layout) const {
493     return AppendImpl(LayoutPatternEqualImpl(layout));
494   }
496   // Modifies the pattern to match only if the layout has a dense format.
497   constexpr auto WithDenseFormat() const {
498     return AppendImpl(LayoutPatternFormatImpl(DENSE));
499   }
501  private:
502   Impl impl_;
503   LayoutType** matched_layout_;
504 };
506 template <typename Item, typename... Patterns>
507 class AnyOfPattern {
508  public:
509   explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
511   bool Match(const Item* item, MatchOption option) const {
512     return MatchImpl(item, option);
513   }
515   bool Match(Item* item, MatchOption option) const {
516     return MatchImpl(item, option);
517   }
519   void DescribeTo(std::ostream* os, int64 indent = 0) const {
520     *os << "any of:";
521     Indent(os, indent);
522     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
523   }
525  private:
526   template <typename ItemType>
527   bool MatchImpl(ItemType* item, MatchOption option) const {
528     // If we're generating an explanation, buffer it until we know we failed.
529     absl::optional<std::stringstream> explanation;
530     MatchOption new_option = option;
531     if (option.explain_os) {
532       new_option.explain_os = &explanation.emplace();
533     }
534     bool rv = MatchRecursiveImpl(item, new_option,
535                                  std::integral_constant<size_t, 0>());
536     if (!rv && option.explain_os) {
537       EXPLAIN << "None of the following matchers succeeded:";
538       EXPLAIN << explanation->str();
539     }
540     return rv;
541   }
543   template <typename ItemType, size_t index>
544   bool MatchRecursiveImpl(ItemType* item, MatchOption option,
545                           std::integral_constant<size_t, index>) const {
546     auto new_option = option;
547     new_option.capture = false;
549     absl::optional<std::stringstream> explanation;
550     if (option.explain_os) {
551       new_option.explain_os = &explanation.emplace();
552     }
554     // Try to match the sub-pattern without capturing behavior.
555     if (std::get<index>(patterns_).Match(item, new_option)) {
556       // Capture the branch.
557       if (option.capture) {
558         // TODO(timshen): Currently the behavior can be exponential. Optimize it
559         // with memoization or recording the matched sub-pattern index, if it
560         // takes too long to run.
561         //
562         // Specifically, the "memoization" approach is to create an empty
563         // container with the key (pattern, instruction), and value as whether
564         // matched or not.
565         //
566         // Alternatively, we may run the pattern matching with captures off, but
567         // instead record a "trace" somewhere, indicating how exactly the
568         // pattern matches the input. For example, the trace information for
569         // AnyOf will be a runtime number indicate which sub-pattern is matched.
570         // Then we run another pass to do captures only with the help of the
571         // trace.
572         bool matched = std::get<index>(patterns_).Match(item, option);
573         DCHECK(matched);
574       }
575       return true;
576     }
577     if (option.explain_os) {
578       EXPLAIN << "\nMatcher #" << index + 1;
579       EXPLAIN << "\n - ";
580       std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
581       EXPLAIN << "\nfailed with";
582       EXPLAIN << "\n - ";
583       EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n   "}});
584     }
585     return MatchRecursiveImpl(item, option,
586                               std::integral_constant<size_t, index + 1>());
587   }
589   template <typename ItemType>
590   bool MatchRecursiveImpl(
591       ItemType* item, MatchOption option,
592       std::integral_constant<size_t, sizeof...(Patterns)>) const {
593     return false;
594   }
596   template <size_t index>
597   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
598                       int64 indent) const {
599     *os << " - ";
600     std::get<index>(patterns_).DescribeTo(os, indent + 3);
601     if (index != sizeof...(Patterns) - 1) {
602       *os << " OR";
603       Indent(os, indent);
604     }
605     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
606   }
608   void DescribeToImpl(std::ostream* os,
609                       std::integral_constant<size_t, sizeof...(Patterns)>,
610                       int64 indent) const {}
612   std::tuple<Patterns...> patterns_;
613 };
615 }  // namespace detail
617 // Returns a pattern that represents the logical disjunction of the input
618 // patterns. The returned pattern matches from left to right, and stops on the
619 // first match.
620 template <typename Item, typename... Patterns>
621 auto AnyOf(const Patterns&... patterns) {
622   return detail::AnyOfPattern<typename std::remove_const<Item>::type,
623                               Patterns...>(patterns...);
624 }
626 // Creates a layout pattern that will capture the matched layout in the
627 // argument.
628 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
629   return detail::LayoutPattern<const ::xla::Layout,
630                                detail::LayoutPatternBaseImpl>(
631       detail::LayoutPatternBaseImpl(), matched_layout);
632 }
634 // Creates a layout pattern that will capture the matched layout in the
635 // argument.
636 inline constexpr auto Layout(::xla::Layout** matched_layout) {
637   return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
638       detail::LayoutPatternBaseImpl(), matched_layout);
639 }
641 namespace detail {
643 template <typename ShapeType, typename Impl>
644 class ShapePattern;
646 // The base ShapePattern implementation. Matches only if the shape is not
647 // nullptr.
648 class ShapePatternBaseImpl {
649  public:
650   bool Match(const ::xla::Shape* shape, MatchOption option) const {
651     if (shape == nullptr) {
652       EXPLAIN << "Shape is null";
653     }
654     return shape != nullptr;
655   }
657   void DescribeTo(std::ostream* os, int64 indent = 0) const {
658     *os << "a shape";
659   }
661   static constexpr bool kIsTrivialMatcher = true;
662 };
664 // A ShapePattern implementation that matches only if the shape equals a Shape
665 // proto.
666 class ShapePatternEqualImpl {
667  public:
668   explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
669       : shape_(shape) {}
671   bool Match(const ::xla::Shape* shape, MatchOption option) const {
672     if (!ShapeUtil::Equal(*shape_, *shape)) {
673       EXPLAIN << "Shape not equal to "
674               << ShapeUtil::HumanStringWithLayout(*shape_);
675       return false;
676     }
677     return true;
678   }
680   void DescribeTo(std::ostream* os, int64 indent = 0) const {
681     *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
682   }
684  private:
685   const ::xla::Shape* shape_;
686 };
688 // A ShapePattern implementation that matches only if the shape is compatible to
689 // a Shape proto.
690 class ShapePatternCompatibleImpl {
691  public:
692   explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
693       : shape_(shape) {}
695   bool Match(const ::xla::Shape* shape, MatchOption option) const {
696     if (!ShapeUtil::Compatible(*shape_, *shape)) {
697       EXPLAIN << "Shape not compatible with "
698               << ShapeUtil::HumanString(*shape_);
699       return false;
700     }
701     return true;
702   }
704   void DescribeTo(std::ostream* os, int64 indent = 0) const {
705     *os << "compatible with " << ShapeUtil::HumanString(*shape_);
706   }
708  private:
709   const ::xla::Shape* shape_;
710 };
712 // A ShapePattern implementation that matches only if the shape has a given
713 // element type.
714 class ShapePatternElementTypeImpl {
715  public:
716   explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
717       : element_type_(element_type) {}
719   bool Match(const ::xla::Shape* shape, MatchOption option) const {
720     if (shape->element_type() != element_type_) {
721       EXPLAIN << "Shape does not have element type "
722               << PrimitiveType_Name(element_type_);
723       return false;
724     }
725     return true;
726   }
728   void DescribeTo(std::ostream* os, int64 indent = 0) const {
729     *os << "with element type " << PrimitiveType_Name(element_type_);
730   }
732  private:
733   PrimitiveType element_type_;
734 };
736 // A ShapePattern implementation that matches only if the shape is scalar.
737 class ShapePatternIsScalarImpl {
738  public:
739   explicit constexpr ShapePatternIsScalarImpl() {}
741   bool Match(const ::xla::Shape* shape, MatchOption option) const {
742     if (!ShapeUtil::IsScalar(*shape)) {
743       EXPLAIN << "Shape is not a scalar";
744       return false;
745     }
746     return true;
747   }
749   void DescribeTo(std::ostream* os, int64 indent = 0) const {
750     *os << "that represents a scalar";
751   }
752 };
754 // A ShapePattern implementation that matches only if the shape is an array
755 class ShapePatternIsArrayImpl {
756  public:
757   explicit constexpr ShapePatternIsArrayImpl() {}
759   bool Match(const ::xla::Shape* shape, MatchOption option) const {
760     if (!shape->IsArray()) {
761       EXPLAIN << "Shape is not an array";
762       return false;
763     }
764     return true;
765   }
767   void DescribeTo(std::ostream* os, int64 indent = 0) const {
768     *os << "that represents an array";
769   }
770 };
772 // A ShapePattern implementation that matches only if the shape is a tuple.
773 class ShapePatternIsTupleImpl {
774  public:
775   explicit constexpr ShapePatternIsTupleImpl() {}
777   bool Match(const ::xla::Shape* shape, MatchOption option) const {
778     if (!shape->IsTuple()) {
779       EXPLAIN << "Shape is not a tuple";
780       return false;
781     }
782     return true;
783   }
785   void DescribeTo(std::ostream* os, int64 indent = 0) const {
786     *os << "that represents a tuple";
787   }
788 };
790 // A ShapePattern implementation that matches only if the shape is an effective
791 // scalar.
792 class ShapePatternEffectiveScalarImpl {
793  public:
794   explicit constexpr ShapePatternEffectiveScalarImpl() {}
796   bool Match(const ::xla::Shape* shape, MatchOption option) const {
797     if (!ShapeUtil::IsEffectiveScalar(*shape)) {
798       EXPLAIN << "Shape is not an effective scalar";
799       return false;
800     }
801     return true;
802   }
804   void DescribeTo(std::ostream* os, int64 indent = 0) const {
805     *os << "that is an effective scalar";
806   }
807 };
809 // A ShapePattern implementation that matches only if the shape has a given
810 // rank.
811 class ShapePatternRankImpl {
812  public:
813   explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
815   bool Match(const ::xla::Shape* shape, MatchOption option) const {
816     if (shape->rank() != rank_) {
817       if (rank_ == 0) {
818         EXPLAIN << "Shape is not a scalar";
819       } else {
820         EXPLAIN << "Shape does not have rank " << rank_;
821       }
822       return false;
823     }
824     return true;
825   }
827   void DescribeTo(std::ostream* os, int64 indent = 0) const {
828     if (rank_ == 0) {
829       *os << "that is a scalar";
830     } else {
831       *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
832     }
833   }
835  private:
836   int64 rank_;
837 };
839 // A ShapePattern implementation that matches only if the shape has a layout
840 // that matches a given pattern.
841 template <typename LayoutType, typename LayoutImpl>
842 class ShapePatternLayoutImpl {
843  public:
844   explicit constexpr ShapePatternLayoutImpl(
845       const LayoutPattern<LayoutType, LayoutImpl>& layout)
846       : layout_(layout) {}
848   bool Match(const ::xla::Shape* shape, MatchOption option) const {
849     return LayoutUtil::HasLayout(*shape) &&
850            layout_.Match(&shape->layout(), option);
851   }
853   bool Match(::xla::Shape* shape, MatchOption option) const {
854     if (!LayoutUtil::HasLayout(*shape)) {
855       EXPLAIN << "Shape does not have a layout";
856       return false;
857     }
858     if (!layout_.Match(shape->mutable_layout(), option)) {
859       EXPLAIN << "\nin layout";
860       return false;
861     }
862     return true;
863   }
865   void DescribeTo(std::ostream* os, int64 indent = 0) const {
866     *os << "with";
867     Indent(os, indent + kIndentInc);
868     layout_.DescribeTo(os, indent + kIndentInc);
869   }
871  private:
872   LayoutPattern<LayoutType, LayoutImpl> layout_;
873 };
875 // A ShapePattern implementation that matches only if the shape has a subshape
876 // that matches a given pattern.
877 template <typename SubshapeType, typename SubshapeImpl>
878 class ShapePatternSubshapeImpl {
879  public:
880   explicit ShapePatternSubshapeImpl(
881       ShapeIndexView index,
882       const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
883       : index_(index), subshape_(subshape) {}
885   bool Match(const ::xla::Shape* shape, MatchOption option) const {
886     return MatchImpl(shape, option);
887   }
889   bool Match(::xla::Shape* shape, MatchOption option) const {
890     return MatchImpl(shape, option);
891   }
893   void DescribeTo(std::ostream* os, int64 indent = 0) const {
894     *os << "with subshape at index " << index_.ToString() << " which is";
895     Indent(os, indent + kIndentInc);
896     subshape_.DescribeTo(os, indent + kIndentInc);
897   }
899  private:
900   ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
901     return ShapeUtil::GetMutableSubshape(shape, index_);
902   }
903   const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
904     return &ShapeUtil::GetSubshape(*shape, index_);
905   }
907   template <typename ShapeType>
908   bool MatchImpl(ShapeType* shape, MatchOption option) const {
909     if (!ShapeUtil::IndexIsValid(*shape, index_)) {
910       EXPLAIN << "No subshape at " << index_.ToString();
911       return false;
912     }
913     if (!subshape_.Match(GetSubshape(shape), option)) {
914       EXPLAIN << "\nin subshape at " << index_.ToString();
915       return false;
916     }
917     return true;
918   }
920   ShapeIndexView index_;
921   ShapePattern<SubshapeType, SubshapeImpl> subshape_;
922 };
924 // A pattern that matches Shapes.
925 template <typename ShapeType, typename Impl>
926 class ShapePattern {
927  private:
928   template <typename NewImpl>
929   auto AppendImpl(NewImpl new_impl) const {
930     auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
931     return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
932                                                          matched_shape_);
933   }
935  public:
936   explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
937       : impl_(impl), matched_shape_(matched_shape) {}
939   // Returns true and captures the shape iff it matches the pattern.
940   bool Match(const ::xla::Shape* shape, MatchOption option) const {
941     if (impl_.Match(shape, option)) {
942       if (option.capture && matched_shape_) {
943         *matched_shape_ = shape;
944       }
945       return true;
946     }
947     if (shape) {
948       EXPLAIN << "\nin "
949               << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
950                                       : ShapeUtil::HumanString(*shape));
951     }
952     return false;
953   }
955   // Returns true and captures the shape iff it matches the pattern.
956   bool Match(::xla::Shape* shape, MatchOption option) const {
957     if (impl_.Match(shape, option)) {
958       if (option.capture && matched_shape_) {
959         *matched_shape_ = shape;
960       }
961       return true;
962     }
963     EXPLAIN << "\nin "
964             << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
965                                     : ShapeUtil::HumanString(*shape));
966     return false;
967   }
969   void DescribeTo(std::ostream* os, int64 indent = 0) const {
970     return impl_.DescribeTo(os, indent);
971   }
973   // Modifies the pattern to match only if the shape equals the given proto.
974   // The layout must outlive the returned pattern.
975   constexpr auto EqualTo(const ::xla::Shape* shape) const {
976     return AppendImpl(ShapePatternEqualImpl(shape));
977   }
979   // Modifies the pattern to match only if the shape is compatible to the given
980   // proto. The layout must outlive the returned pattern.
981   constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
982     return AppendImpl(ShapePatternCompatibleImpl(shape));
983   }
985   // Modifies the pattern to match only if the shape has the given element type.
986   constexpr auto WithElementType(PrimitiveType element_type) const {
987     return AppendImpl(ShapePatternElementTypeImpl(element_type));
988   }
990   // Modifies the pattern to match only if the shape is scalar.
991   constexpr auto IsScalar() const {
992     return AppendImpl(ShapePatternIsScalarImpl());
993   }
995   // Modifies the pattern to match only if the shape is an array.
996   constexpr auto IsArray() const {
997     return AppendImpl(ShapePatternIsArrayImpl());
998   }
1000   // Modifies the pattern to match only if the shape is a tuple.
1001   constexpr auto IsTuple() const {
1002     return AppendImpl(ShapePatternIsTupleImpl());
1003   }
1005   constexpr auto IsEffectiveScalar() const {
1006     return AppendImpl(ShapePatternEffectiveScalarImpl());
1007   }
1009   // Modifies the pattern to match only if the shape has the given rank.
1010   constexpr auto WithRank(int64 rank) const {
1011     return AppendImpl(ShapePatternRankImpl(rank));
1012   }
1014   // Modifies the pattern to match only if the shape has a layout that matches
1015   // the given pattern.
1016   template <typename LayoutType, typename LayoutImpl>
1017   auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1018     return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1019   }
1021   constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1022     return WithLayout(Layout().EqualTo(layout));
1023   }
1025   constexpr auto IsDenseArray() const {
1026     return WithLayout(Layout().WithDenseFormat());
1027   }
1029   // Modifies the pattern to match only if the shape has a subshape that matches
1030   // the given pattern.
1031   template <typename SubshapeType, typename SubshapeImpl>
1032   auto WithSubshape(
1033       ShapeIndexView index,
1034       const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1035     return AppendImpl(
1036         ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1037   }
1039   ShapePattern<ShapeType,
1040                AllOfPattern<::xla::Shape, Impl,
1041                             ShapePatternSubshapeImpl<
1042                                 const ::xla::Shape,
1043                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1044                                              ShapePatternEqualImpl>>>>
1045   WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1046     return WithSubshape(index,
1047                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1048                             ShapePatternBaseImpl(), nullptr)
1049                             .EqualTo(shape));
1050   }
1052   ShapePattern<ShapeType,
1053                AllOfPattern<::xla::Shape, Impl,
1054                             ShapePatternSubshapeImpl<
1055                                 const ::xla::Shape,
1056                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1057                                              ShapePatternCompatibleImpl>>>>
1058   WithSubshapeCompatibleTo(ShapeIndexView index,
1059                            const ::xla::Shape* shape) const {
1060     return WithSubshape(index,
1061                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1062                             ShapePatternBaseImpl(), nullptr)
1063                             .CompatibleTo(shape));
1064   }
1066  private:
1067   Impl impl_;
1068   ShapeType** matched_shape_;
1069 };
1071 }  // namespace detail
1073 // Creates a shape pattern that will capture the matched layout in the argument.
1074 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1075   return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1076       detail::ShapePatternBaseImpl(), matched_shape);
1077 }
1079 // Creates a shape pattern that will capture the matched layout in the argument.
1080 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1081   return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1082       detail::ShapePatternBaseImpl(), matched_shape);
1083 }
1085 namespace detail {
1087 // Overloads to get a const or non-const operand out of an instruction.
1088 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) {
1089   return instr->mutable_operand(idx);
1090 }
1091 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1092                                         int64 idx) {
1093   return instr->operand(idx);
1094 }
1096 // Pretty-printer for HloInstruction.  Sort of like ToShortString, but with
1097 // fewer %s and more shapes.
1098 inline string InstToString(const HloInstruction* inst) {
1099   return inst->ToString(
1100       HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1101 }
1103 template <typename HloInstructionType, typename Impl>
1104 class HloInstructionPattern;
1106 // The base HloInstructionPattern implementation. Matches only if the
1107 // instruction is not nullptr.
1108 class HloInstructionPatternBaseImpl {
1109  public:
1110   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1111     if (inst == nullptr) {
1112       EXPLAIN << "HloInstruction* is null";
1113       return false;
1114     }
1115     return true;
1116   }
1118   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1119     *os << "an HloInstruction";
1120   }
1122   static constexpr bool kIsTrivialMatcher = true;
1123 };
1125 // An HloInstructionPattern implementation that matches only if the instruction
1126 // has a given name.
1127 class HloInstructionPatternNameImpl {
1128  public:
1129   explicit HloInstructionPatternNameImpl(absl::string_view name)
1130       : name_(name) {}
1132   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1133     if (inst->name() != name_) {
1134       EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1135       return false;
1136     }
1137     return true;
1138   }
1140   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1141     *os << "named \"" << name_ << "\"";
1142   }
1144  private:
1145   absl::string_view name_;
1146 };
1148 // An HloInstructionPattern implementation that matches only if the instruction
1149 // equals a particular pointer.
1150 class HloInstructionIsImpl {
1151  public:
1152   explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1154   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1155     if (inst != inst_) {
1156       EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1157               << std::showbase << reinterpret_cast<uint64>(inst) << " is not "
1158               << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1159               << ")";
1160       return false;
1161     }
1162     return true;
1163   }
1165   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1166     *os << "which is " << std::hex << std::nouppercase << std::showbase
1167         << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1168         << ")";
1169   }
1171  private:
1172   const HloInstruction* inst_;
1173 };
1175 // An HloInstructionPattern implementation that matches only if the instruction
1176 // has a given opcode.
1177 class HloInstructionPatternOpcodeImpl {
1178  public:
1179   explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1180                                                      bool invert)
1181       : opcode_(opcode), invert_(invert) {}
1183   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1184     if (invert_ && inst->opcode() == opcode_) {
1185       EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1186               << ", expected anything else";
1187       return false;
1188     }
1189     if (!invert_ && inst->opcode() != opcode_) {
1190       EXPLAIN << "HloInstruction doesn't have opcode "
1191               << HloOpcodeString(opcode_);
1192       return false;
1193     }
1194     return true;
1195   }
1197   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1198     if (!invert_) {
1199       *os << "with opcode " << HloOpcodeString(opcode_);
1200     } else {
1201       *os << "with any opcode other than " << HloOpcodeString(opcode_);
1202     }
1203   }
1205  private:
1206   HloOpcode opcode_;
1207   bool invert_;
1208 };
1210 // An HloInstructionPattern implementation that matches only if the instruction
1211 // has a given custom call target.
1212 class HloInstructionCustomCallTargetImpl {
1213  public:
1214   explicit HloInstructionCustomCallTargetImpl(
1215       absl::string_view custom_call_target)
1216       : custom_call_target_(custom_call_target) {}
1218   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1219     if (inst->opcode() != HloOpcode::kCustomCall ||
1220         inst->custom_call_target() != custom_call_target_) {
1221       EXPLAIN << "HloInstruction is not a custom call with a target '"
1222               << custom_call_target_ << "'";
1223       return false;
1224     }
1225     return true;
1226   }
1228   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1229     *os << "custom call with target '" << custom_call_target_ << "'";
1230   }
1232  private:
1233   std::string custom_call_target_;
1234 };
1236 // An HloInstructionPattern implementation that matches only if the instruction
1237 // has the given number of operands.
1238 class HloInstructionPatternNumOperandsImpl {
1239  public:
1240   explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands)
1241       : num_operands_(num_operands) {}
1243   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1244     if (inst->operand_count() != num_operands_) {
1245       EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1246       return false;
1247     }
1248     return true;
1249   }
1251   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1252     *os << "with " << num_operands_ << " operand"
1253         << (num_operands_ != 1 ? "s" : "");
1254   }
1256  private:
1257   int64 num_operands_;
1258 };
1260 // An HloInstructionPattern implementation that matches only if the instruction
1261 // has a shape that matches a given pattern.
1262 template <typename ShapeType, typename ShapeImpl>
1263 class HloInstructionPatternShapeImpl {
1264  public:
1265   explicit constexpr HloInstructionPatternShapeImpl(
1266       const ShapePattern<ShapeType, ShapeImpl>& shape)
1267       : shape_(shape) {}
1269   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1270     if (!shape_.Match(&inst->shape(), option)) {
1271       EXPLAIN << "\nin output shape";
1272       return false;
1273     }
1274     return true;
1275   }
1277   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1278     if (!shape_.Match(inst->mutable_shape(), option)) {
1279       EXPLAIN << "\nin output shape";
1280       return false;
1281     }
1282     return true;
1283   }
1285   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1286     *os << "outputting";
1287     Indent(os, indent + kIndentInc);
1288     shape_.DescribeTo(os, indent + kIndentInc);
1289   }
1291  private:
1292   ShapePattern<ShapeType, ShapeImpl> shape_;
1293 };
1295 // An HloInstructionPattern implementation that matches only if the instruction
1296 // has an operand that matches a given pattern.
1297 template <typename OperandType, typename OperandImpl>
1298 class HloInstructionPatternOperandImpl {
1299  public:
1300   explicit constexpr HloInstructionPatternOperandImpl(
1301       int64 operand_index,
1302       const HloInstructionPattern<OperandType, OperandImpl>& operand)
1303       : operand_index_(operand_index), operand_(operand) {}
1305   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1306     return MatchImpl(inst, option);
1307   }
1309   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1310     return MatchImpl(inst, option);
1311   }
1313   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1314     *os << "with operand " << operand_index_ << " which is:";
1315     Indent(os, indent + kIndentInc);
1316     operand_.DescribeTo(os, indent + kIndentInc);
1317   }
1319  private:
1320   template <typename HloInstructionType>
1321   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1322     if (operand_index_ >= inst->operand_count()) {
1323       EXPLAIN << "desired operand index " << operand_index_
1324               << " is out of bounds";
1325       return false;
1326     }
1327     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1328       EXPLAIN << "\nin operand " << operand_index_;
1329       return false;
1330     }
1331     return true;
1332   }
1334   int64 operand_index_;
1335   HloInstructionPattern<OperandType, OperandImpl> operand_;
1336 };
1338 // Matches a binary instruction whose operands come in any order.
1339 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1340           typename OperandImpl2>
1341 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1342  public:
1343   explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1344       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1345       const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1346       : op1_(op1), op2_(op2) {}
1348   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1349     return MatchImpl(inst, option);
1350   }
1352   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1353     return MatchImpl(inst, option);
1354   }
1356   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1357     *os << "with two operands in either order:";
1358     Indent(os, indent);
1359     *os << " - ";
1360     op1_.DescribeTo(os, indent + 3);
1361     Indent(os, indent);
1362     *os << " - ";
1363     op2_.DescribeTo(os, indent + 3);
1364   }
1366  private:
1367   HloInstruction* operand(HloInstruction* inst, int64 idx) const {
1368     return inst->mutable_operand(idx);
1369   }
1370   const HloInstruction* operand(const HloInstruction* inst, int64 idx) const {
1371     return inst->operand(idx);
1372   }
1374   template <typename HloInstructionType>
1375   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1376     // We could implement this using AnyOf and AllOf matchers, but the templates
1377     // get pretty difficult to debug, since any compile error herein becomes
1378     // not-an-error via SFINAE.  Also this way lets us give better messages on
1379     // failure.
1380     if (inst->operand_count() != 2) {
1381       EXPLAIN << "HloInstruction did not have two operands";
1382       return false;
1383     }
1385     // If we're not generating explanations, this is pretty simple.
1386     if (!option.explain_os) {
1387       auto try_match = [&](int64 idx1, int64 idx2) {
1388         MatchOption new_option = option;
1389         new_option.capture = false;
1390         if (op1_.Match(operand(inst, idx1), new_option) &&
1391             op2_.Match(operand(inst, idx2), new_option)) {
1392           if (option.capture) {
1393             bool matched = op1_.Match(operand(inst, idx1), option) &&
1394                            op2_.Match(operand(inst, idx2), option);
1395             DCHECK(matched);
1396           }
1397           return true;
1398         }
1399         return false;
1400       };
1401       return try_match(0, 1) || try_match(1, 0);
1402     }
1404     // If we are generating explanations, we have some work to do in order to
1405     // generate a helpful error.
1406     //
1407     // First, try all four operand/matcher combinations, recording the
1408     // failure explanations separately from option.explain_os. matches[i][j]
1409     // tells us if matcher_i matches operand j.
1410     bool matches[/*matcher*/ 2][/*operand*/ 2];
1411     std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1412     for (int i = 0; i < 2; ++i) {
1413       for (int j = 0; j < 2; ++j) {
1414         MatchOption new_option = option;
1415         new_option.capture = false;
1416         new_option.explain_os = &explanations[i][j];
1417         matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1418                                : op2_.Match(operand(inst, j), new_option);
1419       }
1420     }
1422     // Check if the match succeeded.
1423     for (int i = 0; i < 2; ++i) {
1424       if (matches[0][i] && matches[1][(i + 1) % 2]) {
1425         // Rerun the matches with capture enabled if necessary.
1426         if (option.capture) {
1427           auto* operand1 = operand(inst, i);
1428           auto* operand2 = operand(inst, (i + 1) % 2);
1429           bool matched =
1430               op1_.Match(operand1, option) && op2_.Match(operand2, option);
1431           DCHECK(matched);
1432         }
1433         return true;
1434       }
1435     }
1437     auto describe_matcher = [&](int matcher_idx) {
1438       EXPLAIN << "\n - ";
1439       if (matcher_idx == 0) {
1440         op1_.DescribeTo(option.explain_os, /*indent=*/3);
1441       } else {
1442         CHECK_EQ(matcher_idx, 1);
1443         op2_.DescribeTo(option.explain_os, /*indent=*/3);
1444       }
1445       for (int i = 0; i < 2; ++i) {
1446         if (matches[matcher_idx][/*operand*/ i]) {
1447           continue;
1448         }
1449         EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1450         EXPLAIN << " - ";
1451         EXPLAIN << absl::StrReplaceAll(
1452             explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n   "}});
1453       }
1454     };
1456     // If we failed to match, one of the following is true:
1457     //  1. op1 (op2) matches neither LHS nor RHS, or
1458     //  2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1459     // We print different explanations depending on which case we're in.
1461     // Case 1.
1462     bool wrote_explanation = false;
1463     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1464       if (!matches[i][0] && !matches[i][1]) {
1465         EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1466                 << (i == 0 ? "first" : "second") << " matcher.  Specifically,";
1467         describe_matcher(i);
1468         wrote_explanation = true;
1469       }
1470     }
1472     // Case 2.
1473     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1474       if (matches[/*matcher*/ 0][/*operand*/ i] &&
1475           matches[/*matcher*/ 1][/*operand*/ i]) {
1476         CHECK(!matches[0][(i + 1) % 2]);
1477         CHECK(!matches[1][(i + 1) % 2]);
1478         CHECK(!wrote_explanation);
1479         EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1480                 << " operand did not match either of the two matchers.  "
1481                    "Specifically,";
1482         describe_matcher(0);
1483         EXPLAIN << "\nand";
1484         describe_matcher(1);
1485         wrote_explanation = true;
1486       }
1487     }
1489     CHECK(wrote_explanation);
1490     return false;
1491   }
1493   HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1494   HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1495 };
1497 // An HloInstructionPattern implementation that matches only if the instruction
1498 // is a fusion node with a particular kind.
1499 class HloInstructionPatternFusionKindImpl {
1500  public:
1501   explicit constexpr HloInstructionPatternFusionKindImpl(
1502       ::xla::HloInstruction::FusionKind kind)
1503       : kind_(kind) {}
1505   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1506     return MatchImpl(inst, option);
1507   }
1509   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1510     return MatchImpl(inst, option);
1511   }
1513   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1514     *os << "with fusion kind " << ToString(kind_);
1515   }
1517  private:
1518   template <typename HloInstructionType>
1519   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1520     if (inst->opcode() != HloOpcode::kFusion) {
1521       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1522               << "; it's not a fusion";
1523       return false;
1524     }
1525     if (inst->fusion_kind() != kind_) {
1526       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1527       return false;
1528     }
1529     return true;
1530   }
1532   ::xla::HloInstruction::FusionKind kind_;
1533 };
1535 // An HloInstructionPattern implementation that matches only if the instruction
1536 // is a kGetTupleElement with a particular tuple index.
1537 class HloInstructionPatternTupleIndexImpl {
1538  public:
1539   explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
1540       : tuple_index_(tuple_index) {}
1542   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1543     return MatchImpl(inst, option);
1544   }
1546   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1547     return MatchImpl(inst, option);
1548   }
1550   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1551     *os << "which is a GTE with index " << tuple_index_;
1552   }
1554  private:
1555   template <typename HloInstructionType>
1556   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1557     if (inst->opcode() != HloOpcode::kGetTupleElement) {
1558       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1559               << "; it's not a GTE at all";
1560       return false;
1561     }
1562     if (inst->tuple_index() != tuple_index_) {
1563       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1564       return false;
1565     }
1566     return true;
1567   }
1569   int64 tuple_index_;
1570 };
1572 class HloInstructionPatternParameterNumImpl {
1573  public:
1574   explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num)
1575       : parameter_num_(parameter_num) {}
1577   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1578     return MatchImpl(inst, option);
1579   }
1581   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1582     return MatchImpl(inst, option);
1583   }
1585   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1586     *os << "which is parameter " << parameter_num_;
1587   }
1589  private:
1590   template <typename HloInstructionType>
1591   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1592     if (inst->opcode() != HloOpcode::kParameter ||
1593         inst->parameter_number() != parameter_num_) {
1594       EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1595       return false;
1596     }
1597     return true;
1598   }
1600   int64 parameter_num_;
1601 };
1603 // Superclass that contains common code used by Op::WithOneUse() and
1604 // Op::WithOneUser().
1605 class HloInstructionPatternOneUseOrUserImpl {
1606  protected:
1607   bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1608     if (inst->user_count() != 1) {
1609       EXPLAIN << "HloInstruction has " << inst->user_count()
1610               << " users, but expected exactly one.";
1611       if (inst->user_count() > 1) {
1612         EXPLAIN << "\nAll users:";
1613         for (const HloInstruction* user : inst->users()) {
1614           EXPLAIN << "\n - " << InstToString(user);
1615         }
1616       }
1617       return false;
1618     }
1619     return true;
1620   }
1621 };
1623 class HloInstructionPatternOneUseImpl
1624     : public HloInstructionPatternOneUseOrUserImpl {
1625  public:
1626   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1627     if (!MatchOneUser(inst, option)) {
1628       return false;
1629     }
1631     int64 use_count = absl::c_count_if(
1632         inst->users()[0]->operands(),
1633         [&](const HloInstruction* operand) { return operand == inst; });
1634     if (use_count != 1) {
1635       EXPLAIN << "HloInstruction is used " << use_count
1636               << " times by its user, but is expected to be used just once: "
1637               << InstToString(inst->users()[0]);
1638       return false;
1639     }
1640     return true;
1641   }
1643   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1644     *os << "which has exactly one use";
1645   }
1646 };
1648 class HloInstructionPatternOneUserImpl
1649     : public HloInstructionPatternOneUseOrUserImpl {
1650  public:
1651   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1652     return MatchOneUser(inst, option);
1653   }
1655   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1656     *os << "which has exactly one user (but possibly is used multiple times by "
1657            "that instruction)";
1658   }
1659 };
1661 class HloInstructionPatternComparisonDirectionImpl {
1662  public:
1663   explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1664       ComparisonDirection direction)
1665       : direction_(direction) {}
1667   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1668     return MatchImpl(inst, option);
1669   }
1671   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1672     return MatchImpl(inst, option);
1673   }
1675   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1676     *os << "which has comparison direction "
1677         << ComparisonDirectionToString(direction_);
1678   }
1680  private:
1681   template <typename HloInstructionType>
1682   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1683     if (inst->opcode() != HloOpcode::kCompare ||
1684         inst->comparison_direction() != direction_) {
1685       EXPLAIN << "HloInstruction is not comparison "
1686               << ComparisonDirectionToString(direction_);
1687       return false;
1688     }
1689     return true;
1690   }
1692   ComparisonDirection direction_;
1693 };
1695 // Matches a constant scalar or effective scalar, optionally with a given value.
1696 template <typename ScalarTy>
1697 class HloConstantScalarImpl {
1698  public:
1699   explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1700       : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1702   constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1703       : val_(val), match_effective_scalar_(match_effective_scalar) {}
1705   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1706     return MatchImpl(inst, option);
1707   }
1709   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1710     return MatchImpl(inst, option);
1711   }
1713   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1714     *os << "which is a constant "
1715         << (match_effective_scalar_ ? "effective " : "") << "scalar";
1716     if (val_.has_value()) {
1717       *os << " with value " << *val_;
1718     }
1719   }
1721  private:
1722   template <typename InstTy>
1723   bool MatchImpl(InstTy* inst, MatchOption option) const {
1724     const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1725     if (!const_inst) {
1726       EXPLAIN << "HloInstruction is not a constant";
1727       return false;
1728     }
1729     if (match_effective_scalar_ &&
1730         !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1731       EXPLAIN << "HloInstruction is not an effective scalar";
1732       return false;
1733     }
1734     if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1735       EXPLAIN << "HloInstruction is not a scalar";
1736       return false;
1737     }
1738     if (!val_.has_value()) {
1739       return true;
1740     }
1742     auto const_inst_scalar_or = const_inst->literal().Reshape({});
1743     if (!const_inst_scalar_or.ok()) {
1744       EXPLAIN << "could not convert matched literal to effective scalar";
1745       return false;
1746     }
1747     Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1748     if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1749       EXPLAIN << "HloInstruction's constant value "
1750               << const_inst_scalar.ToStringWithoutShape()
1751               << " did not match expected value " << *val_;
1752       return false;
1753     }
1754     return true;
1755   }
1757   absl::optional<ScalarTy> val_;
1758   bool match_effective_scalar_;
1759 };
1761 // A pattern that matches HloInstructions.
1762 template <typename HloInstructionType, typename Impl>
1763 class HloInstructionPattern {
1764  private:
1765   template <typename NewImpl>
1766   auto AppendImpl(NewImpl new_impl) const {
1767     auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1768     return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1769         std::move(new_allof), matched_inst_);
1770   }
1772  public:
1773   explicit constexpr HloInstructionPattern(const Impl& impl,
1774                                            HloInstructionType** matched_inst)
1775       : impl_(impl), matched_inst_(matched_inst) {}
1777   // Returns true and captures the instruction iff it matches the pattern.
1778   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1779     if (impl_.Match(inst, option)) {
1780       if (option.capture && matched_inst_) {
1781         *matched_inst_ = inst;
1782       }
1783       return true;
1784     }
1785     if (inst != nullptr) {
1786       EXPLAIN << "\nin " << InstToString(inst);
1787     }
1788     return false;
1789   }
1791   // Returns true and captures the instruction iff it matches the pattern.
1792   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1793     if (impl_.Match(inst, option)) {
1794       if (option.capture && matched_inst_) {
1795         *matched_inst_ = inst;
1796       }
1797       return true;
1798     }
1799     EXPLAIN << "\nin " << InstToString(inst);
1800     return false;
1801   }
1803   // Modifies the pattern to match only if the instruction has the given name.
1804   auto WithName(absl::string_view name) const {
1805     return AppendImpl(HloInstructionPatternNameImpl(name));
1806   }
1808   // Modifies the pattern to match only if the instruction has the given opcode.
1809   auto WithOpcode(HloOpcode opcode) const {
1810     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1811   }
1813   // Modifies the pattern to match only the custom call with a given target.
1814   auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1815     return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
1816   }
1818   auto WithNumOperands(int64 num_operands) const {
1819     return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1820   }
1822   // Modifies the pattern to match only if the instruction does not have the
1823   // given opcode.
1824   auto WithoutOpcode(HloOpcode opcode) const {
1825     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1826   }
1828   constexpr auto Is(const HloInstruction* instr) const {
1829     return AppendImpl(HloInstructionIsImpl(instr));
1830   }
1832   // Modifies the pattern to match only if the instruction is a constant.
1833   constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1835   constexpr auto IsConstantScalar() const {
1836     return AppendImpl(
1837         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1838   }
1840   // This does not check that T has the same type as the instruction, so e.g.
1841   // IsConstantScalar(1.0) may match a constant of shape int32[].
1842   template <typename ScalarTy>
1843   constexpr auto IsConstantScalar(const ScalarTy& val) const {
1844     return AppendImpl(
1845         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1846   }
1848   constexpr auto IsConstantEffectiveScalar() const {
1849     return AppendImpl(
1850         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1851   }
1853   template <typename ScalarTy>
1854   constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1855     return AppendImpl(
1856         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1857   }
1859   // Modifies the pattern to match only if the instruction is not a constant.
1860   constexpr auto IsNonConstant() const {
1861     return WithoutOpcode(HloOpcode::kConstant);
1862   }
1864   // Modifies the pattern to match only if the instruction has a shape that
1865   // matches the given pattern.
1866   template <typename ShapeType, typename ShapeImpl>
1867   constexpr auto WithShape(
1868       const ShapePattern<ShapeType, ShapeImpl>& shape) const {
1869     return AppendImpl(
1870         HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1871   }
1873   // Make this a templated function to work around gcc 4.9.4 template infinite
1874   // recursion bug.
1875   template <typename Dummy = void>
1876   constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
1877     return WithShape(Shape().EqualTo(shape));
1878   }
1880   // Make this a templated function to work around gcc 4.9.4 template infinite
1881   // recursion bug.
1882   template <typename Dummy = void>
1883   constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
1884     return WithShape(Shape().CompatibleTo(shape));
1885   }
1887   // Modifies the pattern to match only if the instruction has an operand that
1888   // matches the given pattern.
1889   template <typename OperandType, typename OperandImpl>
1890   constexpr auto WithOperand(
1891       int64 operand_index,
1892       const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
1893     return AppendImpl(
1894         HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1895             operand_index, operand));
1896   }
1898   template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1899             typename OperandImpl2>
1900   constexpr auto WithBinaryOperandsAnyOrder(
1901       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1902       const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
1903     return AppendImpl(
1904         HloInstructionPatternBinaryOperandsAnyOrderImpl<
1905             OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1906   }
1908   // Modifies the pattern to match only if the instruction is a fusion node with
1909   // the given kind.
1910   constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
1911     return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1912   }
1914   // Modifies the pattern to match only if the instruction is a
1915   // get-tuple-element with the given tuple index.
1916   constexpr auto WithTupleIndex(int64 tuple_index) const {
1917     return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1918   }
1920   // Modifies the pattern to match only if the instruction is a parameter
1921   // with the given parameter number.
1922   constexpr auto WithParameterNum(int64 parameter_num) const {
1923     return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1924   }
1926   // Modifies the pattern to match if the instruction is used exactly once.
1927   // Does not match if the instruction is used twice by the same user (e.g.
1928   // multiply(x,x)).
1929   constexpr auto WithOneUse() const {
1930     return AppendImpl(HloInstructionPatternOneUseImpl());
1931   }
1933   // Modifies the pattern to match if the instruction is used by exactly one
1934   // other instruction.  Will match if the instruction is used twice, so long as
1935   // it's by the same user (e.g.  multiply(x,x)).
1936   constexpr auto WithOneUser() const {
1937     return AppendImpl(HloInstructionPatternOneUserImpl());
1938   }
1940   // Modifies the pattern to match only if the instruction has the given
1941   // comparison direction.
1942   auto WithComparisonDirection(ComparisonDirection direction) const {
1943     return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
1944   }
1946   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1947     impl_.DescribeTo(os, indent);
1948   }
1950  private:
1951   Impl impl_;
1952   HloInstructionType** matched_inst_;
1953 };
1955 }  // namespace detail
1957 // Creates an instruction pattern that will capture the matched instruction in
1958 // the argument.
1959 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
1960   return detail::HloInstructionPattern<const ::xla::HloInstruction,
1961                                        detail::HloInstructionPatternBaseImpl>(
1962       detail::HloInstructionPatternBaseImpl(), matched_inst);
1963 }
1965 // Creates an instruction pattern that will capture the matched instruction in
1966 // the argument.
1967 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
1968   return detail::HloInstructionPattern<::xla::HloInstruction,
1969                                        detail::HloInstructionPatternBaseImpl>(
1970       detail::HloInstructionPatternBaseImpl(), matched_inst);
1971 }
1973 // Helpers for nullary instructions.
1974 #define XLA_NULLOP_PATTERN(NAME)                                     \
1975   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
1976                                                                      \
1977   template <typename HloInstructionType>                             \
1978   inline auto NAME(HloInstructionType** matched_inst) {              \
1979     return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);          \
1980   }
1982 XLA_NULLOP_PATTERN(Parameter)
1987 // Helpers for unary instructions.
1988 #define XLA_UNOP_PATTERN(NAME)                                       \
1989   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
1990                                                                      \
1991   template <typename Arg>                                            \
1992   inline auto NAME(Arg&& arg) {                                      \
1993     return Op()                                                      \
1994         .WithOpcode(HloOpcode::k##NAME)                              \
1995         .WithOperand(0, std::forward<Arg>(arg));                     \
1996   }                                                                  \
1997                                                                      \
1998   template <typename HloInstructionType, typename Arg>               \
1999   inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) {   \
2000     return Op(matched_inst)                                          \
2001         .WithOpcode(HloOpcode::k##NAME)                              \
2002         .WithOperand(0, std::forward<Arg>(arg));                     \
2003   }
2005 XLA_UNOP_PATTERN(RoundNearestAfz)
2006 XLA_UNOP_PATTERN(Bitcast)
2007 XLA_UNOP_PATTERN(BitcastConvert)
2008 XLA_UNOP_PATTERN(Broadcast)
2010 XLA_UNOP_PATTERN(Convert)
2013 XLA_UNOP_PATTERN(AllReduce)
2017 XLA_UNOP_PATTERN(GetTupleElement)
2027 XLA_UNOP_PATTERN(ReducePrecision)
2028 XLA_UNOP_PATTERN(Reshape)
2029 XLA_UNOP_PATTERN(Reverse)
2037 XLA_UNOP_PATTERN(Transpose)
2038 #undef XLA_UNOP_PATTERN
2040 // Helpers for binary instructions.
2041 #define XLA_BINOP_PATTERN(NAME)                                               \
2042   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2043                                                                               \
2044   template <typename Lhs, typename Rhs>                                       \
2045   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2046     return Op()                                                               \
2047         .WithOpcode(HloOpcode::k##NAME)                                       \
2048         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2049         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2050   }                                                                           \
2051                                                                               \
2052   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2053   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2054     return Op(matched_inst)                                                   \
2055         .WithOpcode(HloOpcode::k##NAME)                                       \
2056         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2057         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2058   }
2060 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                \
2061   XLA_BINOP_PATTERN(NAME)                                                  \
2062                                                                            \
2063   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2064   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2065                              Rhs&& rhs) {                                  \
2066     return Op(matched_inst)                                                \
2067         .WithOpcode(HloOpcode::k##NAME)                                    \
2068         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2069                                     std::forward<Rhs>(rhs));               \
2070   }                                                                        \
2071   template <typename Lhs, typename Rhs>                                    \
2072   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2073     return NAME##AnyOrder<const HloInstruction>(                           \
2074         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2075   }
2081 XLA_BINOP_PATTERN(Convolution)
2090 XLA_BINOP_PATTERN(ReduceWindow)
2091 XLA_BINOP_PATTERN(Remainder)
2093 XLA_BINOP_PATTERN(Subtract)
2097 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2098 XLA_BINOP_PATTERN(ShiftRightLogical)
2102 // Helpers for ternary instructions.
2103 #define XLA_TERNOP_PATTERN(NAME)                                       \
2104   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }   \
2105                                                                        \
2106   template <typename Arg0, typename Arg1, typename Arg2>               \
2107   inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {            \
2108     return Op()                                                        \
2109         .WithOpcode(HloOpcode::k##NAME)                                \
2110         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2111         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2112         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2113   }                                                                    \
2114                                                                        \
2115   template <typename HloInstructionType, typename Arg0, typename Arg1, \
2116             typename Arg2>                                             \
2117   inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
2118                    Arg1&& arg1, Arg2&& arg2) {                         \
2119     return Op(matched_inst)                                            \
2120         .WithOpcode(HloOpcode::k##NAME)                                \
2121         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2122         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2123         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2124   }
2128 XLA_TERNOP_PATTERN(SelectAndScatter);
2131 namespace detail {
2132 template <typename Matcher, typename FirstArg>
2133 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) {
2134   return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2135 }
2137 template <typename Matcher, typename FirstArg, typename... Args>
2138 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
2139                          Args&&... args) {
2140   return WithOperands(
2141       m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2142       operand_num + 1, std::forward<Args>(args)...);
2143 }
2144 }  // namespace detail
2146 #define XLA_VARIADIC_OP_PATTERN(NAME)                                         \
2147   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2148                                                                               \
2149   template <typename... Args>                                                 \
2150   inline auto NAME(Args&&... args) {                                          \
2151     return detail::WithOperands(                                              \
2152         Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2153         /*operand_num=*/0, std::forward<Args>(args)...);                      \
2154   }                                                                           \
2155                                                                               \
2156   template <typename HloInstructionType, typename... Args>                    \
2157   inline auto NAME(HloInstructionType** matched_inst, Args&&... args) {       \
2158     return detail::WithOperands(Op(matched_inst)                              \
2159                                     .WithOpcode(HloOpcode::k##NAME)           \
2160                                     .WithNumOperands(sizeof...(Args)),        \
2161                                 /*operand_num=*/0,                            \
2162                                 std::forward<Args>(args)...);                 \
2163   }
2165 // We could implement all ops as "variadic" ops, but it would make the
2166 // already-bad compile errors even worse.
2168 XLA_VARIADIC_OP_PATTERN(Concatenate);
2169 XLA_VARIADIC_OP_PATTERN(Conditional);
2178 // Helpers for comparison instructions.
2179 #define XLA_COMPARE_PATTERN(NAME)                                             \
2180   inline auto NAME() {                                                        \
2181     return Op()                                                               \
2182         .WithOpcode(HloOpcode::kCompare)                                      \
2183         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2184   }                                                                           \
2185                                                                               \
2186   template <typename Lhs, typename Rhs>                                       \
2187   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2188     return Op()                                                               \
2189         .WithOpcode(HloOpcode::kCompare)                                      \
2190         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2191         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2192         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2193   }                                                                           \
2194                                                                               \
2195   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2196   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2197     return Op(matched_inst)                                                   \
2198         .WithOpcode(HloOpcode::kCompare)                                      \
2199         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2200         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2201         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2202   }
2204 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME)                              \
2205   XLA_COMPARE_PATTERN(NAME)                                                \
2206                                                                            \
2207   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2208   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2209                              Rhs&& rhs) {                                  \
2210     return Op(matched_inst)                                                \
2211         .WithOpcode(HloOpcode::kCompare)                                   \
2212         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2213                                     std::forward<Rhs>(rhs));               \
2214   }                                                                        \
2215   template <typename Lhs, typename Rhs>                                    \
2216   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2217     return NAME##AnyOrder<const HloInstruction>(                           \
2218         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2219   }
2228 // Helpers for matching non-constant instructions.
2229 inline auto NonConstant() { return Op().IsNonConstant(); }
2231 template <typename HloInstructionType>
2232 inline auto NonConstant(HloInstructionType** matched_inst) {
2233   return Op(matched_inst).IsNonConstant();
2234 }
2236 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2237 // element is selected.
2238 template <typename Arg>
2239 inline auto GetTupleElement(Arg&& arg, int64 tuple_index) {
2240   return Op()
2241       .WithOpcode(HloOpcode::kGetTupleElement)
2242       .WithOperand(0, std::forward<Arg>(arg))
2243       .WithTupleIndex(tuple_index);
2244 }
2246 template <typename HloInstructionType, typename Arg>
2247 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2248                             int64 tuple_index) {
2249   return Op(matched_inst)
2250       .WithOpcode(HloOpcode::kGetTupleElement)
2251       .WithOperand(0, std::forward<Arg>(arg))
2252       .WithTupleIndex(tuple_index);
2253 }
2255 // Add overloads for Parameter which take an int64 specifying the parameter
2256 // number.
2257 inline auto Parameter(int64 parameter_num) {
2258   return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2259 }
2260 template <typename HloInstructionType>
2261 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) {
2262   return Op(matched_inst)
2263       .WithOpcode(HloOpcode::kParameter)
2264       .WithParameterNum(parameter_num);
2265 }
2267 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2269 template <typename HloInstructionType>
2270 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2271   return Op(matched_inst).IsConstantScalar();
2272 }
2274 template <typename ScalarTy>
2275 inline auto ConstantScalar(ScalarTy val) {
2276   return Op().IsConstantScalar(val);
2277 }
2279 template <typename HloInstructionType, typename ScalarTy>
2280 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2281   return Op(matched_inst).IsConstantScalar(val);
2282 }
2284 inline auto ConstantEffectiveScalar() {
2285   return Op().IsConstantEffectiveScalar();
2286 }
2288 template <typename HloInstructionType>
2289 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2290   return Op(matched_inst).IsConstantEffectiveScalar();
2291 }
2293 template <typename ScalarTy>
2294 inline auto ConstantEffectiveScalar(ScalarTy val) {
2295   return Op().IsConstantEffectiveScalar(val);
2296 }
2298 template <typename HloInstructionType, typename ScalarTy>
2299 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2300                                     ScalarTy val) {
2301   return Op(matched_inst).IsConstantEffectiveScalar(val);
2302 }
2304 }  // namespace match
2306 }  // namespace xla
2308 #undef EXPLAIN
2309 #pragma pop_macro("EXPLAIN")