• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18 
19 #include <type_traits>
20 
21 #include "absl/strings/str_replace.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/utility/utility.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 
32 namespace xla {
33 
34 // A pattern matcher for HloInstructions, Shapes, and Layouts.
35 //
36 // The Match function's first argument must be HloInstruction*, Shape*, or
37 // Layout*. The second argument is a pattern that will be matched against the
38 // first argument, as described below.
39 //
40 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
41 // functions. By default, the returned patterns will match any HloInstruction,
42 // Shape, or Layout, respectively. However the match can be made more specific
43 // by using the pattern's modifier methods, for example:
44 //
45 //   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
46 //     0, match::Op().WithOpcode(HloOpcode::kConstant))
47 //
48 // This pattern will match Add instructions whose first operand is a constant.
49 //
50 // Each pattern type has the following modifiers, which are described where
51 // nontrivial.
52 //
53 //   Op():
54 //     - Is: is the given HloInstruction* (i.e. pointer equality)
55 //     - WithName
56 //     - WithOpcode
57 //     - WithoutOpcode: anything other than the given opcode
58 //     - WithShape: instr's shape matches the given pattern
59 //     - WithShapeEqualTo: instr's shape is equal to the given Shape
60 //     - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
61 //     - WithNumOperands
62 //     - WithOperand: operand at the given index matches the given pattern
63 //     - IsConstant
64 //     - IsNonConstant
65 //     - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
66 //       e.g. IsConstantScalar() or IsConstantScalar(42).
67 //     - WithFusionKind
68 //     - WithTupleIndex: get-tuple-element operations with the given tuple index
69 //     - WithOneUse: Instruction is used as an operand exactly once.
70 //     - WithOneUser: Instruction is used by exactly one other instruction, but
71 //       is possibly used more than once as an operand (e.g. multiply(x,x)).
72 //     - WithComparisonDirection: instr has the given direction
73 //
74 //   Shape():
75 //     - EqualTo
76 //     - CompatibleTo
77 //     - IsScalar/IsEffectiveScalar/IsArray/IsTuple
78 //     - IsDenseArray
79 //     - WithLayout: layout shape's layout matches the given pattern (e.g.
80 //       Layout().WithDenseFormat())
81 //     - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
82 //       Layout, but not the result of Layout().foo())
83 //     - WithSubshape: shape is a tuple whose subshape matches the given pattern
84 //       (e.g. Shape().IsScalar()).
85 //     - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
86 //       (i.e. another Shape, but not the result of Shape().foo())
87 //     - WithElementType: shape is an array/scalar with the given elem type
88 //     - WithRank: shape is an array/scalar with the given rank
89 //
90 //  Layout():
91 //     - EqualTo
92 //     - WithDenseFormat
93 //
94 // Op(), Shape(), and Layout() may be passed an argument of type
95 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
96 // these pointers. If the pattern is matched, the address of the matched value
97 // will be "captured" and stored at this location.
98 //
99 // For example:
100 //   HloInstruction* foo = ...;
101 //   HloInstruction* matched_operand;
102 //   CHECK(Match(foo,
103 //               match::Op().WithOperand(0, match::Op(&matched_operand))));
104 //
105 // Helpers are provided for most HLO instructions. These helpers can be called
106 // with no arguments, in which case they will match any instruction matching the
107 // opcode. They may also be called with matches for the operands and with an
108 // optional capture. (The capture must be the first argument.) Some examples of
109 // these helpers and their equivalents are provided below.
110 
111 // Example nullary instruction:
112 //   Parameter()                    == Op().WithOpcode(HloOpcode::kParameter)
113 //   Parameter(&a)                  == Op(&a).WithOpcode(HloOpcode::kParameter)
114 //
115 // Example unary instruction:
116 //   Abs()                          == Op().WithOpcode(HloOpcode::kAbs)
117 //   Abs(Op(&a))                    == Op().WithOpcode(HloOpcode::kAbs)
118 //                                         .WithOperand(0, Op(&a)))
119 //   Abs(&a, Op(&b))                == Op(&a).WithOpcode(HloOpcode::kAbs)
120 //                                           .WithOperand(0, Op(&b))
121 //
122 // Commutative binary instructions have a special form that accepts either order
123 // of args, e.g.:
124 //
125 //   AddAnyOrder(Parameter(1), Abs()) ==
126 //     Op().WithOpcode(HloOpcode::kAdd)
127 //         .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
128 //
129 //   MultiplyAnyOrder(&a, Parameter(), Abs())  // Captures the mul in `a`.
130 //
131 // The following additional helpers are provided.  In all cases, `&a` is
132 // optional.
133 //
134 //   ConstantScalar(&a)               == Op(&a).IsConstantScalar();
135 //   ConstantScalar(&a, v)            == Op(&a).IsConstantScalar(v);
136 //   ConstantEffectiveScalar(&a)      == Op(&a).IsConstantEffectiveScalar();
137 //   ConstantEffectiveScalar(&a, v)   == Op(&a).IsConstantEffectiveScalar(&a, v)
138 //   NonConstant(&a)                  == Op(&a).IsNonConstant()
139 //   GetTupleElement(&a, b, index)    == Op(&a).WithTupleIndex(index)
140 //                                             .WithOperand(0, b);
141 //   Parameter(&a, n)                 == Op(&a).WithParameterNum(n);
142 
143 struct MatchOption {
144   // If true, actually capture matched item into the user pointer.
145   bool capture;
146 
147   // An explanation for why we failed to match is streamed here, if not-null.
148   std::ostream* explain_os;
149 };
150 
151 template <typename Value, typename Pattern>
152 bool Match(Value* value, const Pattern& pattern,
153            MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
154   if (option.capture) {
155     auto new_option = option;
156     new_option.capture = false;
157     if (!pattern.Match(value, new_option)) {
158       return false;
159     }
160   }
161   return pattern.Match(value, option);
162 }
163 
164 namespace match {
165 
166 namespace detail {
167 
168 // Macro for streaming to option.explain_os if it's not null.
169 //
170 //   EXPLAIN << "value of foo(): " << foo()
171 //
172 #pragma push_macro("EXPLAIN")
173 #define EXPLAIN \
174   if (option.explain_os) *option.explain_os
175 
176 // kIndentInc is the additional number of spaces that we indent by when we
177 // increase the indent "by one".
178 enum {
179   kIndentInc = 2,
180 };
181 
182 // Writes a newline and then `indent` spaces.
183 //
184 // We follow an unintuitive convention in this file's pretty-printers: Indents
185 // are performed by the caller, not the callee.  For example, if you want to
186 // print
187 //
188 //   foo:
189 //    - bar
190 //
191 // you'd do:
192 //
193 //  Foo::DescribeTo(std::ostream* os, int64 indent) {
194 //    *os << "foo:";
195 //    Indent(os, indent)  // Create a newline at the *current* indent level.
196 //    *os << " - ";
197 //    bar.DescribeTo(os, indent + 3);  // + 3 because strlen(" * ") == 3.
198 //  }
199 //
200 //  Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
201 //
202 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
203 // performed by Foo.  This convention allows the caller to decide whether a
204 // matcher is preceded by a newline, which is important e.g. for the AllOf
205 // matcher.
206 //
207 // (Incidentally, indenting in Match's explanations is handled differently.
208 // Indents are a common case in DescribeTo [we're printing a whole tree], but
209 // they're a special case in Match [we're printing only a path through the tree
210 // that encounters a failing node]. Indents in Match only appear when we
211 // encounter a failing disjunction, so we just handle them as a special case
212 // there.)
Indent(std::ostream * os,int64_t indent)213 inline void Indent(std::ostream* os, int64_t indent) {
214   *os << "\n";
215   for (int64_t i = 0; i < indent; ++i) {
216     *os << " ";
217   }
218 }
219 
220 // SFINAE template that determines whether T declares a static member
221 // kIsTrivialMatcher.
222 //
223 // Trivial matchers get special treatment.  For example, when printing
224 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
225 // yields e.g.
226 //    "a shape compatible with f32[1,2]"
227 // rather than
228 //    "a shape AND compatible with f32[1,2]"
229 template <typename T, typename Dummy = void>
230 struct IsTrivialMatcher {
231   static constexpr bool value = false;
232 };
233 template <typename T>
234 struct IsTrivialMatcher<T,
235                         typename std::enable_if<T::kIsTrivialMatcher>::type> {
236   static constexpr bool value = true;
237 };
238 
239 template <typename Item, typename... Patterns>
240 class AllOfPattern {
241  public:
242   explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
243 
244   bool Match(const Item* item, MatchOption option) const {
245     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
246     // This invariant is guaranteed by the top-level Match and AnyOf.
247     DCHECK(matched || !option.capture);
248     return matched;
249   }
250 
251   bool Match(Item* item, MatchOption option) const {
252     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
253     // This invariant is guaranteed by the top-level Match and AnyOf.
254     DCHECK(matched || !option.capture);
255     return matched;
256   }
257 
258   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
259     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
260   }
261 
262   // Accessor for patterns_.  Please don't use this outside of this file.
263   const std::tuple<Patterns...>& patterns() const { return patterns_; }
264 
265  private:
266   template <typename ItemType, size_t index>
267   bool MatchImpl(ItemType* item, MatchOption option,
268                  std::integral_constant<size_t, index>) const {
269     // We don't need to do any EXPLAINing here; it's all correctly handled by
270     // our sub-matchers (if any fail).
271     return std::get<index>(patterns_).Match(item, option) &&
272            MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
273   }
274 
275   template <typename ItemType>
276   bool MatchImpl(ItemType* item, MatchOption option,
277                  std::integral_constant<size_t, sizeof...(Patterns)>) const {
278     return true;
279   }
280 
281   // Pretty-printing a conjunction has some special cases to make it easy to
282   // read in the simple (common) case.
283   //
284   // If sizeof...(Patterns) == 1, prints as e.g.
285   //
286   //   a shape
287   //
288   // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
289   // shape") prints as
290   //
291   //   a shape compatible with f32[1,2]
292   //
293   // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
294   //
295   //   a shape:
296   //    * compatible with f32[1,2] AND
297   //    * that represents a scalar
298   //
299   // Otherwise prints as:
300   //
301   //   all of:
302   //    * foo AND
303   //    * bar
304   //
305   template <size_t index>
306   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
307                       int64_t indent) const {
308     constexpr bool first_is_trivial =
309         IsTrivialMatcher<typename std::remove_reference<decltype(std::get<0>(
310             patterns_))>::type>::value;
311     constexpr bool is_last = index == sizeof...(Patterns) - 1;
312     const auto& submatcher = std::get<index>(patterns_);
313 
314     auto print_bulleted_item = [&] {
315       *os << " * ";
316       submatcher.DescribeTo(os, indent + 3);
317       if (!is_last) {
318         *os << " AND";
319         Indent(os, indent);
320       }
321     };
322 
323     if (index == 0) {
324       if (first_is_trivial || is_last) {
325         submatcher.DescribeTo(os, indent + kIndentInc);
326         if (sizeof...(Patterns) > 2) {
327           *os << ":";
328           Indent(os, indent);
329         }
330       } else {
331         *os << "all of:";
332         Indent(os, indent);
333         print_bulleted_item();
334       }
335     } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
336       *os << " ";
337       submatcher.DescribeTo(os, indent);
338     } else {
339       print_bulleted_item();
340     }
341     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
342   }
343 
344   void DescribeToImpl(std::ostream* os,
345                       std::integral_constant<size_t, sizeof...(Patterns)>,
346                       int64_t indent) const {}
347 
348   std::tuple<Patterns...> patterns_;
349 };
350 
351 }  // namespace detail
352 
353 // Returns a pattern that represents the conjunction of all input patterns. All
354 // patterns need to match in order to have the AllOf pattern match.
355 template <typename Item, typename... Patterns>
356 auto AllOf(const Patterns&... patterns) {
357   return detail::AllOfPattern<typename std::remove_const<Item>::type,
358                               Patterns...>(patterns...);
359 }
360 
361 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
362 //
363 // This transformation is necessary for good pretty-printing.
364 template <typename Item, typename... InnerPs, typename... OuterPs>
365 auto AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
366            const OuterPs&... outer_ps) {
367   // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
368   auto make_all_of = [](const InnerPs&... inner_ps,
369                         const OuterPs&... outer_ps) {
370     return detail::AllOfPattern<typename std::remove_const<Item>::type,
371                                 InnerPs..., OuterPs...>(inner_ps...,
372                                                         outer_ps...);
373   };
374   return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
375                                                  std::make_tuple(outer_ps...)));
376 }
377 
378 namespace detail {
379 
380 template <typename LayoutType, typename Impl>
381 class LayoutPattern;
382 
383 // The base LayoutPattern implementation. Matches only if the layout is not
384 // nullptr.
385 class LayoutPatternBaseImpl {
386  public:
387   bool Match(const ::xla::Layout* layout, MatchOption option) const {
388     if (layout == nullptr) {
389       EXPLAIN << "Layout is null";
390       return false;
391     }
392     return true;
393   }
394 
395   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
396     *os << "a layout";
397   }
398 
399   static constexpr bool kIsTrivialMatcher = true;
400 };
401 
402 // A LayoutPattern implementation that matches only if the layout equals a
403 // Layout proto.
404 class LayoutPatternEqualImpl {
405  public:
406   explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
407       : layout_(layout) {}
408 
409   bool Match(const ::xla::Layout* layout, MatchOption option) const {
410     if (!LayoutUtil::Equal(*layout_, *layout)) {
411       EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
412               << " is not equal to expected "
413               << LayoutUtil::HumanString(*layout_);
414       return false;
415     }
416     return true;
417   }
418 
419   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
420     *os << "equal to " << LayoutUtil::HumanString(*layout_);
421   }
422 
423  private:
424   const ::xla::Layout* layout_;
425 };
426 
427 // A LayoutPattern implementation that matches only if the layout has a given
428 // format.
429 class LayoutPatternFormatImpl {
430  public:
431   explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
432 
433   bool Match(const ::xla::Layout* layout, MatchOption option) const {
434     if (layout->format() != format_) {
435       EXPLAIN << "Layout has format " << Format_Name(layout->format())
436               << " but expected " << Format_Name(format_);
437       return false;
438     }
439     return true;
440   }
441 
442   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
443     *os << "with format " << Format_Name(format_);
444   }
445 
446  private:
447   Format format_;
448 };
449 
450 // A pattern that matches Layouts.
451 template <typename LayoutType, typename Impl>
452 class LayoutPattern {
453  private:
454   template <typename NewImpl>
455   auto AppendImpl(NewImpl new_impl) const {
456     auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl));
457     return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
458                                                           matched_layout_);
459   }
460 
461  public:
462   explicit constexpr LayoutPattern(const Impl& impl,
463                                    LayoutType** matched_layout)
464       : impl_(impl), matched_layout_(matched_layout) {}
465 
466   // Returns true and captures the layout iff it matches the pattern.
467   bool Match(const ::xla::Layout* layout, MatchOption option) const {
468     if (impl_.Match(layout, option)) {
469       if (option.capture && matched_layout_) {
470         *matched_layout_ = layout;
471       }
472       return true;
473     }
474     return false;
475   }
476 
477   // Returns true and captures the layout iff it matches the pattern.
478   bool Match(::xla::Layout* layout, MatchOption option) const {
479     if (impl_.Match(layout, option)) {
480       if (option.capture && matched_layout_) {
481         *matched_layout_ = layout;
482       }
483       return true;
484     }
485     return false;
486   }
487 
488   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
489     impl_.DescribeTo(os, indent);
490   }
491 
492   // Modifies the pattern to match only if the layout equals the given proto.
493   // The layout must outlive the returned pattern.
494   constexpr auto EqualTo(const ::xla::Layout* layout) const {
495     return AppendImpl(LayoutPatternEqualImpl(layout));
496   }
497 
498   // Modifies the pattern to match only if the layout has a dense format.
499   constexpr auto WithDenseFormat() const {
500     return AppendImpl(LayoutPatternFormatImpl(DENSE));
501   }
502 
503  private:
504   Impl impl_;
505   LayoutType** matched_layout_;
506 };
507 
508 template <typename Item, typename... Patterns>
509 class AnyOfPattern {
510  public:
511   explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
512 
513   bool Match(const Item* item, MatchOption option) const {
514     return MatchImpl(item, option);
515   }
516 
517   bool Match(Item* item, MatchOption option) const {
518     return MatchImpl(item, option);
519   }
520 
521   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
522     *os << "any of:";
523     Indent(os, indent);
524     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
525   }
526 
527  private:
528   template <typename ItemType>
529   bool MatchImpl(ItemType* item, MatchOption option) const {
530     // If we're generating an explanation, buffer it until we know we failed.
531     absl::optional<std::stringstream> explanation;
532     MatchOption new_option = option;
533     if (option.explain_os) {
534       new_option.explain_os = &explanation.emplace();
535     }
536     bool rv = MatchRecursiveImpl(item, new_option,
537                                  std::integral_constant<size_t, 0>());
538     if (!rv && option.explain_os) {
539       EXPLAIN << "None of the following matchers succeeded:";
540       EXPLAIN << explanation->str();
541     }
542     return rv;
543   }
544 
545   template <typename ItemType, size_t index>
546   bool MatchRecursiveImpl(ItemType* item, MatchOption option,
547                           std::integral_constant<size_t, index>) const {
548     auto new_option = option;
549     new_option.capture = false;
550 
551     absl::optional<std::stringstream> explanation;
552     if (option.explain_os) {
553       new_option.explain_os = &explanation.emplace();
554     }
555 
556     // Try to match the sub-pattern without capturing behavior.
557     if (std::get<index>(patterns_).Match(item, new_option)) {
558       // Capture the branch.
559       if (option.capture) {
560         // TODO(timshen): Currently the behavior can be exponential. Optimize it
561         // with memoization or recording the matched sub-pattern index, if it
562         // takes too long to run.
563         //
564         // Specifically, the "memoization" approach is to create an empty
565         // container with the key (pattern, instruction), and value as whether
566         // matched or not.
567         //
568         // Alternatively, we may run the pattern matching with captures off, but
569         // instead record a "trace" somewhere, indicating how exactly the
570         // pattern matches the input. For example, the trace information for
571         // AnyOf will be a runtime number indicate which sub-pattern is matched.
572         // Then we run another pass to do captures only with the help of the
573         // trace.
574         bool matched = std::get<index>(patterns_).Match(item, option);
575         DCHECK(matched);
576       }
577       return true;
578     }
579     if (option.explain_os) {
580       EXPLAIN << "\nMatcher #" << index + 1;
581       EXPLAIN << "\n - ";
582       std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
583       EXPLAIN << "\nfailed with";
584       EXPLAIN << "\n - ";
585       EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n   "}});
586     }
587     return MatchRecursiveImpl(item, option,
588                               std::integral_constant<size_t, index + 1>());
589   }
590 
591   template <typename ItemType>
592   bool MatchRecursiveImpl(
593       ItemType* item, MatchOption option,
594       std::integral_constant<size_t, sizeof...(Patterns)>) const {
595     return false;
596   }
597 
598   template <size_t index>
599   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
600                       int64_t indent) const {
601     *os << " - ";
602     std::get<index>(patterns_).DescribeTo(os, indent + 3);
603     if (index != sizeof...(Patterns) - 1) {
604       *os << " OR";
605       Indent(os, indent);
606     }
607     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
608   }
609 
610   void DescribeToImpl(std::ostream* os,
611                       std::integral_constant<size_t, sizeof...(Patterns)>,
612                       int64_t indent) const {}
613 
614   std::tuple<Patterns...> patterns_;
615 };
616 
617 }  // namespace detail
618 
619 // Returns a pattern that represents the logical disjunction of the input
620 // patterns. The returned pattern matches from left to right, and stops on the
621 // first match.
622 template <typename Item, typename... Patterns>
623 auto AnyOf(const Patterns&... patterns) {
624   return detail::AnyOfPattern<typename std::remove_const<Item>::type,
625                               Patterns...>(patterns...);
626 }
627 
628 // Creates a layout pattern that will capture the matched layout in the
629 // argument.
630 inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) {
631   return detail::LayoutPattern<const ::xla::Layout,
632                                detail::LayoutPatternBaseImpl>(
633       detail::LayoutPatternBaseImpl(), matched_layout);
634 }
635 
636 // Creates a layout pattern that will capture the matched layout in the
637 // argument.
638 inline constexpr auto Layout(::xla::Layout** matched_layout) {
639   return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
640       detail::LayoutPatternBaseImpl(), matched_layout);
641 }
642 
643 namespace detail {
644 
645 template <typename ShapeType, typename Impl>
646 class ShapePattern;
647 
648 // The base ShapePattern implementation. Matches only if the shape is not
649 // nullptr.
650 class ShapePatternBaseImpl {
651  public:
652   bool Match(const ::xla::Shape* shape, MatchOption option) const {
653     if (shape == nullptr) {
654       EXPLAIN << "Shape is null";
655     }
656     return shape != nullptr;
657   }
658 
659   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
660     *os << "a shape";
661   }
662 
663   static constexpr bool kIsTrivialMatcher = true;
664 };
665 
666 // A ShapePattern implementation that matches only if the shape equals a Shape
667 // proto.
668 class ShapePatternEqualImpl {
669  public:
670   explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
671       : shape_(shape) {}
672 
673   bool Match(const ::xla::Shape* shape, MatchOption option) const {
674     if (!ShapeUtil::Equal(*shape_, *shape)) {
675       EXPLAIN << "Shape not equal to "
676               << ShapeUtil::HumanStringWithLayout(*shape_);
677       return false;
678     }
679     return true;
680   }
681 
682   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
683     *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
684   }
685 
686  private:
687   const ::xla::Shape* shape_;
688 };
689 
690 // A ShapePattern implementation that matches only if the shape is compatible to
691 // a Shape proto.
692 class ShapePatternCompatibleImpl {
693  public:
694   explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
695       : shape_(shape) {}
696 
697   bool Match(const ::xla::Shape* shape, MatchOption option) const {
698     if (!ShapeUtil::Compatible(*shape_, *shape)) {
699       EXPLAIN << "Shape not compatible with "
700               << ShapeUtil::HumanString(*shape_);
701       return false;
702     }
703     return true;
704   }
705 
706   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
707     *os << "compatible with " << ShapeUtil::HumanString(*shape_);
708   }
709 
710  private:
711   const ::xla::Shape* shape_;
712 };
713 
714 // A ShapePattern implementation that matches only if the shape has a given
715 // element type.
716 class ShapePatternElementTypeImpl {
717  public:
718   explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
719       : element_type_(element_type) {}
720 
721   bool Match(const ::xla::Shape* shape, MatchOption option) const {
722     if (shape->element_type() != element_type_) {
723       EXPLAIN << "Shape does not have element type "
724               << PrimitiveType_Name(element_type_);
725       return false;
726     }
727     return true;
728   }
729 
730   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
731     *os << "with element type " << PrimitiveType_Name(element_type_);
732   }
733 
734  private:
735   PrimitiveType element_type_;
736 };
737 
738 // A ShapePattern implementation that matches only if the shape has a given
739 // list of dimensions.
740 class ShapePatternDimsImpl {
741  public:
742   explicit ShapePatternDimsImpl(absl::Span<const int64> dims)
743       : dims_(dims.begin(), dims.end()) {}
744 
745   bool Match(const ::xla::Shape* shape, MatchOption option) const {
746     if (shape->dimensions() != dims_) {
747       EXPLAIN << "Shape does not have dimensions [" << absl::StrJoin(dims_, ",")
748               << "]";
749       return false;
750     }
751     return true;
752   }
753 
754   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
755     *os << "with dimensions [" << absl::StrJoin(dims_, ",") << "]";
756   }
757 
758  private:
759   absl::InlinedVector<int64, 8> dims_;
760 };
761 
762 // A ShapePattern implementation that matches only if the shape is scalar.
763 class ShapePatternIsScalarImpl {
764  public:
765   explicit constexpr ShapePatternIsScalarImpl() {}
766 
767   bool Match(const ::xla::Shape* shape, MatchOption option) const {
768     if (!ShapeUtil::IsScalar(*shape)) {
769       EXPLAIN << "Shape is not a scalar";
770       return false;
771     }
772     return true;
773   }
774 
775   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
776     *os << "that represents a scalar";
777   }
778 };
779 
780 // A ShapePattern implementation that matches only if the shape is an array
781 class ShapePatternIsArrayImpl {
782  public:
783   explicit constexpr ShapePatternIsArrayImpl() {}
784 
785   bool Match(const ::xla::Shape* shape, MatchOption option) const {
786     if (!shape->IsArray()) {
787       EXPLAIN << "Shape is not an array";
788       return false;
789     }
790     return true;
791   }
792 
793   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
794     *os << "that represents an array";
795   }
796 };
797 
798 // A ShapePattern implementation that matches only if the shape is a tuple.
799 class ShapePatternIsTupleImpl {
800  public:
801   explicit constexpr ShapePatternIsTupleImpl() {}
802 
803   bool Match(const ::xla::Shape* shape, MatchOption option) const {
804     if (!shape->IsTuple()) {
805       EXPLAIN << "Shape is not a tuple";
806       return false;
807     }
808     return true;
809   }
810 
811   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
812     *os << "that represents a tuple";
813   }
814 };
815 
816 // A ShapePattern implementation that matches only if the shape is an effective
817 // scalar.
818 class ShapePatternEffectiveScalarImpl {
819  public:
820   explicit constexpr ShapePatternEffectiveScalarImpl() {}
821 
822   bool Match(const ::xla::Shape* shape, MatchOption option) const {
823     if (!ShapeUtil::IsEffectiveScalar(*shape)) {
824       EXPLAIN << "Shape is not an effective scalar";
825       return false;
826     }
827     return true;
828   }
829 
830   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
831     *os << "that is an effective scalar";
832   }
833 };
834 
835 // A ShapePattern implementation that matches only if the shape has a given
836 // rank.
837 class ShapePatternRankImpl {
838  public:
839   explicit constexpr ShapePatternRankImpl(int64_t rank) : rank_(rank) {}
840 
841   bool Match(const ::xla::Shape* shape, MatchOption option) const {
842     if (shape->rank() != rank_) {
843       if (rank_ == 0) {
844         EXPLAIN << "Shape is not a scalar";
845       } else {
846         EXPLAIN << "Shape does not have rank " << rank_;
847       }
848       return false;
849     }
850     return true;
851   }
852 
853   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
854     if (rank_ == 0) {
855       *os << "that is a scalar";
856     } else {
857       *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
858     }
859   }
860 
861  private:
862   int64 rank_;
863 };
864 
865 // A ShapePattern implementation that matches only if the shape has a layout
866 // that matches a given pattern.
867 template <typename LayoutType, typename LayoutImpl>
868 class ShapePatternLayoutImpl {
869  public:
870   explicit constexpr ShapePatternLayoutImpl(
871       const LayoutPattern<LayoutType, LayoutImpl>& layout)
872       : layout_(layout) {}
873 
874   bool Match(const ::xla::Shape* shape, MatchOption option) const {
875     return LayoutUtil::HasLayout(*shape) &&
876            layout_.Match(&shape->layout(), option);
877   }
878 
879   bool Match(::xla::Shape* shape, MatchOption option) const {
880     if (!LayoutUtil::HasLayout(*shape)) {
881       EXPLAIN << "Shape does not have a layout";
882       return false;
883     }
884     if (!layout_.Match(shape->mutable_layout(), option)) {
885       EXPLAIN << "\nin layout";
886       return false;
887     }
888     return true;
889   }
890 
891   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
892     *os << "with";
893     Indent(os, indent + kIndentInc);
894     layout_.DescribeTo(os, indent + kIndentInc);
895   }
896 
897  private:
898   LayoutPattern<LayoutType, LayoutImpl> layout_;
899 };
900 
901 // A ShapePattern implementation that matches only if the shape has a subshape
902 // that matches a given pattern.
903 template <typename SubshapeType, typename SubshapeImpl>
904 class ShapePatternSubshapeImpl {
905  public:
906   explicit ShapePatternSubshapeImpl(
907       ShapeIndexView index,
908       const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
909       : index_(index), subshape_(subshape) {}
910 
911   bool Match(const ::xla::Shape* shape, MatchOption option) const {
912     return MatchImpl(shape, option);
913   }
914 
915   bool Match(::xla::Shape* shape, MatchOption option) const {
916     return MatchImpl(shape, option);
917   }
918 
919   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
920     *os << "with subshape at index " << index_.ToString() << " which is";
921     Indent(os, indent + kIndentInc);
922     subshape_.DescribeTo(os, indent + kIndentInc);
923   }
924 
925  private:
926   ::xla::Shape* GetSubshape(::xla::Shape* shape) const {
927     return ShapeUtil::GetMutableSubshape(shape, index_);
928   }
929   const ::xla::Shape* GetSubshape(const ::xla::Shape* shape) const {
930     return &ShapeUtil::GetSubshape(*shape, index_);
931   }
932 
933   template <typename ShapeType>
934   bool MatchImpl(ShapeType* shape, MatchOption option) const {
935     if (!ShapeUtil::IndexIsValid(*shape, index_)) {
936       EXPLAIN << "No subshape at " << index_.ToString();
937       return false;
938     }
939     if (!subshape_.Match(GetSubshape(shape), option)) {
940       EXPLAIN << "\nin subshape at " << index_.ToString();
941       return false;
942     }
943     return true;
944   }
945 
946   ShapeIndexView index_;
947   ShapePattern<SubshapeType, SubshapeImpl> subshape_;
948 };
949 
950 // A pattern that matches Shapes.
951 template <typename ShapeType, typename Impl>
952 class ShapePattern {
953  private:
954   template <typename NewImpl>
955   auto AppendImpl(NewImpl new_impl) const {
956     auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl));
957     return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
958                                                          matched_shape_);
959   }
960 
961  public:
962   explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
963       : impl_(impl), matched_shape_(matched_shape) {}
964 
965   // Returns true and captures the shape iff it matches the pattern.
966   bool Match(const ::xla::Shape* shape, MatchOption option) const {
967     if (impl_.Match(shape, option)) {
968       if (option.capture && matched_shape_) {
969         *matched_shape_ = shape;
970       }
971       return true;
972     }
973     if (shape) {
974       EXPLAIN << "\nin "
975               << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
976                                       : ShapeUtil::HumanString(*shape));
977     }
978     return false;
979   }
980 
981   // Returns true and captures the shape iff it matches the pattern.
982   bool Match(::xla::Shape* shape, MatchOption option) const {
983     if (impl_.Match(shape, option)) {
984       if (option.capture && matched_shape_) {
985         *matched_shape_ = shape;
986       }
987       return true;
988     }
989     EXPLAIN << "\nin "
990             << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
991                                     : ShapeUtil::HumanString(*shape));
992     return false;
993   }
994 
995   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
996     return impl_.DescribeTo(os, indent);
997   }
998 
999   // Modifies the pattern to match only if the shape equals the given proto.
1000   // The layout must outlive the returned pattern.
1001   constexpr auto EqualTo(const ::xla::Shape* shape) const {
1002     return AppendImpl(ShapePatternEqualImpl(shape));
1003   }
1004 
1005   // Modifies the pattern to match only if the shape is compatible to the given
1006   // proto. The layout must outlive the returned pattern.
1007   constexpr auto CompatibleTo(const ::xla::Shape* shape) const {
1008     return AppendImpl(ShapePatternCompatibleImpl(shape));
1009   }
1010 
1011   // Modifies the pattern to match only if the shape has the given element type.
1012   constexpr auto WithElementType(PrimitiveType element_type) const {
1013     return AppendImpl(ShapePatternElementTypeImpl(element_type));
1014   }
1015 
1016   constexpr auto WithDims(absl::Span<const int64> dims) const {
1017     return AppendImpl(ShapePatternDimsImpl(dims));
1018   }
1019 
1020   // Modifies the pattern to match only if the shape is scalar.
1021   constexpr auto IsScalar() const {
1022     return AppendImpl(ShapePatternIsScalarImpl());
1023   }
1024 
1025   // Modifies the pattern to match only if the shape is an array.
1026   constexpr auto IsArray() const {
1027     return AppendImpl(ShapePatternIsArrayImpl());
1028   }
1029 
1030   // Modifies the pattern to match only if the shape is a tuple.
1031   constexpr auto IsTuple() const {
1032     return AppendImpl(ShapePatternIsTupleImpl());
1033   }
1034 
1035   constexpr auto IsEffectiveScalar() const {
1036     return AppendImpl(ShapePatternEffectiveScalarImpl());
1037   }
1038 
1039   // Modifies the pattern to match only if the shape has the given rank.
1040   constexpr auto WithRank(int64_t rank) const {
1041     return AppendImpl(ShapePatternRankImpl(rank));
1042   }
1043 
1044   // Modifies the pattern to match only if the shape has a layout that matches
1045   // the given pattern.
1046   template <typename LayoutType, typename LayoutImpl>
1047   auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
1048     return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1049   }
1050 
1051   constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const {
1052     return WithLayout(Layout().EqualTo(layout));
1053   }
1054 
1055   constexpr auto IsDenseArray() const {
1056     return WithLayout(Layout().WithDenseFormat());
1057   }
1058 
1059   // Modifies the pattern to match only if the shape has a subshape that matches
1060   // the given pattern.
1061   template <typename SubshapeType, typename SubshapeImpl>
1062   auto WithSubshape(
1063       ShapeIndexView index,
1064       const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
1065     return AppendImpl(
1066         ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1067   }
1068 
1069   ShapePattern<ShapeType,
1070                AllOfPattern<::xla::Shape, Impl,
1071                             ShapePatternSubshapeImpl<
1072                                 const ::xla::Shape,
1073                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1074                                              ShapePatternEqualImpl>>>>
1075   WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1076     return WithSubshape(index,
1077                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1078                             ShapePatternBaseImpl(), nullptr)
1079                             .EqualTo(shape));
1080   }
1081 
1082   ShapePattern<ShapeType,
1083                AllOfPattern<::xla::Shape, Impl,
1084                             ShapePatternSubshapeImpl<
1085                                 const ::xla::Shape,
1086                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1087                                              ShapePatternCompatibleImpl>>>>
1088   WithSubshapeCompatibleTo(ShapeIndexView index,
1089                            const ::xla::Shape* shape) const {
1090     return WithSubshape(index,
1091                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1092                             ShapePatternBaseImpl(), nullptr)
1093                             .CompatibleTo(shape));
1094   }
1095 
1096  private:
1097   Impl impl_;
1098   ShapeType** matched_shape_;
1099 };
1100 
1101 }  // namespace detail
1102 
1103 // Creates a shape pattern that will capture the matched layout in the argument.
1104 inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) {
1105   return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1106       detail::ShapePatternBaseImpl(), matched_shape);
1107 }
1108 
1109 // Creates a shape pattern that will capture the matched layout in the argument.
1110 inline constexpr auto Shape(::xla::Shape** matched_shape) {
1111   return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1112       detail::ShapePatternBaseImpl(), matched_shape);
1113 }
1114 
1115 namespace detail {
1116 
1117 // Overloads to get a const or non-const operand out of an instruction.
1118 inline HloInstruction* HloOperand(HloInstruction* instr, int64_t idx) {
1119   return instr->mutable_operand(idx);
1120 }
1121 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1122                                         int64_t idx) {
1123   return instr->operand(idx);
1124 }
1125 
1126 // Pretty-printer for HloInstruction.  Sort of like ToShortString, but with
1127 // fewer %s and more shapes.
1128 inline string InstToString(const HloInstruction* inst) {
1129   return inst->ToString(
1130       HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1131 }
1132 
1133 template <typename HloInstructionType, typename Impl>
1134 class HloInstructionPattern;
1135 
1136 // The base HloInstructionPattern implementation. Matches only if the
1137 // instruction is not nullptr.
1138 class HloInstructionPatternBaseImpl {
1139  public:
1140   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1141     if (inst == nullptr) {
1142       EXPLAIN << "HloInstruction* is null";
1143       return false;
1144     }
1145     return true;
1146   }
1147 
1148   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1149     *os << "an HloInstruction";
1150   }
1151 
1152   static constexpr bool kIsTrivialMatcher = true;
1153 };
1154 
1155 // An HloInstructionPattern implementation that matches only if the instruction
1156 // has a given name.
1157 class HloInstructionPatternNameImpl {
1158  public:
1159   explicit HloInstructionPatternNameImpl(absl::string_view name)
1160       : name_(name) {}
1161 
1162   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1163     if (inst->name() != name_) {
1164       EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1165       return false;
1166     }
1167     return true;
1168   }
1169 
1170   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1171     *os << "named \"" << name_ << "\"";
1172   }
1173 
1174  private:
1175   absl::string_view name_;
1176 };
1177 
1178 // An HloInstructionPattern implementation that matches only if the instruction
1179 // equals a particular pointer.
1180 class HloInstructionIsImpl {
1181  public:
1182   explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1183 
1184   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1185     if (inst != inst_) {
1186       EXPLAIN << "HloInstruction " << std::hex << std::nouppercase
1187               << std::showbase << reinterpret_cast<uint64>(inst) << " is not "
1188               << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1189               << ")";
1190       return false;
1191     }
1192     return true;
1193   }
1194 
1195   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1196     *os << "which is " << std::hex << std::nouppercase << std::showbase
1197         << reinterpret_cast<uint64>(inst_) << " (" << InstToString(inst_)
1198         << ")";
1199   }
1200 
1201  private:
1202   const HloInstruction* inst_;
1203 };
1204 
1205 // An HloInstructionPattern implementation that matches only if the instruction
1206 // has a given opcode.
1207 class HloInstructionPatternOpcodeImpl {
1208  public:
1209   explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1210                                                      bool invert)
1211       : opcode_(opcode), invert_(invert) {}
1212 
1213   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1214     if (invert_ && inst->opcode() == opcode_) {
1215       EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1216               << ", expected anything else";
1217       return false;
1218     }
1219     if (!invert_ && inst->opcode() != opcode_) {
1220       EXPLAIN << "HloInstruction doesn't have opcode "
1221               << HloOpcodeString(opcode_);
1222       return false;
1223     }
1224     return true;
1225   }
1226 
1227   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1228     if (!invert_) {
1229       *os << "with opcode " << HloOpcodeString(opcode_);
1230     } else {
1231       *os << "with any opcode other than " << HloOpcodeString(opcode_);
1232     }
1233   }
1234 
1235  private:
1236   HloOpcode opcode_;
1237   bool invert_;
1238 };
1239 
1240 // An HloInstructionPattern implementation that matches only if the instruction
1241 // has one of a given list of custom call targets.
1242 class HloInstructionCustomCallTargetImpl {
1243  public:
1244   explicit HloInstructionCustomCallTargetImpl(
1245       absl::Span<const absl::string_view> custom_call_targets)
1246       : custom_call_targets_(custom_call_targets.begin(),
1247                              custom_call_targets.end()) {}
1248 
1249   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1250     if (inst->opcode() != HloOpcode::kCustomCall ||
1251         !absl::c_linear_search(custom_call_targets_,
1252                                inst->custom_call_target())) {
1253       if (custom_call_targets_.size() == 1) {
1254         EXPLAIN << "HloInstruction is not a custom call with a target '"
1255                 << custom_call_targets_.front() << "'";
1256       } else {
1257         EXPLAIN << "HloInstruction is not a custom call with a target in {"
1258                 << absl::StrJoin(custom_call_targets_, ", ") << "}";
1259       }
1260       return false;
1261     }
1262     return true;
1263   }
1264 
1265   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1266     if (custom_call_targets_.size() == 1) {
1267       *os << "custom call with target '" << custom_call_targets_.front() << "'";
1268     } else {
1269       *os << "custom call with target in {"
1270           << absl::StrJoin(custom_call_targets_, ", ") << "}";
1271     }
1272   }
1273 
1274  private:
1275   absl::InlinedVector<std::string, 1> custom_call_targets_;
1276 };
1277 
1278 // An HloInstructionPattern implementation that matches only if the instruction
1279 // has the given number of operands.
1280 class HloInstructionPatternNumOperandsImpl {
1281  public:
1282   explicit constexpr HloInstructionPatternNumOperandsImpl(int64_t num_operands)
1283       : num_operands_(num_operands) {}
1284 
1285   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1286     if (inst->operand_count() != num_operands_) {
1287       EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1288       return false;
1289     }
1290     return true;
1291   }
1292 
1293   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1294     *os << "with " << num_operands_ << " operand"
1295         << (num_operands_ != 1 ? "s" : "");
1296   }
1297 
1298  private:
1299   int64 num_operands_;
1300 };
1301 
1302 // An HloInstructionPattern implementation that matches only if the instruction
1303 // has a shape that matches a given pattern.
1304 template <typename ShapeType, typename ShapeImpl>
1305 class HloInstructionPatternShapeImpl {
1306  public:
1307   explicit constexpr HloInstructionPatternShapeImpl(
1308       const ShapePattern<ShapeType, ShapeImpl>& shape)
1309       : shape_(shape) {}
1310 
1311   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1312     if (!shape_.Match(&inst->shape(), option)) {
1313       EXPLAIN << "\nin output shape";
1314       return false;
1315     }
1316     return true;
1317   }
1318 
1319   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1320     if (!shape_.Match(inst->mutable_shape(), option)) {
1321       EXPLAIN << "\nin output shape";
1322       return false;
1323     }
1324     return true;
1325   }
1326 
1327   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1328     *os << "outputting";
1329     Indent(os, indent + kIndentInc);
1330     shape_.DescribeTo(os, indent + kIndentInc);
1331   }
1332 
1333  private:
1334   ShapePattern<ShapeType, ShapeImpl> shape_;
1335 };
1336 
1337 // An HloInstructionPattern implementation that matches only if the instruction
1338 // has an operand that matches a given pattern.
1339 template <typename OperandType, typename OperandImpl>
1340 class HloInstructionPatternOperandImpl {
1341  public:
1342   explicit constexpr HloInstructionPatternOperandImpl(
1343       int64_t operand_index,
1344       const HloInstructionPattern<OperandType, OperandImpl>& operand)
1345       : operand_index_(operand_index), operand_(operand) {}
1346 
1347   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1348     return MatchImpl(inst, option);
1349   }
1350 
1351   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1352     return MatchImpl(inst, option);
1353   }
1354 
1355   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1356     *os << "with operand " << operand_index_ << " which is:";
1357     Indent(os, indent + kIndentInc);
1358     operand_.DescribeTo(os, indent + kIndentInc);
1359   }
1360 
1361  private:
1362   template <typename HloInstructionType>
1363   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1364     if (operand_index_ >= inst->operand_count()) {
1365       EXPLAIN << "desired operand index " << operand_index_
1366               << " is out of bounds";
1367       return false;
1368     }
1369     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1370       EXPLAIN << "\nin operand " << operand_index_;
1371       return false;
1372     }
1373     return true;
1374   }
1375 
1376   int64 operand_index_;
1377   HloInstructionPattern<OperandType, OperandImpl> operand_;
1378 };
1379 
1380 // Matches a binary instruction whose operands come in any order.
1381 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1382           typename OperandImpl2>
1383 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1384  public:
1385   explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1386       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1387       const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1388       : op1_(op1), op2_(op2) {}
1389 
1390   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1391     return MatchImpl(inst, option);
1392   }
1393 
1394   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1395     return MatchImpl(inst, option);
1396   }
1397 
1398   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1399     *os << "with two operands in either order:";
1400     Indent(os, indent);
1401     *os << " - ";
1402     op1_.DescribeTo(os, indent + 3);
1403     Indent(os, indent);
1404     *os << " - ";
1405     op2_.DescribeTo(os, indent + 3);
1406   }
1407 
1408  private:
1409   HloInstruction* operand(HloInstruction* inst, int64_t idx) const {
1410     return inst->mutable_operand(idx);
1411   }
1412   const HloInstruction* operand(const HloInstruction* inst, int64_t idx) const {
1413     return inst->operand(idx);
1414   }
1415 
1416   template <typename HloInstructionType>
1417   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1418     // We could implement this using AnyOf and AllOf matchers, but the templates
1419     // get pretty difficult to debug, since any compile error herein becomes
1420     // not-an-error via SFINAE.  Also this way lets us give better messages on
1421     // failure.
1422     if (inst->operand_count() != 2) {
1423       EXPLAIN << "HloInstruction did not have two operands";
1424       return false;
1425     }
1426 
1427     // If we're not generating explanations, this is pretty simple.
1428     if (!option.explain_os) {
1429       auto try_match = [&](int64_t idx1, int64_t idx2) {
1430         MatchOption new_option = option;
1431         new_option.capture = false;
1432         if (op1_.Match(operand(inst, idx1), new_option) &&
1433             op2_.Match(operand(inst, idx2), new_option)) {
1434           if (option.capture) {
1435             bool matched = op1_.Match(operand(inst, idx1), option) &&
1436                            op2_.Match(operand(inst, idx2), option);
1437             DCHECK(matched);
1438           }
1439           return true;
1440         }
1441         return false;
1442       };
1443       return try_match(0, 1) || try_match(1, 0);
1444     }
1445 
1446     // If we are generating explanations, we have some work to do in order to
1447     // generate a helpful error.
1448     //
1449     // First, try all four operand/matcher combinations, recording the
1450     // failure explanations separately from option.explain_os. matches[i][j]
1451     // tells us if matcher_i matches operand j.
1452     bool matches[/*matcher*/ 2][/*operand*/ 2];
1453     std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1454     for (int i = 0; i < 2; ++i) {
1455       for (int j = 0; j < 2; ++j) {
1456         MatchOption new_option = option;
1457         new_option.capture = false;
1458         new_option.explain_os = &explanations[i][j];
1459         matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1460                                : op2_.Match(operand(inst, j), new_option);
1461       }
1462     }
1463 
1464     // Check if the match succeeded.
1465     for (int i = 0; i < 2; ++i) {
1466       if (matches[0][i] && matches[1][(i + 1) % 2]) {
1467         // Rerun the matches with capture enabled if necessary.
1468         if (option.capture) {
1469           auto* operand1 = operand(inst, i);
1470           auto* operand2 = operand(inst, (i + 1) % 2);
1471           bool matched =
1472               op1_.Match(operand1, option) && op2_.Match(operand2, option);
1473           DCHECK(matched);
1474         }
1475         return true;
1476       }
1477     }
1478 
1479     auto describe_matcher = [&](int matcher_idx) {
1480       EXPLAIN << "\n - ";
1481       if (matcher_idx == 0) {
1482         op1_.DescribeTo(option.explain_os, /*indent=*/3);
1483       } else {
1484         CHECK_EQ(matcher_idx, 1);
1485         op2_.DescribeTo(option.explain_os, /*indent=*/3);
1486       }
1487       for (int i = 0; i < 2; ++i) {
1488         if (matches[matcher_idx][/*operand*/ i]) {
1489           continue;
1490         }
1491         EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1492         EXPLAIN << " - ";
1493         EXPLAIN << absl::StrReplaceAll(
1494             explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n   "}});
1495       }
1496     };
1497 
1498     // If we failed to match, one of the following is true:
1499     //  1. op1 (op2) matches neither LHS nor RHS, or
1500     //  2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1501     // We print different explanations depending on which case we're in.
1502 
1503     // Case 1.
1504     bool wrote_explanation = false;
1505     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1506       if (!matches[i][0] && !matches[i][1]) {
1507         EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1508                 << (i == 0 ? "first" : "second") << " matcher.  Specifically,";
1509         describe_matcher(i);
1510         wrote_explanation = true;
1511       }
1512     }
1513 
1514     // Case 2.
1515     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1516       if (matches[/*matcher*/ 0][/*operand*/ i] &&
1517           matches[/*matcher*/ 1][/*operand*/ i]) {
1518         CHECK(!matches[0][(i + 1) % 2]);
1519         CHECK(!matches[1][(i + 1) % 2]);
1520         CHECK(!wrote_explanation);
1521         EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1522                 << " operand did not match either of the two matchers.  "
1523                    "Specifically,";
1524         describe_matcher(0);
1525         EXPLAIN << "\nand";
1526         describe_matcher(1);
1527         wrote_explanation = true;
1528       }
1529     }
1530 
1531     CHECK(wrote_explanation);
1532     return false;
1533   }
1534 
1535   HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1536   HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1537 };
1538 
1539 // An HloInstructionPattern implementation that matches only if the instruction
1540 // is a fusion node with a particular kind.
1541 class HloInstructionPatternFusionKindImpl {
1542  public:
1543   explicit constexpr HloInstructionPatternFusionKindImpl(
1544       ::xla::HloInstruction::FusionKind kind)
1545       : kind_(kind) {}
1546 
1547   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1548     return MatchImpl(inst, option);
1549   }
1550 
1551   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1552     return MatchImpl(inst, option);
1553   }
1554 
1555   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1556     *os << "with fusion kind " << ToString(kind_);
1557   }
1558 
1559  private:
1560   template <typename HloInstructionType>
1561   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1562     if (inst->opcode() != HloOpcode::kFusion) {
1563       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1564               << "; it's not a fusion";
1565       return false;
1566     }
1567     if (inst->fusion_kind() != kind_) {
1568       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1569       return false;
1570     }
1571     return true;
1572   }
1573 
1574   ::xla::HloInstruction::FusionKind kind_;
1575 };
1576 
1577 // An HloInstructionPattern implementation that matches only if the instruction
1578 // is a kGetTupleElement with a particular tuple index.
1579 class HloInstructionPatternTupleIndexImpl {
1580  public:
1581   explicit constexpr HloInstructionPatternTupleIndexImpl(int64_t tuple_index)
1582       : tuple_index_(tuple_index) {}
1583 
1584   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1585     return MatchImpl(inst, option);
1586   }
1587 
1588   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1589     return MatchImpl(inst, option);
1590   }
1591 
1592   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1593     *os << "which is a GTE with index " << tuple_index_;
1594   }
1595 
1596  private:
1597   template <typename HloInstructionType>
1598   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1599     if (inst->opcode() != HloOpcode::kGetTupleElement) {
1600       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1601               << "; it's not a GTE at all";
1602       return false;
1603     }
1604     if (inst->tuple_index() != tuple_index_) {
1605       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1606       return false;
1607     }
1608     return true;
1609   }
1610 
1611   int64 tuple_index_;
1612 };
1613 
1614 class HloInstructionPatternParameterNumImpl {
1615  public:
1616   explicit constexpr HloInstructionPatternParameterNumImpl(
1617       int64_t parameter_num)
1618       : parameter_num_(parameter_num) {}
1619 
1620   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1621     return MatchImpl(inst, option);
1622   }
1623 
1624   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1625     return MatchImpl(inst, option);
1626   }
1627 
1628   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1629     *os << "which is parameter " << parameter_num_;
1630   }
1631 
1632  private:
1633   template <typename HloInstructionType>
1634   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1635     if (inst->opcode() != HloOpcode::kParameter ||
1636         inst->parameter_number() != parameter_num_) {
1637       EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1638       return false;
1639     }
1640     return true;
1641   }
1642 
1643   int64 parameter_num_;
1644 };
1645 
1646 // Superclass that contains common code used by Op::WithOneUse() and
1647 // Op::WithOneUser().
1648 class HloInstructionPatternOneUseOrUserImpl {
1649  protected:
1650   bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1651     if (inst->user_count() != 1) {
1652       EXPLAIN << "HloInstruction has " << inst->user_count()
1653               << " users, but expected exactly one.";
1654       if (inst->user_count() > 1) {
1655         EXPLAIN << "\nAll users:";
1656         for (const HloInstruction* user : inst->users()) {
1657           EXPLAIN << "\n - " << InstToString(user);
1658         }
1659       }
1660       return false;
1661     }
1662     return true;
1663   }
1664 };
1665 
1666 class HloInstructionPatternOneUseImpl
1667     : public HloInstructionPatternOneUseOrUserImpl {
1668  public:
1669   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1670     if (!MatchOneUser(inst, option)) {
1671       return false;
1672     }
1673 
1674     int64_t use_count = absl::c_count_if(
1675         inst->users()[0]->operands(),
1676         [&](const HloInstruction* operand) { return operand == inst; });
1677     if (use_count != 1) {
1678       EXPLAIN << "HloInstruction is used " << use_count
1679               << " times by its user, but is expected to be used just once: "
1680               << InstToString(inst->users()[0]);
1681       return false;
1682     }
1683     return true;
1684   }
1685 
1686   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1687     *os << "which has exactly one use";
1688   }
1689 };
1690 
1691 class HloInstructionPatternOneUserImpl
1692     : public HloInstructionPatternOneUseOrUserImpl {
1693  public:
1694   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1695     return MatchOneUser(inst, option);
1696   }
1697 
1698   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1699     *os << "which has exactly one user (but possibly is used multiple times by "
1700            "that instruction)";
1701   }
1702 };
1703 
1704 class HloInstructionPatternComparisonDirectionImpl {
1705  public:
1706   explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1707       ComparisonDirection direction)
1708       : direction_(direction) {}
1709 
1710   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1711     return MatchImpl(inst, option);
1712   }
1713 
1714   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1715     return MatchImpl(inst, option);
1716   }
1717 
1718   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1719     *os << "which has comparison direction "
1720         << ComparisonDirectionToString(direction_);
1721   }
1722 
1723  private:
1724   template <typename HloInstructionType>
1725   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1726     if (inst->opcode() != HloOpcode::kCompare ||
1727         inst->comparison_direction() != direction_) {
1728       EXPLAIN << "HloInstruction is not comparison "
1729               << ComparisonDirectionToString(direction_);
1730       return false;
1731     }
1732     return true;
1733   }
1734 
1735   ComparisonDirection direction_;
1736 };
1737 
1738 // Matches a constant scalar or effective scalar, optionally with a given value.
1739 template <typename ScalarTy>
1740 class HloConstantScalarImpl {
1741  public:
1742   explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1743       : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1744 
1745   constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1746       : val_(val), match_effective_scalar_(match_effective_scalar) {}
1747 
1748   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1749     return MatchImpl(inst, option);
1750   }
1751 
1752   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1753     return MatchImpl(inst, option);
1754   }
1755 
1756   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
1757     *os << "which is a constant "
1758         << (match_effective_scalar_ ? "effective " : "") << "scalar";
1759     if (val_.has_value()) {
1760       *os << " with value " << *val_;
1761     }
1762   }
1763 
1764  private:
1765   template <typename InstTy>
1766   bool MatchImpl(InstTy* inst, MatchOption option) const {
1767     const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1768     if (!const_inst) {
1769       EXPLAIN << "HloInstruction is not a constant";
1770       return false;
1771     }
1772     if (match_effective_scalar_ &&
1773         !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1774       EXPLAIN << "HloInstruction is not an effective scalar";
1775       return false;
1776     }
1777     if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1778       EXPLAIN << "HloInstruction is not a scalar";
1779       return false;
1780     }
1781     if (!val_.has_value()) {
1782       return true;
1783     }
1784 
1785     auto const_inst_scalar_or = const_inst->literal().Reshape({});
1786     if (!const_inst_scalar_or.ok()) {
1787       EXPLAIN << "could not convert matched literal to effective scalar";
1788       return false;
1789     }
1790     Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie();
1791     if (!const_inst_scalar.IsEqualAt({}, *val_)) {
1792       EXPLAIN << "HloInstruction's constant value "
1793               << const_inst_scalar.ToStringWithoutShape()
1794               << " did not match expected value " << *val_;
1795       return false;
1796     }
1797     return true;
1798   }
1799 
1800   absl::optional<ScalarTy> val_;
1801   bool match_effective_scalar_;
1802 };
1803 
1804 // A pattern that matches HloInstructions.
1805 template <typename HloInstructionType, typename Impl>
1806 class HloInstructionPattern {
1807  private:
1808   template <typename NewImpl>
1809   auto AppendImpl(NewImpl new_impl) const {
1810     auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl));
1811     return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1812         std::move(new_allof), matched_inst_);
1813   }
1814 
1815  public:
1816   explicit constexpr HloInstructionPattern(const Impl& impl,
1817                                            HloInstructionType** matched_inst)
1818       : impl_(impl), matched_inst_(matched_inst) {}
1819 
1820   // Returns true and captures the instruction iff it matches the pattern.
1821   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1822     if (impl_.Match(inst, option)) {
1823       if (option.capture && matched_inst_) {
1824         *matched_inst_ = inst;
1825       }
1826       return true;
1827     }
1828     if (inst != nullptr) {
1829       EXPLAIN << "\nin " << InstToString(inst);
1830     }
1831     return false;
1832   }
1833 
1834   // Returns true and captures the instruction iff it matches the pattern.
1835   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1836     if (impl_.Match(inst, option)) {
1837       if (option.capture && matched_inst_) {
1838         *matched_inst_ = inst;
1839       }
1840       return true;
1841     }
1842     EXPLAIN << "\nin " << InstToString(inst);
1843     return false;
1844   }
1845 
1846   // Modifies the pattern to match only if the instruction has the given name.
1847   auto WithName(absl::string_view name) const {
1848     return AppendImpl(HloInstructionPatternNameImpl(name));
1849   }
1850 
1851   // Modifies the pattern to match only if the instruction has the given opcode.
1852   auto WithOpcode(HloOpcode opcode) const {
1853     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1854   }
1855 
1856   // Modifies the pattern to match only the custom call with a given target.
1857   auto WithCustomCallTarget(absl::string_view custom_call_target) const {
1858     return AppendImpl(HloInstructionCustomCallTargetImpl({custom_call_target}));
1859   }
1860 
1861   // Modifies the pattern to match a custom call with one of the given targets.
1862   auto WithCustomCallTarget(
1863       absl::Span<const absl::string_view> custom_call_targets) const {
1864     return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_targets));
1865   }
1866 
1867   auto WithNumOperands(int64_t num_operands) const {
1868     return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1869   }
1870 
1871   // Modifies the pattern to match only if the instruction does not have the
1872   // given opcode.
1873   auto WithoutOpcode(HloOpcode opcode) const {
1874     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1875   }
1876 
1877   constexpr auto Is(const HloInstruction* instr) const {
1878     return AppendImpl(HloInstructionIsImpl(instr));
1879   }
1880 
1881   // Modifies the pattern to match only if the instruction is a constant.
1882   constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); }
1883 
1884   constexpr auto IsConstantScalar() const {
1885     return AppendImpl(
1886         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1887   }
1888 
1889   // This does not check that T has the same type as the instruction, so e.g.
1890   // IsConstantScalar(1.0) may match a constant of shape int32[].
1891   template <typename ScalarTy>
1892   constexpr auto IsConstantScalar(const ScalarTy& val) const {
1893     return AppendImpl(
1894         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1895   }
1896 
1897   constexpr auto IsConstantEffectiveScalar() const {
1898     return AppendImpl(
1899         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1900   }
1901 
1902   template <typename ScalarTy>
1903   constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const {
1904     return AppendImpl(
1905         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1906   }
1907 
1908   // Modifies the pattern to match only if the instruction is not a constant.
1909   constexpr auto IsNonConstant() const {
1910     return WithoutOpcode(HloOpcode::kConstant);
1911   }
1912 
1913   // Modifies the pattern to match only if the instruction has a shape that
1914   // matches the given pattern.
1915   template <typename ShapeType, typename ShapeImpl>
1916   constexpr auto WithShape(
1917       const ShapePattern<ShapeType, ShapeImpl>& shape) const {
1918     return AppendImpl(
1919         HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1920   }
1921 
1922   // Because we only specify the shape's element type and dims, this is
1923   // effectivley checking shape-compatible-to, not shape-equal-to.  Perhaps this
1924   // function should be called WithShapeCompatibleTo, but the short name is
1925   // nice, and there's no ambiguity because there's no layout in the args!
1926   constexpr auto WithShape(PrimitiveType ty, absl::Span<const int64> dims) {
1927     return WithShape(Shape().WithElementType(ty).WithDims(dims));
1928   }
1929 
1930   // Make this a templated function to work around gcc 4.9.4 template infinite
1931   // recursion bug.
1932   template <typename Dummy = void>
1933   constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const {
1934     return WithShape(Shape().EqualTo(shape));
1935   }
1936 
1937   // Make this a templated function to work around gcc 4.9.4 template infinite
1938   // recursion bug.
1939   template <typename Dummy = void>
1940   constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const {
1941     return WithShape(Shape().CompatibleTo(shape));
1942   }
1943 
1944   // Modifies the pattern to match only if the instruction has an operand that
1945   // matches the given pattern.
1946   template <typename OperandType, typename OperandImpl>
1947   constexpr auto WithOperand(
1948       int64_t operand_index,
1949       const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
1950     return AppendImpl(
1951         HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1952             operand_index, operand));
1953   }
1954 
1955   template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1956             typename OperandImpl2>
1957   constexpr auto WithBinaryOperandsAnyOrder(
1958       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1959       const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const {
1960     return AppendImpl(
1961         HloInstructionPatternBinaryOperandsAnyOrderImpl<
1962             OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1963   }
1964 
1965   // Modifies the pattern to match only if the instruction is a fusion node with
1966   // the given kind.
1967   constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const {
1968     return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1969   }
1970 
1971   // Modifies the pattern to match only if the instruction is a
1972   // get-tuple-element with the given tuple index.
1973   constexpr auto WithTupleIndex(int64_t tuple_index) const {
1974     return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1975   }
1976 
1977   // Modifies the pattern to match only if the instruction is a parameter
1978   // with the given parameter number.
1979   constexpr auto WithParameterNum(int64_t parameter_num) const {
1980     return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1981   }
1982 
1983   // Modifies the pattern to match if the instruction is used exactly once.
1984   // Does not match if the instruction is used twice by the same user (e.g.
1985   // multiply(x,x)).
1986   constexpr auto WithOneUse() const {
1987     return AppendImpl(HloInstructionPatternOneUseImpl());
1988   }
1989 
1990   // Modifies the pattern to match if the instruction is used by exactly one
1991   // other instruction.  Will match if the instruction is used twice, so long as
1992   // it's by the same user (e.g.  multiply(x,x)).
1993   constexpr auto WithOneUser() const {
1994     return AppendImpl(HloInstructionPatternOneUserImpl());
1995   }
1996 
1997   // Modifies the pattern to match only if the instruction has the given
1998   // comparison direction.
1999   auto WithComparisonDirection(ComparisonDirection direction) const {
2000     return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
2001   }
2002 
2003   void DescribeTo(std::ostream* os, int64_t indent = 0) const {
2004     impl_.DescribeTo(os, indent);
2005   }
2006 
2007  private:
2008   Impl impl_;
2009   HloInstructionType** matched_inst_;
2010 };
2011 
2012 }  // namespace detail
2013 
2014 // Creates an instruction pattern that will capture the matched instruction in
2015 // the argument.
2016 inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) {
2017   return detail::HloInstructionPattern<const ::xla::HloInstruction,
2018                                        detail::HloInstructionPatternBaseImpl>(
2019       detail::HloInstructionPatternBaseImpl(), matched_inst);
2020 }
2021 
2022 // Creates an instruction pattern that will capture the matched instruction in
2023 // the argument.
2024 inline constexpr auto Op(::xla::HloInstruction** matched_inst) {
2025   return detail::HloInstructionPattern<::xla::HloInstruction,
2026                                        detail::HloInstructionPatternBaseImpl>(
2027       detail::HloInstructionPatternBaseImpl(), matched_inst);
2028 }
2029 
2030 // Helpers for nullary instructions.
2031 #define XLA_NULLOP_PATTERN(NAME)                                     \
2032   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2033                                                                      \
2034   template <typename HloInstructionType>                             \
2035   inline auto NAME(HloInstructionType** matched_inst) {              \
2036     return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);          \
2037   }
2038 XLA_NULLOP_PATTERN(Constant)
2039 XLA_NULLOP_PATTERN(Parameter)
2040 XLA_NULLOP_PATTERN(Iota)
2041 XLA_NULLOP_PATTERN(Rng)
2042 XLA_NULLOP_PATTERN(PartitionId)
2043 XLA_NULLOP_PATTERN(ReplicaId)
2044 #undef XLA_NULLOP_PATTERN
2045 
2046 // Helpers for unary instructions.
2047 #define XLA_UNOP_PATTERN(NAME)                                       \
2048   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \
2049                                                                      \
2050   template <typename Arg>                                            \
2051   inline auto NAME(Arg&& arg) {                                      \
2052     return Op()                                                      \
2053         .WithOpcode(HloOpcode::k##NAME)                              \
2054         .WithOperand(0, std::forward<Arg>(arg));                     \
2055   }                                                                  \
2056                                                                      \
2057   template <typename HloInstructionType, typename Arg>               \
2058   inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) {   \
2059     return Op(matched_inst)                                          \
2060         .WithOpcode(HloOpcode::k##NAME)                              \
2061         .WithOperand(0, std::forward<Arg>(arg));                     \
2062   }
2063 XLA_UNOP_PATTERN(Abs)
2064 XLA_UNOP_PATTERN(RoundNearestAfz)
2065 XLA_UNOP_PATTERN(Bitcast)
2066 XLA_UNOP_PATTERN(BitcastConvert)
2067 XLA_UNOP_PATTERN(Broadcast)
2068 XLA_UNOP_PATTERN(Ceil)
2069 XLA_UNOP_PATTERN(Convert)
2070 XLA_UNOP_PATTERN(Copy)
2071 XLA_UNOP_PATTERN(Cos)
2072 XLA_UNOP_PATTERN(AllReduce)
2073 XLA_UNOP_PATTERN(Exp)
2074 XLA_UNOP_PATTERN(Fft)
2075 XLA_UNOP_PATTERN(Floor)
2076 XLA_UNOP_PATTERN(GetTupleElement)
2077 XLA_UNOP_PATTERN(Imag)
2078 XLA_UNOP_PATTERN(Infeed)
2079 XLA_UNOP_PATTERN(IsFinite)
2080 XLA_UNOP_PATTERN(Log)
2081 XLA_UNOP_PATTERN(Not)
2082 XLA_UNOP_PATTERN(Negate)
2083 XLA_UNOP_PATTERN(Real)
2084 XLA_UNOP_PATTERN(Recv)
2085 XLA_UNOP_PATTERN(RecvDone)
2086 XLA_UNOP_PATTERN(ReducePrecision)
2087 XLA_UNOP_PATTERN(Reshape)
2088 XLA_UNOP_PATTERN(Reverse)
2089 XLA_UNOP_PATTERN(Rsqrt)
2090 XLA_UNOP_PATTERN(SendDone)
2091 XLA_UNOP_PATTERN(Sign)
2092 XLA_UNOP_PATTERN(Sin)
2093 XLA_UNOP_PATTERN(Slice)
2094 XLA_UNOP_PATTERN(Sqrt)
2095 XLA_UNOP_PATTERN(Tanh)
2096 XLA_UNOP_PATTERN(Transpose)
2097 #undef XLA_UNOP_PATTERN
2098 
2099 // Helpers for binary instructions.
2100 #define XLA_BINOP_PATTERN(NAME)                                               \
2101   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2102                                                                               \
2103   template <typename Lhs, typename Rhs>                                       \
2104   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2105     return Op()                                                               \
2106         .WithOpcode(HloOpcode::k##NAME)                                       \
2107         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2108         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2109   }                                                                           \
2110                                                                               \
2111   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2112   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2113     return Op(matched_inst)                                                   \
2114         .WithOpcode(HloOpcode::k##NAME)                                       \
2115         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2116         .WithOperand(1, std::forward<Rhs>(rhs));                              \
2117   }
2118 
2119 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                \
2120   XLA_BINOP_PATTERN(NAME)                                                  \
2121                                                                            \
2122   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2123   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2124                              Rhs&& rhs) {                                  \
2125     return Op(matched_inst)                                                \
2126         .WithOpcode(HloOpcode::k##NAME)                                    \
2127         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2128                                     std::forward<Rhs>(rhs));               \
2129   }                                                                        \
2130   template <typename Lhs, typename Rhs>                                    \
2131   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2132     return NAME##AnyOrder<const HloInstruction>(                           \
2133         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2134   }
2135 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2136 XLA_BINOP_PATTERN(Atan2)
2137 XLA_BINOP_PATTERN(Divide)
2138 XLA_BINOP_PATTERN(Complex)
2139 XLA_BINOP_PATTERN(Compare)
2140 XLA_BINOP_PATTERN(Convolution)
2141 XLA_BINOP_PATTERN(Dot)
2142 XLA_BINOP_PATTERN(Gather)
2143 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2144 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2145 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2146 XLA_BINOP_PATTERN(Outfeed)
2147 XLA_BINOP_PATTERN(Pad)
2148 XLA_BINOP_PATTERN(Power)
2149 XLA_BINOP_PATTERN(Remainder)
2150 XLA_BINOP_PATTERN(Send)
2151 XLA_BINOP_PATTERN(Subtract)
2152 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2153 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2154 XLA_BINOP_PATTERN(ShiftLeft)
2155 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2156 XLA_BINOP_PATTERN(ShiftRightLogical)
2157 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2158 #undef XLA_BINOP_PATTERN
2159 
2160 // Helpers for ternary instructions.
2161 #define XLA_TERNOP_PATTERN(NAME)                                       \
2162   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }   \
2163                                                                        \
2164   template <typename Arg0, typename Arg1, typename Arg2>               \
2165   inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {            \
2166     return Op()                                                        \
2167         .WithOpcode(HloOpcode::k##NAME)                                \
2168         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2169         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2170         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2171   }                                                                    \
2172                                                                        \
2173   template <typename HloInstructionType, typename Arg0, typename Arg1, \
2174             typename Arg2>                                             \
2175   inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
2176                    Arg1&& arg1, Arg2&& arg2) {                         \
2177     return Op(matched_inst)                                            \
2178         .WithOpcode(HloOpcode::k##NAME)                                \
2179         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2180         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2181         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2182   }
2183 XLA_TERNOP_PATTERN(Clamp);
2184 XLA_TERNOP_PATTERN(Scatter);
2185 XLA_TERNOP_PATTERN(Select);
2186 XLA_TERNOP_PATTERN(SelectAndScatter);
2187 #undef XLA_TERNOP_PATTERN
2188 
2189 namespace detail {
2190 template <typename Matcher, typename FirstArg>
2191 inline auto WithOperands(Matcher&& m, int64_t operand_num,
2192                          FirstArg&& first_arg) {
2193   return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2194 }
2195 
2196 template <typename Matcher, typename FirstArg, typename... Args>
2197 inline auto WithOperands(Matcher&& m, int64_t operand_num, FirstArg&& first_arg,
2198                          Args&&... args) {
2199   return WithOperands(
2200       m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2201       operand_num + 1, std::forward<Args>(args)...);
2202 }
2203 }  // namespace detail
2204 
2205 #define XLA_VARIADIC_OP_PATTERN(NAME)                                         \
2206   inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); }          \
2207                                                                               \
2208   template <typename... Args>                                                 \
2209   inline auto NAME(Args&&... args) {                                          \
2210     return detail::WithOperands(                                              \
2211         Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2212         /*operand_num=*/0, std::forward<Args>(args)...);                      \
2213   }                                                                           \
2214                                                                               \
2215   template <typename HloInstructionType, typename... Args>                    \
2216   inline auto NAME(HloInstructionType** matched_inst, Args&&... args) {       \
2217     return detail::WithOperands(Op(matched_inst)                              \
2218                                     .WithOpcode(HloOpcode::k##NAME)           \
2219                                     .WithNumOperands(sizeof...(Args)),        \
2220                                 /*operand_num=*/0,                            \
2221                                 std::forward<Args>(args)...);                 \
2222   }
2223 
2224 // We could implement all ops as "variadic" ops, but it would make the
2225 // already-bad compile errors even worse.
2226 XLA_VARIADIC_OP_PATTERN(AfterAll);
2227 XLA_VARIADIC_OP_PATTERN(Concatenate);
2228 XLA_VARIADIC_OP_PATTERN(Conditional);
2229 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2230 XLA_VARIADIC_OP_PATTERN(DynamicUpdateSlice)
2231 XLA_VARIADIC_OP_PATTERN(Fusion);
2232 XLA_VARIADIC_OP_PATTERN(Map)
2233 XLA_VARIADIC_OP_PATTERN(Reduce);
2234 XLA_VARIADIC_OP_PATTERN(ReduceWindow)
2235 XLA_VARIADIC_OP_PATTERN(Sort);
2236 XLA_VARIADIC_OP_PATTERN(Tuple);
2237 
2238 // CustomCall doesn't use the XLA_VARIADIC_OP_PATTERN macro so that you can
2239 // optionally pass a string_view for the custom_call_target before the other
2240 // operands.
2241 inline auto CustomCall() { return Op().WithOpcode(HloOpcode::kCustomCall); }
2242 
2243 template <typename HloInstructionType>
2244 auto CustomCall(HloInstructionType** matched_inst) {
2245   return Op(matched_inst).WithOpcode(HloOpcode::kCustomCall);
2246 }
2247 
2248 template <
2249     typename Arg0, typename... Args,
2250     typename std::enable_if<
2251         !std::is_convertible<Arg0, absl::string_view>::value &&
2252         !std::is_convertible<Arg0, HloInstruction**>::value &&
2253         !std::is_convertible<Arg0, const HloInstruction**>::value>::type* =
2254         nullptr>
2255 auto CustomCall(Arg0&& arg0, Args&&... args) {
2256   return detail::WithOperands(CustomCall().WithNumOperands(sizeof...(Args) + 1),
2257                               /*operand_num=*/0, std::forward<Arg0>(arg0),
2258                               std::forward<Args>(args)...);
2259 }
2260 template <typename... Args>
2261 auto CustomCall(absl::string_view custom_call_target, Args&&... args) {
2262   return CustomCall(std::forward<Args>(args)...)
2263       .WithCustomCallTarget(custom_call_target);
2264 }
2265 
2266 template <typename HloInstructionType, typename Arg0, typename... Args,
2267           typename std::enable_if<!std::is_convertible<
2268               Arg0, absl::string_view>::value>::type* = nullptr>
2269 auto CustomCall(HloInstructionType** matched_inst, Arg0&& arg0,
2270                 Args&&... args) {
2271   return detail::WithOperands(
2272       CustomCall(matched_inst).WithNumOperands(sizeof...(Args) + 1),
2273       /*operand_num=*/0, std::forward<Arg0>(arg0), std::forward<Args>(args)...);
2274 }
2275 template <typename HloInstructionType, typename... Args>
2276 auto CustomCall(HloInstructionType** matched_inst,
2277                 absl::string_view custom_call_target, Args&&... args) {
2278   return CustomCall(matched_inst, std::forward<Args>(args)...)
2279       .WithCustomCallTarget(custom_call_target);
2280 }
2281 
2282 // Helpers for comparison instructions.
2283 #define XLA_COMPARE_PATTERN(NAME)                                             \
2284   inline auto NAME() {                                                        \
2285     return Op()                                                               \
2286         .WithOpcode(HloOpcode::kCompare)                                      \
2287         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2288   }                                                                           \
2289                                                                               \
2290   template <typename Lhs, typename Rhs>                                       \
2291   inline auto NAME(Lhs&& lhs, Rhs&& rhs) {                                    \
2292     return Op()                                                               \
2293         .WithOpcode(HloOpcode::kCompare)                                      \
2294         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2295         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2296         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2297   }                                                                           \
2298                                                                               \
2299   template <typename HloInstructionType, typename Lhs, typename Rhs>          \
2300   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \
2301     return Op(matched_inst)                                                   \
2302         .WithOpcode(HloOpcode::kCompare)                                      \
2303         .WithOperand(0, std::forward<Lhs>(lhs))                               \
2304         .WithOperand(1, std::forward<Rhs>(rhs))                               \
2305         .WithComparisonDirection(ComparisonDirection::k##NAME);               \
2306   }
2307 
2308 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME)                              \
2309   XLA_COMPARE_PATTERN(NAME)                                                \
2310                                                                            \
2311   template <typename HloInstructionType, typename Lhs, typename Rhs>       \
2312   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2313                              Rhs&& rhs) {                                  \
2314     return Op(matched_inst)                                                \
2315         .WithOpcode(HloOpcode::kCompare)                                   \
2316         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                \
2317                                     std::forward<Rhs>(rhs));               \
2318   }                                                                        \
2319   template <typename Lhs, typename Rhs>                                    \
2320   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) {                       \
2321     return NAME##AnyOrder<const HloInstruction>(                           \
2322         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));          \
2323   }
2324 
2325 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2326 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2327 XLA_COMPARE_PATTERN(Ge);
2328 XLA_COMPARE_PATTERN(Gt);
2329 XLA_COMPARE_PATTERN(Le);
2330 XLA_COMPARE_PATTERN(Lt);
2331 
2332 // Helpers for matching non-constant instructions.
2333 inline auto NonConstant() { return Op().IsNonConstant(); }
2334 
2335 template <typename HloInstructionType>
2336 inline auto NonConstant(HloInstructionType** matched_inst) {
2337   return Op(matched_inst).IsNonConstant();
2338 }
2339 
2340 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2341 // element is selected.
2342 template <typename Arg>
2343 inline auto GetTupleElement(Arg&& arg, int64_t tuple_index) {
2344   return Op()
2345       .WithOpcode(HloOpcode::kGetTupleElement)
2346       .WithOperand(0, std::forward<Arg>(arg))
2347       .WithTupleIndex(tuple_index);
2348 }
2349 
2350 template <typename HloInstructionType, typename Arg>
2351 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2352                             int64_t tuple_index) {
2353   return Op(matched_inst)
2354       .WithOpcode(HloOpcode::kGetTupleElement)
2355       .WithOperand(0, std::forward<Arg>(arg))
2356       .WithTupleIndex(tuple_index);
2357 }
2358 
2359 // Add overloads for Parameter which take an int64 specifying the parameter
2360 // number.
2361 inline auto Parameter(int64_t parameter_num) {
2362   return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2363 }
2364 template <typename HloInstructionType>
2365 inline auto Parameter(HloInstructionType** matched_inst,
2366                       int64_t parameter_num) {
2367   return Op(matched_inst)
2368       .WithOpcode(HloOpcode::kParameter)
2369       .WithParameterNum(parameter_num);
2370 }
2371 
2372 inline auto ConstantScalar() { return Op().IsConstantScalar(); }
2373 
2374 template <typename HloInstructionType>
2375 inline auto ConstantScalar(HloInstructionType** matched_inst) {
2376   return Op(matched_inst).IsConstantScalar();
2377 }
2378 
2379 template <typename ScalarTy>
2380 inline auto ConstantScalar(ScalarTy val) {
2381   return Op().IsConstantScalar(val);
2382 }
2383 
2384 template <typename HloInstructionType, typename ScalarTy>
2385 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) {
2386   return Op(matched_inst).IsConstantScalar(val);
2387 }
2388 
2389 inline auto ConstantEffectiveScalar() {
2390   return Op().IsConstantEffectiveScalar();
2391 }
2392 
2393 template <typename HloInstructionType>
2394 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) {
2395   return Op(matched_inst).IsConstantEffectiveScalar();
2396 }
2397 
2398 template <typename ScalarTy>
2399 inline auto ConstantEffectiveScalar(ScalarTy val) {
2400   return Op().IsConstantEffectiveScalar(val);
2401 }
2402 
2403 template <typename HloInstructionType, typename ScalarTy>
2404 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2405                                     ScalarTy val) {
2406   return Op(matched_inst).IsConstantEffectiveScalar(val);
2407 }
2408 
2409 }  // namespace match
2410 
2411 }  // namespace xla
2412 
2413 #undef EXPLAIN
2414 #pragma pop_macro("EXPLAIN")
2415 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2416