• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18 
19 #include "absl/strings/str_replace.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/utility/utility.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 
30 namespace xla {
31 
32 // A pattern matcher for HloInstructions, Shapes, and Layouts.
33 //
34 // The Match function's first argument must be HloInstruction*, Shape*, or
35 // Layout*. The second argument is a pattern that will be matched against the
36 // first argument, as described below.
37 //
38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
39 // functions. By default, the returned patterns will match any HloInstruction,
40 // Shape, or Layout, respectively. However the match can be made more specific
41 // by using the pattern's modifier methods, for example:
42 //
43 //   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
44 //     0, match::Op().WithOpcode(HloOpcode::kConstant))
45 //
46 // This pattern will match Add instructions whose first operand is a constant.
47 //
48 // Each pattern type has the following modifiers, which are described where
49 // nontrivial.
50 //
51 //   Op():
52 //     - Is: is the given HloInstruction* (i.e. pointer equality)
53 //     - WithName
54 //     - WithOpcode
55 //     - WithoutOpcode: anything other than the given opcode
56 //     - WithShape: instr's shape matches the given pattern
57 //     - WithShapeEqualTo: instr's shape is equal to the given Shape
58 //     - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
59 //     - WithNumOperands
60 //     - WithOperand: operand at the given index matches the given pattern
61 //     - IsConstant
62 //     - IsNonConstant
63 //     - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
64 //       e.g. IsConstantScalar() or IsConstantScalar(42).
65 //     - WithFusionKind
66 //     - WithTupleIndex: get-tuple-element operations with the given tuple index
67 //     - WithOneUse: Instruction is used as an operand exactly once.
68 //     - WithOneUser: Instruction is used by exactly one other instruction, but
69 //       is possibly used more than once as an operand (e.g. multiply(x,x)).
70 //     - WithComparisonDirection: instr has the given direction
71 //
72 //   Shape():
73 //     - EqualTo
74 //     - CompatibleTo
75 //     - IsScalar/IsEffectiveScalar/IsArray/IsTuple
76 //     - IsDenseArray
77 //     - WithLayout: layout shape's layout matches the given pattern (e.g.
78 //       Layout().WithDenseFormat())
79 //     - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
80 //       Layout, but not the result of Layout().foo())
81 //     - WithSubshape: shape is a tuple whose subshape matches the given pattern
82 //       (e.g. Shape().IsScalar()).
83 //     - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
84 //       (i.e. another Shape, but not the result of Shape().foo())
85 //     - WithElementType: shape is an array/scalar with the given elem type
86 //     - WithRank: shape is an array/scalar with the given rank
87 //
88 //  Layout():
89 //     - EqualTo
90 //     - WithDenseFormat
91 //
92 // Op(), Shape(), and Layout() may be passed an argument of type
93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
94 // these pointers. If the pattern is matched, the address of the matched value
95 // will be "captured" and stored at this location.
96 //
97 // For example:
98 //   HloInstruction* foo = ...;
99 //   HloInstruction* matched_operand;
100 //   CHECK(Match(foo,
101 //               match::Op().WithOperand(0, match::Op(&matched_operand))));
102 //
103 // Helpers are provided for most HLO instructions. These helpers can be called
104 // with no arguments, in which case they will match any instruction matching the
105 // opcode. They may also be called with matches for the operands and with an
106 // optional capture. (The capture must be the first argument.) Some examples of
107 // these helpers and their equivalents are provided below.
108 
109 // Example nullary instruction:
110 //   Parameter()                    == Op().WithOpcode(HloOpcode::kParameter)
111 //   Parameter(&a)                  == Op(&a).WithOpcode(HloOpcode::kParameter)
112 //
113 // Example unary instruction:
114 //   Abs()                          == Op().WithOpcode(HloOpcode::kAbs)
115 //   Abs(Op(&a))                    == Op().WithOpcode(HloOpcode::kAbs)
116 //                                         .WithOperand(0, Op(&a)))
117 //   Abs(&a, Op(&b))                == Op(&a).WithOpcode(HloOpcode::kAbs)
118 //                                           .WithOperand(0, Op(&b))
119 //
120 // Commutative binary instructions have a special form that accepts either order
121 // of args, e.g.:
122 //
123 //   AddAnyOrder(Parameter(1), Abs()) ==
124 //     Op().WithOpcode(HloOpcode::kAdd)
125 //         .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
126 //
127 //   MultiplyAnyOrder(&a, Parameter(), Abs())  // Captures the mul in `a`.
128 //
129 // The following additional helpers are provided.  In all cases, `&a` is
130 // optional.
131 //
132 //   ConstantScalar(&a)               == Op(&a).IsConstantScalar();
133 //   ConstantScalar(&a, v)            == Op(&a).IsConstantScalar(v);
134 //   ConstantEffectiveScalar(&a)      == Op(&a).IsConstantEffectiveScalar();
135 //   ConstantEffectiveScalar(&a, v)   == Op(&a).IsConstantEffectiveScalar(&a, v)
136 //   NonConstant(&a)                  == Op(&a).IsNonConstant()
137 //   GetTupleElement(&a, b, index)    == Op(&a).WithTupleIndex(index)
138 //                                             .WithOperand(0, b);
139 //   Parameter(&a, n)                 == Op(&a).WithParameterNum(n);
140 
141 struct MatchOption {
142   // If true, actually capture matched item into the user pointer.
143   bool capture;
144 
145   // An explanation for why we failed to match is streamed here, if not-null.
146   std::ostream* explain_os;
147 };
148 
149 template <typename Value, typename Pattern>
150 bool Match(Value* value, const Pattern& pattern,
151            MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
152   if (option.capture) {
153     auto new_option = option;
154     new_option.capture = false;
155     if (!pattern.Match(value, new_option)) {
156       return false;
157     }
158   }
159   return pattern.Match(value, option);
160 }
161 
162 namespace match {
163 
164 namespace detail {
165 
166 // Macro for streaming to option.explain_os if it's not null.
167 //
168 //   EXPLAIN << "value of foo(): " << foo()
169 //
170 #pragma push_macro("EXPLAIN")
171 #define EXPLAIN \
172   if (option.explain_os) *option.explain_os
173 
174 // kIndentInc is the additional number of spaces that we indent by when we
175 // increase the indent "by one".
176 enum {
177   kIndentInc = 2,
178 };
179 
180 // Writes a newline and then `indent` spaces.
181 //
182 // We follow an unintuitive convention in this file's pretty-printers: Indents
183 // are performed by the caller, not the callee.  For example, if you want to
184 // print
185 //
186 //   foo:
187 //    - bar
188 //
189 // you'd do:
190 //
191 //  Foo::DescribeTo(std::ostream* os, int64 indent) {
192 //    *os << "foo:";
193 //    Indent(os, indent)  // Create a newline at the *current* indent level.
194 //    *os << " - ";
195 //    bar.DescribeTo(os, indent + 3);  // + 3 because strlen(" * ") == 3.
196 //  }
197 //
198 //  Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
199 //
200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
201 // performed by Foo.  This convention allows the caller to decide whether a
202 // matcher is preceded by a newline, which is important e.g. for the AllOf
203 // matcher.
204 //
205 // (Incidentally, indenting in Match's explanations is handled differently.
206 // Indents are a common case in DescribeTo [we're printing a whole tree], but
207 // they're a special case in Match [we're printing only a path through the tree
208 // that encounters a failing node]. Indents in Match only appear when we
209 // encounter a failing disjunction, so we just handle them as a special case
210 // there.)
Indent(std::ostream * os,int64 indent)211 inline void Indent(std::ostream* os, int64 indent) {
212   *os << "\n";
213   for (int64 i = 0; i < indent; ++i) {
214     *os << " ";
215   }
216 }
217 
218 // SFINAE template that determines whether T declares a static member
219 // kIsTrivialMatcher.
220 //
221 // Trivial matchers get special treatment.  For example, when printing
222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
223 // yields e.g.
224 //    "a shape compatible with f32[1,2]"
225 // rather than
226 //    "a shape AND compatible with f32[1,2]"
227 template <typename T, typename Dummy = void>
228 struct IsTrivialMatcher {
229   static constexpr bool value = false;
230 };
231 template <typename T>
232 struct IsTrivialMatcher<T,
233                         typename std::enable_if<T::kIsTrivialMatcher>::type> {
234   static constexpr bool value = true;
235 };
236 
237 template <typename Item, typename... Patterns>
238 class AllOfPattern {
239  public:
240   explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
241 
242   bool Match(const Item* item, MatchOption option) const {
243     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
244     // This invariant is guaranteed by the top-level Match and AnyOf.
245     DCHECK(matched || !option.capture);
246     return matched;
247   }
248 
249   bool Match(Item* item, MatchOption option) const {
250     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
251     // This invariant is guaranteed by the top-level Match and AnyOf.
252     DCHECK(matched || !option.capture);
253     return matched;
254   }
255 
256   void DescribeTo(std::ostream* os, int64 indent = 0) const {
257     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
258   }
259 
260   // Accessor for patterns_.  Please don't use this outside of this file.
261   const std::tuple<Patterns...>& patterns() const { return patterns_; }
262 
263  private:
264   template <typename ItemType, size_t index>
265   bool MatchImpl(ItemType* item, MatchOption option,
266                  std::integral_constant<size_t, index>) const {
267     // We don't need to do any EXPLAINing here; it's all correctly handled by
268     // our sub-matchers (if any fail).
269     return std::get<index>(patterns_).Match(item, option) &&
270            MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
271   }
272 
273   template <typename ItemType>
274   bool MatchImpl(ItemType* item, MatchOption option,
275                  std::integral_constant<size_t, sizeof...(Patterns)>) const {
276     return true;
277   }
278 
279   // Pretty-printing a conjunction has some special cases to make it easy to
280   // read in the simple (common) case.
281   //
282   // If sizeof...(Patterns) == 1, prints as e.g.
283   //
284   //   a shape
285   //
286   // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
287   // shape") prints as
288   //
289   //   a shape compatible with f32[1,2]
290   //
291   // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
292   //
293   //   a shape:
294   //    * compatible with f32[1,2] AND
295   //    * that represents a scalar
296   //
297   // Otherwise prints as:
298   //
299   //   all of:
300   //    * foo AND
301   //    * bar
302   //
303   template <size_t index>
304   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
305                       int64 indent) const {
306     constexpr bool first_is_trivial =
307         IsTrivialMatcher<typename std::remove_reference<decltype(
308             std::get<0>(patterns_))>::type>::value;
309     constexpr bool is_last = index == sizeof...(Patterns) - 1;
310     const auto& submatcher = std::get<index>(patterns_);
311 
312     auto print_bulleted_item = [&] {
313       *os << " * ";
314       submatcher.DescribeTo(os, indent + 3);
315       if (!is_last) {
316         *os << " AND";
317         Indent(os, indent);
318       }
319     };
320 
321     if (index == 0) {
322       if (first_is_trivial || is_last) {
323         submatcher.DescribeTo(os, indent + kIndentInc);
324         if (sizeof...(Patterns) > 2) {
325           *os << ":";
326           Indent(os, indent);
327         }
328       } else {
329         *os << "all of:";
330         Indent(os, indent);
331         print_bulleted_item();
332       }
333     } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
334       *os << " ";
335       submatcher.DescribeTo(os, indent);
336     } else {
337       print_bulleted_item();
338     }
339     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
340   }
341 
342   void DescribeToImpl(std::ostream* os,
343                       std::integral_constant<size_t, sizeof...(Patterns)>,
344                       int64 indent) const {}
345 
346   std::tuple<Patterns...> patterns_;
347 };
348 
349 }  // namespace detail
350 
351 // Returns a pattern that represents the conjunction of all input patterns. All
352 // patterns need to match in order to have the AllOf pattern match.
353 template <typename Item, typename... Patterns>
354 auto AllOf(const Patterns&... patterns) {
355   return detail::AllOfPattern<typename std::remove_const<Item>::type,
356                               Patterns...>(patterns...);
357 }
358 
359 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
360 //
361 // This transformation is necessary for good pretty-printing.
362 template <typename Item, typename... InnerPs, typename... OuterPs>
363 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
364            const OuterPs&... outer_ps) {
365   // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
366   auto make_all_of = [](const InnerPs&... inner_ps,
367                         const OuterPs&... outer_ps) {
368     return detail::AllOfPattern<typename std::remove_const<Item>::type,
369                                 InnerPs..., OuterPs...>(inner_ps...,
370                                                         outer_ps...);
371   };
372   return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
373                                                  std::make_tuple(outer_ps...)));
374 }
375 
376 namespace detail {
377 
378 template <typename LayoutType, typename Impl>
379 class LayoutPattern;
380 
381 // The base LayoutPattern implementation. Matches only if the layout is not
382 // nullptr.
383 class LayoutPatternBaseImpl {
384  public:
385   bool Match(const ::xla::Layout* layout, MatchOption option) const {
386     if (layout == nullptr) {
387       EXPLAIN << "Layout is null";
388       return false;
389     }
390     return true;
391   }
392 
393   void DescribeTo(std::ostream* os, int64 indent = 0) const {
394     *os << "a layout";
395   }
396 
397   static constexpr bool kIsTrivialMatcher = true;
398 };
399 
400 // A LayoutPattern implementation that matches only if the layout equals a
401 // Layout proto.
402 class LayoutPatternEqualImpl {
403  public:
404   explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
405       : layout_(layout) {}
406 
407   bool Match(const ::xla::Layout* layout, MatchOption option) const {
408     if (!LayoutUtil::Equal(*layout_, *layout)) {
409       EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
410               << " is not equal to expected "
411               << LayoutUtil::HumanString(*layout_);
412       return false;
413     }
414     return true;
415   }
416 
417   void DescribeTo(std::ostream* os, int64 indent = 0) const {
418     *os << "equal to " << LayoutUtil::HumanString(*layout_);
419   }
420 
421  private:
422   const ::xla::Layout* layout_;
423 };
424 
425 // A LayoutPattern implementation that matches only if the layout has a given
426 // format.
427 class LayoutPatternFormatImpl {
428  public:
429   explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
430 
431   bool Match(const ::xla::Layout* layout, MatchOption option) const {
432     if (layout->format() != format_) {
433       EXPLAIN << "Layout has format " << Format_Name(layout->format())
434               << " but expected " << Format_Name(format_);
435       return false;
436     }
437     return true;
438   }
439 
440   void DescribeTo(std::ostream* os, int64 indent = 0) const {
441     *os << "with format " << Format_Name(format_);
442   }
443 
444  private:
445   Format format_;
446 };
447 
448 // A pattern that matches Layouts.
449 template <typename LayoutType, typename Impl>
450 class LayoutPattern {
451  private:
452   template <typename NewImpl>
453   auto AppendImpl(NewImpl new_impl) const {
454     auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
455     return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
456                                                           matched_layout_);
457   }
458 
459  public:
460   explicit constexpr LayoutPattern(const Impl& impl,
461                                    LayoutType** matched_layout)
462       : impl_(impl), matched_layout_(matched_layout) {}
463 
464   // Returns true and captures the layout iff it matches the pattern.
465   bool Match(const ::xla::Layout* layout, MatchOption option) const {
466     if (impl_.Match(layout, option)) {
467       if (option.capture && matched_layout_) {
468         *matched_layout_ = layout;
469       }
470       return true;
471     }
472     return false;
473   }
474 
475   // Returns true and captures the layout iff it matches the pattern.
476   bool Match(::xla::Layout* layout, MatchOption option) const {
477     if (impl_.Match(layout, option)) {
478       if (option.capture && matched_layout_) {
479         *matched_layout_ = layout;
480       }
481       return true;
482     }
483     return false;
484   }
485 
486   void DescribeTo(std::ostream* os, int64 indent = 0) const {
487     impl_.DescribeTo(os, indent);
488   }
489 
490   // Modifies the pattern to match only if the layout equals the given proto.
491   // The layout must outlive the returned pattern.
492   constexpr auto EqualTo(const ::xla::Layout* layout) const {
493     return AppendImpl(LayoutPatternEqualImpl(layout));
494   }
495 
496   // Modifies the pattern to match only if the layout has a dense format.
497   constexpr auto WithDenseFormat() const {
498     return AppendImpl(LayoutPatternFormatImpl(DENSE));
499   }
500 
501  private:
502   Impl impl_;
503   LayoutType** matched_layout_;
504 };
505 
506 template <typename Item, typename... Patterns>
507 class AnyOfPattern {
508  public:
509   explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
510 
511   bool Match(const Item* item, MatchOption option) const {
512     return MatchImpl(item, option);
513   }
514 
515   bool Match(Item* item, MatchOption option) const {
516     return MatchImpl(item, option);
517   }
518 
519   void DescribeTo(std::ostream* os, int64 indent = 0) const {
520     *os << "any of:";
521     Indent(os, indent);
522     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
523   }
524 
525  private:
526   template <typename ItemType>
527   bool MatchImpl(ItemType* item, MatchOption option) const {
528     // If we're generating an explanation, buffer it until we know we failed.
529     absl::optional<std::stringstream> explanation;
530     MatchOption new_option = option;
531     if (option.explain_os) {
532       new_option.explain_os = &explanation.emplace();
533     }
534     bool rv = MatchRecursiveImpl(item, new_option,
535                                  std::integral_constant<size_t, 0>());
536     if (!rv && option.explain_os) {
537       EXPLAIN << "None of the following matchers succeeded:";
538       EXPLAIN << explanation->str();
539     }
540     return rv;
541   }
542 
543   template <typename ItemType, size_t index>
544   bool MatchRecursiveImpl(ItemType* item, MatchOption option,
545                           std::integral_constant<size_t, index>) const {
546     auto new_option = option;
547     new_option.capture = false;
548 
549     absl::optional<std::stringstream> explanation;
550     if (option.explain_os) {
551       new_option.explain_os = &explanation.emplace();
552     }
553 
554     // Try to match the sub-pattern without capturing behavior.
555     if (std::get<index>(patterns_).Match(item, new_option)) {
556       // Capture the branch.
557       if (option.capture) {
558         // TODO(timshen): Currently the behavior can be exponential. Optimize it
559         // with memoization or recording the matched sub-pattern index, if it
560         // takes too long to run.
561         //
562         // Specifically, the "memoization" approach is to create an empty
563         // container with the key (pattern, instruction), and value as whether
564         // matched or not.
565         //
566         // Alternatively, we may run the pattern matching with captures off, but
567         // instead record a "trace" somewhere, indicating how exactly the
568         // pattern matches the input. For example, the trace information for
569         // AnyOf will be a runtime number indicate which sub-pattern is matched.
570         // Then we run another pass to do captures only with the help of the
571         // trace.
572         bool matched = std::get<index>(patterns_).Match(item, option);
573         DCHECK(matched);
574       }
575       return true;
576     }
577     if (option.explain_os) {
578       EXPLAIN << "\nMatcher #" << index + 1;
579       EXPLAIN << "\n - ";
580       std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
581       EXPLAIN << "\nfailed with";
582       EXPLAIN << "\n - ";
583       EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n   "}});
584     }
585     return MatchRecursiveImpl(item, option,
586                               std::integral_constant<size_t, index + 1>());
587   }
588 
589   template <typename ItemType>
590   bool MatchRecursiveImpl(
591       ItemType* item, MatchOption option,
592       std::integral_constant<size_t, sizeof...(Patterns)>) const {
593     return false;
594   }
595 
596   template <size_t index>
597   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
598                       int64 indent) const {
599     *os << " - ";
600     std::get<index>(patterns_).DescribeTo(os, indent + 3);
601     if (index != sizeof...(Patterns) - 1) {
602       *os << " OR";
603       Indent(os, indent);
604     }
605     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
606   }
607 
608   void DescribeToImpl(std::ostream* os,
609                       std::integral_constant<size_t, sizeof...(Patterns)>,
610                       int64 indent) const {}
611 
612   std::tuple<Patterns...> patterns_;
613 };
614 
615 }  // namespace detail
616 
617 // Returns a pattern that represents the logical disjunction of the input
618 // patterns. The returned pattern matches from left to right, and stops on the
619 // first match.
620 template <typename Item, typename... Patterns>
621 auto AnyOf(const Patterns&... patterns) {
622   return detail::AnyOfPattern<typename std::remove_const<Item>::type,
623                               Patterns...>(patterns...);
624 }
625 
626 // Creates a layout pattern that will capture the matched layout in the
627 // argument.
628 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
629   return detail::LayoutPattern<const ::xla::Layout,
630                                detail::LayoutPatternBaseImpl>(
631       detail::LayoutPatternBaseImpl(), matched_layout);
632 }
633 
634 // Creates a layout pattern that will capture the matched layout in the
635 // argument.
636 inline constexpr auto Layout(::xla::Layout** matched_layout) {
637   return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
638       detail::LayoutPatternBaseImpl(), matched_layout);
639 }
640 
641 namespace detail {
642 
643 template <typename ShapeType, typename Impl>
644 class ShapePattern;
645 
646 // The base ShapePattern implementation. Matches only if the shape is not
647 // nullptr.
648 class ShapePatternBaseImpl {
649  public:
650   bool Match(const ::xla::Shape* shape, MatchOption option) const {
651     if (shape == nullptr) {
652       EXPLAIN << "Shape is null";
653     }
654     return shape != nullptr;
655   }
656 
657   void DescribeTo(std::ostream* os, int64 indent = 0) const {
658     *os << "a shape";
659   }
660 
661   static constexpr bool kIsTrivialMatcher = true;
662 };
663 
664 // A ShapePattern implementation that matches only if the shape equals a Shape
665 // proto.
666 class ShapePatternEqualImpl {
667  public:
668   explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
669       : shape_(shape) {}
670 
671   bool Match(const ::xla::Shape* shape, MatchOption option) const {
672     if (!ShapeUtil::Equal(*shape_, *shape)) {
673       EXPLAIN << "Shape not equal to "
674               << ShapeUtil::HumanStringWithLayout(*shape_);
675       return false;
676     }
677     return true;
678   }
679 
680   void DescribeTo(std::ostream* os, int64 indent = 0) const {
681     *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
682   }
683 
684  private:
685   const ::xla::Shape* shape_;
686 };
687 
688 // A ShapePattern implementation that matches only if the shape is compatible to
689 // a Shape proto.
690 class ShapePatternCompatibleImpl {
691  public:
692   explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
693       : shape_(shape) {}
694 
695   bool Match(const ::xla::Shape* shape, MatchOption option) const {
696     if (!ShapeUtil::Compatible(*shape_, *shape)) {
697       EXPLAIN << "Shape not compatible with "
698               << ShapeUtil::HumanString(*shape_);
699       return false;
700     }
701     return true;
702   }
703 
704   void DescribeTo(std::ostream* os, int64 indent = 0) const {
705     *os << "compatible with " << ShapeUtil::HumanString(*shape_);
706   }
707 
708  private:
709   const ::xla::Shape* shape_;
710 };
711 
712 // A ShapePattern implementation that matches only if the shape has a given
713 // element type.
714 class ShapePatternElementTypeImpl {
715  public:
716   explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
717       : element_type_(element_type) {}
718 
719   bool Match(const ::xla::Shape* shape, MatchOption option) const {
720     if (shape->element_type() != element_type_) {
721       EXPLAIN << "Shape does not have element type "
722               << PrimitiveType_Name(element_type_);
723       return false;
724     }
725     return true;
726   }
727 
728   void DescribeTo(std::ostream* os, int64 indent = 0) const {
729     *os << "with element type " << PrimitiveType_Name(element_type_);
730   }
731 
732  private:
733   PrimitiveType element_type_;
734 };
735 
736 // A ShapePattern implementation that matches only if the shape is scalar.
737 class ShapePatternIsScalarImpl {
738  public:
739   explicit constexpr ShapePatternIsScalarImpl() {}
740 
741   bool Match(const ::xla::Shape* shape, MatchOption option) const {
742     if (!ShapeUtil::IsScalar(*shape)) {
743       EXPLAIN << "Shape is not a scalar";
744       return false;
745     }
746     return true;
747   }
748 
749   void DescribeTo(std::ostream* os, int64 indent = 0) const {
750     *os << "that represents a scalar";
751   }
752 };
753 
754 // A ShapePattern implementation that matches only if the shape is an array
755 class ShapePatternIsArrayImpl {
756  public:
757   explicit constexpr ShapePatternIsArrayImpl() {}
758 
759   bool Match(const ::xla::Shape* shape, MatchOption option) const {
760     if (!shape->IsArray()) {
761       EXPLAIN << "Shape is not an array";
762       return false;
763     }
764     return true;
765   }
766 
767   void DescribeTo(std::ostream* os, int64 indent = 0) const {
768     *os << "that represents an array";
769   }
770 };
771 
772 // A ShapePattern implementation that matches only if the shape is a tuple.
773 class ShapePatternIsTupleImpl {
774  public:
775   explicit constexpr ShapePatternIsTupleImpl() {}
776 
777   bool Match(const ::xla::Shape* shape, MatchOption option) const {
778     if (!shape->IsTuple()) {
779       EXPLAIN << "Shape is not a tuple";
780       return false;
781     }
782     return true;
783   }
784 
785   void DescribeTo(std::ostream* os, int64 indent = 0) const {
786     *os << "that represents a tuple";
787   }
788 };
789 
790 // A ShapePattern implementation that matches only if the shape is an effective
791 // scalar.
792 class ShapePatternEffectiveScalarImpl {
793  public:
794   explicit constexpr ShapePatternEffectiveScalarImpl() {}
795 
796   bool Match(const ::xla::Shape* shape, MatchOption option) const {
797     if (!ShapeUtil::IsEffectiveScalar(*shape)) {
798       EXPLAIN << "Shape is not an effective scalar";
799       return false;
800     }
801     return true;
802   }
803 
804   void DescribeTo(std::ostream* os, int64 indent = 0) const {
805     *os << "that is an effective scalar";
806   }
807 };
808 
809 // A ShapePattern implementation that matches only if the shape has a given
810 // rank.
811 class ShapePatternRankImpl {
812  public:
813   explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
814 
815   bool Match(const ::xla::Shape* shape, MatchOption option) const {
816     if (shape->rank() != rank_) {
817       if (rank_ == 0) {
818         EXPLAIN << "Shape is not a scalar";
819       } else {
820         EXPLAIN << "Shape does not have rank " << rank_;
821       }
822       return false;
823     }
824     return true;
825   }
826 
827   void DescribeTo(std::ostream* os, int64 indent = 0) const {
828     if (rank_ == 0) {
829       *os << "that is a scalar";
830     } else {
831       *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
832     }
833   }
834 
835  private:
836   int64 rank_;
837 };
838 
839 // A ShapePattern implementation that matches only if the shape has a layout
840 // that matches a given pattern.
841 template <typename LayoutType, typename LayoutImpl>
842 class ShapePatternLayoutImpl {
843  public:
844   explicit constexpr ShapePatternLayoutImpl(
845       const LayoutPattern<LayoutType, LayoutImpl>& layout)
846       : layout_(layout) {}
847 
848   bool Match(const ::xla::Shape* shape, MatchOption option) const {
849     return LayoutUtil::HasLayout(*shape) &&
850            layout_.Match(&shape->layout(), option);
851   }
852 
853   bool Match(::xla::Shape* shape, MatchOption option) const {
854     if (!LayoutUtil::HasLayout(*shape)) {
855       EXPLAIN << "Shape does not have a layout";
856       return false;
857     }
858     if (!layout_.Match(shape->mutable_layout(), option)) {
859       EXPLAIN << "\nin layout";
860       return false;
861     }
862     return true;
863   }
864 
865   void DescribeTo(std::ostream* os, int64 indent = 0) const {
866     *os << "with";
867     Indent(os, indent + kIndentInc);
868     layout_.DescribeTo(os, indent + kIndentInc);
869   }
870 
871  private:
872   LayoutPattern<LayoutType, LayoutImpl> layout_;
873 };
874 
875 // A ShapePattern implementation that matches only if the shape has a subshape
876 // that matches a given pattern.
877 template <typename SubshapeType, typename SubshapeImpl>
878 class ShapePatternSubshapeImpl {
879  public:
880   explicit ShapePatternSubshapeImpl(
881       ShapeIndexView index,
882       const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
883       : index_(index), subshape_(subshape) {}
884 
885   bool Match(const ::xla::Shape* shape, MatchOption option) const {
886     return MatchImpl(shape, option);
887   }
888 
889   bool Match(::xla::Shape* shape, MatchOption option) const {
890     return MatchImpl(shape, option);
891   }
892 
893   void DescribeTo(std::ostream* os, int64 indent = 0) const {
894     *os << "with subshape at index " << index_.ToString() << " which is";
895     Indent(os, indent + kIndentInc);
896     subshape_.DescribeTo(os, indent + kIndentInc);
897   }
898 
899  private:
900   ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
901     return ShapeUtil::GetMutableSubshape(shape, index_);
902   }
903   const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
904     return &ShapeUtil::GetSubshape(*shape, index_);
905   }
906 
907   template <typename ShapeType>
908   bool MatchImpl(ShapeType* shape, MatchOption option) const {
909     if (!ShapeUtil::IndexIsValid(*shape, index_)) {
910       EXPLAIN << "No subshape at " << index_.ToString();
911       return false;
912     }
913     if (!subshape_.Match(GetSubshape(shape), option)) {
914       EXPLAIN << "\nin subshape at " << index_.ToString();
915       return false;
916     }
917     return true;
918   }
919 
920   ShapeIndexView index_;
921   ShapePattern<SubshapeType, SubshapeImpl> subshape_;
922 };
923 
924 // A pattern that matches Shapes.
925 template <typename ShapeType, typename Impl>
926 class ShapePattern {
927  private:
928   template <typename NewImpl>
929   auto AppendImpl(NewImpl new_impl) const {
930     auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
931     return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
932                                                          matched_shape_);
933   }
934 
935  public:
936   explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
937       : impl_(impl), matched_shape_(matched_shape) {}
938 
939   // Returns true and captures the shape iff it matches the pattern.
940   bool Match(const ::xla::Shape* shape, MatchOption option) const {
941     if (impl_.Match(shape, option)) {
942       if (option.capture && matched_shape_) {
943         *matched_shape_ = shape;
944       }
945       return true;
946     }
947     if (shape) {
948       EXPLAIN << "\nin "
949               << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
950                                       : ShapeUtil::HumanString(*shape));
951     }
952     return false;
953   }
954 
955   // Returns true and captures the shape iff it matches the pattern.
956   bool Match(::xla::Shape* shape, MatchOption option) const {
957     if (impl_.Match(shape, option)) {
958       if (option.capture && matched_shape_) {
959         *matched_shape_ = shape;
960       }
961       return true;
962     }
963     EXPLAIN << "\nin "
964             << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
965                                     : ShapeUtil::HumanString(*shape));
966     return false;
967   }
968 
969   void DescribeTo(std::ostream* os, int64 indent = 0) const {
970     return impl_.DescribeTo(os, indent);
971   }
972 
973   // Modifies the pattern to match only if the shape equals the given proto.
974   // The layout must outlive the returned pattern.
975   constexpr auto EqualTo(const ::xla::Shape* shape) const {
976     return AppendImpl(ShapePatternEqualImpl(shape));
977   }
978 
979   // Modifies the pattern to match only if the shape is compatible to the given
980   // proto. The layout must outlive the returned pattern.
981   constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
982     return AppendImpl(ShapePatternCompatibleImpl(shape));
983   }
984 
985   // Modifies the pattern to match only if the shape has the given element type.
986   constexpr auto WithElementType(PrimitiveType element_type) const {
987     return AppendImpl(ShapePatternElementTypeImpl(element_type));
988   }
989 
990   // Modifies the pattern to match only if the shape is scalar.
991   constexpr auto IsScalar() const {
992     return AppendImpl(ShapePatternIsScalarImpl());
993   }
994 
995   // Modifies the pattern to match only if the shape is an array.
996   constexpr auto IsArray() const {
997     return AppendImpl(ShapePatternIsArrayImpl());
998   }
999 
1000   // Modifies the pattern to match only if the shape is a tuple.
1001   constexpr auto IsTuple() const {
1002     return AppendImpl(ShapePatternIsTupleImpl());
1003   }
1004 
1005   constexpr auto IsEffectiveScalar() const {
1006     return AppendImpl(ShapePatternEffectiveScalarImpl());
1007   }
1008 
1009   // Modifies the pattern to match only if the shape has the given rank.
1010   constexpr auto WithRank(int64 rank) const {
1011     return AppendImpl(ShapePatternRankImpl(rank));
1012   }
1013 
1014   // Modifies the pattern to match only if the shape has a layout that matches
1015   // the given pattern.
1016   template <typename LayoutType, typename LayoutImpl>
1017   auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1018     return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1019   }
1020 
1021   constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1022     return WithLayout(Layout().EqualTo(layout));
1023   }
1024 
1025   constexpr auto IsDenseArray() const {
1026     return WithLayout(Layout().WithDenseFormat());
1027   }
1028 
1029   // Modifies the pattern to match only if the shape has a subshape that matches
1030   // the given pattern.
1031   template <typename SubshapeType, typename SubshapeImpl>
1032   auto WithSubshape(
1033       ShapeIndexView index,
1034       const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1035     return AppendImpl(
1036         ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1037   }
1038 
1039   ShapePattern<ShapeType,
1040                AllOfPattern<::xla::Shape, Impl,
1041                             ShapePatternSubshapeImpl<
1042                                 const ::xla::Shape,
1043                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1044                                              ShapePatternEqualImpl>>>>
1045   WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1046     return WithSubshape(index,
1047                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1048                             ShapePatternBaseImpl(), nullptr)
1049                             .EqualTo(shape));
1050   }
1051 
1052   ShapePattern<ShapeType,
1053                AllOfPattern<::xla::Shape, Impl,
1054                             ShapePatternSubshapeImpl<
1055                                 const ::xla::Shape,
1056                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1057                                              ShapePatternCompatibleImpl>>>>
1058   WithSubshapeCompatibleTo(ShapeIndexView index,
1059                            const ::xla::Shape* shape) const {
1060     return WithSubshape(index,
1061                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1062                             ShapePatternBaseImpl(), nullptr)
1063                             .CompatibleTo(shape));
1064   }
1065 
1066  private:
1067   Impl impl_;
1068   ShapeType** matched_shape_;
1069 };
1070 
1071 }  // namespace detail
1072 
1073 // Creates a shape pattern that will capture the matched layout in the argument.
1074 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1075   return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1076       detail::ShapePatternBaseImpl(), matched_shape);
1077 }
1078 
1079 // Creates a shape pattern that will capture the matched layout in the argument.
1080 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1081   return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1082       detail::ShapePatternBaseImpl(), matched_shape);
1083 }
1084 
1085 namespace detail {
1086 
1087 // Overloads to get a const or non-const operand out of an instruction.
1088 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) {
1089   return instr->mutable_operand(idx);
1090 }
1091 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1092                                         int64 idx) {
1093   return instr->operand(idx);
1094 }
1095 
1096 // Pretty-printer for HloInstruction.  Sort of like ToShortString, but with
1097 // fewer %s and more shapes.
1098 inline string InstToString(const HloInstruction* inst) {
1099   return inst->ToString(
1100       HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1101 }
1102 
1103 template <typename HloInstructionType, typename Impl>
1104 class HloInstructionPattern;
1105 
1106 // The base HloInstructionPattern implementation. Matches only if the
1107 // instruction is not nullptr.
1108 class HloInstructionPatternBaseImpl {
1109  public:
1110   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1111     if (inst == nullptr) {
1112       EXPLAIN << "HloInstruction* is null";
1113       return false;
1114     }
1115     return true;
1116   }
1117 
1118   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1119     *os << "an HloInstruction";
1120   }
1121 
1122   static constexpr bool kIsTrivialMatcher = true;
1123 };
1124 
1125 // An HloInstructionPattern implementation that matches only if the instruction
1126 // has a given name.
1127 class HloInstructionPatternNameImpl {
1128  public:
1129   explicit HloInstructionPatternNameImpl(absl::string_view name)
1130       : name_(name) {}
1131 
1132   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1133     if (inst->name() != name_) {
1134       EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1135       return false;
1136     }
1137     return true;
1138   }
1139 
1140   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1141     *os << "named \"" << name_ << "\"";
1142   }
1143 
1144  private:
1145   absl::string_view name_;
1146 };
1147 
1148 // An HloInstructionPattern implementation that matches only if the instruction
1149 // equals a particular pointer.
1150 class HloInstructionIsImpl {
1151  public:
1152   explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1153 
1154   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1155     if (inst != inst_) {
1156       EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1157               << std::showbase << reinterpret_cast<uint64>(inst) << " is not "
1158               << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1159               << ")";
1160       return false;
1161     }
1162     return true;
1163   }
1164 
1165   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1166     *os << "which is " << std::hex << std::nouppercase << std::showbase
1167         << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1168         << ")";
1169   }
1170 
1171  private:
1172   const HloInstruction* inst_;
1173 };
1174 
1175 // An HloInstructionPattern implementation that matches only if the instruction
1176 // has a given opcode.
1177 class HloInstructionPatternOpcodeImpl {
1178  public:
1179   explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1180                                                      bool invert)
1181       : opcode_(opcode), invert_(invert) {}
1182 
1183   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1184     if (invert_ && inst->opcode() == opcode_) {
1185       EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1186               << ", expected anything else";
1187       return false;
1188     }
1189     if (!invert_ && inst->opcode() != opcode_) {
1190       EXPLAIN << "HloInstruction doesn't have opcode "
1191               << HloOpcodeString(opcode_);
1192       return false;
1193     }
1194     return true;
1195   }
1196 
1197   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1198     if (!invert_) {
1199       *os << "with opcode " << HloOpcodeString(opcode_);
1200     } else {
1201       *os << "with any opcode other than " << HloOpcodeString(opcode_);
1202     }
1203   }
1204 
1205  private:
1206   HloOpcode opcode_;
1207   bool invert_;
1208 };
1209 
1210 // An HloInstructionPattern implementation that matches only if the instruction
1211 // has a given custom call target.
1212 class HloInstructionCustomCallTargetImpl {
1213  public:
1214   explicit HloInstructionCustomCallTargetImpl(
1215       absl::string_view custom_call_target)
1216       : custom_call_target_(custom_call_target) {}
1217 
1218   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1219     if (inst->opcode() != HloOpcode::kCustomCall ||
1220         inst->custom_call_target() != custom_call_target_) {
1221       EXPLAIN << "HloInstruction is not a custom call with a target '"
1222               << custom_call_target_ << "'";
1223       return false;
1224     }
1225     return true;
1226   }
1227 
1228   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1229     *os << "custom call with target '" << custom_call_target_ << "'";
1230   }
1231 
1232  private:
1233   std::string custom_call_target_;
1234 };
1235 
1236 // An HloInstructionPattern implementation that matches only if the instruction
1237 // has the given number of operands.
1238 class HloInstructionPatternNumOperandsImpl {
1239  public:
1240   explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands)
1241       : num_operands_(num_operands) {}
1242 
1243   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1244     if (inst->operand_count() != num_operands_) {
1245       EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1246       return false;
1247     }
1248     return true;
1249   }
1250 
1251   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1252     *os << "with " << num_operands_ << " operand"
1253         << (num_operands_ != 1 ? "s" : "");
1254   }
1255 
1256  private:
1257   int64 num_operands_;
1258 };
1259 
1260 // An HloInstructionPattern implementation that matches only if the instruction
1261 // has a shape that matches a given pattern.
1262 template <typename ShapeType, typename ShapeImpl>
1263 class HloInstructionPatternShapeImpl {
1264  public:
1265   explicit constexpr HloInstructionPatternShapeImpl(
1266       const ShapePattern<ShapeType, ShapeImpl>& shape)
1267       : shape_(shape) {}
1268 
1269   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1270     if (!shape_.Match(&inst->shape(), option)) {
1271       EXPLAIN << "\nin output shape";
1272       return false;
1273     }
1274     return true;
1275   }
1276 
1277   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1278     if (!shape_.Match(inst->mutable_shape(), option)) {
1279       EXPLAIN << "\nin output shape";
1280       return false;
1281     }
1282     return true;
1283   }
1284 
1285   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1286     *os << "outputting";
1287     Indent(os, indent + kIndentInc);
1288     shape_.DescribeTo(os, indent + kIndentInc);
1289   }
1290 
1291  private:
1292   ShapePattern<ShapeType, ShapeImpl> shape_;
1293 };
1294 
1295 // An HloInstructionPattern implementation that matches only if the instruction
1296 // has an operand that matches a given pattern.
1297 template <typename OperandType, typename OperandImpl>
1298 class HloInstructionPatternOperandImpl {
1299  public:
1300   explicit constexpr HloInstructionPatternOperandImpl(
1301       int64 operand_index,
1302       const HloInstructionPattern<OperandType, OperandImpl>& operand)
1303       : operand_index_(operand_index), operand_(operand) {}
1304 
1305   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1306     return MatchImpl(inst, option);
1307   }
1308 
1309   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1310     return MatchImpl(inst, option);
1311   }
1312 
1313   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1314     *os << "with operand " << operand_index_ << " which is:";
1315     Indent(os, indent + kIndentInc);
1316     operand_.DescribeTo(os, indent + kIndentInc);
1317   }
1318 
1319  private:
1320   template <typename HloInstructionType>
1321   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1322     if (operand_index_ >= inst->operand_count()) {
1323       EXPLAIN << "desired operand index " << operand_index_
1324               << " is out of bounds";
1325       return false;
1326     }
1327     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1328       EXPLAIN << "\nin operand " << operand_index_;
1329       return false;
1330     }
1331     return true;
1332   }
1333 
1334   int64 operand_index_;
1335   HloInstructionPattern<OperandType, OperandImpl> operand_;
1336 };
1337 
1338 // Matches a binary instruction whose operands come in any order.
1339 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1340           typename OperandImpl2>
1341 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1342  public:
1343   explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1344       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1345       const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1346       : op1_(op1), op2_(op2) {}
1347 
1348   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1349     return MatchImpl(inst, option);
1350   }
1351 
1352   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1353     return MatchImpl(inst, option);
1354   }
1355 
1356   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1357     *os << "with two operands in either order:";
1358     Indent(os, indent);
1359     *os << " - ";
1360     op1_.DescribeTo(os, indent + 3);
1361     Indent(os, indent);
1362     *os << " - ";
1363     op2_.DescribeTo(os, indent + 3);
1364   }
1365 
1366  private:
1367   HloInstruction* operand(HloInstruction* inst, int64 idx) const {
1368     return inst->mutable_operand(idx);
1369   }
1370   const HloInstruction* operand(const HloInstruction* inst, int64 idx) const {
1371     return inst->operand(idx);
1372   }
1373 
1374   template <typename HloInstructionType>
1375   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1376     // We could implement this using AnyOf and AllOf matchers, but the templates
1377     // get pretty difficult to debug, since any compile error herein becomes
1378     // not-an-error via SFINAE.  Also this way lets us give better messages on
1379     // failure.
1380     if (inst->operand_count() != 2) {
1381       EXPLAIN << "HloInstruction did not have two operands";
1382       return false;
1383     }
1384 
1385     // If we're not generating explanations, this is pretty simple.
1386     if (!option.explain_os) {
1387       auto try_match = [&](int64 idx1, int64 idx2) {
1388         MatchOption new_option = option;
1389         new_option.capture = false;
1390         if (op1_.Match(operand(inst, idx1), new_option) &&
1391             op2_.Match(operand(inst, idx2), new_option)) {
1392           if (option.capture) {
1393             bool matched = op1_.Match(operand(inst, idx1), option) &&
1394                            op2_.Match(operand(inst, idx2), option);
1395             DCHECK(matched);
1396           }
1397           return true;
1398         }
1399         return false;
1400       };
1401       return try_match(0, 1) || try_match(1, 0);
1402     }
1403 
1404     // If we are generating explanations, we have some work to do in order to
1405     // generate a helpful error.
1406     //
1407     // First, try all four operand/matcher combinations, recording the
1408     // failure explanations separately from option.explain_os. matches[i][j]
1409     // tells us if matcher_i matches operand j.
1410     bool matches[/*matcher*/ 2][/*operand*/ 2];
1411     std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1412     for (int i = 0; i < 2; ++i) {
1413       for (int j = 0; j < 2; ++j) {
1414         MatchOption new_option = option;
1415         new_option.capture = false;
1416         new_option.explain_os = &explanations[i][j];
1417         matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1418                                : op2_.Match(operand(inst, j), new_option);
1419       }
1420     }
1421 
1422     // Check if the match succeeded.
1423     for (int i = 0; i < 2; ++i) {
1424       if (matches[0][i] && matches[1][(i + 1) % 2]) {
1425         // Rerun the matches with capture enabled if necessary.
1426         if (option.capture) {
1427           auto* operand1 = operand(inst, i);
1428           auto* operand2 = operand(inst, (i + 1) % 2);
1429           bool matched =
1430               op1_.Match(operand1, option) && op2_.Match(operand2, option);
1431           DCHECK(matched);
1432         }
1433         return true;
1434       }
1435     }
1436 
1437     auto describe_matcher = [&](int matcher_idx) {
1438       EXPLAIN << "\n - ";
1439       if (matcher_idx == 0) {
1440         op1_.DescribeTo(option.explain_os, /*indent=*/3);
1441       } else {
1442         CHECK_EQ(matcher_idx, 1);
1443         op2_.DescribeTo(option.explain_os, /*indent=*/3);
1444       }
1445       for (int i = 0; i < 2; ++i) {
1446         if (matches[matcher_idx][/*operand*/ i]) {
1447           continue;
1448         }
1449         EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1450         EXPLAIN << " - ";
1451         EXPLAIN << absl::StrReplaceAll(
1452             explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n   "}});
1453       }
1454     };
1455 
1456     // If we failed to match, one of the following is true:
1457     //  1. op1 (op2) matches neither LHS nor RHS, or
1458     //  2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1459     // We print different explanations depending on which case we're in.
1460 
1461     // Case 1.
1462     bool wrote_explanation = false;
1463     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1464       if (!matches[i][0] && !matches[i][1]) {
1465         EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1466                 << (i == 0 ? "first" : "second") << " matcher.  Specifically,";
1467         describe_matcher(i);
1468         wrote_explanation = true;
1469       }
1470     }
1471 
1472     // Case 2.
1473     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1474       if (matches[/*matcher*/ 0][/*operand*/ i] &&
1475           matches[/*matcher*/ 1][/*operand*/ i]) {
1476         CHECK(!matches[0][(i + 1) % 2]);
1477         CHECK(!matches[1][(i + 1) % 2]);
1478         CHECK(!wrote_explanation);
1479         EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1480                 << " operand did not match either of the two matchers.  "
1481                    "Specifically,";
1482         describe_matcher(0);
1483         EXPLAIN << "\nand";
1484         describe_matcher(1);
1485         wrote_explanation = true;
1486       }
1487     }
1488 
1489     CHECK(wrote_explanation);
1490     return false;
1491   }
1492 
1493   HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1494   HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1495 };
1496 
1497 // An HloInstructionPattern implementation that matches only if the instruction
1498 // is a fusion node with a particular kind.
1499 class HloInstructionPatternFusionKindImpl {
1500  public:
1501   explicit constexpr HloInstructionPatternFusionKindImpl(
1502       ::xla::HloInstruction::FusionKind kind)
1503       : kind_(kind) {}
1504 
1505   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1506     return MatchImpl(inst, option);
1507   }
1508 
1509   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1510     return MatchImpl(inst, option);
1511   }
1512 
1513   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1514     *os << "with fusion kind " << ToString(kind_);
1515   }
1516 
1517  private:
1518   template <typename HloInstructionType>
1519   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1520     if (inst->opcode() != HloOpcode::kFusion) {
1521       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1522               << "; it's not a fusion";
1523       return false;
1524     }
1525     if (inst->fusion_kind() != kind_) {
1526       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1527       return false;
1528     }
1529     return true;
1530   }
1531 
1532   ::xla::HloInstruction::FusionKind kind_;
1533 };
1534 
1535 // An HloInstructionPattern implementation that matches only if the instruction
1536 // is a kGetTupleElement with a particular tuple index.
1537 class HloInstructionPatternTupleIndexImpl {
1538  public:
1539   explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
1540       : tuple_index_(tuple_index) {}
1541 
1542   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1543     return MatchImpl(inst, option);
1544   }
1545 
1546   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1547     return MatchImpl(inst, option);
1548   }
1549 
1550   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1551     *os << "which is a GTE with index " << tuple_index_;
1552   }
1553 
1554  private:
1555   template <typename HloInstructionType>
1556   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1557     if (inst->opcode() != HloOpcode::kGetTupleElement) {
1558       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1559               << "; it's not a GTE at all";
1560       return false;
1561     }
1562     if (inst->tuple_index() != tuple_index_) {
1563       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1564       return false;
1565     }
1566     return true;
1567   }
1568 
1569   int64 tuple_index_;
1570 };
1571 
1572 class HloInstructionPatternParameterNumImpl {
1573  public:
1574   explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num)
1575       : parameter_num_(parameter_num) {}
1576 
1577   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1578     return MatchImpl(inst, option);
1579   }
1580 
1581   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1582     return MatchImpl(inst, option);
1583   }
1584 
1585   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1586     *os << "which is parameter " << parameter_num_;
1587   }
1588 
1589  private:
1590   template <typename HloInstructionType>
1591   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1592     if (inst->opcode() != HloOpcode::kParameter ||
1593         inst->parameter_number() != parameter_num_) {
1594       EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1595       return false;
1596     }
1597     return true;
1598   }
1599 
1600   int64 parameter_num_;
1601 };
1602 
1603 // Superclass that contains common code used by Op::WithOneUse() and
1604 // Op::WithOneUser().
1605 class HloInstructionPatternOneUseOrUserImpl {
1606  protected:
1607   bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1608     if (inst->user_count() != 1) {
1609       EXPLAIN << "HloInstruction has " << inst->user_count()
1610               << " users, but expected exactly one.";
1611       if (inst->user_count() > 1) {
1612         EXPLAIN << "\nAll users:";
1613         for (const HloInstruction* user : inst->users()) {
1614           EXPLAIN << "\n - " << InstToString(user);
1615         }
1616       }
1617       return false;
1618     }
1619     return true;
1620   }
1621 };
1622 
1623 class HloInstructionPatternOneUseImpl
1624     : public HloInstructionPatternOneUseOrUserImpl {
1625  public:
1626   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1627     if (!MatchOneUser(inst, option)) {
1628       return false;
1629     }
1630 
1631     int64 use_count = absl::c_count_if(
1632         inst->users()[0]->operands(),
1633         [&](const HloInstruction* operand) { return operand == inst; });
1634     if (use_count != 1) {
1635       EXPLAIN << "HloInstruction is used " << use_count
1636               << " times by its user, but is expected to be used just once: "
1637               << InstToString(inst->users()[0]);
1638       return false;
1639     }
1640     return true;
1641   }
1642 
1643   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1644     *os << "which has exactly one use";
1645   }
1646 };
1647 
1648 class HloInstructionPatternOneUserImpl
1649     : public HloInstructionPatternOneUseOrUserImpl {
1650  public:
1651   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1652     return MatchOneUser(inst, option);
1653   }
1654 
1655   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1656     *os << "which has exactly one user (but possibly is used multiple times by "
1657            "that instruction)";
1658   }
1659 };
1660 
1661 class HloInstructionPatternComparisonDirectionImpl {
1662  public:
1663   explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1664       ComparisonDirection direction)
1665       : direction_(direction) {}
1666 
1667   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1668     return MatchImpl(inst, option);
1669   }
1670 
1671   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1672     return MatchImpl(inst, option);
1673   }
1674 
1675   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1676     *os << "which has comparison direction "
1677         << ComparisonDirectionToString(direction_);
1678   }
1679 
1680  private:
1681   template <typename HloInstructionType>
1682   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1683     if (inst->opcode() != HloOpcode::kCompare ||
1684         inst->comparison_direction() != direction_) {
1685       EXPLAIN << "HloInstruction is not comparison "
1686               << ComparisonDirectionToString(direction_);
1687       return false;
1688     }
1689     return true;
1690   }
1691 
1692   ComparisonDirection direction_;
1693 };
1694 
1695 // Matches a constant scalar or effective scalar, optionally with a given value.
1696 template <typename ScalarTy>
1697 class HloConstantScalarImpl {
1698  public:
1699   explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1700       : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1701 
1702   constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1703       : val_(val), match_effective_scalar_(match_effective_scalar) {}
1704 
1705   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1706     return MatchImpl(inst, option);
1707   }
1708 
1709   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1710     return MatchImpl(inst, option);
1711   }
1712 
1713   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1714     *os << "which is a constant "
1715         << (match_effective_scalar_ ? "effective " : "") << "scalar";
1716     if (val_.has_value()) {
1717       *os << " with value " << *val_;
1718     }
1719   }
1720 
1721  private:
1722   template <typename InstTy>
1723   bool MatchImpl(InstTy* inst, MatchOption option) const {
1724     const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1725     if (!const_inst) {
1726       EXPLAIN << "HloInstruction is not a constant";
1727       return false;
1728     }
1729     if (match_effective_scalar_ &&
1730         !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1731       EXPLAIN << "HloInstruction is not an effective scalar";
1732       return false;
1733     }
1734     if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1735       EXPLAIN << "HloInstruction is not a scalar";
1736       return false;
1737     }
1738     if (!val_.has_value()) {
1739       return true;
1740     }
1741 
1742     auto const_inst_scalar_or = const_inst->literal().Reshape({});
1743     if (!const_inst_scalar_or.ok()) {
1744       EXPLAIN << "could not convert matched literal to effective scalar";
1745       return false;
1746     }
1747     Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1748     if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1749       EXPLAIN << "HloInstruction's constant value "
1750               << const_inst_scalar.ToStringWithoutShape()
1751               << " did not match expected value " << *val_;
1752       return false;
1753     }
1754     return true;
1755   }
1756 
1757   absl::optional<ScalarTy> val_;
1758   bool match_effective_scalar_;
1759 };
1760 
1761 // A pattern that matches HloInstructions.
1762 template <typename HloInstructionType, typename Impl>
1763 class HloInstructionPattern {
1764  private:
1765   template <typename NewImpl>
1766   auto AppendImpl(NewImpl new_impl) const {
1767     auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1768     return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1769         std::move(new_allof), matched_inst_);
1770   }
1771 
1772  public:
1773   explicit constexpr HloInstructionPattern(const Impl& impl,
1774                                            HloInstructionType** matched_inst)
1775       : impl_(impl), matched_inst_(matched_inst) {}
1776 
1777   // Returns true and captures the instruction iff it matches the pattern.
1778   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1779     if (impl_.Match(inst, option)) {
1780       if (option.capture && matched_inst_) {
1781         *matched_inst_ = inst;
1782       }
1783       return true;
1784     }
1785     if (inst != nullptr) {
1786       EXPLAIN << "\nin " << InstToString(inst);
1787     }
1788     return false;
1789   }
1790 
1791   // Returns true and captures the instruction iff it matches the pattern.
1792   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1793     if (impl_.Match(inst, option)) {
1794       if (option.capture && matched_inst_) {
1795         *matched_inst_ = inst;
1796       }
1797       return true;
1798     }
1799     EXPLAIN << "\nin " << InstToString(inst);
1800     return false;
1801   }
1802 
1803   // Modifies the pattern to match only if the instruction has the given name.
1804   auto WithName(absl::string_view name) const {
1805     return AppendImpl(HloInstructionPatternNameImpl(name));
1806   }
1807 
1808   // Modifies the pattern to match only if the instruction has the given opcode.
1809   auto WithOpcode(HloOpcode opcode) const {
1810     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1811   }
1812 
1813   // Modifies the pattern to match only the custom call with a given target.
1814   auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1815     return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target));
1816   }
1817 
1818   auto WithNumOperands(int64 num_operands) const {
1819     return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1820   }
1821 
1822   // Modifies the pattern to match only if the instruction does not have the
1823   // given opcode.
1824   auto WithoutOpcode(HloOpcode opcode) const {
1825     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1826   }
1827 
1828   constexpr auto Is(const HloInstruction* instr) const {
1829     return AppendImpl(HloInstructionIsImpl(instr));
1830   }
1831 
1832   // Modifies the pattern to match only if the instruction is a constant.
1833   constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1834 
1835   constexpr auto IsConstantScalar() const {
1836     return AppendImpl(
1837         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1838   }
1839 
1840   // This does not check that T has the same type as the instruction, so e.g.
1841   // IsConstantScalar(1.0) may match a constant of shape int32[].
1842   template <typename ScalarTy>
1843   constexpr auto IsConstantScalar(const ScalarTy& val) const {
1844     return AppendImpl(
1845         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1846   }
1847 
1848   constexpr auto IsConstantEffectiveScalar() const {
1849     return AppendImpl(
1850         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1851   }
1852 
1853   template <typename ScalarTy>
1854   constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1855     return AppendImpl(
1856         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1857   }
1858 
1859   // Modifies the pattern to match only if the instruction is not a constant.
1860   constexpr auto IsNonConstant() const {
1861     return WithoutOpcode(HloOpcode::kConstant);
1862   }
1863 
1864   // Modifies the pattern to match only if the instruction has a shape that
1865   // matches the given pattern.
1866   template <typename ShapeType, typename ShapeImpl>
1867   constexpr auto WithShape(
1868       const ShapePattern<ShapeType, ShapeImpl>& shape) const {
1869     return AppendImpl(
1870         HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1871   }
1872 
1873   // Make this a templated function to work around gcc 4.9.4 template infinite
1874   // recursion bug.
1875   template <typename Dummy = void>
1876   constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
1877     return WithShape(Shape().EqualTo(shape));
1878   }
1879 
1880   // Make this a templated function to work around gcc 4.9.4 template infinite
1881   // recursion bug.
1882   template <typename Dummy = void>
1883   constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
1884     return WithShape(Shape().CompatibleTo(shape));
1885   }
1886 
1887   // Modifies the pattern to match only if the instruction has an operand that
1888   // matches the given pattern.
1889   template <typename OperandType, typename OperandImpl>
1890   constexpr auto WithOperand(
1891       int64 operand_index,
1892       const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
1893     return AppendImpl(
1894         HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1895             operand_index, operand));
1896   }
1897 
1898   template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1899             typename OperandImpl2>
1900   constexpr auto WithBinaryOperandsAnyOrder(
1901       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1902       const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
1903     return AppendImpl(
1904         HloInstructionPatternBinaryOperandsAnyOrderImpl<
1905             OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1906   }
1907 
1908   // Modifies the pattern to match only if the instruction is a fusion node with
1909   // the given kind.
1910   constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
1911     return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1912   }
1913 
1914   // Modifies the pattern to match only if the instruction is a
1915   // get-tuple-element with the given tuple index.
1916   constexpr auto WithTupleIndex(int64 tuple_index) const {
1917     return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1918   }
1919 
1920   // Modifies the pattern to match only if the instruction is a parameter
1921   // with the given parameter number.
1922   constexpr auto WithParameterNum(int64 parameter_num) const {
1923     return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1924   }
1925 
1926   // Modifies the pattern to match if the instruction is used exactly once.
1927   // Does not match if the instruction is used twice by the same user (e.g.
1928   // multiply(x,x)).
1929   constexpr auto WithOneUse() const {
1930     return AppendImpl(HloInstructionPatternOneUseImpl());
1931   }
1932 
1933   // Modifies the pattern to match if the instruction is used by exactly one
1934   // other instruction.  Will match if the instruction is used twice, so long as
1935   // it's by the same user (e.g.  multiply(x,x)).
1936   constexpr auto WithOneUser() const {
1937     return AppendImpl(HloInstructionPatternOneUserImpl());
1938   }
1939 
1940   // Modifies the pattern to match only if the instruction has the given
1941   // comparison direction.
1942   auto WithComparisonDirection(ComparisonDirection direction) const {
1943     return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
1944   }
1945 
1946   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1947     impl_.DescribeTo(os, indent);
1948   }
1949 
1950  private:
1951   Impl impl_;
1952   HloInstructionType** matched_inst_;
1953 };
1954 
1955 }  // namespace detail
1956 
1957 // Creates an instruction pattern that will capture the matched instruction in
1958 // the argument.
1959 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
1960   return detail::HloInstructionPattern<const ::xla::HloInstruction,
1961                                        detail::HloInstructionPatternBaseImpl>(
1962       detail::HloInstructionPatternBaseImpl(), matched_inst);
1963 }
1964 
1965 // Creates an instruction pattern that will capture the matched instruction in
1966 // the argument.
1967 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
1968   return detail::HloInstructionPattern<::xla::HloInstruction,
1969                                        detail::HloInstructionPatternBaseImpl>(
1970       detail::HloInstructionPatternBaseImpl(), matched_inst);
1971 }
1972 
1973 // Helpers for nullary instructions.
1974 #define XLA_NULLOP_PATTERN(NAME)                                     \
1975   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
1976                                                                      \
1977   template <typename HloInstructionType>                             \
1978   inline auto NAME(HloInstructionType** matched_inst) {              \
1979     return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);          \
1980   }
1981 XLA_NULLOP_PATTERN(Constant)
1982 XLA_NULLOP_PATTERN(Parameter)
1983 XLA_NULLOP_PATTERN(Iota)
1984 XLA_NULLOP_PATTERN(Rng)
1985 #undef XLA_NULLOP_PATTERN
1986 
1987 // Helpers for unary instructions.
1988 #define XLA_UNOP_PATTERN(NAME)                                       \
1989   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
1990                                                                      \
1991   template <typename Arg>                                            \
1992   inline auto NAME(Arg&& arg) {                                      \
1993     return Op()                                                      \
1994         .WithOpcode(HloOpcode::k##NAME)                              \
1995         .WithOperand(0, std::forward<Arg>(arg));                     \
1996   }                                                                  \
1997                                                                      \
1998   template <typename HloInstructionType, typename Arg>               \
1999   inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) {   \
2000     return Op(matched_inst)                                          \
2001         .WithOpcode(HloOpcode::k##NAME)                              \
2002         .WithOperand(0, std::forward<Arg>(arg));                     \
2003   }
2004 XLA_UNOP_PATTERN(Abs)
2005 XLA_UNOP_PATTERN(RoundNearestAfz)
2006 XLA_UNOP_PATTERN(Bitcast)
2007 XLA_UNOP_PATTERN(BitcastConvert)
2008 XLA_UNOP_PATTERN(Broadcast)
2009 XLA_UNOP_PATTERN(Ceil)
2010 XLA_UNOP_PATTERN(Convert)
2011 XLA_UNOP_PATTERN(Copy)
2012 XLA_UNOP_PATTERN(Cos)
2013 XLA_UNOP_PATTERN(AllReduce)
2014 XLA_UNOP_PATTERN(Exp)
2015 XLA_UNOP_PATTERN(Fft)
2016 XLA_UNOP_PATTERN(Floor)
2017 XLA_UNOP_PATTERN(GetTupleElement)
2018 XLA_UNOP_PATTERN(Imag)
2019 XLA_UNOP_PATTERN(Infeed)
2020 XLA_UNOP_PATTERN(IsFinite)
2021 XLA_UNOP_PATTERN(Log)
2022 XLA_UNOP_PATTERN(Not)
2023 XLA_UNOP_PATTERN(Negate)
2024 XLA_UNOP_PATTERN(Real)
2025 XLA_UNOP_PATTERN(Recv)
2026 XLA_UNOP_PATTERN(RecvDone)
2027 XLA_UNOP_PATTERN(ReducePrecision)
2028 XLA_UNOP_PATTERN(Reshape)
2029 XLA_UNOP_PATTERN(Reverse)
2030 XLA_UNOP_PATTERN(Rsqrt)
2031 XLA_UNOP_PATTERN(SendDone)
2032 XLA_UNOP_PATTERN(Sign)
2033 XLA_UNOP_PATTERN(Sin)
2034 XLA_UNOP_PATTERN(Slice)
2035 XLA_UNOP_PATTERN(Sqrt)
2036 XLA_UNOP_PATTERN(Tanh)
2037 XLA_UNOP_PATTERN(Transpose)
2038 #undef XLA_UNOP_PATTERN
2039 
2040 // Helpers for binary instructions.
2041 #define XLA_BINOP_PATTERN(NAME)                                               \
2042   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2043                                                                               \
2044   template <typename Lhs, typename Rhs>                                       \
2045   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2046     return Op()                                                               \
2047         .WithOpcode(HloOpcode::k##NAME)                                       \
2048         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2049         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2050   }                                                                           \
2051                                                                               \
2052   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2053   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2054     return Op(matched_inst)                                                   \
2055         .WithOpcode(HloOpcode::k##NAME)                                       \
2056         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2057         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2058   }
2059 
2060 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                \
2061   XLA_BINOP_PATTERN(NAME)                                                  \
2062                                                                            \
2063   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2064   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2065                              Rhs&& rhs) {                                  \
2066     return Op(matched_inst)                                                \
2067         .WithOpcode(HloOpcode::k##NAME)                                    \
2068         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2069                                     std::forward<Rhs>(rhs));               \
2070   }                                                                        \
2071   template <typename Lhs, typename Rhs>                                    \
2072   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2073     return NAME##AnyOrder<const HloInstruction>(                           \
2074         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2075   }
2076 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2077 XLA_BINOP_PATTERN(Atan2)
2078 XLA_BINOP_PATTERN(Divide)
2079 XLA_BINOP_PATTERN(Complex)
2080 XLA_BINOP_PATTERN(Compare)
2081 XLA_BINOP_PATTERN(Convolution)
2082 XLA_BINOP_PATTERN(Dot)
2083 XLA_BINOP_PATTERN(Gather)
2084 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2085 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2086 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2087 XLA_BINOP_PATTERN(Outfeed)
2088 XLA_BINOP_PATTERN(Pad)
2089 XLA_BINOP_PATTERN(Power)
2090 XLA_BINOP_PATTERN(ReduceWindow)
2091 XLA_BINOP_PATTERN(Remainder)
2092 XLA_BINOP_PATTERN(Send)
2093 XLA_BINOP_PATTERN(Subtract)
2094 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2095 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2096 XLA_BINOP_PATTERN(ShiftLeft)
2097 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2098 XLA_BINOP_PATTERN(ShiftRightLogical)
2099 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2100 #undef XLA_BINOP_PATTERN
2101 
2102 // Helpers for ternary instructions.
2103 #define XLA_TERNOP_PATTERN(NAME)                                       \
2104   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }   \
2105                                                                        \
2106   template <typename Arg0, typename Arg1, typename Arg2>               \
2107   inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {            \
2108     return Op()                                                        \
2109         .WithOpcode(HloOpcode::k##NAME)                                \
2110         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2111         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2112         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2113   }                                                                    \
2114                                                                        \
2115   template <typename HloInstructionType, typename Arg0, typename Arg1, \
2116             typename Arg2>                                             \
2117   inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
2118                    Arg1&& arg1, Arg2&& arg2) {                         \
2119     return Op(matched_inst)                                            \
2120         .WithOpcode(HloOpcode::k##NAME)                                \
2121         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2122         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2123         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2124   }
2125 XLA_TERNOP_PATTERN(Clamp);
2126 XLA_TERNOP_PATTERN(Scatter);
2127 XLA_TERNOP_PATTERN(Select);
2128 XLA_TERNOP_PATTERN(SelectAndScatter);
2129 #undef XLA_TERNOP_PATTERN
2130 
2131 namespace detail {
2132 template <typename Matcher, typename FirstArg>
2133 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) {
2134   return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2135 }
2136 
2137 template <typename Matcher, typename FirstArg, typename... Args>
2138 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
2139                          Args&&... args) {
2140   return WithOperands(
2141       m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2142       operand_num + 1, std::forward<Args>(args)...);
2143 }
2144 }  // namespace detail
2145 
2146 #define XLA_VARIADIC_OP_PATTERN(NAME)                                         \
2147   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2148                                                                               \
2149   template <typename... Args>                                                 \
2150   inline auto NAME(Args&&... args) {                                          \
2151     return detail::WithOperands(                                              \
2152         Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2153         /*operand_num=*/0, std::forward<Args>(args)...);                      \
2154   }                                                                           \
2155                                                                               \
2156   template <typename HloInstructionType, typename... Args>                    \
2157   inline auto NAME(HloInstructionType** matched_inst, Args&&... args) {       \
2158     return detail::WithOperands(Op(matched_inst)                              \
2159                                     .WithOpcode(HloOpcode::k##NAME)           \
2160                                     .WithNumOperands(sizeof...(Args)),        \
2161                                 /*operand_num=*/0,                            \
2162                                 std::forward<Args>(args)...);                 \
2163   }
2164 
2165 // We could implement all ops as "variadic" ops, but it would make the
2166 // already-bad compile errors even worse.
2167 XLA_VARIADIC_OP_PATTERN(AfterAll);
2168 XLA_VARIADIC_OP_PATTERN(Concatenate);
2169 XLA_VARIADIC_OP_PATTERN(Conditional);
2170 XLA_VARIADIC_OP_PATTERN(CustomCall);
2171 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2172 XLA_VARIADIC_OP_PATTERN(Fusion);
2173 XLA_VARIADIC_OP_PATTERN(Map)
2174 XLA_VARIADIC_OP_PATTERN(Reduce);
2175 XLA_VARIADIC_OP_PATTERN(Sort);
2176 XLA_VARIADIC_OP_PATTERN(Tuple);
2177 
2178 // Helpers for comparison instructions.
2179 #define XLA_COMPARE_PATTERN(NAME)                                             \
2180   inline auto NAME() {                                                        \
2181     return Op()                                                               \
2182         .WithOpcode(HloOpcode::kCompare)                                      \
2183         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2184   }                                                                           \
2185                                                                               \
2186   template <typename Lhs, typename Rhs>                                       \
2187   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2188     return Op()                                                               \
2189         .WithOpcode(HloOpcode::kCompare)                                      \
2190         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2191         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2192         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2193   }                                                                           \
2194                                                                               \
2195   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2196   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2197     return Op(matched_inst)                                                   \
2198         .WithOpcode(HloOpcode::kCompare)                                      \
2199         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2200         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2201         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2202   }
2203 
2204 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME)                              \
2205   XLA_COMPARE_PATTERN(NAME)                                                \
2206                                                                            \
2207   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2208   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2209                              Rhs&& rhs) {                                  \
2210     return Op(matched_inst)                                                \
2211         .WithOpcode(HloOpcode::kCompare)                                   \
2212         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2213                                     std::forward<Rhs>(rhs));               \
2214   }                                                                        \
2215   template <typename Lhs, typename Rhs>                                    \
2216   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2217     return NAME##AnyOrder<const HloInstruction>(                           \
2218         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2219   }
2220 
2221 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2222 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2223 XLA_COMPARE_PATTERN(Ge);
2224 XLA_COMPARE_PATTERN(Gt);
2225 XLA_COMPARE_PATTERN(Le);
2226 XLA_COMPARE_PATTERN(Lt);
2227 
2228 // Helpers for matching non-constant instructions.
2229 inline auto NonConstant() { return Op().IsNonConstant(); }
2230 
2231 template <typename HloInstructionType>
2232 inline auto NonConstant(HloInstructionType** matched_inst) {
2233   return Op(matched_inst).IsNonConstant();
2234 }
2235 
2236 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2237 // element is selected.
2238 template <typename Arg>
2239 inline auto GetTupleElement(Arg&& arg, int64 tuple_index) {
2240   return Op()
2241       .WithOpcode(HloOpcode::kGetTupleElement)
2242       .WithOperand(0, std::forward<Arg>(arg))
2243       .WithTupleIndex(tuple_index);
2244 }
2245 
2246 template <typename HloInstructionType, typename Arg>
2247 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2248                             int64 tuple_index) {
2249   return Op(matched_inst)
2250       .WithOpcode(HloOpcode::kGetTupleElement)
2251       .WithOperand(0, std::forward<Arg>(arg))
2252       .WithTupleIndex(tuple_index);
2253 }
2254 
2255 // Add overloads for Parameter which take an int64 specifying the parameter
2256 // number.
2257 inline auto Parameter(int64 parameter_num) {
2258   return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2259 }
2260 template <typename HloInstructionType>
2261 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) {
2262   return Op(matched_inst)
2263       .WithOpcode(HloOpcode::kParameter)
2264       .WithParameterNum(parameter_num);
2265 }
2266 
2267 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2268 
2269 template <typename HloInstructionType>
2270 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2271   return Op(matched_inst).IsConstantScalar();
2272 }
2273 
2274 template <typename ScalarTy>
2275 inline auto ConstantScalar(ScalarTy val) {
2276   return Op().IsConstantScalar(val);
2277 }
2278 
2279 template <typename HloInstructionType, typename ScalarTy>
2280 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2281   return Op(matched_inst).IsConstantScalar(val);
2282 }
2283 
2284 inline auto ConstantEffectiveScalar() {
2285   return Op().IsConstantEffectiveScalar();
2286 }
2287 
2288 template <typename HloInstructionType>
2289 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2290   return Op(matched_inst).IsConstantEffectiveScalar();
2291 }
2292 
2293 template <typename ScalarTy>
2294 inline auto ConstantEffectiveScalar(ScalarTy val) {
2295   return Op().IsConstantEffectiveScalar(val);
2296 }
2297 
2298 template <typename HloInstructionType, typename ScalarTy>
2299 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2300                                     ScalarTy val) {
2301   return Op(matched_inst).IsConstantEffectiveScalar(val);
2302 }
2303 
2304 }  // namespace match
2305 
2306 }  // namespace xla
2307 
2308 #undef EXPLAIN
2309 #pragma pop_macro("EXPLAIN")
2310 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2311